1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.app; 18 19 import android.content.Context; 20 import android.util.Log; 21 22 import androidx.test.InstrumentationRegistry; 23 24 import com.android.nn.benchmark.core.BenchmarkException; 25 import com.android.nn.benchmark.core.BenchmarkResult; 26 import com.android.nn.benchmark.core.NNTestBase; 27 import com.android.nn.benchmark.core.NnApiDelegationFailure; 28 import com.android.nn.benchmark.core.Processor; 29 import com.android.nn.benchmark.core.TestModels; 30 import com.android.nn.benchmark.core.TfLiteBackend; 31 32 import java.io.IOException; 33 import java.util.ArrayList; 34 import java.util.Arrays; 35 import java.util.Collections; 36 import java.util.List; 37 import java.util.Optional; 38 import java.util.concurrent.Callable; 39 import java.util.concurrent.atomic.AtomicBoolean; 40 import java.util.stream.Collectors; 41 42 public interface AcceleratorSpecificTestSupport { 43 String TAG = "AcceleratorTest"; 44 findTestModelRunningOnAccelerator( Context context, String acceleratorName)45 static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator( 46 Context context, String acceleratorName) throws NnApiDelegationFailure { 47 for (TestModels.TestModelEntry model : TestModels.modelsList()) { 48 if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) { 49 return Optional.of(model); 50 } 51 } 52 return Optional.empty(); 53 } 54 findAllTestModelsRunningOnAccelerator( Context context, String acceleratorName)55 static List<TestModels.TestModelEntry> findAllTestModelsRunningOnAccelerator( 56 Context context, String acceleratorName) throws NnApiDelegationFailure { 57 List<TestModels.TestModelEntry> result = new ArrayList<>(); 58 for (TestModels.TestModelEntry model : TestModels.modelsList()) { 59 if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) { 60 result.add(model); 61 } 62 } 63 return result; 64 } 65 ramdomInRange(long min, long max)66 default long ramdomInRange(long min, long max) { 67 return min + (long) (Math.random() * (max - min)); 68 } 69 getTestParameter(String key, String defaultValue)70 static String getTestParameter(String key, String defaultValue) { 71 return InstrumentationRegistry.getArguments().getString(key, defaultValue); 72 } 73 getBooleanTestParameter(String key, boolean defaultValue)74 static boolean getBooleanTestParameter(String key, boolean defaultValue) { 75 // All instrumentation arguments are passed as String so I have to convert the value here. 76 return Boolean.parseBoolean( 77 InstrumentationRegistry.getArguments().getString(key, "" + defaultValue)); 78 } 79 80 static final String ACCELERATOR_FILTER_PROPERTY = "nnCrashtestDeviceFilter"; 81 static final String INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY = 82 "nnCrashtestIncludeNnapiReference"; 83 getTargetAcceleratorNames()84 static List<String> getTargetAcceleratorNames() { 85 List<String> accelerators = new ArrayList<>(); 86 String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, ".+"); 87 accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter( 88 name -> name.matches(acceleratorFilter)).collect( 89 Collectors.toList())); 90 if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) { 91 accelerators.add(null); // running tests with no specified target accelerator too 92 } 93 return accelerators; 94 } 95 96 // This method returns an empty list if no accelerator name has been specified. getOptionalTargetAcceleratorNames()97 static List<String> getOptionalTargetAcceleratorNames() { 98 List<String> accelerators = new ArrayList<>(); 99 String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, ""); 100 if (acceleratorFilter.isEmpty()) { 101 return Collections.emptyList(); 102 } 103 accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter( 104 name -> name.matches(acceleratorFilter)).collect( 105 Collectors.toList())); 106 if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) { 107 accelerators.add(null); // running tests with no specified target accelerator too 108 } 109 return accelerators; 110 } 111 perAcceleratorTestConfig(List<Object[]> testConfig)112 static List<Object[]> perAcceleratorTestConfig(List<Object[]> testConfig) { 113 return testConfig.stream() 114 .flatMap(currConfigurationParams -> getTargetAcceleratorNames().stream().map( 115 accelerator -> { 116 Object[] result = 117 Arrays.copyOf(currConfigurationParams, 118 currConfigurationParams.length + 1); 119 result[currConfigurationParams.length] = accelerator; 120 return result; 121 })) 122 .collect(Collectors.toList()); 123 } 124 125 // Generates a per-accelerator list of test configurations if an accelerator filter has been 126 // specified. Will return the origin list with an extra `null` parameter for the accelerator 127 // name if not. maybeAddAcceleratorsToTestConfig(List<Object[]> testConfig)128 static List<Object[]> maybeAddAcceleratorsToTestConfig(List<Object[]> testConfig) { 129 return testConfig.stream() 130 .flatMap(currConfigurationParams -> { 131 List<String> accelerators = getOptionalTargetAcceleratorNames(); 132 if (accelerators.isEmpty()) { 133 accelerators = Collections.singletonList((String)null); 134 } 135 return accelerators.stream().map( 136 accelerator -> { 137 Object[] result = 138 Arrays.copyOf(currConfigurationParams, 139 currConfigurationParams.length + 1); 140 result[currConfigurationParams.length] = accelerator; 141 return result; 142 }); 143 }) 144 .collect(Collectors.toList()); 145 } 146 147 class DriverLivenessChecker implements Callable<Boolean> { 148 final Processor mProcessor; 149 private final AtomicBoolean mRun = new AtomicBoolean(true); 150 private final TestModels.TestModelEntry mTestModelEntry; 151 152 public DriverLivenessChecker(Context context, String acceleratorName, 153 TestModels.TestModelEntry testModelEntry) { 154 mProcessor = new Processor(context, 155 new Processor.Callback() { 156 @Override 157 public void onBenchmarkFinish(boolean ok) { 158 } 159 160 @Override 161 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 162 } 163 }, new int[0]); 164 mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI); 165 mProcessor.setCompleteInputSet(false); 166 mProcessor.setNnApiAcceleratorName(acceleratorName); 167 mProcessor.setUseNnApiSupportLibrary(NNTestBase.shouldUseNnApiSupportLibrary()); 168 mProcessor.setExtractNnApiSupportLibrary(NNTestBase.shouldExtractNnApiSupportLibrary()); 169 mProcessor.setNnApiSupportLibraryVendor(NNTestBase.getNnApiSupportLibraryVendor()); 170 mTestModelEntry = testModelEntry; 171 } 172 173 public void stop() { 174 mRun.set(false); 175 } 176 177 @Override 178 public Boolean call() throws Exception { 179 while (mRun.get()) { 180 try { 181 BenchmarkResult modelExecutionResult = mProcessor.getInstrumentationResult( 182 mTestModelEntry, 0, 3); 183 if (modelExecutionResult.hasBenchmarkError()) { 184 Log.e(TAG, String.format("Benchmark failed with message %s", 185 modelExecutionResult.getBenchmarkError())); 186 return false; 187 } 188 } catch (IOException | BenchmarkException e) { 189 Log.e(TAG, String.format("Error running model %s", mTestModelEntry.mModelName), e); 190 return false; 191 } 192 } 193 194 return true; 195 } 196 } 197 } 198