1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 
21 import android.content.Context;
22 import android.os.Trace;
23 import android.util.Log;
24 import android.util.Pair;
25 
26 import com.android.nn.benchmark.core.TestModels.TestModelEntry;
27 import java.io.IOException;
28 import java.util.Collections;
29 import java.util.List;
30 import java.util.concurrent.CountDownLatch;
31 import java.util.concurrent.atomic.AtomicBoolean;
32 
33 /** Processor is a helper thread for running the work without blocking the UI thread. */
34 public class Processor implements Runnable {
35 
36 
37     public interface Callback {
onBenchmarkFinish(boolean ok)38         void onBenchmarkFinish(boolean ok);
39 
onStatusUpdate(int testNumber, int numTests, String modelName)40         void onStatusUpdate(int testNumber, int numTests, String modelName);
41     }
42 
43     protected static final String TAG = "NN_BENCHMARK";
44     private Context mContext;
45 
46     private final AtomicBoolean mRun = new AtomicBoolean(true);
47 
48     volatile boolean mHasBeenStarted = false;
49     // You cannot restart a thread, so the completion flag is final
50     private final CountDownLatch mCompleted = new CountDownLatch(1);
51     private NNTestBase mTest;
52     private int mTestList[];
53     private BenchmarkResult mTestResults[];
54 
55     private Processor.Callback mCallback;
56 
57     private TfLiteBackend mBackend;
58     private boolean mMmapModel;
59     private boolean mCompleteInputSet;
60     private boolean mToggleLong;
61     private boolean mTogglePause;
62     private String mAcceleratorName;
63     private boolean mIgnoreUnsupportedModels;
64     private boolean mRunModelCompilationOnly;
65     // Max number of benchmark iterations to do in run method.
66     // Less or equal to 0 means unlimited
67     private int mMaxRunIterations;
68 
69     private boolean mBenchmarkCompilationCaching;
70     private float mCompilationBenchmarkWarmupTimeSeconds;
71     private float mCompilationBenchmarkRunTimeSeconds;
72     private int mCompilationBenchmarkMaxIterations;
73 
74     // Used to avoid accessing the Instrumentation Arguments when the crash tests are spawning
75     // a separate process.
76     private String mModelFilterRegex;
77 
78     private boolean mUseNnApiSupportLibrary;
79     private boolean mExtractNnApiSupportLibrary;
80     private String mNnApiSupportLibraryVendor;
81 
Processor(Context context, Processor.Callback callback, int[] testList)82     public Processor(Context context, Processor.Callback callback, int[] testList) {
83         mContext = context;
84         mCallback = callback;
85         mTestList = testList;
86         if (mTestList != null) {
87             mTestResults = new BenchmarkResult[mTestList.length];
88         }
89         mAcceleratorName = null;
90         mIgnoreUnsupportedModels = false;
91         mRunModelCompilationOnly = false;
92         mMaxRunIterations = 0;
93         mBenchmarkCompilationCaching = false;
94         mBackend = TfLiteBackend.CPU;
95         mModelFilterRegex = null;
96         mUseNnApiSupportLibrary = false;
97         mExtractNnApiSupportLibrary = false;
98         mNnApiSupportLibraryVendor = "";
99     }
100 
setUseNNApi(boolean useNNApi)101     public void setUseNNApi(boolean useNNApi) {
102         setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
103     }
104 
setTfLiteBackend(TfLiteBackend backend)105     public void setTfLiteBackend(TfLiteBackend backend) {
106         mBackend = backend;
107     }
108 
setCompleteInputSet(boolean completeInputSet)109     public void setCompleteInputSet(boolean completeInputSet) {
110         mCompleteInputSet = completeInputSet;
111     }
112 
setToggleLong(boolean toggleLong)113     public void setToggleLong(boolean toggleLong) {
114         mToggleLong = toggleLong;
115     }
116 
setTogglePause(boolean togglePause)117     public void setTogglePause(boolean togglePause) {
118         mTogglePause = togglePause;
119     }
120 
setNnApiAcceleratorName(String acceleratorName)121     public void setNnApiAcceleratorName(String acceleratorName) {
122         mAcceleratorName = acceleratorName;
123     }
124 
setIgnoreUnsupportedModels(boolean value)125     public void setIgnoreUnsupportedModels(boolean value) {
126         mIgnoreUnsupportedModels = value;
127     }
128 
setRunModelCompilationOnly(boolean value)129     public void setRunModelCompilationOnly(boolean value) {
130         mRunModelCompilationOnly = value;
131     }
132 
setMmapModel(boolean value)133     public void setMmapModel(boolean value) {
134         mMmapModel = value;
135     }
136 
setMaxRunIterations(int value)137     public void setMaxRunIterations(int value) {
138         mMaxRunIterations = value;
139     }
140 
setModelFilterRegex(String value)141     public void setModelFilterRegex(String value) {
142         this.mModelFilterRegex = value;
143     }
144 
setUseNnApiSupportLibrary(boolean value)145     public void setUseNnApiSupportLibrary(boolean value) { mUseNnApiSupportLibrary = value; }
setExtractNnApiSupportLibrary(boolean value)146     public void setExtractNnApiSupportLibrary(boolean value) { mExtractNnApiSupportLibrary = value; }
setNnApiSupportLibraryVendor(String value)147     public void setNnApiSupportLibraryVendor(String value) { mNnApiSupportLibraryVendor = value; }
148 
enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)149     public void enableCompilationCachingBenchmarks(
150             float warmupTimeSeconds, float runTimeSeconds, int maxIterations) {
151         mBenchmarkCompilationCaching = true;
152         mCompilationBenchmarkWarmupTimeSeconds = warmupTimeSeconds;
153         mCompilationBenchmarkRunTimeSeconds = runTimeSeconds;
154         mCompilationBenchmarkMaxIterations = maxIterations;
155     }
156 
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)157     public BenchmarkResult getInstrumentationResult(
158             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
159             throws IOException, BenchmarkException {
160         return getInstrumentationResult(t, warmupTimeSeconds, runTimeSeconds, false);
161     }
162 
163     // Method to retrieve benchmark results for instrumentation tests.
164     // Returns null if the processor is configured to run compilation only
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)165     public BenchmarkResult getInstrumentationResult(
166             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds,
167             boolean sampleResults)
168             throws IOException, BenchmarkException {
169         mTest = changeTest(mTest, t);
170         mTest.setSampleResult(sampleResults);
171         try {
172             BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(
173                     warmupTimeSeconds,
174                     runTimeSeconds);
175             return result;
176         } finally {
177             mTest.destroy();
178             mTest = null;
179         }
180     }
181 
isTestModelSupportedByAccelerator(Context context, TestModels.TestModelEntry testModelEntry, String acceleratorName)182     public static boolean isTestModelSupportedByAccelerator(Context context,
183             TestModels.TestModelEntry testModelEntry, String acceleratorName)
184             throws NnApiDelegationFailure {
185         try (NNTestBase tb = testModelEntry.createNNTestBase(TfLiteBackend.NNAPI,
186                 /*enableIntermediateTensorsDump=*/false,
187                 /*mmapModel=*/ false,
188                 NNTestBase.shouldUseNnApiSupportLibrary(),
189                 NNTestBase.shouldExtractNnApiSupportLibrary(),
190                 NNTestBase.getNnApiSupportLibraryVendor()
191             )) {
192             tb.setNNApiDeviceName(acceleratorName);
193             return tb.setupModel(context);
194         } catch (IOException e) {
195             Log.w(TAG,
196                     String.format("Error trying to check support for model %s on accelerator %s",
197                             testModelEntry.mModelName, acceleratorName), e);
198             return false;
199         } catch (NnApiDelegationFailure nnApiDelegationFailure) {
200             if (nnApiDelegationFailure.getNnApiErrno() == 4 /*ANEURALNETWORKS_BAD_DATA*/) {
201                 // Compilation will fail with ANEURALNETWORKS_BAD_DATA if the device is not
202                 // supporting all operation in the model
203                 return false;
204             }
205 
206             throw nnApiDelegationFailure;
207         }
208     }
209 
changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)210     private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)
211             throws IOException, UnsupportedModelException, NnApiDelegationFailure {
212         if (oldTestBase != null) {
213             // Make sure we don't leak memory.
214             oldTestBase.destroy();
215         }
216         NNTestBase tb = t.createNNTestBase(mBackend, /*enableIntermediateTensorsDump=*/false,
217             mMmapModel, mUseNnApiSupportLibrary, mExtractNnApiSupportLibrary, mNnApiSupportLibraryVendor);
218         if (mBackend == TfLiteBackend.NNAPI) {
219             tb.setNNApiDeviceName(mAcceleratorName);
220         }
221         if (!tb.setupModel(mContext)) {
222             throw new UnsupportedModelException("Cannot initialise model");
223         }
224         return tb;
225     }
226 
227     // Run one loop of kernels for at most the specified minimum time.
228     // The function returns the average time in ms for the test run
runBenchmarkLoop(float maxTime, boolean completeInputSet)229     private BenchmarkResult runBenchmarkLoop(float maxTime, boolean completeInputSet)
230             throws IOException {
231         try {
232             // Run the kernel
233             Pair<List<InferenceInOutSequence>, List<InferenceResult>> results;
234             if (maxTime > 0.f) {
235                 if (completeInputSet) {
236                     results = mTest.runBenchmarkCompleteInputSet(1, maxTime);
237                 } else {
238                     results = mTest.runBenchmark(maxTime);
239                 }
240             } else {
241                 results = mTest.runInferenceOnce();
242             }
243             return BenchmarkResult.fromInferenceResults(
244                     mTest.getTestInfo(),
245                     mBackend.toString(),
246                     results.first,
247                     results.second,
248                     mTest.getEvaluator());
249         } catch (BenchmarkException e) {
250             return new BenchmarkResult(e.getMessage());
251         }
252     }
253 
254     // Run one loop of compilations for at least the specified minimum time.
255     // The function will set the compilation results into the provided benchmark result object.
runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, int maxIterations, BenchmarkResult benchmarkResult)256     private void runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime,
257             int maxIterations, BenchmarkResult benchmarkResult) throws IOException {
258         try {
259             CompilationBenchmarkResult result =
260                     mTest.runCompilationBenchmark(warmupMinTime, runMinTime, maxIterations);
261             benchmarkResult.setCompilationBenchmarkResult(result);
262         } catch (BenchmarkException e) {
263             benchmarkResult.setBenchmarkError(e.getMessage());
264         }
265     }
266 
getTestResults()267     public BenchmarkResult[] getTestResults() {
268         return mTestResults;
269     }
270 
271     // Get a benchmark result for a specific test
getBenchmark(float warmupTimeSeconds, float runTimeSeconds)272     private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds)
273             throws IOException {
274         try {
275             mTest.checkSdkVersion();
276         } catch (UnsupportedSdkException e) {
277             BenchmarkResult r = new BenchmarkResult(e.getMessage());
278             Log.w(TAG, "Unsupported SDK for test: " + r.toString());
279             return r;
280         }
281 
282         // We run a short bit of work before starting the actual test
283         // this is to let any power management do its job and respond.
284         // For NNAPI systrace usage documentation, see
285         // frameworks/ml/nn/common/include/Tracing.h.
286         try {
287             final String traceName = "[NN_LA_PWU]runBenchmarkLoop";
288             Trace.beginSection(traceName);
289             runBenchmarkLoop(warmupTimeSeconds, false);
290         } finally {
291             Trace.endSection();
292         }
293 
294         // Run the actual benchmark
295         BenchmarkResult r;
296         try {
297             final String traceName = "[NN_LA_PBM]runBenchmarkLoop";
298             Trace.beginSection(traceName);
299             r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet);
300         } finally {
301             Trace.endSection();
302         }
303 
304         // Compilation benchmark
305         if (mBenchmarkCompilationCaching) {
306             runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds,
307                     mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r);
308         }
309 
310         return r;
311     }
312 
313     @Override
run()314     public void run() {
315         mHasBeenStarted = true;
316         Log.d(TAG, "Processor starting");
317         boolean success = true;
318         int benchmarkIterationsCount = 0;
319         try {
320             while (mRun.get()) {
321                 if (mMaxRunIterations > 0 && benchmarkIterationsCount >= mMaxRunIterations) {
322                     break;
323                 }
324                 benchmarkIterationsCount++;
325                 try {
326                     benchmarkAllModels();
327                 } catch (IOException | BenchmarkException e) {
328                     Log.e(TAG, "Exception during benchmark run", e);
329                     success = false;
330                     break;
331                 } catch (Throwable e) {
332                     Log.e(TAG, "Error during execution", e);
333                     throw e;
334                 }
335             }
336             Log.d(TAG, "Processor completed work");
337             mCallback.onBenchmarkFinish(success);
338         } finally {
339             if (mTest != null) {
340                 // Make sure we don't leak memory.
341                 mTest.destroy();
342                 mTest = null;
343             }
344             mCompleted.countDown();
345         }
346     }
347 
benchmarkAllModels()348     private void benchmarkAllModels() throws IOException, BenchmarkException {
349         final List<TestModelEntry> modelsList = TestModels.modelsList(mModelFilterRegex);
350         // Loop over the tests we want to benchmark
351         for (int ct = 0; ct < mTestList.length; ct++) {
352             if (!mRun.get()) {
353                 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct));
354                 break;
355             }
356             // For reproducibility we wait a short time for any sporadic work
357             // created by the user touching the screen to launch the test to pass.
358             // Also allows for things to settle after the test changes.
359             try {
360                 Thread.sleep(250);
361             } catch (InterruptedException ignored) {
362                 Thread.currentThread().interrupt();
363                 break;
364             }
365 
366             TestModels.TestModelEntry testModel =
367                     modelsList.get(mTestList[ct]);
368 
369             int testNumber = ct + 1;
370             mCallback.onStatusUpdate(testNumber, mTestList.length,
371                     testModel.toString());
372 
373             // Select the next test
374             try {
375                 mTest = changeTest(mTest, testModel);
376             } catch (UnsupportedModelException e) {
377                 if (mIgnoreUnsupportedModels) {
378                     Log.d(TAG, String.format(
379                             "Cannot initialise test %d: '%s' on accelerator %s, skipping", ct,
380                             testModel.mTestName, mAcceleratorName));
381                 } else {
382                     Log.e(TAG,
383                             String.format("Cannot initialise test %d: '%s'  on accelerator %s.", ct,
384                                     testModel.mTestName, mAcceleratorName), e);
385                     throw e;
386                 }
387             }
388 
389             // If the user selected the "long pause" option, wait
390             if (mTogglePause) {
391                 for (int i = 0; (i < 100) && mRun.get(); i++) {
392                     try {
393                         Thread.sleep(100);
394                     } catch (InterruptedException ignored) {
395                         Thread.currentThread().interrupt();
396                         break;
397                     }
398                 }
399             }
400 
401             if (mRunModelCompilationOnly) {
402                 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName,
403                         mBackend.toString(),
404                         Collections.emptyList(),
405                         Collections.emptyList(), null);
406             } else {
407                 // Run the test
408                 float warmupTime = 0.3f;
409                 float runTime = 1.f;
410                 if (mToggleLong) {
411                     warmupTime = 2.f;
412                     runTime = 10.f;
413                 }
414                 mTestResults[ct] = getBenchmark(warmupTime, runTime);
415             }
416         }
417     }
418 
exit()419     public void exit() {
420         exitWithTimeout(-1l);
421     }
422 
exitWithTimeout(long timeoutMs)423     public void exitWithTimeout(long timeoutMs) {
424         mRun.set(false);
425 
426         if (mHasBeenStarted) {
427             Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs));
428             try {
429                 if (timeoutMs > 0) {
430                     boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS);
431                     if (!hasCompleted) {
432                         Log.w(TAG, "Exiting before execution actually completed");
433                     }
434                 } else {
435                     mCompleted.await();
436                 }
437             } catch (InterruptedException e) {
438                 Thread.currentThread().interrupt();
439                 Log.w(TAG, "Interrupted while waiting for Processor to complete", e);
440             }
441         }
442 
443         Log.d(TAG, "Done, cleaning up");
444 
445         if (mTest != null) {
446             mTest.destroy();
447             mTest = null;
448         }
449     }
450 }
451