1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.app;
18 
19 import android.content.Context;
20 import android.util.Log;
21 
22 import androidx.test.InstrumentationRegistry;
23 
24 import com.android.nn.benchmark.core.BenchmarkException;
25 import com.android.nn.benchmark.core.BenchmarkResult;
26 import com.android.nn.benchmark.core.NNTestBase;
27 import com.android.nn.benchmark.core.NnApiDelegationFailure;
28 import com.android.nn.benchmark.core.Processor;
29 import com.android.nn.benchmark.core.TestModels;
30 import com.android.nn.benchmark.core.TfLiteBackend;
31 
32 import java.io.IOException;
33 import java.util.ArrayList;
34 import java.util.Arrays;
35 import java.util.Collections;
36 import java.util.List;
37 import java.util.Optional;
38 import java.util.concurrent.Callable;
39 import java.util.concurrent.atomic.AtomicBoolean;
40 import java.util.stream.Collectors;
41 
42 public interface AcceleratorSpecificTestSupport {
43     String TAG = "AcceleratorTest";
44 
findTestModelRunningOnAccelerator( Context context, String acceleratorName)45     static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator(
46             Context context, String acceleratorName) throws NnApiDelegationFailure {
47         for (TestModels.TestModelEntry model : TestModels.modelsList()) {
48             if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
49                 return Optional.of(model);
50             }
51         }
52         return Optional.empty();
53     }
54 
findAllTestModelsRunningOnAccelerator( Context context, String acceleratorName)55     static List<TestModels.TestModelEntry> findAllTestModelsRunningOnAccelerator(
56             Context context, String acceleratorName) throws NnApiDelegationFailure {
57         List<TestModels.TestModelEntry> result = new ArrayList<>();
58         for (TestModels.TestModelEntry model : TestModels.modelsList()) {
59             if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
60                 result.add(model);
61             }
62         }
63         return result;
64     }
65 
ramdomInRange(long min, long max)66     default long ramdomInRange(long min, long max) {
67         return min + (long) (Math.random() * (max - min));
68     }
69 
getTestParameter(String key, String defaultValue)70     static String getTestParameter(String key, String defaultValue) {
71         return InstrumentationRegistry.getArguments().getString(key, defaultValue);
72     }
73 
getBooleanTestParameter(String key, boolean defaultValue)74     static boolean getBooleanTestParameter(String key, boolean defaultValue) {
75         // All instrumentation arguments are passed as String so I have to convert the value here.
76         return Boolean.parseBoolean(
77                 InstrumentationRegistry.getArguments().getString(key, "" + defaultValue));
78     }
79 
80     static final String ACCELERATOR_FILTER_PROPERTY = "nnCrashtestDeviceFilter";
81     static final String INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY =
82             "nnCrashtestIncludeNnapiReference";
83 
getTargetAcceleratorNames()84     static List<String> getTargetAcceleratorNames() {
85         List<String> accelerators = new ArrayList<>();
86         String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, ".+");
87         accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter(
88                 name -> name.matches(acceleratorFilter)).collect(
89                 Collectors.toList()));
90         if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) {
91             accelerators.add(null); // running tests with no specified target accelerator too
92         }
93         return accelerators;
94     }
95 
96     // This method returns an empty list if no accelerator name has been specified.
getOptionalTargetAcceleratorNames()97     static List<String> getOptionalTargetAcceleratorNames() {
98         List<String> accelerators = new ArrayList<>();
99         String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, "");
100         if (acceleratorFilter.isEmpty()) {
101             return Collections.emptyList();
102         }
103         accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter(
104             name -> name.matches(acceleratorFilter)).collect(
105             Collectors.toList()));
106         if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) {
107             accelerators.add(null); // running tests with no specified target accelerator too
108         }
109         return accelerators;
110     }
111 
perAcceleratorTestConfig(List<Object[]> testConfig)112     static List<Object[]> perAcceleratorTestConfig(List<Object[]> testConfig) {
113         return testConfig.stream()
114                 .flatMap(currConfigurationParams -> getTargetAcceleratorNames().stream().map(
115                         accelerator -> {
116                             Object[] result =
117                                     Arrays.copyOf(currConfigurationParams,
118                                             currConfigurationParams.length + 1);
119                             result[currConfigurationParams.length] = accelerator;
120                             return result;
121                         }))
122                 .collect(Collectors.toList());
123     }
124 
125     // Generates a per-accelerator list of test configurations if an accelerator filter has been
126     // specified. Will return the origin list with an extra `null` parameter for the accelerator
127     // name if not.
maybeAddAcceleratorsToTestConfig(List<Object[]> testConfig)128     static List<Object[]> maybeAddAcceleratorsToTestConfig(List<Object[]> testConfig) {
129         return testConfig.stream()
130             .flatMap(currConfigurationParams -> {
131                 List<String> accelerators = getOptionalTargetAcceleratorNames();
132                 if (accelerators.isEmpty()) {
133                     accelerators = Collections.singletonList((String)null);
134                 }
135                 return accelerators.stream().map(
136                     accelerator -> {
137                         Object[] result =
138                             Arrays.copyOf(currConfigurationParams,
139                                 currConfigurationParams.length + 1);
140                         result[currConfigurationParams.length] = accelerator;
141                         return result;
142                     });
143             })
144             .collect(Collectors.toList());
145     }
146 
147     class DriverLivenessChecker implements Callable<Boolean> {
148         final Processor mProcessor;
149         private final AtomicBoolean mRun = new AtomicBoolean(true);
150         private final TestModels.TestModelEntry mTestModelEntry;
151 
152         public DriverLivenessChecker(Context context, String acceleratorName,
153                 TestModels.TestModelEntry testModelEntry) {
154             mProcessor = new Processor(context,
155                     new Processor.Callback() {
156                         @Override
157                         public void onBenchmarkFinish(boolean ok) {
158                         }
159 
160                         @Override
161                         public void onStatusUpdate(int testNumber, int numTests, String modelName) {
162                         }
163                     }, new int[0]);
164             mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
165             mProcessor.setCompleteInputSet(false);
166             mProcessor.setNnApiAcceleratorName(acceleratorName);
167             mProcessor.setUseNnApiSupportLibrary(NNTestBase.shouldUseNnApiSupportLibrary());
168             mProcessor.setExtractNnApiSupportLibrary(NNTestBase.shouldExtractNnApiSupportLibrary());
169             mProcessor.setNnApiSupportLibraryVendor(NNTestBase.getNnApiSupportLibraryVendor());
170             mTestModelEntry = testModelEntry;
171         }
172 
173         public void stop() {
174             mRun.set(false);
175         }
176 
177         @Override
178         public Boolean call() throws Exception {
179             while (mRun.get()) {
180                 try {
181                     BenchmarkResult modelExecutionResult = mProcessor.getInstrumentationResult(
182                             mTestModelEntry, 0, 3);
183                     if (modelExecutionResult.hasBenchmarkError()) {
184                         Log.e(TAG, String.format("Benchmark failed with message %s",
185                                 modelExecutionResult.getBenchmarkError()));
186                         return false;
187                     }
188                 } catch (IOException | BenchmarkException e) {
189                     Log.e(TAG, String.format("Error running model %s", mTestModelEntry.mModelName), e);
190                     return false;
191                 }
192             }
193 
194             return true;
195         }
196     }
197 }
198