1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.core.test; 18 19 import static java.util.concurrent.TimeUnit.MILLISECONDS; 20 21 import android.annotation.SuppressLint; 22 import android.content.Context; 23 import android.content.Intent; 24 import android.util.Log; 25 26 import com.android.nn.benchmark.core.Processor; 27 import com.android.nn.crashtest.core.CrashTest; 28 import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer; 29 import com.android.nn.benchmark.core.TfLiteBackend; 30 31 import java.time.Duration; 32 import java.util.ArrayList; 33 import java.util.Collections; 34 import java.util.HashSet; 35 import java.util.List; 36 import java.util.Optional; 37 import java.util.Set; 38 import java.util.concurrent.CountDownLatch; 39 import java.util.concurrent.ExecutionException; 40 import java.util.concurrent.ExecutorService; 41 import java.util.concurrent.Executors; 42 import java.util.concurrent.Future; 43 44 public class RunModelsInParallel implements CrashTest { 45 46 private static final String MODELS = "models"; 47 private static final String DURATION = "duration"; 48 private static final String THREADS = "thread_counts"; 49 private static final String TEST_NAME = "test_name"; 50 private static final String ACCELERATOR_NAME = "accelerator_name"; 51 private static final String IGNORE_UNSUPPORTED_MODELS = "ignore_unsupported_models"; 52 private static final String RUN_MODEL_COMPILATION_ONLY = "run_model_compilation_only"; 53 private static final String MEMORY_MAP_MODEL = "memory_map_model"; 54 private static final String MODEL_FILTER = "model_filter"; 55 private static final String USE_NNAPI_SL = "use_nnapi_sl"; 56 private static final String EXTRACT_NNAPI_SL = "extract_nnapi_sl"; 57 58 private final Set<Processor> activeTests = new HashSet<>(); 59 private final List<Boolean> mTestCompletionResults = Collections.synchronizedList( 60 new ArrayList<>()); 61 private long mTestDurationMillis = 0; 62 private int mThreadCount = 0; 63 private int[] mTestList = new int[0]; 64 private String mTestName; 65 private String mAcceleratorName; 66 private boolean mIgnoreUnsupportedModels; 67 private Context mContext; 68 private boolean mRunModelCompilationOnly; 69 private ExecutorService mExecutorService = null; 70 private CountDownLatch mParallelTestComplete; 71 private ProgressListener mProgressListener; 72 private boolean mMmapModel; 73 private boolean mUseNnapiSl; 74 private boolean mExtractNnapiSl; 75 intentInitializer(int[] models, int threadCount, Duration duration, String testName, String acceleratorName, boolean ignoreUnsupportedModels, boolean runModelCompilationOnly, boolean mmapModel, String modelFilter, boolean useNnapiSl, boolean extractNnapiSl)76 static public CrashTestIntentInitializer intentInitializer(int[] models, int threadCount, 77 Duration duration, String testName, String acceleratorName, 78 boolean ignoreUnsupportedModels, 79 boolean runModelCompilationOnly, boolean mmapModel, String modelFilter, boolean useNnapiSl, 80 boolean extractNnapiSl) { 81 return intent -> { 82 intent.putExtra(MODELS, models); 83 intent.putExtra(DURATION, duration.toMillis()); 84 intent.putExtra(THREADS, threadCount); 85 intent.putExtra(TEST_NAME, testName); 86 intent.putExtra(ACCELERATOR_NAME, acceleratorName); 87 intent.putExtra(IGNORE_UNSUPPORTED_MODELS, ignoreUnsupportedModels); 88 intent.putExtra(RUN_MODEL_COMPILATION_ONLY, runModelCompilationOnly); 89 intent.putExtra(MEMORY_MAP_MODEL, mmapModel); 90 intent.putExtra(MODEL_FILTER, modelFilter); 91 intent.putExtra(USE_NNAPI_SL, useNnapiSl); 92 intent.putExtra(EXTRACT_NNAPI_SL, extractNnapiSl); 93 }; 94 } 95 96 @Override init(Context context, Intent configParams, Optional<ProgressListener> progressListener)97 public void init(Context context, Intent configParams, 98 Optional<ProgressListener> progressListener) { 99 mTestList = configParams.getIntArrayExtra(MODELS); 100 mThreadCount = configParams.getIntExtra(THREADS, 10); 101 mTestDurationMillis = configParams.getLongExtra(DURATION, 1000 * 60 * 10); 102 mTestName = configParams.getStringExtra(TEST_NAME); 103 mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME); 104 mIgnoreUnsupportedModels = mAcceleratorName != null && configParams.getBooleanExtra( 105 IGNORE_UNSUPPORTED_MODELS, false); 106 mRunModelCompilationOnly = configParams.getBooleanExtra(RUN_MODEL_COMPILATION_ONLY, false); 107 mMmapModel = configParams.getBooleanExtra(MEMORY_MAP_MODEL, false); 108 mUseNnapiSl = configParams.getBooleanExtra(USE_NNAPI_SL, false); 109 mExtractNnapiSl = configParams.getBooleanExtra(EXTRACT_NNAPI_SL, false); 110 mContext = context; 111 mProgressListener = progressListener.orElseGet(() -> (Optional<String> message) -> { 112 Log.v(CrashTest.TAG, message.orElse(".")); 113 }); 114 mExecutorService = Executors.newFixedThreadPool(mThreadCount); 115 mTestCompletionResults.clear(); 116 } 117 118 @Override call()119 public Optional<String> call() { 120 mParallelTestComplete = new CountDownLatch(mThreadCount); 121 for (int i = 0; i < mThreadCount; i++) { 122 Processor testProcessor = createSubTestRunner(mTestList, i); 123 124 activeTests.add(testProcessor); 125 mExecutorService.submit(testProcessor); 126 } 127 128 return completedSuccessfully(); 129 } 130 createSubTestRunner(final int[] testList, final int testIndex)131 private Processor createSubTestRunner(final int[] testList, final int testIndex) { 132 final Processor result = new Processor(mContext, new Processor.Callback() { 133 @SuppressLint("DefaultLocale") 134 @Override 135 public void onBenchmarkFinish(boolean ok) { 136 notifyProgress("Test '%s': Benchmark #%d completed %s", mTestName, testIndex, 137 ok ? "successfully" : "with failure"); 138 mTestCompletionResults.add(ok); 139 mParallelTestComplete.countDown(); 140 } 141 142 @Override 143 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 144 } 145 }, testList); 146 result.setTfLiteBackend(TfLiteBackend.NNAPI); 147 result.setCompleteInputSet(false); 148 result.setNnApiAcceleratorName(mAcceleratorName); 149 result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels); 150 result.setRunModelCompilationOnly(mRunModelCompilationOnly); 151 result.setMmapModel(mMmapModel); 152 result.setUseNnApiSupportLibrary(mUseNnapiSl); 153 result.setExtractNnApiSupportLibrary(mExtractNnapiSl); 154 return result; 155 } 156 endTests()157 private void endTests() { 158 ExecutorService terminatorsThreadPool = Executors.newFixedThreadPool(activeTests.size()); 159 List<Future<?>> terminationCommands = new ArrayList<>(); 160 for (final Processor test : activeTests) { 161 // Exit will block until the thread is completed 162 terminationCommands.add(terminatorsThreadPool.submit( 163 () -> test.exitWithTimeout(Duration.ofSeconds(20).toMillis()))); 164 } 165 terminationCommands.forEach(terminationCommand -> { 166 try { 167 terminationCommand.get(); 168 } catch (ExecutionException e) { 169 Log.w(TAG, "Failure while waiting for completion of tests", e); 170 } catch (InterruptedException e) { 171 Thread.interrupted(); 172 } 173 }); 174 } 175 176 @SuppressLint("DefaultLocale") notifyProgress(String messageFormat, Object... args)177 void notifyProgress(String messageFormat, Object... args) { 178 mProgressListener.testProgress(Optional.of(String.format(messageFormat, args))); 179 } 180 181 // This method blocks until the tests complete and returns true if all tests completed 182 // successfully 183 @SuppressLint("DefaultLocale") completedSuccessfully()184 private Optional<String> completedSuccessfully() { 185 try { 186 boolean testsEnded = mParallelTestComplete.await(mTestDurationMillis, MILLISECONDS); 187 if (!testsEnded) { 188 Log.i(TAG, 189 String.format( 190 "Test '%s': Tests are not completed (they might have been " 191 + "designed to run " 192 + "indefinitely. Forcing termination.", mTestName)); 193 endTests(); 194 } 195 } catch (InterruptedException ignored) { 196 Thread.currentThread().interrupt(); 197 } 198 199 final long failedTestCount = mTestCompletionResults.stream().filter( 200 testResult -> !testResult).count(); 201 if (failedTestCount > 0) { 202 String failureMsg = String.format("Test '%s': %d out of %d test failed", mTestName, 203 failedTestCount, 204 mTestCompletionResults.size()); 205 Log.w(CrashTest.TAG, failureMsg); 206 return failure(failureMsg); 207 } else { 208 Log.i(CrashTest.TAG, 209 String.format("Test '%s': Test completed successfully", mTestName)); 210 return success(); 211 } 212 } 213 } 214