1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.annotation.SuppressLint;
20 import android.content.Context;
21 import android.content.res.AssetManager;
22 import android.os.Build;
23 import android.system.Os;
24 import android.system.ErrnoException;
25 import android.util.Log;
26 import android.util.Pair;
27 import android.widget.TextView;
28 import androidx.test.InstrumentationRegistry;
29 import com.android.nn.benchmark.core.sl.ArmSupportLibraryDriverHandler;
30 import com.android.nn.benchmark.core.sl.MediaTekSupportLibraryDriverHandler;
31 import com.android.nn.benchmark.core.sl.QualcommSupportLibraryDriverHandler;
32 import com.android.nn.benchmark.core.sl.SupportLibraryDriverHandler;
33 import java.io.BufferedReader;
34 import java.io.File;
35 import java.io.FileNotFoundException;
36 import java.io.FileOutputStream;
37 import java.io.IOException;
38 import java.io.InputStream;
39 import java.io.InputStreamReader;
40 import java.io.OutputStream;
41 import java.util.ArrayList;
42 import java.util.Collections;
43 import java.util.function.Supplier;
44 import java.util.HashMap;
45 import java.util.List;
46 import java.util.Optional;
47 import java.util.Random;
48 import java.util.stream.Collectors;
49 import dalvik.system.BaseDexClassLoader;
50 import android.content.res.AssetFileDescriptor;
51 import android.os.ParcelFileDescriptor;
52 import android.os.ParcelFileDescriptor.AutoCloseInputStream;
53 import java.util.jar.JarFile;
54 import java.util.jar.JarEntry;
55 
56 public class NNTestBase implements AutoCloseable {
57     protected static final String TAG = "NN_TESTBASE";
58 
59     // Used to load the 'native-lib' library on application startup.
60     static {
61         System.loadLibrary("nnbenchmark_jni");
62     }
63 
64     // Does the device has any NNAPI accelerator?
65     // We only consider a real device, not 'nnapi-reference'.
hasAccelerator()66     public static native boolean hasAccelerator();
67 
68     /**
69      * Fills resultList with the name of the available NNAPI accelerators
70      *
71      * @return False if any error occurred, true otherwise
72      */
getAcceleratorNames(List<String> resultList)73     public static native boolean getAcceleratorNames(List<String> resultList);
hasNnApiDevice(String nnApiDeviceName)74     public static native boolean hasNnApiDevice(String nnApiDeviceName);
75 
initModel( String modelFileName, int tfliteBackend, boolean enableIntermediateTensorsDump, String nnApiDeviceName, boolean mmapModel, String nnApiCacheDir, long nnApiLibHandle)76     private synchronized native long initModel(
77             String modelFileName,
78             int tfliteBackend,
79             boolean enableIntermediateTensorsDump,
80             String nnApiDeviceName,
81             boolean mmapModel,
82             String nnApiCacheDir,
83             long nnApiLibHandle) throws NnApiDelegationFailure;
84 
destroyModel(long modelHandle)85     private synchronized native void destroyModel(long modelHandle);
86 
resizeInputTensors(long modelHandle, int[] inputShape)87     private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape);
88 
runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)89     private synchronized native boolean runBenchmark(long modelHandle,
90             List<InferenceInOutSequence> inOutList,
91             List<InferenceResult> resultList,
92             int inferencesSeqMaxCount,
93             float timeoutSec,
94             int flags);
95 
runCompilationBenchmark( long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec, boolean useNnapiSl)96     private synchronized native CompilationBenchmarkResult runCompilationBenchmark(
97         long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec,
98         boolean useNnapiSl);
99 
dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)100     private synchronized native void dumpAllLayers(
101             long modelHandle,
102             String dumpPath,
103             List<InferenceInOutSequence> inOutList);
104 
availableAcceleratorNames()105     public static List<String> availableAcceleratorNames() {
106         List<String> availableAccelerators = new ArrayList<>();
107         if (NNTestBase.getAcceleratorNames(availableAccelerators)) {
108             return availableAccelerators.stream().filter(
109                     acceleratorName -> !acceleratorName.equalsIgnoreCase(
110                             "nnapi-reference")).collect(Collectors.toList());
111         } else {
112             Log.e(TAG, "Unable to retrieve accelerator names!!");
113             return Collections.EMPTY_LIST;
114         }
115     }
116 
117     /** Discard inference output in inference results. */
118     public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
119     /**
120      * Do not expect golden outputs with inference inputs.
121      *
122      * Useful in cases where there's no straightforward golden output values
123      * for the benchmark. This will also skip calculating basic (golden
124      * output based) error metrics.
125      */
126     public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
127 
128 
129     /** Collect only 1 benchmark result every 10 **/
130     public static final int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2;
131 
132     protected Context mContext;
133     protected TextView mText;
134     private final String mModelName;
135     private final String mModelFile;
136     private long mModelHandle;
137     private final int[] mInputShape;
138     private final InferenceInOutSequence.FromAssets[] mInputOutputAssets;
139     private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets;
140     private final EvaluatorConfig mEvaluatorConfig;
141     private EvaluatorInterface mEvaluator;
142     private boolean mHasGoldenOutputs;
143     private TfLiteBackend mTfLiteBackend;
144     private boolean mEnableIntermediateTensorsDump = false;
145     private final int mMinSdkVersion;
146     private Optional<String> mNNApiDeviceName = Optional.empty();
147     private boolean mMmapModel = false;
148     // Path where the current model has been stored for execution
149     private String mTemporaryModelFilePath;
150     private boolean mSampleResults;
151 
152     // If set to true the test will look for the NNAPI SL binaries in the app resources,
153     // copy them into the app cache dir and configure the TfLite test to load NNAPI
154     // from the library.
155     private boolean mUseNnApiSupportLibrary = false;
156     private boolean mExtractNnApiSupportLibrary = false;
157     private String mNnApiSupportLibraryVendor = "";
158 
159     static final String USE_NNAPI_SL_PROPERTY = "useNnApiSupportLibrary";
160     static final String EXTRACT_NNAPI_SL_PROPERTY = "extractNnApiSupportLibrary";
161     static final String NNAPI_SL_VENDOR = "nnApiSupportLibraryVendor";
162 
getBooleanTestParameter(String key, boolean defaultValue)163     private static boolean getBooleanTestParameter(String key, boolean defaultValue) {
164       // All instrumentation arguments are passed as String so I have to convert the value here.
165       return Boolean.parseBoolean(
166           InstrumentationRegistry.getArguments().getString(key, "" + defaultValue));
167     }
168 
shouldUseNnApiSupportLibrary()169     public static boolean shouldUseNnApiSupportLibrary() {
170       return getBooleanTestParameter(USE_NNAPI_SL_PROPERTY, false);
171     }
172 
shouldExtractNnApiSupportLibrary()173     public static boolean shouldExtractNnApiSupportLibrary() {
174         return getBooleanTestParameter(EXTRACT_NNAPI_SL_PROPERTY, false);
175     }
176 
getNnApiSupportLibraryVendor()177     public static String getNnApiSupportLibraryVendor() {
178         return InstrumentationRegistry.getArguments().getString(NNAPI_SL_VENDOR);
179     }
180 
NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)181     public NNTestBase(String modelName, String modelFile, int[] inputShape,
182             InferenceInOutSequence.FromAssets[] inputOutputAssets,
183             InferenceInOutSequence.FromDataset[] inputOutputDatasets,
184             EvaluatorConfig evaluator, int minSdkVersion) {
185         if (inputOutputAssets == null && inputOutputDatasets == null) {
186             throw new IllegalArgumentException(
187                     "Neither inputOutputAssets or inputOutputDatasets given - no inputs");
188         }
189         if (inputOutputAssets != null && inputOutputDatasets != null) {
190             throw new IllegalArgumentException(
191                     "Both inputOutputAssets or inputOutputDatasets given. Only one" +
192                             "supported at once.");
193         }
194         mModelName = modelName;
195         mModelFile = modelFile;
196         mInputShape = inputShape;
197         mInputOutputAssets = inputOutputAssets;
198         mInputOutputDatasets = inputOutputDatasets;
199         mModelHandle = 0;
200         mEvaluatorConfig = evaluator;
201         mMinSdkVersion = minSdkVersion;
202         mSampleResults = false;
203     }
204 
setTfLiteBackend(TfLiteBackend tfLiteBackend)205     public void setTfLiteBackend(TfLiteBackend tfLiteBackend) {
206         mTfLiteBackend = tfLiteBackend;
207     }
208 
enableIntermediateTensorsDump()209     public void enableIntermediateTensorsDump() {
210         enableIntermediateTensorsDump(true);
211     }
212 
enableIntermediateTensorsDump(boolean value)213     public void enableIntermediateTensorsDump(boolean value) {
214         mEnableIntermediateTensorsDump = value;
215     }
216 
useNNApi()217     public void useNNApi() {
218       setTfLiteBackend(TfLiteBackend.NNAPI);
219     }
220 
setUseNnApiSupportLibrary(boolean value)221     public  void setUseNnApiSupportLibrary(boolean value) {mUseNnApiSupportLibrary = value;}
setExtractNnApiSupportLibrary(boolean value)222     public  void setExtractNnApiSupportLibrary(boolean value) {mExtractNnApiSupportLibrary = value;}
setNnApiSupportLibraryVendor(String value)223     public  void setNnApiSupportLibraryVendor(String value) {mNnApiSupportLibraryVendor = value;}
224 
setNNApiDeviceName(String value)225     public void setNNApiDeviceName(String value) {
226         if (mTfLiteBackend != TfLiteBackend.NNAPI) {
227             Log.e(TAG, "Setting device name has no effect when not using NNAPI");
228         }
229         mNNApiDeviceName = Optional.ofNullable(value);
230     }
231 
setMmapModel(boolean value)232     public void setMmapModel(boolean value) {
233         mMmapModel = value;
234     }
235 
setupModel(Context ipcxt)236     public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure {
237         mContext = ipcxt;
238         long nnApiLibHandle = 0;
239         if (mUseNnApiSupportLibrary) {
240           HashMap<String, Supplier<SupportLibraryDriverHandler>> vendors = new HashMap<>();
241           vendors.put("qc", () -> new QualcommSupportLibraryDriverHandler());
242           vendors.put("arm", () -> new ArmSupportLibraryDriverHandler());
243           vendors.put("mtk", () -> new MediaTekSupportLibraryDriverHandler());
244           Supplier<SupportLibraryDriverHandler> vendor = vendors.get(mNnApiSupportLibraryVendor);
245           if (vendor == null) {
246               throw new NnApiDelegationFailure(String
247                   .format("NNAPI SL vendor is invalid '%s', expected one of %s.",
248                       mNnApiSupportLibraryVendor, vendors.keySet().toString()));
249           }
250           SupportLibraryDriverHandler slHandler = vendor.get();
251           nnApiLibHandle = slHandler.getOrLoadNnApiSlHandle(mContext, mExtractNnApiSupportLibrary);
252           if (nnApiLibHandle == 0) {
253             Log.e(TAG, String
254                 .format("Unable to find NNAPI SL entry point '%s' in embedded libraries path.",
255                     SupportLibraryDriverHandler.NNAPI_SL_LIB_NAME));
256             throw new NnApiDelegationFailure(String
257                 .format("Unable to find NNAPI SL entry point '%s' in embedded libraries path.",
258                     SupportLibraryDriverHandler.NNAPI_SL_LIB_NAME));
259           }
260         }
261         if (mTemporaryModelFilePath != null) {
262             deleteOrWarn(mTemporaryModelFilePath);
263         }
264         mTemporaryModelFilePath = copyAssetToFile();
265         String nnApiCacheDir = mContext.getCodeCacheDir().toString();
266         mModelHandle = initModel(
267                 mTemporaryModelFilePath, mTfLiteBackend.ordinal(), mEnableIntermediateTensorsDump,
268                 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir, nnApiLibHandle);
269         if (mModelHandle == 0) {
270             Log.e(TAG, "Failed to init the model");
271             return false;
272         }
273         if (!resizeInputTensors(mModelHandle, mInputShape)) {
274             return false;
275         }
276 
277         if (mEvaluatorConfig != null) {
278             mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets());
279         }
280         return true;
281     }
282 
getTestInfo()283     public String getTestInfo() {
284         return mModelName;
285     }
286 
getEvaluator()287     public EvaluatorInterface getEvaluator() {
288         return mEvaluator;
289     }
290 
checkSdkVersion()291     public void checkSdkVersion() throws UnsupportedSdkException {
292         if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) {
293             throw new UnsupportedSdkException("SDK version not supported. Mininum required: " +
294                     mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT);
295         }
296     }
297 
deleteOrWarn(String path)298     private void deleteOrWarn(String path) {
299         if (!new File(path).delete()) {
300             Log.w(TAG, String.format(
301                     "Unable to delete file '%s'. This might cause device to run out of space.",
302                     path));
303         }
304     }
305 
306 
getInputOutputAssets()307     private List<InferenceInOutSequence> getInputOutputAssets() throws IOException {
308         // TODO: Caching, don't read inputs for every inference
309         List<InferenceInOutSequence> inOutList =
310                 getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets);
311 
312         Boolean lastGolden = null;
313         for (InferenceInOutSequence sequence : inOutList) {
314             mHasGoldenOutputs = sequence.hasGoldenOutput();
315             if (lastGolden == null) {
316                 lastGolden = mHasGoldenOutputs;
317             } else {
318                 if (lastGolden != mHasGoldenOutputs) {
319                     throw new IllegalArgumentException(
320                             "Some inputs for " + mModelName + " have outputs while some don't.");
321                 }
322             }
323         }
324         return inOutList;
325     }
326 
getInputOutputAssets(Context context, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets)327     public static List<InferenceInOutSequence> getInputOutputAssets(Context context,
328             InferenceInOutSequence.FromAssets[] inputOutputAssets,
329             InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException {
330         // TODO: Caching, don't read inputs for every inference
331         List<InferenceInOutSequence> inOutList = new ArrayList<>();
332         if (inputOutputAssets != null) {
333             for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) {
334                 inOutList.add(ioAsset.readAssets(context.getAssets()));
335             }
336         }
337         if (inputOutputDatasets != null) {
338             for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) {
339                 inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir()));
340             }
341         }
342 
343         return inOutList;
344     }
345 
getDefaultFlags()346     public int getDefaultFlags() {
347         int flags = 0;
348         if (!mHasGoldenOutputs) {
349             flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT;
350         }
351         if (mEvaluator == null) {
352             flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT;
353         }
354         // For very long tests we will collect only a sample of the results
355         if (mSampleResults) {
356             flags = flags | FLAG_SAMPLE_BENCHMARK_RESULTS;
357         }
358         return flags;
359     }
360 
dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)361     public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)
362             throws IOException {
363         if (!dumpDir.exists() || !dumpDir.isDirectory()) {
364             throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory");
365         }
366         if (!mEnableIntermediateTensorsDump) {
367             throw new IllegalStateException("mEnableIntermediateTensorsDump is " +
368                     "set to false, impossible to proceed");
369         }
370 
371         List<InferenceInOutSequence> ios = getInputOutputAssets();
372         dumpAllLayers(mModelHandle, dumpDir.toString(),
373                 ios.subList(inputAssetIndex, inputAssetSize));
374     }
375 
runInferenceOnce()376     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce()
377             throws IOException, BenchmarkException {
378         List<InferenceInOutSequence> ios = getInputOutputAssets();
379         int flags = getDefaultFlags();
380         Pair<List<InferenceInOutSequence>, List<InferenceResult>> output =
381                 runBenchmark(ios, 1, Float.MAX_VALUE, flags);
382         return output;
383     }
384 
runBenchmark(float timeoutSec)385     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec)
386             throws IOException, BenchmarkException {
387         // Run as many as possible before timeout.
388         int flags = getDefaultFlags();
389         return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags);
390     }
391 
392     /** Run through whole input set (once or multiple times). */
runBenchmarkCompleteInputSet( int minInferences, float timeoutSec)393     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet(
394             int minInferences,
395             float timeoutSec)
396             throws IOException, BenchmarkException {
397         int flags = getDefaultFlags();
398         List<InferenceInOutSequence> ios = getInputOutputAssets();
399         int setInferences = 0;
400         for (InferenceInOutSequence iosSeq : ios) {
401             setInferences += iosSeq.size();
402         }
403         int setRepeat = (minInferences + setInferences - 1) / setInferences; // ceil.
404         int totalSequenceInferencesCount = ios.size() * setRepeat;
405         int expectedResults = setInferences * setRepeat;
406 
407         Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
408                 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
409                         flags);
410         if (result.second.size() != expectedResults) {
411             // We reached a timeout or failed to evaluate whole set for other reason, abort.
412             @SuppressLint("DefaultLocale")
413             final String errorMsg = String.format(
414                     "Failed to evaluate complete input set, in %f seconds expected: %d, received:"
415                             + " %d",
416                     timeoutSec, expectedResults, result.second.size());
417             Log.w(TAG, errorMsg);
418             throw new IllegalStateException(errorMsg);
419         }
420         return result;
421     }
422 
runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)423     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(
424             List<InferenceInOutSequence> inOutList,
425             int inferencesSeqMaxCount,
426             float timeoutSec,
427             int flags)
428             throws IOException, BenchmarkException {
429         if (mModelHandle == 0) {
430             throw new UnsupportedModelException("Unsupported model");
431         }
432         List<InferenceResult> resultList = new ArrayList<>();
433         if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount,
434                 timeoutSec, flags)) {
435             throw new BenchmarkException("Failed to run benchmark");
436         }
437         return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>(
438                 inOutList, resultList);
439     }
440 
runCompilationBenchmark(float warmupTimeoutSec, float runTimeoutSec, int maxIterations)441     public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec,
442             float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException {
443         if (mModelHandle == 0) {
444             throw new UnsupportedModelException("Unsupported model");
445         }
446         CompilationBenchmarkResult result = runCompilationBenchmark(
447             mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec,
448             shouldUseNnApiSupportLibrary());
449         if (result == null) {
450             throw new BenchmarkException("Failed to run compilation benchmark");
451         }
452         return result;
453     }
454 
destroy()455     public void destroy() {
456         if (mModelHandle != 0) {
457             destroyModel(mModelHandle);
458             mModelHandle = 0;
459         }
460         if (mTemporaryModelFilePath != null) {
461             deleteOrWarn(mTemporaryModelFilePath);
462             mTemporaryModelFilePath = null;
463         }
464     }
465 
466     private final Random mRandom = new Random(System.currentTimeMillis());
467 
468     // We need to copy it to cache dir, so that TFlite can load it directly.
copyAssetToFile()469     private String copyAssetToFile() throws IOException {
470         @SuppressLint("DefaultLocale")
471         String outFileName =
472                 String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(),
473                         mModelFile,
474                         Thread.currentThread().getId(), mRandom.nextInt(10000));
475 
476         copyAssetToFile(mContext, mModelFile + ".tflite", outFileName);
477         return outFileName;
478     }
479 
copyModelToFile(Context context, String modelFileName, File targetFile)480     public static boolean copyModelToFile(Context context, String modelFileName, File targetFile)
481             throws IOException {
482         if (!targetFile.exists() && !targetFile.createNewFile()) {
483             Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath()));
484             return false;
485         }
486         NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath());
487         return true;
488     }
489 
copyAssetToFile(Context context, String modelAssetName, String targetPath)490     public static void copyAssetToFile(Context context, String modelAssetName, String targetPath)
491             throws IOException {
492         AssetManager assetManager = context.getAssets();
493         try {
494             File outFile = new File(targetPath);
495 
496             try (InputStream in = assetManager.open(modelAssetName);
497                  FileOutputStream out = new FileOutputStream(outFile)) {
498                 copyFull(in, out);
499             }
500         } catch (IOException e) {
501             Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e);
502             throw e;
503         }
504     }
505 
copyFull(InputStream in, OutputStream out)506     public static void copyFull(InputStream in, OutputStream out) throws IOException {
507         byte[] byteBuffer = new byte[1024];
508         int readBytes = -1;
509         while ((readBytes = in.read(byteBuffer)) != -1) {
510             out.write(byteBuffer, 0, readBytes);
511         }
512     }
513 
514     @Override
close()515     public void close() {
516         destroy();
517     }
518 
setSampleResult(boolean sampleResults)519     public void setSampleResult(boolean sampleResults) {
520         this.mSampleResults = sampleResults;
521     }
522 }
523