1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import static java.util.concurrent.TimeUnit.MILLISECONDS; 20 21 import android.content.Context; 22 import android.os.Trace; 23 import android.util.Log; 24 import android.util.Pair; 25 26 import com.android.nn.benchmark.core.TestModels.TestModelEntry; 27 import java.io.IOException; 28 import java.util.Collections; 29 import java.util.List; 30 import java.util.concurrent.CountDownLatch; 31 import java.util.concurrent.atomic.AtomicBoolean; 32 33 /** Processor is a helper thread for running the work without blocking the UI thread. */ 34 public class Processor implements Runnable { 35 36 37 public interface Callback { onBenchmarkFinish(boolean ok)38 void onBenchmarkFinish(boolean ok); 39 onStatusUpdate(int testNumber, int numTests, String modelName)40 void onStatusUpdate(int testNumber, int numTests, String modelName); 41 } 42 43 protected static final String TAG = "NN_BENCHMARK"; 44 private Context mContext; 45 46 private final AtomicBoolean mRun = new AtomicBoolean(true); 47 48 volatile boolean mHasBeenStarted = false; 49 // You cannot restart a thread, so the completion flag is final 50 private final CountDownLatch mCompleted = new CountDownLatch(1); 51 private NNTestBase mTest; 52 private int mTestList[]; 53 private BenchmarkResult mTestResults[]; 54 55 private Processor.Callback mCallback; 56 57 private TfLiteBackend mBackend; 58 private boolean mMmapModel; 59 private boolean mCompleteInputSet; 60 private boolean mToggleLong; 61 private boolean mTogglePause; 62 private String mAcceleratorName; 63 private boolean mIgnoreUnsupportedModels; 64 private boolean mRunModelCompilationOnly; 65 // Max number of benchmark iterations to do in run method. 66 // Less or equal to 0 means unlimited 67 private int mMaxRunIterations; 68 69 private boolean mBenchmarkCompilationCaching; 70 private float mCompilationBenchmarkWarmupTimeSeconds; 71 private float mCompilationBenchmarkRunTimeSeconds; 72 private int mCompilationBenchmarkMaxIterations; 73 74 // Used to avoid accessing the Instrumentation Arguments when the crash tests are spawning 75 // a separate process. 76 private String mModelFilterRegex; 77 78 private boolean mUseNnApiSupportLibrary; 79 private boolean mExtractNnApiSupportLibrary; 80 private String mNnApiSupportLibraryVendor; 81 Processor(Context context, Processor.Callback callback, int[] testList)82 public Processor(Context context, Processor.Callback callback, int[] testList) { 83 mContext = context; 84 mCallback = callback; 85 mTestList = testList; 86 if (mTestList != null) { 87 mTestResults = new BenchmarkResult[mTestList.length]; 88 } 89 mAcceleratorName = null; 90 mIgnoreUnsupportedModels = false; 91 mRunModelCompilationOnly = false; 92 mMaxRunIterations = 0; 93 mBenchmarkCompilationCaching = false; 94 mBackend = TfLiteBackend.CPU; 95 mModelFilterRegex = null; 96 mUseNnApiSupportLibrary = false; 97 mExtractNnApiSupportLibrary = false; 98 mNnApiSupportLibraryVendor = ""; 99 } 100 setUseNNApi(boolean useNNApi)101 public void setUseNNApi(boolean useNNApi) { 102 setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU); 103 } 104 setTfLiteBackend(TfLiteBackend backend)105 public void setTfLiteBackend(TfLiteBackend backend) { 106 mBackend = backend; 107 } 108 setCompleteInputSet(boolean completeInputSet)109 public void setCompleteInputSet(boolean completeInputSet) { 110 mCompleteInputSet = completeInputSet; 111 } 112 setToggleLong(boolean toggleLong)113 public void setToggleLong(boolean toggleLong) { 114 mToggleLong = toggleLong; 115 } 116 setTogglePause(boolean togglePause)117 public void setTogglePause(boolean togglePause) { 118 mTogglePause = togglePause; 119 } 120 setNnApiAcceleratorName(String acceleratorName)121 public void setNnApiAcceleratorName(String acceleratorName) { 122 mAcceleratorName = acceleratorName; 123 } 124 setIgnoreUnsupportedModels(boolean value)125 public void setIgnoreUnsupportedModels(boolean value) { 126 mIgnoreUnsupportedModels = value; 127 } 128 setRunModelCompilationOnly(boolean value)129 public void setRunModelCompilationOnly(boolean value) { 130 mRunModelCompilationOnly = value; 131 } 132 setMmapModel(boolean value)133 public void setMmapModel(boolean value) { 134 mMmapModel = value; 135 } 136 setMaxRunIterations(int value)137 public void setMaxRunIterations(int value) { 138 mMaxRunIterations = value; 139 } 140 setModelFilterRegex(String value)141 public void setModelFilterRegex(String value) { 142 this.mModelFilterRegex = value; 143 } 144 setUseNnApiSupportLibrary(boolean value)145 public void setUseNnApiSupportLibrary(boolean value) { mUseNnApiSupportLibrary = value; } setExtractNnApiSupportLibrary(boolean value)146 public void setExtractNnApiSupportLibrary(boolean value) { mExtractNnApiSupportLibrary = value; } setNnApiSupportLibraryVendor(String value)147 public void setNnApiSupportLibraryVendor(String value) { mNnApiSupportLibraryVendor = value; } 148 enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)149 public void enableCompilationCachingBenchmarks( 150 float warmupTimeSeconds, float runTimeSeconds, int maxIterations) { 151 mBenchmarkCompilationCaching = true; 152 mCompilationBenchmarkWarmupTimeSeconds = warmupTimeSeconds; 153 mCompilationBenchmarkRunTimeSeconds = runTimeSeconds; 154 mCompilationBenchmarkMaxIterations = maxIterations; 155 } 156 getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)157 public BenchmarkResult getInstrumentationResult( 158 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds) 159 throws IOException, BenchmarkException { 160 return getInstrumentationResult(t, warmupTimeSeconds, runTimeSeconds, false); 161 } 162 163 // Method to retrieve benchmark results for instrumentation tests. 164 // Returns null if the processor is configured to run compilation only getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults)165 public BenchmarkResult getInstrumentationResult( 166 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds, 167 boolean sampleResults) 168 throws IOException, BenchmarkException { 169 mTest = changeTest(mTest, t); 170 mTest.setSampleResult(sampleResults); 171 try { 172 BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark( 173 warmupTimeSeconds, 174 runTimeSeconds); 175 return result; 176 } finally { 177 mTest.destroy(); 178 mTest = null; 179 } 180 } 181 isTestModelSupportedByAccelerator(Context context, TestModels.TestModelEntry testModelEntry, String acceleratorName)182 public static boolean isTestModelSupportedByAccelerator(Context context, 183 TestModels.TestModelEntry testModelEntry, String acceleratorName) 184 throws NnApiDelegationFailure { 185 try (NNTestBase tb = testModelEntry.createNNTestBase(TfLiteBackend.NNAPI, 186 /*enableIntermediateTensorsDump=*/false, 187 /*mmapModel=*/ false, 188 NNTestBase.shouldUseNnApiSupportLibrary(), 189 NNTestBase.shouldExtractNnApiSupportLibrary(), 190 NNTestBase.getNnApiSupportLibraryVendor() 191 )) { 192 tb.setNNApiDeviceName(acceleratorName); 193 return tb.setupModel(context); 194 } catch (IOException e) { 195 Log.w(TAG, 196 String.format("Error trying to check support for model %s on accelerator %s", 197 testModelEntry.mModelName, acceleratorName), e); 198 return false; 199 } catch (NnApiDelegationFailure nnApiDelegationFailure) { 200 if (nnApiDelegationFailure.getNnApiErrno() == 4 /*ANEURALNETWORKS_BAD_DATA*/) { 201 // Compilation will fail with ANEURALNETWORKS_BAD_DATA if the device is not 202 // supporting all operation in the model 203 return false; 204 } 205 206 throw nnApiDelegationFailure; 207 } 208 } 209 changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)210 private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) 211 throws IOException, UnsupportedModelException, NnApiDelegationFailure { 212 if (oldTestBase != null) { 213 // Make sure we don't leak memory. 214 oldTestBase.destroy(); 215 } 216 NNTestBase tb = t.createNNTestBase(mBackend, /*enableIntermediateTensorsDump=*/false, 217 mMmapModel, mUseNnApiSupportLibrary, mExtractNnApiSupportLibrary, mNnApiSupportLibraryVendor); 218 if (mBackend == TfLiteBackend.NNAPI) { 219 tb.setNNApiDeviceName(mAcceleratorName); 220 } 221 if (!tb.setupModel(mContext)) { 222 throw new UnsupportedModelException("Cannot initialise model"); 223 } 224 return tb; 225 } 226 227 // Run one loop of kernels for at most the specified minimum time. 228 // The function returns the average time in ms for the test run runBenchmarkLoop(float maxTime, boolean completeInputSet)229 private BenchmarkResult runBenchmarkLoop(float maxTime, boolean completeInputSet) 230 throws IOException { 231 try { 232 // Run the kernel 233 Pair<List<InferenceInOutSequence>, List<InferenceResult>> results; 234 if (maxTime > 0.f) { 235 if (completeInputSet) { 236 results = mTest.runBenchmarkCompleteInputSet(1, maxTime); 237 } else { 238 results = mTest.runBenchmark(maxTime); 239 } 240 } else { 241 results = mTest.runInferenceOnce(); 242 } 243 return BenchmarkResult.fromInferenceResults( 244 mTest.getTestInfo(), 245 mBackend.toString(), 246 results.first, 247 results.second, 248 mTest.getEvaluator()); 249 } catch (BenchmarkException e) { 250 return new BenchmarkResult(e.getMessage()); 251 } 252 } 253 254 // Run one loop of compilations for at least the specified minimum time. 255 // The function will set the compilation results into the provided benchmark result object. runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, int maxIterations, BenchmarkResult benchmarkResult)256 private void runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, 257 int maxIterations, BenchmarkResult benchmarkResult) throws IOException { 258 try { 259 CompilationBenchmarkResult result = 260 mTest.runCompilationBenchmark(warmupMinTime, runMinTime, maxIterations); 261 benchmarkResult.setCompilationBenchmarkResult(result); 262 } catch (BenchmarkException e) { 263 benchmarkResult.setBenchmarkError(e.getMessage()); 264 } 265 } 266 getTestResults()267 public BenchmarkResult[] getTestResults() { 268 return mTestResults; 269 } 270 271 // Get a benchmark result for a specific test getBenchmark(float warmupTimeSeconds, float runTimeSeconds)272 private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds) 273 throws IOException { 274 try { 275 mTest.checkSdkVersion(); 276 } catch (UnsupportedSdkException e) { 277 BenchmarkResult r = new BenchmarkResult(e.getMessage()); 278 Log.w(TAG, "Unsupported SDK for test: " + r.toString()); 279 return r; 280 } 281 282 // We run a short bit of work before starting the actual test 283 // this is to let any power management do its job and respond. 284 // For NNAPI systrace usage documentation, see 285 // frameworks/ml/nn/common/include/Tracing.h. 286 try { 287 final String traceName = "[NN_LA_PWU]runBenchmarkLoop"; 288 Trace.beginSection(traceName); 289 runBenchmarkLoop(warmupTimeSeconds, false); 290 } finally { 291 Trace.endSection(); 292 } 293 294 // Run the actual benchmark 295 BenchmarkResult r; 296 try { 297 final String traceName = "[NN_LA_PBM]runBenchmarkLoop"; 298 Trace.beginSection(traceName); 299 r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet); 300 } finally { 301 Trace.endSection(); 302 } 303 304 // Compilation benchmark 305 if (mBenchmarkCompilationCaching) { 306 runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds, 307 mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r); 308 } 309 310 return r; 311 } 312 313 @Override run()314 public void run() { 315 mHasBeenStarted = true; 316 Log.d(TAG, "Processor starting"); 317 boolean success = true; 318 int benchmarkIterationsCount = 0; 319 try { 320 while (mRun.get()) { 321 if (mMaxRunIterations > 0 && benchmarkIterationsCount >= mMaxRunIterations) { 322 break; 323 } 324 benchmarkIterationsCount++; 325 try { 326 benchmarkAllModels(); 327 } catch (IOException | BenchmarkException e) { 328 Log.e(TAG, "Exception during benchmark run", e); 329 success = false; 330 break; 331 } catch (Throwable e) { 332 Log.e(TAG, "Error during execution", e); 333 throw e; 334 } 335 } 336 Log.d(TAG, "Processor completed work"); 337 mCallback.onBenchmarkFinish(success); 338 } finally { 339 if (mTest != null) { 340 // Make sure we don't leak memory. 341 mTest.destroy(); 342 mTest = null; 343 } 344 mCompleted.countDown(); 345 } 346 } 347 benchmarkAllModels()348 private void benchmarkAllModels() throws IOException, BenchmarkException { 349 final List<TestModelEntry> modelsList = TestModels.modelsList(mModelFilterRegex); 350 // Loop over the tests we want to benchmark 351 for (int ct = 0; ct < mTestList.length; ct++) { 352 if (!mRun.get()) { 353 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct)); 354 break; 355 } 356 // For reproducibility we wait a short time for any sporadic work 357 // created by the user touching the screen to launch the test to pass. 358 // Also allows for things to settle after the test changes. 359 try { 360 Thread.sleep(250); 361 } catch (InterruptedException ignored) { 362 Thread.currentThread().interrupt(); 363 break; 364 } 365 366 TestModels.TestModelEntry testModel = 367 modelsList.get(mTestList[ct]); 368 369 int testNumber = ct + 1; 370 mCallback.onStatusUpdate(testNumber, mTestList.length, 371 testModel.toString()); 372 373 // Select the next test 374 try { 375 mTest = changeTest(mTest, testModel); 376 } catch (UnsupportedModelException e) { 377 if (mIgnoreUnsupportedModels) { 378 Log.d(TAG, String.format( 379 "Cannot initialise test %d: '%s' on accelerator %s, skipping", ct, 380 testModel.mTestName, mAcceleratorName)); 381 } else { 382 Log.e(TAG, 383 String.format("Cannot initialise test %d: '%s' on accelerator %s.", ct, 384 testModel.mTestName, mAcceleratorName), e); 385 throw e; 386 } 387 } 388 389 // If the user selected the "long pause" option, wait 390 if (mTogglePause) { 391 for (int i = 0; (i < 100) && mRun.get(); i++) { 392 try { 393 Thread.sleep(100); 394 } catch (InterruptedException ignored) { 395 Thread.currentThread().interrupt(); 396 break; 397 } 398 } 399 } 400 401 if (mRunModelCompilationOnly) { 402 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName, 403 mBackend.toString(), 404 Collections.emptyList(), 405 Collections.emptyList(), null); 406 } else { 407 // Run the test 408 float warmupTime = 0.3f; 409 float runTime = 1.f; 410 if (mToggleLong) { 411 warmupTime = 2.f; 412 runTime = 10.f; 413 } 414 mTestResults[ct] = getBenchmark(warmupTime, runTime); 415 } 416 } 417 } 418 exit()419 public void exit() { 420 exitWithTimeout(-1l); 421 } 422 exitWithTimeout(long timeoutMs)423 public void exitWithTimeout(long timeoutMs) { 424 mRun.set(false); 425 426 if (mHasBeenStarted) { 427 Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs)); 428 try { 429 if (timeoutMs > 0) { 430 boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS); 431 if (!hasCompleted) { 432 Log.w(TAG, "Exiting before execution actually completed"); 433 } 434 } else { 435 mCompleted.await(); 436 } 437 } catch (InterruptedException e) { 438 Thread.currentThread().interrupt(); 439 Log.w(TAG, "Interrupted while waiting for Processor to complete", e); 440 } 441 } 442 443 Log.d(TAG, "Done, cleaning up"); 444 445 if (mTest != null) { 446 mTest.destroy(); 447 mTest = null; 448 } 449 } 450 } 451