1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.core.test;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 
21 import android.annotation.SuppressLint;
22 import android.content.Context;
23 import android.content.Intent;
24 import android.util.Log;
25 
26 import com.android.nn.benchmark.core.Processor;
27 import com.android.nn.crashtest.core.CrashTest;
28 import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer;
29 import com.android.nn.benchmark.core.TfLiteBackend;
30 
31 import java.time.Duration;
32 import java.util.ArrayList;
33 import java.util.Collections;
34 import java.util.HashSet;
35 import java.util.List;
36 import java.util.Optional;
37 import java.util.Set;
38 import java.util.concurrent.CountDownLatch;
39 import java.util.concurrent.ExecutionException;
40 import java.util.concurrent.ExecutorService;
41 import java.util.concurrent.Executors;
42 import java.util.concurrent.Future;
43 
44 public class RunModelsInParallel implements CrashTest {
45 
46     private static final String MODELS = "models";
47     private static final String DURATION = "duration";
48     private static final String THREADS = "thread_counts";
49     private static final String TEST_NAME = "test_name";
50     private static final String ACCELERATOR_NAME = "accelerator_name";
51     private static final String IGNORE_UNSUPPORTED_MODELS = "ignore_unsupported_models";
52     private static final String RUN_MODEL_COMPILATION_ONLY = "run_model_compilation_only";
53     private static final String MEMORY_MAP_MODEL = "memory_map_model";
54     private static final String MODEL_FILTER = "model_filter";
55     private static final String USE_NNAPI_SL = "use_nnapi_sl";
56     private static final String EXTRACT_NNAPI_SL = "extract_nnapi_sl";
57 
58     private final Set<Processor> activeTests = new HashSet<>();
59     private final List<Boolean> mTestCompletionResults = Collections.synchronizedList(
60             new ArrayList<>());
61     private long mTestDurationMillis = 0;
62     private int mThreadCount = 0;
63     private int[] mTestList = new int[0];
64     private String mTestName;
65     private String mAcceleratorName;
66     private boolean mIgnoreUnsupportedModels;
67     private Context mContext;
68     private boolean mRunModelCompilationOnly;
69     private ExecutorService mExecutorService = null;
70     private CountDownLatch mParallelTestComplete;
71     private ProgressListener mProgressListener;
72     private boolean mMmapModel;
73     private boolean mUseNnapiSl;
74     private boolean mExtractNnapiSl;
75 
intentInitializer(int[] models, int threadCount, Duration duration, String testName, String acceleratorName, boolean ignoreUnsupportedModels, boolean runModelCompilationOnly, boolean mmapModel, String modelFilter, boolean useNnapiSl, boolean extractNnapiSl)76     static public CrashTestIntentInitializer intentInitializer(int[] models, int threadCount,
77         Duration duration, String testName, String acceleratorName,
78         boolean ignoreUnsupportedModels,
79         boolean runModelCompilationOnly, boolean mmapModel, String modelFilter, boolean useNnapiSl,
80         boolean extractNnapiSl) {
81         return intent -> {
82             intent.putExtra(MODELS, models);
83             intent.putExtra(DURATION, duration.toMillis());
84             intent.putExtra(THREADS, threadCount);
85             intent.putExtra(TEST_NAME, testName);
86             intent.putExtra(ACCELERATOR_NAME, acceleratorName);
87             intent.putExtra(IGNORE_UNSUPPORTED_MODELS, ignoreUnsupportedModels);
88             intent.putExtra(RUN_MODEL_COMPILATION_ONLY, runModelCompilationOnly);
89             intent.putExtra(MEMORY_MAP_MODEL, mmapModel);
90             intent.putExtra(MODEL_FILTER, modelFilter);
91             intent.putExtra(USE_NNAPI_SL, useNnapiSl);
92             intent.putExtra(EXTRACT_NNAPI_SL, extractNnapiSl);
93         };
94     }
95 
96     @Override
init(Context context, Intent configParams, Optional<ProgressListener> progressListener)97     public void init(Context context, Intent configParams,
98             Optional<ProgressListener> progressListener) {
99         mTestList = configParams.getIntArrayExtra(MODELS);
100         mThreadCount = configParams.getIntExtra(THREADS, 10);
101         mTestDurationMillis = configParams.getLongExtra(DURATION, 1000 * 60 * 10);
102         mTestName = configParams.getStringExtra(TEST_NAME);
103         mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME);
104         mIgnoreUnsupportedModels = mAcceleratorName != null && configParams.getBooleanExtra(
105                 IGNORE_UNSUPPORTED_MODELS, false);
106         mRunModelCompilationOnly = configParams.getBooleanExtra(RUN_MODEL_COMPILATION_ONLY, false);
107         mMmapModel = configParams.getBooleanExtra(MEMORY_MAP_MODEL, false);
108         mUseNnapiSl = configParams.getBooleanExtra(USE_NNAPI_SL, false);
109         mExtractNnapiSl = configParams.getBooleanExtra(EXTRACT_NNAPI_SL, false);
110         mContext = context;
111         mProgressListener = progressListener.orElseGet(() -> (Optional<String> message) -> {
112             Log.v(CrashTest.TAG, message.orElse("."));
113         });
114         mExecutorService = Executors.newFixedThreadPool(mThreadCount);
115         mTestCompletionResults.clear();
116     }
117 
118     @Override
call()119     public Optional<String> call() {
120         mParallelTestComplete = new CountDownLatch(mThreadCount);
121         for (int i = 0; i < mThreadCount; i++) {
122             Processor testProcessor = createSubTestRunner(mTestList, i);
123 
124             activeTests.add(testProcessor);
125             mExecutorService.submit(testProcessor);
126         }
127 
128         return completedSuccessfully();
129     }
130 
createSubTestRunner(final int[] testList, final int testIndex)131     private Processor createSubTestRunner(final int[] testList, final int testIndex) {
132         final Processor result = new Processor(mContext, new Processor.Callback() {
133             @SuppressLint("DefaultLocale")
134             @Override
135             public void onBenchmarkFinish(boolean ok) {
136                 notifyProgress("Test '%s': Benchmark #%d completed %s", mTestName, testIndex,
137                         ok ? "successfully" : "with failure");
138                 mTestCompletionResults.add(ok);
139                 mParallelTestComplete.countDown();
140             }
141 
142             @Override
143             public void onStatusUpdate(int testNumber, int numTests, String modelName) {
144             }
145         }, testList);
146         result.setTfLiteBackend(TfLiteBackend.NNAPI);
147         result.setCompleteInputSet(false);
148         result.setNnApiAcceleratorName(mAcceleratorName);
149         result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels);
150         result.setRunModelCompilationOnly(mRunModelCompilationOnly);
151         result.setMmapModel(mMmapModel);
152         result.setUseNnApiSupportLibrary(mUseNnapiSl);
153         result.setExtractNnApiSupportLibrary(mExtractNnapiSl);
154         return result;
155     }
156 
endTests()157     private void endTests() {
158         ExecutorService terminatorsThreadPool = Executors.newFixedThreadPool(activeTests.size());
159         List<Future<?>> terminationCommands = new ArrayList<>();
160         for (final Processor test : activeTests) {
161             // Exit will block until the thread is completed
162             terminationCommands.add(terminatorsThreadPool.submit(
163                     () -> test.exitWithTimeout(Duration.ofSeconds(20).toMillis())));
164         }
165         terminationCommands.forEach(terminationCommand -> {
166             try {
167                 terminationCommand.get();
168             } catch (ExecutionException e) {
169                 Log.w(TAG, "Failure while waiting for completion of tests", e);
170             } catch (InterruptedException e) {
171                 Thread.interrupted();
172             }
173         });
174     }
175 
176     @SuppressLint("DefaultLocale")
notifyProgress(String messageFormat, Object... args)177     void notifyProgress(String messageFormat, Object... args) {
178         mProgressListener.testProgress(Optional.of(String.format(messageFormat, args)));
179     }
180 
181     // This method blocks until the tests complete and returns true if all tests completed
182     // successfully
183     @SuppressLint("DefaultLocale")
completedSuccessfully()184     private Optional<String> completedSuccessfully() {
185         try {
186             boolean testsEnded = mParallelTestComplete.await(mTestDurationMillis, MILLISECONDS);
187             if (!testsEnded) {
188                 Log.i(TAG,
189                         String.format(
190                                 "Test '%s': Tests are not completed (they might have been "
191                                         + "designed to run "
192                                         + "indefinitely. Forcing termination.", mTestName));
193                 endTests();
194             }
195         } catch (InterruptedException ignored) {
196             Thread.currentThread().interrupt();
197         }
198 
199         final long failedTestCount = mTestCompletionResults.stream().filter(
200                 testResult -> !testResult).count();
201         if (failedTestCount > 0) {
202             String failureMsg = String.format("Test '%s': %d out of %d test failed", mTestName,
203                     failedTestCount,
204                     mTestCompletionResults.size());
205             Log.w(CrashTest.TAG, failureMsg);
206             return failure(failureMsg);
207         } else {
208             Log.i(CrashTest.TAG,
209                     String.format("Test '%s': Test completed successfully", mTestName));
210             return success();
211         }
212     }
213 }
214