1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.annotation.SuppressLint; 20 import android.content.Context; 21 import android.content.res.AssetManager; 22 import android.os.Build; 23 import android.system.Os; 24 import android.system.ErrnoException; 25 import android.util.Log; 26 import android.util.Pair; 27 import android.widget.TextView; 28 import androidx.test.InstrumentationRegistry; 29 import com.android.nn.benchmark.core.sl.ArmSupportLibraryDriverHandler; 30 import com.android.nn.benchmark.core.sl.MediaTekSupportLibraryDriverHandler; 31 import com.android.nn.benchmark.core.sl.QualcommSupportLibraryDriverHandler; 32 import com.android.nn.benchmark.core.sl.SupportLibraryDriverHandler; 33 import java.io.BufferedReader; 34 import java.io.File; 35 import java.io.FileNotFoundException; 36 import java.io.FileOutputStream; 37 import java.io.IOException; 38 import java.io.InputStream; 39 import java.io.InputStreamReader; 40 import java.io.OutputStream; 41 import java.util.ArrayList; 42 import java.util.Collections; 43 import java.util.function.Supplier; 44 import java.util.HashMap; 45 import java.util.List; 46 import java.util.Optional; 47 import java.util.Random; 48 import java.util.stream.Collectors; 49 import dalvik.system.BaseDexClassLoader; 50 import android.content.res.AssetFileDescriptor; 51 import android.os.ParcelFileDescriptor; 52 import android.os.ParcelFileDescriptor.AutoCloseInputStream; 53 import java.util.jar.JarFile; 54 import java.util.jar.JarEntry; 55 56 public class NNTestBase implements AutoCloseable { 57 protected static final String TAG = "NN_TESTBASE"; 58 59 // Used to load the 'native-lib' library on application startup. 60 static { 61 System.loadLibrary("nnbenchmark_jni"); 62 } 63 64 // Does the device has any NNAPI accelerator? 65 // We only consider a real device, not 'nnapi-reference'. hasAccelerator()66 public static native boolean hasAccelerator(); 67 68 /** 69 * Fills resultList with the name of the available NNAPI accelerators 70 * 71 * @return False if any error occurred, true otherwise 72 */ getAcceleratorNames(List<String> resultList)73 public static native boolean getAcceleratorNames(List<String> resultList); hasNnApiDevice(String nnApiDeviceName)74 public static native boolean hasNnApiDevice(String nnApiDeviceName); 75 initModel( String modelFileName, int tfliteBackend, boolean enableIntermediateTensorsDump, String nnApiDeviceName, boolean mmapModel, String nnApiCacheDir, long nnApiLibHandle)76 private synchronized native long initModel( 77 String modelFileName, 78 int tfliteBackend, 79 boolean enableIntermediateTensorsDump, 80 String nnApiDeviceName, 81 boolean mmapModel, 82 String nnApiCacheDir, 83 long nnApiLibHandle) throws NnApiDelegationFailure; 84 destroyModel(long modelHandle)85 private synchronized native void destroyModel(long modelHandle); 86 resizeInputTensors(long modelHandle, int[] inputShape)87 private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape); 88 runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)89 private synchronized native boolean runBenchmark(long modelHandle, 90 List<InferenceInOutSequence> inOutList, 91 List<InferenceResult> resultList, 92 int inferencesSeqMaxCount, 93 float timeoutSec, 94 int flags); 95 runCompilationBenchmark( long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec, boolean useNnapiSl)96 private synchronized native CompilationBenchmarkResult runCompilationBenchmark( 97 long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec, 98 boolean useNnapiSl); 99 dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)100 private synchronized native void dumpAllLayers( 101 long modelHandle, 102 String dumpPath, 103 List<InferenceInOutSequence> inOutList); 104 availableAcceleratorNames()105 public static List<String> availableAcceleratorNames() { 106 List<String> availableAccelerators = new ArrayList<>(); 107 if (NNTestBase.getAcceleratorNames(availableAccelerators)) { 108 return availableAccelerators.stream().filter( 109 acceleratorName -> !acceleratorName.equalsIgnoreCase( 110 "nnapi-reference")).collect(Collectors.toList()); 111 } else { 112 Log.e(TAG, "Unable to retrieve accelerator names!!"); 113 return Collections.EMPTY_LIST; 114 } 115 } 116 117 /** Discard inference output in inference results. */ 118 public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0; 119 /** 120 * Do not expect golden outputs with inference inputs. 121 * 122 * Useful in cases where there's no straightforward golden output values 123 * for the benchmark. This will also skip calculating basic (golden 124 * output based) error metrics. 125 */ 126 public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1; 127 128 129 /** Collect only 1 benchmark result every 10 **/ 130 public static final int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2; 131 132 protected Context mContext; 133 protected TextView mText; 134 private final String mModelName; 135 private final String mModelFile; 136 private long mModelHandle; 137 private final int[] mInputShape; 138 private final InferenceInOutSequence.FromAssets[] mInputOutputAssets; 139 private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets; 140 private final EvaluatorConfig mEvaluatorConfig; 141 private EvaluatorInterface mEvaluator; 142 private boolean mHasGoldenOutputs; 143 private TfLiteBackend mTfLiteBackend; 144 private boolean mEnableIntermediateTensorsDump = false; 145 private final int mMinSdkVersion; 146 private Optional<String> mNNApiDeviceName = Optional.empty(); 147 private boolean mMmapModel = false; 148 // Path where the current model has been stored for execution 149 private String mTemporaryModelFilePath; 150 private boolean mSampleResults; 151 152 // If set to true the test will look for the NNAPI SL binaries in the app resources, 153 // copy them into the app cache dir and configure the TfLite test to load NNAPI 154 // from the library. 155 private boolean mUseNnApiSupportLibrary = false; 156 private boolean mExtractNnApiSupportLibrary = false; 157 private String mNnApiSupportLibraryVendor = ""; 158 159 static final String USE_NNAPI_SL_PROPERTY = "useNnApiSupportLibrary"; 160 static final String EXTRACT_NNAPI_SL_PROPERTY = "extractNnApiSupportLibrary"; 161 static final String NNAPI_SL_VENDOR = "nnApiSupportLibraryVendor"; 162 getBooleanTestParameter(String key, boolean defaultValue)163 private static boolean getBooleanTestParameter(String key, boolean defaultValue) { 164 // All instrumentation arguments are passed as String so I have to convert the value here. 165 return Boolean.parseBoolean( 166 InstrumentationRegistry.getArguments().getString(key, "" + defaultValue)); 167 } 168 shouldUseNnApiSupportLibrary()169 public static boolean shouldUseNnApiSupportLibrary() { 170 return getBooleanTestParameter(USE_NNAPI_SL_PROPERTY, false); 171 } 172 shouldExtractNnApiSupportLibrary()173 public static boolean shouldExtractNnApiSupportLibrary() { 174 return getBooleanTestParameter(EXTRACT_NNAPI_SL_PROPERTY, false); 175 } 176 getNnApiSupportLibraryVendor()177 public static String getNnApiSupportLibraryVendor() { 178 return InstrumentationRegistry.getArguments().getString(NNAPI_SL_VENDOR); 179 } 180 NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)181 public NNTestBase(String modelName, String modelFile, int[] inputShape, 182 InferenceInOutSequence.FromAssets[] inputOutputAssets, 183 InferenceInOutSequence.FromDataset[] inputOutputDatasets, 184 EvaluatorConfig evaluator, int minSdkVersion) { 185 if (inputOutputAssets == null && inputOutputDatasets == null) { 186 throw new IllegalArgumentException( 187 "Neither inputOutputAssets or inputOutputDatasets given - no inputs"); 188 } 189 if (inputOutputAssets != null && inputOutputDatasets != null) { 190 throw new IllegalArgumentException( 191 "Both inputOutputAssets or inputOutputDatasets given. Only one" + 192 "supported at once."); 193 } 194 mModelName = modelName; 195 mModelFile = modelFile; 196 mInputShape = inputShape; 197 mInputOutputAssets = inputOutputAssets; 198 mInputOutputDatasets = inputOutputDatasets; 199 mModelHandle = 0; 200 mEvaluatorConfig = evaluator; 201 mMinSdkVersion = minSdkVersion; 202 mSampleResults = false; 203 } 204 setTfLiteBackend(TfLiteBackend tfLiteBackend)205 public void setTfLiteBackend(TfLiteBackend tfLiteBackend) { 206 mTfLiteBackend = tfLiteBackend; 207 } 208 enableIntermediateTensorsDump()209 public void enableIntermediateTensorsDump() { 210 enableIntermediateTensorsDump(true); 211 } 212 enableIntermediateTensorsDump(boolean value)213 public void enableIntermediateTensorsDump(boolean value) { 214 mEnableIntermediateTensorsDump = value; 215 } 216 useNNApi()217 public void useNNApi() { 218 setTfLiteBackend(TfLiteBackend.NNAPI); 219 } 220 setUseNnApiSupportLibrary(boolean value)221 public void setUseNnApiSupportLibrary(boolean value) {mUseNnApiSupportLibrary = value;} setExtractNnApiSupportLibrary(boolean value)222 public void setExtractNnApiSupportLibrary(boolean value) {mExtractNnApiSupportLibrary = value;} setNnApiSupportLibraryVendor(String value)223 public void setNnApiSupportLibraryVendor(String value) {mNnApiSupportLibraryVendor = value;} 224 setNNApiDeviceName(String value)225 public void setNNApiDeviceName(String value) { 226 if (mTfLiteBackend != TfLiteBackend.NNAPI) { 227 Log.e(TAG, "Setting device name has no effect when not using NNAPI"); 228 } 229 mNNApiDeviceName = Optional.ofNullable(value); 230 } 231 setMmapModel(boolean value)232 public void setMmapModel(boolean value) { 233 mMmapModel = value; 234 } 235 setupModel(Context ipcxt)236 public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure { 237 mContext = ipcxt; 238 long nnApiLibHandle = 0; 239 if (mUseNnApiSupportLibrary) { 240 HashMap<String, Supplier<SupportLibraryDriverHandler>> vendors = new HashMap<>(); 241 vendors.put("qc", () -> new QualcommSupportLibraryDriverHandler()); 242 vendors.put("arm", () -> new ArmSupportLibraryDriverHandler()); 243 vendors.put("mtk", () -> new MediaTekSupportLibraryDriverHandler()); 244 Supplier<SupportLibraryDriverHandler> vendor = vendors.get(mNnApiSupportLibraryVendor); 245 if (vendor == null) { 246 throw new NnApiDelegationFailure(String 247 .format("NNAPI SL vendor is invalid '%s', expected one of %s.", 248 mNnApiSupportLibraryVendor, vendors.keySet().toString())); 249 } 250 SupportLibraryDriverHandler slHandler = vendor.get(); 251 nnApiLibHandle = slHandler.getOrLoadNnApiSlHandle(mContext, mExtractNnApiSupportLibrary); 252 if (nnApiLibHandle == 0) { 253 Log.e(TAG, String 254 .format("Unable to find NNAPI SL entry point '%s' in embedded libraries path.", 255 SupportLibraryDriverHandler.NNAPI_SL_LIB_NAME)); 256 throw new NnApiDelegationFailure(String 257 .format("Unable to find NNAPI SL entry point '%s' in embedded libraries path.", 258 SupportLibraryDriverHandler.NNAPI_SL_LIB_NAME)); 259 } 260 } 261 if (mTemporaryModelFilePath != null) { 262 deleteOrWarn(mTemporaryModelFilePath); 263 } 264 mTemporaryModelFilePath = copyAssetToFile(); 265 String nnApiCacheDir = mContext.getCodeCacheDir().toString(); 266 mModelHandle = initModel( 267 mTemporaryModelFilePath, mTfLiteBackend.ordinal(), mEnableIntermediateTensorsDump, 268 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir, nnApiLibHandle); 269 if (mModelHandle == 0) { 270 Log.e(TAG, "Failed to init the model"); 271 return false; 272 } 273 if (!resizeInputTensors(mModelHandle, mInputShape)) { 274 return false; 275 } 276 277 if (mEvaluatorConfig != null) { 278 mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets()); 279 } 280 return true; 281 } 282 getTestInfo()283 public String getTestInfo() { 284 return mModelName; 285 } 286 getEvaluator()287 public EvaluatorInterface getEvaluator() { 288 return mEvaluator; 289 } 290 checkSdkVersion()291 public void checkSdkVersion() throws UnsupportedSdkException { 292 if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) { 293 throw new UnsupportedSdkException("SDK version not supported. Mininum required: " + 294 mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT); 295 } 296 } 297 deleteOrWarn(String path)298 private void deleteOrWarn(String path) { 299 if (!new File(path).delete()) { 300 Log.w(TAG, String.format( 301 "Unable to delete file '%s'. This might cause device to run out of space.", 302 path)); 303 } 304 } 305 306 getInputOutputAssets()307 private List<InferenceInOutSequence> getInputOutputAssets() throws IOException { 308 // TODO: Caching, don't read inputs for every inference 309 List<InferenceInOutSequence> inOutList = 310 getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets); 311 312 Boolean lastGolden = null; 313 for (InferenceInOutSequence sequence : inOutList) { 314 mHasGoldenOutputs = sequence.hasGoldenOutput(); 315 if (lastGolden == null) { 316 lastGolden = mHasGoldenOutputs; 317 } else { 318 if (lastGolden != mHasGoldenOutputs) { 319 throw new IllegalArgumentException( 320 "Some inputs for " + mModelName + " have outputs while some don't."); 321 } 322 } 323 } 324 return inOutList; 325 } 326 getInputOutputAssets(Context context, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets)327 public static List<InferenceInOutSequence> getInputOutputAssets(Context context, 328 InferenceInOutSequence.FromAssets[] inputOutputAssets, 329 InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException { 330 // TODO: Caching, don't read inputs for every inference 331 List<InferenceInOutSequence> inOutList = new ArrayList<>(); 332 if (inputOutputAssets != null) { 333 for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) { 334 inOutList.add(ioAsset.readAssets(context.getAssets())); 335 } 336 } 337 if (inputOutputDatasets != null) { 338 for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) { 339 inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir())); 340 } 341 } 342 343 return inOutList; 344 } 345 getDefaultFlags()346 public int getDefaultFlags() { 347 int flags = 0; 348 if (!mHasGoldenOutputs) { 349 flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT; 350 } 351 if (mEvaluator == null) { 352 flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT; 353 } 354 // For very long tests we will collect only a sample of the results 355 if (mSampleResults) { 356 flags = flags | FLAG_SAMPLE_BENCHMARK_RESULTS; 357 } 358 return flags; 359 } 360 dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)361 public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize) 362 throws IOException { 363 if (!dumpDir.exists() || !dumpDir.isDirectory()) { 364 throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory"); 365 } 366 if (!mEnableIntermediateTensorsDump) { 367 throw new IllegalStateException("mEnableIntermediateTensorsDump is " + 368 "set to false, impossible to proceed"); 369 } 370 371 List<InferenceInOutSequence> ios = getInputOutputAssets(); 372 dumpAllLayers(mModelHandle, dumpDir.toString(), 373 ios.subList(inputAssetIndex, inputAssetSize)); 374 } 375 runInferenceOnce()376 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce() 377 throws IOException, BenchmarkException { 378 List<InferenceInOutSequence> ios = getInputOutputAssets(); 379 int flags = getDefaultFlags(); 380 Pair<List<InferenceInOutSequence>, List<InferenceResult>> output = 381 runBenchmark(ios, 1, Float.MAX_VALUE, flags); 382 return output; 383 } 384 runBenchmark(float timeoutSec)385 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec) 386 throws IOException, BenchmarkException { 387 // Run as many as possible before timeout. 388 int flags = getDefaultFlags(); 389 return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags); 390 } 391 392 /** Run through whole input set (once or multiple times). */ runBenchmarkCompleteInputSet( int minInferences, float timeoutSec)393 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet( 394 int minInferences, 395 float timeoutSec) 396 throws IOException, BenchmarkException { 397 int flags = getDefaultFlags(); 398 List<InferenceInOutSequence> ios = getInputOutputAssets(); 399 int setInferences = 0; 400 for (InferenceInOutSequence iosSeq : ios) { 401 setInferences += iosSeq.size(); 402 } 403 int setRepeat = (minInferences + setInferences - 1) / setInferences; // ceil. 404 int totalSequenceInferencesCount = ios.size() * setRepeat; 405 int expectedResults = setInferences * setRepeat; 406 407 Pair<List<InferenceInOutSequence>, List<InferenceResult>> result = 408 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec, 409 flags); 410 if (result.second.size() != expectedResults) { 411 // We reached a timeout or failed to evaluate whole set for other reason, abort. 412 @SuppressLint("DefaultLocale") 413 final String errorMsg = String.format( 414 "Failed to evaluate complete input set, in %f seconds expected: %d, received:" 415 + " %d", 416 timeoutSec, expectedResults, result.second.size()); 417 Log.w(TAG, errorMsg); 418 throw new IllegalStateException(errorMsg); 419 } 420 return result; 421 } 422 runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)423 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark( 424 List<InferenceInOutSequence> inOutList, 425 int inferencesSeqMaxCount, 426 float timeoutSec, 427 int flags) 428 throws IOException, BenchmarkException { 429 if (mModelHandle == 0) { 430 throw new UnsupportedModelException("Unsupported model"); 431 } 432 List<InferenceResult> resultList = new ArrayList<>(); 433 if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount, 434 timeoutSec, flags)) { 435 throw new BenchmarkException("Failed to run benchmark"); 436 } 437 return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>( 438 inOutList, resultList); 439 } 440 runCompilationBenchmark(float warmupTimeoutSec, float runTimeoutSec, int maxIterations)441 public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec, 442 float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException { 443 if (mModelHandle == 0) { 444 throw new UnsupportedModelException("Unsupported model"); 445 } 446 CompilationBenchmarkResult result = runCompilationBenchmark( 447 mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec, 448 shouldUseNnApiSupportLibrary()); 449 if (result == null) { 450 throw new BenchmarkException("Failed to run compilation benchmark"); 451 } 452 return result; 453 } 454 destroy()455 public void destroy() { 456 if (mModelHandle != 0) { 457 destroyModel(mModelHandle); 458 mModelHandle = 0; 459 } 460 if (mTemporaryModelFilePath != null) { 461 deleteOrWarn(mTemporaryModelFilePath); 462 mTemporaryModelFilePath = null; 463 } 464 } 465 466 private final Random mRandom = new Random(System.currentTimeMillis()); 467 468 // We need to copy it to cache dir, so that TFlite can load it directly. copyAssetToFile()469 private String copyAssetToFile() throws IOException { 470 @SuppressLint("DefaultLocale") 471 String outFileName = 472 String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(), 473 mModelFile, 474 Thread.currentThread().getId(), mRandom.nextInt(10000)); 475 476 copyAssetToFile(mContext, mModelFile + ".tflite", outFileName); 477 return outFileName; 478 } 479 copyModelToFile(Context context, String modelFileName, File targetFile)480 public static boolean copyModelToFile(Context context, String modelFileName, File targetFile) 481 throws IOException { 482 if (!targetFile.exists() && !targetFile.createNewFile()) { 483 Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath())); 484 return false; 485 } 486 NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath()); 487 return true; 488 } 489 copyAssetToFile(Context context, String modelAssetName, String targetPath)490 public static void copyAssetToFile(Context context, String modelAssetName, String targetPath) 491 throws IOException { 492 AssetManager assetManager = context.getAssets(); 493 try { 494 File outFile = new File(targetPath); 495 496 try (InputStream in = assetManager.open(modelAssetName); 497 FileOutputStream out = new FileOutputStream(outFile)) { 498 copyFull(in, out); 499 } 500 } catch (IOException e) { 501 Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e); 502 throw e; 503 } 504 } 505 copyFull(InputStream in, OutputStream out)506 public static void copyFull(InputStream in, OutputStream out) throws IOException { 507 byte[] byteBuffer = new byte[1024]; 508 int readBytes = -1; 509 while ((readBytes = in.read(byteBuffer)) != -1) { 510 out.write(byteBuffer, 0, readBytes); 511 } 512 } 513 514 @Override close()515 public void close() { 516 destroy(); 517 } 518 setSampleResult(boolean sampleResults)519 public void setSampleResult(boolean sampleResults) { 520 this.mSampleResults = sampleResults; 521 } 522 } 523