1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import java.util.ArrayList;
20 import java.util.List;
21 import java.util.concurrent.atomic.AtomicReference;
22 import androidx.test.InstrumentationRegistry;
23 import java.util.stream.Collectors;
24 import android.util.Log;
25 
26 /** Information about available benchmarking models */
27 public class TestModels {
28     /** Entry for a single benchmarking model */
29     public static class TestModelEntry {
30         /** Unique model name, used to find benchmark data */
31         public final String mModelName;
32 
33         /** Expected inference performance in seconds */
34         public final float mBaselineSec;
35 
36         /** Shape of input data */
37         public final int[] mInputShape;
38 
39         /** File pair asset input/output pairs */
40         public final InferenceInOutSequence.FromAssets[] mInOutAssets;
41 
42         /** Dataset inputs */
43         public final InferenceInOutSequence.FromDataset[] mInOutDatasets;
44 
45         /** Readable name for test output */
46         public final String mTestName;
47 
48         /** Name of model file, so that the same file can be reused */
49         public final String mModelFile;
50 
51         /** The evaluator to use for validating the results. */
52         public final EvaluatorConfig mEvaluator;
53 
54         /** Min SDK version that the model can run on. */
55         public final int mMinSdkVersion;
56 
57         /* Number of bytes per input data entry  */
58         public final int mInDataSize;
59 
TestModelEntry(String modelName, float baselineSec, int[] inputShape, InferenceInOutSequence.FromAssets[] inOutAssets, InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, String modelFile, EvaluatorConfig evaluator, int minSdkVersion, int inDataSize)60         public TestModelEntry(String modelName, float baselineSec, int[] inputShape,
61                 InferenceInOutSequence.FromAssets[] inOutAssets,
62                 InferenceInOutSequence.FromDataset[] inOutDatasets, String testName,
63                 String modelFile,
64                 EvaluatorConfig evaluator, int minSdkVersion, int inDataSize) {
65             mModelName = modelName;
66             mBaselineSec = baselineSec;
67             mInputShape = inputShape;
68             mInOutAssets = inOutAssets;
69             mInOutDatasets = inOutDatasets;
70             mTestName = testName;
71             mModelFile = modelFile;
72             mEvaluator = evaluator;
73             mMinSdkVersion = minSdkVersion;
74             mInDataSize = inDataSize;
75         }
76 
77         // Used by VTS tests.
createNNTestBase()78         public NNTestBase createNNTestBase() {
79             return new NNTestBase(mModelName, mModelFile, mInputShape, mInOutAssets, mInOutDatasets,
80                     mEvaluator, mMinSdkVersion);
81         }
82 
createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump)83         public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump) {
84             return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false,
85                 /*useNnApiSl=*/false, /*extractNnApiSl=*/false, /*nnApiSlVendor=*/"");
86         }
87 
88         // Used by CTS tests.
createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump)89         public NNTestBase createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump) {
90             TfLiteBackend tfLiteBackend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU;
91             return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump,
92                 /*mmapModel=*/false, /*useNnApiSl=*/false, /*extractNnApiSl=*/false, /*nnApiSlVendor=*/"");
93         }
94 
createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump, boolean mmapModel, boolean useNnApiSl, boolean extractNnApiSl, String nnApiSlVendor)95         public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump,
96                 boolean mmapModel, boolean useNnApiSl, boolean extractNnApiSl, String nnApiSlVendor) {
97             NNTestBase test = createNNTestBase();
98             test.setTfLiteBackend(tfLiteBackend);
99             test.enableIntermediateTensorsDump(enableIntermediateTensorsDump);
100             test.setMmapModel(mmapModel);
101             test.setUseNnApiSupportLibrary(useNnApiSl);
102             test.setExtractNnApiSupportLibrary(extractNnApiSl);
103             test.setNnApiSupportLibraryVendor(nnApiSlVendor);
104             return test;
105         }
106 
toString()107         public String toString() {
108             return mModelName;
109         }
110 
getTestName()111         public String getTestName() {
112             return mTestName;
113         }
114 
115 
withDisabledEvaluation()116         public TestModelEntry withDisabledEvaluation() {
117             return new TestModelEntry(mModelName, mBaselineSec, mInputShape, mInOutAssets,
118                     mInOutDatasets, mTestName, mModelFile,
119                     null, // Disable evaluation.
120                     mMinSdkVersion, mInDataSize);
121         }
122     }
123 
124     static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>();
125     static private final AtomicReference<List<TestModelEntry>> frozenEntries =
126             new AtomicReference<>(null);
127 
128 
129     /** Add new benchmark model. */
registerModel(TestModelEntry model)130     static public void registerModel(TestModelEntry model) {
131         if (frozenEntries.get() != null) {
132             throw new IllegalStateException("Can't register new models after its list is frozen");
133         }
134         sTestModelEntryList.add(model);
135     }
136 
isListFrozen()137     public static boolean isListFrozen() {
138         return frozenEntries.get() != null;
139     }
140 
141     static final String MODEL_FILTER_PROPERTY = "nnBenchmarkModelFilter";
142 
getModelFilterRegex()143     public static String getModelFilterRegex() {
144         // All instrumentation arguments are passed as String so I have to convert the value here.
145         return InstrumentationRegistry.getArguments().getString(MODEL_FILTER_PROPERTY, "");
146     }
147 
148     /**
149      * Returns the list of models eventually by a user specified instrumentation filter regex.
150      */
modelsList()151     static public List<TestModelEntry> modelsList() {
152         return modelsList(getModelFilterRegex());
153     }
154 
155     /**
156      * Returns the list of models eventually by a user specified instrumentation filter.
157      */
modelsList(String modelFilterRegex)158     static public List<TestModelEntry> modelsList(String modelFilterRegex) {
159         if (modelFilterRegex == null || modelFilterRegex.isEmpty()) {
160             Log.i("NN_BENCHMARK", "No model filter, returning all models");
161             return fullModelsList();
162         }
163         Log.i("NN_BENCHMARK", "Filtering model with filter " + modelFilterRegex);
164         List<TestModelEntry> result = fullModelsList().stream()
165                 .filter( modelEntry ->
166                     modelEntry.mModelName.matches(modelFilterRegex)
167                 )
168                 .collect(Collectors.toList());
169 
170         Log.i("NN_BENCHMARK", "Returning models: " + result);
171 
172         return result;
173     }
174 
175     /**
176      * Fetch list of test models.
177      *
178      * If this method was called at least once, then it's impossible to register new models.
179      */
fullModelsList()180     static public List<TestModelEntry> fullModelsList() {
181         frozenEntries.compareAndSet(null, sTestModelEntryList);
182         return frozenEntries.get();
183     }
184 
185     /** Fetch model by its name. */
getModelByName(String name)186     static public TestModelEntry getModelByName(String name) {
187         for (TestModelEntry testModelEntry : modelsList()) {
188             if (testModelEntry.mModelName.equals(name)) {
189                 return testModelEntry;
190             }
191         }
192         throw new IllegalArgumentException("Unknown TestModelEntry named " + name);
193     }
194 
195 }
196