1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import java.util.ArrayList; 20 import java.util.List; 21 import java.util.concurrent.atomic.AtomicReference; 22 import androidx.test.InstrumentationRegistry; 23 import java.util.stream.Collectors; 24 import android.util.Log; 25 26 /** Information about available benchmarking models */ 27 public class TestModels { 28 /** Entry for a single benchmarking model */ 29 public static class TestModelEntry { 30 /** Unique model name, used to find benchmark data */ 31 public final String mModelName; 32 33 /** Expected inference performance in seconds */ 34 public final float mBaselineSec; 35 36 /** Shape of input data */ 37 public final int[] mInputShape; 38 39 /** File pair asset input/output pairs */ 40 public final InferenceInOutSequence.FromAssets[] mInOutAssets; 41 42 /** Dataset inputs */ 43 public final InferenceInOutSequence.FromDataset[] mInOutDatasets; 44 45 /** Readable name for test output */ 46 public final String mTestName; 47 48 /** Name of model file, so that the same file can be reused */ 49 public final String mModelFile; 50 51 /** The evaluator to use for validating the results. */ 52 public final EvaluatorConfig mEvaluator; 53 54 /** Min SDK version that the model can run on. */ 55 public final int mMinSdkVersion; 56 57 /* Number of bytes per input data entry */ 58 public final int mInDataSize; 59 TestModelEntry(String modelName, float baselineSec, int[] inputShape, InferenceInOutSequence.FromAssets[] inOutAssets, InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, String modelFile, EvaluatorConfig evaluator, int minSdkVersion, int inDataSize)60 public TestModelEntry(String modelName, float baselineSec, int[] inputShape, 61 InferenceInOutSequence.FromAssets[] inOutAssets, 62 InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, 63 String modelFile, 64 EvaluatorConfig evaluator, int minSdkVersion, int inDataSize) { 65 mModelName = modelName; 66 mBaselineSec = baselineSec; 67 mInputShape = inputShape; 68 mInOutAssets = inOutAssets; 69 mInOutDatasets = inOutDatasets; 70 mTestName = testName; 71 mModelFile = modelFile; 72 mEvaluator = evaluator; 73 mMinSdkVersion = minSdkVersion; 74 mInDataSize = inDataSize; 75 } 76 77 // Used by VTS tests. createNNTestBase()78 public NNTestBase createNNTestBase() { 79 return new NNTestBase(mModelName, mModelFile, mInputShape, mInOutAssets, mInOutDatasets, 80 mEvaluator, mMinSdkVersion); 81 } 82 createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump)83 public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump) { 84 return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false, 85 /*useNnApiSl=*/false, /*extractNnApiSl=*/false, /*nnApiSlVendor=*/""); 86 } 87 88 // Used by CTS tests. createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump)89 public NNTestBase createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump) { 90 TfLiteBackend tfLiteBackend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU; 91 return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, 92 /*mmapModel=*/false, /*useNnApiSl=*/false, /*extractNnApiSl=*/false, /*nnApiSlVendor=*/""); 93 } 94 createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump, boolean mmapModel, boolean useNnApiSl, boolean extractNnApiSl, String nnApiSlVendor)95 public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump, 96 boolean mmapModel, boolean useNnApiSl, boolean extractNnApiSl, String nnApiSlVendor) { 97 NNTestBase test = createNNTestBase(); 98 test.setTfLiteBackend(tfLiteBackend); 99 test.enableIntermediateTensorsDump(enableIntermediateTensorsDump); 100 test.setMmapModel(mmapModel); 101 test.setUseNnApiSupportLibrary(useNnApiSl); 102 test.setExtractNnApiSupportLibrary(extractNnApiSl); 103 test.setNnApiSupportLibraryVendor(nnApiSlVendor); 104 return test; 105 } 106 toString()107 public String toString() { 108 return mModelName; 109 } 110 getTestName()111 public String getTestName() { 112 return mTestName; 113 } 114 115 withDisabledEvaluation()116 public TestModelEntry withDisabledEvaluation() { 117 return new TestModelEntry(mModelName, mBaselineSec, mInputShape, mInOutAssets, 118 mInOutDatasets, mTestName, mModelFile, 119 null, // Disable evaluation. 120 mMinSdkVersion, mInDataSize); 121 } 122 } 123 124 static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>(); 125 static private final AtomicReference<List<TestModelEntry>> frozenEntries = 126 new AtomicReference<>(null); 127 128 129 /** Add new benchmark model. */ registerModel(TestModelEntry model)130 static public void registerModel(TestModelEntry model) { 131 if (frozenEntries.get() != null) { 132 throw new IllegalStateException("Can't register new models after its list is frozen"); 133 } 134 sTestModelEntryList.add(model); 135 } 136 isListFrozen()137 public static boolean isListFrozen() { 138 return frozenEntries.get() != null; 139 } 140 141 static final String MODEL_FILTER_PROPERTY = "nnBenchmarkModelFilter"; 142 getModelFilterRegex()143 public static String getModelFilterRegex() { 144 // All instrumentation arguments are passed as String so I have to convert the value here. 145 return InstrumentationRegistry.getArguments().getString(MODEL_FILTER_PROPERTY, ""); 146 } 147 148 /** 149 * Returns the list of models eventually by a user specified instrumentation filter regex. 150 */ modelsList()151 static public List<TestModelEntry> modelsList() { 152 return modelsList(getModelFilterRegex()); 153 } 154 155 /** 156 * Returns the list of models eventually by a user specified instrumentation filter. 157 */ modelsList(String modelFilterRegex)158 static public List<TestModelEntry> modelsList(String modelFilterRegex) { 159 if (modelFilterRegex == null || modelFilterRegex.isEmpty()) { 160 Log.i("NN_BENCHMARK", "No model filter, returning all models"); 161 return fullModelsList(); 162 } 163 Log.i("NN_BENCHMARK", "Filtering model with filter " + modelFilterRegex); 164 List<TestModelEntry> result = fullModelsList().stream() 165 .filter( modelEntry -> 166 modelEntry.mModelName.matches(modelFilterRegex) 167 ) 168 .collect(Collectors.toList()); 169 170 Log.i("NN_BENCHMARK", "Returning models: " + result); 171 172 return result; 173 } 174 175 /** 176 * Fetch list of test models. 177 * 178 * If this method was called at least once, then it's impossible to register new models. 179 */ fullModelsList()180 static public List<TestModelEntry> fullModelsList() { 181 frozenEntries.compareAndSet(null, sTestModelEntryList); 182 return frozenEntries.get(); 183 } 184 185 /** Fetch model by its name. */ getModelByName(String name)186 static public TestModelEntry getModelByName(String name) { 187 for (TestModelEntry testModelEntry : modelsList()) { 188 if (testModelEntry.mModelName.equals(name)) { 189 return testModelEntry; 190 } 191 } 192 throw new IllegalArgumentException("Unknown TestModelEntry named " + name); 193 } 194 195 } 196