1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.thread;
18 
19 import android.annotation.Nullable;
20 import android.content.Context;
21 import android.net.thread.ActiveOperationalDataset;
22 import android.net.thread.IOperationReceiver;
23 import android.net.thread.OperationalDatasetTimestamp;
24 import android.net.thread.PendingOperationalDataset;
25 import android.net.thread.ThreadNetworkException;
26 import android.text.TextUtils;
27 
28 import com.android.internal.annotations.VisibleForTesting;
29 import com.android.modules.utils.BasicShellCommandHandler;
30 import com.android.net.module.util.HexDump;
31 
32 import java.io.PrintWriter;
33 import java.time.Duration;
34 import java.time.Instant;
35 import java.util.concurrent.CompletableFuture;
36 import java.util.concurrent.ExecutionException;
37 import java.util.concurrent.TimeUnit;
38 import java.util.concurrent.TimeoutException;
39 
40 /**
41  * Interprets and executes 'adb shell cmd thread_network <subcommand>'.
42  *
43  * <p>Subcommands which don't have an equivalent Java API now require the
44  * "android.permission.THREAD_NETWORK_TESTING" permission. For a specific subcommand, it also
45  * requires the same permissions of the equivalent Java / AIDL API.
46  *
47  * <p>To add new commands: - onCommand: Add a case "<command>" execute. Return a 0 if command
48  * executed successfully. - onHelp: add a description string.
49  */
50 public final class ThreadNetworkShellCommand extends BasicShellCommandHandler {
51     private static final Duration SET_ENABLED_TIMEOUT = Duration.ofSeconds(2);
52     private static final Duration LEAVE_TIMEOUT = Duration.ofSeconds(2);
53     private static final Duration MIGRATE_TIMEOUT = Duration.ofSeconds(2);
54     private static final Duration FORCE_STOP_TIMEOUT = Duration.ofSeconds(1);
55     private static final String PERMISSION_THREAD_NETWORK_TESTING =
56             "android.permission.THREAD_NETWORK_TESTING";
57 
58     private final Context mContext;
59     private final ThreadNetworkControllerService mControllerService;
60     private final ThreadNetworkCountryCode mCountryCode;
61 
62     @Nullable private PrintWriter mOutputWriter;
63     @Nullable private PrintWriter mErrorWriter;
64 
ThreadNetworkShellCommand( Context context, ThreadNetworkControllerService controllerService, ThreadNetworkCountryCode countryCode)65     public ThreadNetworkShellCommand(
66             Context context,
67             ThreadNetworkControllerService controllerService,
68             ThreadNetworkCountryCode countryCode) {
69         mContext = context;
70         mControllerService = controllerService;
71         mCountryCode = countryCode;
72     }
73 
74     @VisibleForTesting
setPrintWriters(PrintWriter outputWriter, PrintWriter errorWriter)75     public void setPrintWriters(PrintWriter outputWriter, PrintWriter errorWriter) {
76         mOutputWriter = outputWriter;
77         mErrorWriter = errorWriter;
78     }
79 
getOutputWriter()80     private PrintWriter getOutputWriter() {
81         return (mOutputWriter != null) ? mOutputWriter : getOutPrintWriter();
82     }
83 
getErrorWriter()84     private PrintWriter getErrorWriter() {
85         return (mErrorWriter != null) ? mErrorWriter : getErrPrintWriter();
86     }
87 
88     @Override
onHelp()89     public void onHelp() {
90         final PrintWriter pw = getOutputWriter();
91         pw.println("Thread network commands:");
92         pw.println("  help or -h");
93         pw.println("    Print this help text.");
94         pw.println("  enable");
95         pw.println("    Enables Thread radio");
96         pw.println("  disable");
97         pw.println("    Disables Thread radio");
98         pw.println("  join <active-dataset-tlvs>");
99         pw.println("    Joins a network of the given dataset");
100         pw.println("  migrate <active-dataset-tlvs> <delay-seconds>");
101         pw.println("    Migrate to the given network by a specific delay");
102         pw.println("  leave");
103         pw.println("    Leave the current network and erase datasets");
104         pw.println("  force-stop-ot-daemon enabled | disabled ");
105         pw.println("    force stop ot-daemon service");
106         pw.println("  get-country-code");
107         pw.println("    Gets country code as a two-letter string");
108         pw.println("  force-country-code enabled <two-letter code> | disabled ");
109         pw.println("    Sets country code to <two-letter code> or left for normal value");
110     }
111 
112     @Override
onCommand(String cmd)113     public int onCommand(String cmd) {
114         // Treat no command as the "help" command
115         if (TextUtils.isEmpty(cmd)) {
116             cmd = "help";
117         }
118 
119         switch (cmd) {
120             case "enable":
121                 return setThreadEnabled(true);
122             case "disable":
123                 return setThreadEnabled(false);
124             case "join":
125                 return join();
126             case "leave":
127                 return leave();
128             case "migrate":
129                 return migrate();
130             case "force-stop-ot-daemon":
131                 return forceStopOtDaemon();
132             case "force-country-code":
133                 return forceCountryCode();
134             case "get-country-code":
135                 return getCountryCode();
136             default:
137                 return handleDefaultCommands(cmd);
138         }
139     }
140 
ensureTestingPermission()141     private void ensureTestingPermission() {
142         mContext.enforceCallingOrSelfPermission(
143                 PERMISSION_THREAD_NETWORK_TESTING,
144                 "Permission " + PERMISSION_THREAD_NETWORK_TESTING + " is missing!");
145     }
146 
setThreadEnabled(boolean enabled)147     private int setThreadEnabled(boolean enabled) {
148         CompletableFuture<Void> setEnabledFuture = new CompletableFuture<>();
149         mControllerService.setEnabled(enabled, newOperationReceiver(setEnabledFuture));
150         return waitForFuture(setEnabledFuture, SET_ENABLED_TIMEOUT, getErrorWriter());
151     }
152 
join()153     private int join() {
154         byte[] datasetTlvs = HexDump.hexStringToByteArray(getNextArgRequired());
155         ActiveOperationalDataset dataset;
156         try {
157             dataset = ActiveOperationalDataset.fromThreadTlvs(datasetTlvs);
158         } catch (IllegalArgumentException e) {
159             getErrorWriter().println("Invalid dataset argument: " + e.getMessage());
160             return -1;
161         }
162         // Do not wait for join to complete because this can take 8 to 30 seconds
163         mControllerService.join(dataset, new IOperationReceiver.Default());
164         return 0;
165     }
166 
leave()167     private int leave() {
168         CompletableFuture<Void> leaveFuture = new CompletableFuture<>();
169         mControllerService.leave(newOperationReceiver(leaveFuture));
170         return waitForFuture(leaveFuture, LEAVE_TIMEOUT, getErrorWriter());
171     }
172 
migrate()173     private int migrate() {
174         byte[] datasetTlvs = HexDump.hexStringToByteArray(getNextArgRequired());
175         ActiveOperationalDataset dataset;
176         try {
177             dataset = ActiveOperationalDataset.fromThreadTlvs(datasetTlvs);
178         } catch (IllegalArgumentException e) {
179             getErrorWriter().println("Invalid dataset argument: " + e.getMessage());
180             return -1;
181         }
182 
183         int delaySeconds;
184         try {
185             delaySeconds = Integer.parseInt(getNextArgRequired());
186         } catch (NumberFormatException e) {
187             getErrorWriter().println("Invalid delay argument: " + e.getMessage());
188             return -1;
189         }
190 
191         PendingOperationalDataset pendingDataset =
192                 new PendingOperationalDataset(
193                         dataset,
194                         OperationalDatasetTimestamp.fromInstant(Instant.now()),
195                         Duration.ofSeconds(delaySeconds));
196         CompletableFuture<Void> migrateFuture = new CompletableFuture<>();
197         mControllerService.scheduleMigration(pendingDataset, newOperationReceiver(migrateFuture));
198         return waitForFuture(migrateFuture, MIGRATE_TIMEOUT, getErrorWriter());
199     }
200 
forceStopOtDaemon()201     private int forceStopOtDaemon() {
202         ensureTestingPermission();
203         final PrintWriter errorWriter = getErrorWriter();
204         boolean enabled;
205         try {
206             enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
207         } catch (IllegalArgumentException e) {
208             errorWriter.println("Invalid argument: " + e.getMessage());
209             return -1;
210         }
211 
212         CompletableFuture<Void> forceStopFuture = new CompletableFuture<>();
213         mControllerService.forceStopOtDaemonForTest(enabled, newOperationReceiver(forceStopFuture));
214         return waitForFuture(forceStopFuture, FORCE_STOP_TIMEOUT, getErrorWriter());
215     }
216 
forceCountryCode()217     private int forceCountryCode() {
218         ensureTestingPermission();
219         final PrintWriter perr = getErrorWriter();
220         boolean enabled;
221         try {
222             enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
223         } catch (IllegalArgumentException e) {
224             perr.println("Invalid argument: " + e.getMessage());
225             return -1;
226         }
227 
228         if (enabled) {
229             String countryCode = getNextArgRequired();
230             if (!ThreadNetworkCountryCode.isValidCountryCode(countryCode)) {
231                 perr.println(
232                         "Invalid argument: Country code must be a 2-letter"
233                                 + " string. But got country code "
234                                 + countryCode
235                                 + " instead");
236                 return -1;
237             }
238             mCountryCode.setOverrideCountryCode(countryCode);
239         } else {
240             mCountryCode.clearOverrideCountryCode();
241         }
242         return 0;
243     }
244 
getCountryCode()245     private int getCountryCode() {
246         ensureTestingPermission();
247         getOutputWriter().println("Thread country code = " + mCountryCode.getCountryCode());
248         return 0;
249     }
250 
newOperationReceiver(CompletableFuture<Void> future)251     private static IOperationReceiver newOperationReceiver(CompletableFuture<Void> future) {
252         return new IOperationReceiver.Stub() {
253             @Override
254             public void onSuccess() {
255                 future.complete(null);
256             }
257 
258             @Override
259             public void onError(int errorCode, String errorMessage) {
260                 future.completeExceptionally(new ThreadNetworkException(errorCode, errorMessage));
261             }
262         };
263     }
264 
265     /**
266      * Waits for the future to complete within given timeout.
267      *
268      * <p>Returns 0 if {@code future} completed successfully, or -1 if {@code future} failed to
269      * complete. When failed, error messages are printed to {@code errorWriter}.
270      */
271     private int waitForFuture(
272             CompletableFuture<Void> future, Duration timeout, PrintWriter errorWriter) {
273         try {
274             future.get(timeout.toSeconds(), TimeUnit.SECONDS);
275             return 0;
276         } catch (InterruptedException e) {
277             Thread.currentThread().interrupt();
278             errorWriter.println("Failed: " + e.getMessage());
279         } catch (ExecutionException e) {
280             errorWriter.println("Failed: " + e.getCause().getMessage());
281         } catch (TimeoutException e) {
282             errorWriter.println("Failed: command timeout for " + timeout);
283         }
284 
285         return -1;
286     }
287 
288     private static boolean argTrueOrFalse(String arg, String trueString, String falseString) {
289         if (trueString.equals(arg)) {
290             return true;
291         } else if (falseString.equals(arg)) {
292             return false;
293         } else {
294             throw new IllegalArgumentException(
295                     "Expected '"
296                             + trueString
297                             + "' or '"
298                             + falseString
299                             + "' as next arg but got '"
300                             + arg
301                             + "'");
302         }
303     }
304 
305     private boolean getNextArgRequiredTrueOrFalse(String trueString, String falseString) {
306         String nextArg = getNextArgRequired();
307         return argTrueOrFalse(nextArg, trueString, falseString);
308     }
309 }
310