1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ShimPreparedModel.h"
18
19 #include <aidl/android/hardware/neuralnetworks/BnBurst.h>
20 #include <aidl/android/hardware/neuralnetworks/BnExecution.h>
21 #include <aidl/android/hardware/neuralnetworks/BnFencedExecutionCallback.h>
22 #include <aidl/android/hardware/neuralnetworks/ErrorStatus.h>
23 #include <aidl/android/hardware/neuralnetworks/OutputShape.h>
24 #include <aidl/android/hardware/neuralnetworks/RequestMemoryPool.h>
25 #include <android-base/chrono_utils.h>
26 #include <android-base/logging.h>
27 #include <android-base/scopeguard.h>
28 #include <android/binder_auto_utils.h>
29 #include <nnapi/TypeUtils.h>
30 #include <nnapi/hal/aidl/Conversions.h>
31 #include <nnapi/hal/aidl/Utils.h>
32
33 #include <algorithm>
34 #include <chrono>
35 #include <limits>
36 #include <memory>
37 #include <thread>
38 #include <unordered_map>
39 #include <utility>
40 #include <vector>
41
42 #include "ShimConverter.h"
43 #include "ShimUtils.h"
44
45 namespace aidl::android::hardware::neuralnetworks {
46
parseInputs(const Request & request,bool measure,int64_t deadlineNs,int64_t loopTimeoutDurationNs,::android::nn::sl_wrapper::Execution * execution,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> * requestMemoryPools,const std::vector<TokenValuePair> & executionHints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix)47 ErrorStatus ShimPreparedModel::parseInputs(
48 const Request& request, bool measure, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
49 ::android::nn::sl_wrapper::Execution* execution,
50 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>>* requestMemoryPools,
51 const std::vector<TokenValuePair>& executionHints,
52 const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
53 for (const auto& requestPool : request.pools) {
54 switch (requestPool.getTag()) {
55 case RequestMemoryPool::pool: {
56 const auto& memoryPool = requestPool.get<RequestMemoryPool::pool>();
57 std::shared_ptr<::android::nn::sl_wrapper::Memory> mem =
58 convertFromHAL(mNnapi.get(), memoryPool);
59 if (!mem) {
60 LOG(ERROR) << "Failed to convert request HAL memory pools into SL memory";
61 return ErrorStatus::INVALID_ARGUMENT;
62 }
63
64 requestMemoryPools->push_back(mem);
65 break;
66 }
67 case RequestMemoryPool::token: {
68 int token = requestPool.get<RequestMemoryPool::token>();
69
70 auto memory = mBufferTracker->get(static_cast<uint32_t>(token));
71 if (memory == nullptr) {
72 return ErrorStatus::INVALID_ARGUMENT;
73 }
74
75 requestMemoryPools->push_back(memory);
76 break;
77 }
78 }
79 }
80
81 // enable input and output padding
82 const auto enablePaddingResult = execution->enableInputAndOutputPadding(true);
83 if (enablePaddingResult != Result::NO_ERROR) {
84 return convertResultToErrorStatus(enablePaddingResult);
85 }
86
87 const auto& model = mMainAndReferencedModels[0];
88
89 if (request.inputs.size() > model.getInputs().size()) {
90 return ErrorStatus::INVALID_ARGUMENT;
91 }
92
93 // set inputs
94 for (int i = 0; i < request.inputs.size(); ++i) {
95 const auto& input = request.inputs[i];
96 ::android::nn::wrapper::OperandType operandType = model.getOperands()[model.getInputs()[i]];
97 if (!input.hasNoValue) {
98 if (input.dimensions.size() > 0) {
99 operandType.updateDimensions(::android::nn::toUnsigned(input.dimensions).value());
100 }
101 auto result = execution->setInputFromMemory(
102 i, requestMemoryPools->at(input.location.poolIndex).get(),
103 input.location.offset, input.location.length, &operandType.operandType);
104 if (result != Result::NO_ERROR) {
105 return convertResultToErrorStatus(result);
106 }
107 } else {
108 auto result = execution->setInput(i, nullptr, 0);
109 if (result != Result::NO_ERROR) {
110 return convertResultToErrorStatus(result);
111 }
112 }
113 }
114
115 if (request.outputs.size() > model.getOutputs().size()) {
116 return ErrorStatus::INVALID_ARGUMENT;
117 }
118 // set outputs
119 for (int i = 0; i < request.outputs.size(); ++i) {
120 const auto& output = request.outputs[i];
121 ::android::nn::wrapper::OperandType operandType =
122 model.getOperands()[model.getOutputs()[i]];
123
124 if (!output.hasNoValue) {
125 if (output.dimensions.size() > 0) {
126 operandType.updateDimensions(::android::nn::toUnsigned(output.dimensions).value());
127 }
128 auto result = execution->setOutputFromMemory(
129 i, requestMemoryPools->at(output.location.poolIndex).get(),
130 output.location.offset, output.location.length, &operandType.operandType);
131 if (result != Result::NO_ERROR) {
132 return convertResultToErrorStatus(result);
133 }
134 } else {
135 auto result = execution->setOutput(i, nullptr, 0);
136 if (result != Result::NO_ERROR) {
137 return convertResultToErrorStatus(result);
138 }
139 }
140 }
141
142 if (measure) {
143 execution->setMeasureTiming(true);
144 }
145
146 if (deadlineNs > -1) {
147 std::chrono::time_point<::android::base::boot_clock> deadlinePoint(
148 std::chrono::nanoseconds{deadlineNs});
149 const auto currentTime = ::android::base::boot_clock::now();
150 const auto timeoutDuration = std::chrono::nanoseconds(deadlinePoint - currentTime);
151 if (timeoutDuration <= std::chrono::nanoseconds::zero()) {
152 return ErrorStatus::MISSED_DEADLINE_TRANSIENT;
153 } else {
154 auto result = execution->setTimeout(std::max<uint64_t>(1, timeoutDuration.count()));
155 if (result != Result::NO_ERROR) {
156 return convertResultToErrorStatus(result);
157 }
158 }
159 }
160
161 if (loopTimeoutDurationNs > 0) {
162 execution->setLoopTimeout(loopTimeoutDurationNs);
163 }
164
165 if (!executionHints.empty() || !extensionNameToPrefix.empty()) {
166 std::unordered_map<uint16_t, std::string> prefixToName;
167 for (const auto [name, prefix] : extensionNameToPrefix) {
168 prefixToName.emplace(prefix, name);
169 }
170
171 for (const auto& [token, value] : executionHints) {
172 const auto uToken = static_cast<uint32_t>(token);
173 const auto prefix = ::android::nn::getExtensionPrefix(uToken);
174 const auto attributeCodeWithinExtension = ::android::nn::getTypeWithinExtension(uToken);
175
176 const auto it = prefixToName.find(prefix);
177 if (it == prefixToName.end()) {
178 return ErrorStatus::INVALID_ARGUMENT;
179 }
180 const std::string& extensionName = it->second;
181
182 const auto result = execution->addExtensionAttribute(
183 extensionName, attributeCodeWithinExtension, value);
184 if (result != Result::NO_ERROR) {
185 return convertResultToErrorStatus(result);
186 }
187 }
188 }
189
190 return ErrorStatus::NONE;
191 }
192
193 class ShimFencedExecutionCallback : public BnFencedExecutionCallback {
194 public:
ShimFencedExecutionCallback(std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,Event e,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,bool measureTiming)195 ShimFencedExecutionCallback(
196 std::shared_ptr<::android::nn::sl_wrapper::Execution> execution, Event e,
197 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,
198 bool measureTiming)
199 : mMemoryPools(std::move(memoryPools)),
200 mExecution(std::move(execution)),
201 mEvent(std::move(e)),
202 mMeasureTiming(measureTiming) {}
203
getExecutionInfo(Timing * timingLaunched,Timing * timingFenced,ErrorStatus * errorStatus)204 ndk::ScopedAStatus getExecutionInfo(Timing* timingLaunched, Timing* timingFenced,
205 ErrorStatus* errorStatus) override {
206 auto status = mEvent.wait();
207 *errorStatus = convertResultToErrorStatus(status);
208
209 if (mMeasureTiming) {
210 uint64_t duration;
211 constexpr int64_t int64cap = std::numeric_limits<int64_t>::max();
212 // Special value used for "no measurements"
213 constexpr uint64_t uint64cap = std::numeric_limits<uint64_t>::max();
214 auto result = mExecution->getDuration(Duration::ON_HARDWARE, &duration);
215 SLW2SAS_RETURN_IF_ERROR(result);
216 timingLaunched->timeOnDeviceNs = (duration == uint64cap) ? -1
217 : (duration > int64cap)
218 ? int64cap
219 : static_cast<int64_t>(duration);
220
221 result = mExecution->getDuration(Duration::IN_DRIVER, &duration);
222 SLW2SAS_RETURN_IF_ERROR(result);
223 timingLaunched->timeInDriverNs = (duration == uint64cap) ? -1
224 : (duration > int64cap)
225 ? int64cap
226 : static_cast<int64_t>(duration);
227
228 result = mExecution->getDuration(Duration::FENCED_ON_HARDWARE, &duration);
229 SLW2SAS_RETURN_IF_ERROR(result);
230 timingFenced->timeOnDeviceNs = (duration == uint64cap) ? -1
231 : (duration > int64cap) ? int64cap
232 : static_cast<int64_t>(duration);
233
234 result = mExecution->getDuration(Duration::FENCED_IN_DRIVER, &duration);
235 SLW2SAS_RETURN_IF_ERROR(result);
236 timingFenced->timeInDriverNs = (duration == uint64cap) ? -1
237 : (duration > int64cap) ? int64cap
238 : static_cast<int64_t>(duration);
239 } else {
240 timingFenced->timeOnDeviceNs = -1;
241 timingFenced->timeInDriverNs = -1;
242 timingLaunched->timeOnDeviceNs = -1;
243 timingLaunched->timeInDriverNs = -1;
244 }
245
246 return ndk::ScopedAStatus::ok();
247 }
248
249 private:
250 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> mMemoryPools;
251 std::shared_ptr<::android::nn::sl_wrapper::Execution> mExecution;
252 ::android::nn::wrapper::Event mEvent;
253 bool mMeasureTiming;
254 };
255
executeFencedInternal(const std::shared_ptr<const NnApiSupportLibrary> & nnapi,const std::shared_ptr<::android::nn::sl_wrapper::Execution> & execution,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,const std::vector<ndk::ScopedFileDescriptor> & waitFor,int64_t durationNs,bool measureTiming,FencedExecutionResult * fencedExecutionResult)256 static ndk::ScopedAStatus executeFencedInternal(
257 const std::shared_ptr<const NnApiSupportLibrary>& nnapi,
258 const std::shared_ptr<::android::nn::sl_wrapper::Execution>& execution,
259 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
260 const std::vector<ndk::ScopedFileDescriptor>& waitFor, int64_t durationNs,
261 bool measureTiming, FencedExecutionResult* fencedExecutionResult) {
262 CHECK(execution != nullptr);
263 CHECK(fencedExecutionResult != nullptr);
264
265 std::vector<const ANeuralNetworksEvent*> deps(waitFor.size());
266 auto createResult = Result::NO_ERROR;
267 std::transform(waitFor.begin(), waitFor.end(), deps.begin(),
268 [&](const ::ndk::ScopedFileDescriptor& e) {
269 ANeuralNetworksEvent* r = nullptr;
270 if (createResult == Result::NO_ERROR) {
271 createResult = static_cast<Result>(
272 nnapi->getFL5()->ANeuralNetworksEvent_createFromSyncFenceFd(
273 e.get(), &r));
274 }
275 return r;
276 });
277
278 const auto guard = ::android::base::make_scope_guard([nnapi, deps] {
279 for (auto& dep : deps) {
280 if (dep != nullptr) {
281 nnapi->getFL5()->ANeuralNetworksEvent_free(const_cast<ANeuralNetworksEvent*>(dep));
282 }
283 }
284 });
285
286 SLW2SAS_RETURN_IF_ERROR(createResult);
287
288 Event e(nnapi.get());
289 auto result = execution->startComputeWithDependencies(deps, durationNs, &e);
290 SLW2SAS_RETURN_IF_ERROR(result);
291
292 int syncFence = -1;
293 fencedExecutionResult->syncFence = ndk::ScopedFileDescriptor(
294 (e.getSyncFenceFd(&syncFence) == Result::NO_ERROR) ? syncFence : -1);
295 fencedExecutionResult->callback = ndk::SharedRefBase::make<ShimFencedExecutionCallback>(
296 execution, std::move(e), requestMemoryPools, measureTiming);
297
298 return ndk::ScopedAStatus::ok();
299 }
300
executeFencedCommon(const Request & request,const std::vector<::ndk::ScopedFileDescriptor> & waitFor,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,int64_t durationNs,const std::vector<TokenValuePair> & executionHints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix,FencedExecutionResult * fencedExecutionResult)301 ::ndk::ScopedAStatus ShimPreparedModel::executeFencedCommon(
302 const Request& request, const std::vector<::ndk::ScopedFileDescriptor>& waitFor,
303 bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
304 const std::vector<TokenValuePair>& executionHints,
305 const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
306 FencedExecutionResult* fencedExecutionResult) {
307 CHECK(fencedExecutionResult != nullptr);
308
309 if (deadlineNs < -1) {
310 LOG(ERROR) << "Invalid deadline value, must be >= -1";
311 return ndk::ScopedAStatus::fromServiceSpecificError(
312 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
313 }
314 auto execution =
315 std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
316 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
317 auto errorStatus =
318 parseInputs(request, measureTiming, deadlineNs, loopTimeoutDurationNs, execution.get(),
319 &requestMemoryPools, executionHints, extensionNameToPrefix);
320 if (errorStatus != ErrorStatus::NONE) {
321 return toAStatus(errorStatus);
322 }
323 return executeFencedInternal(mNnapi, execution, std::move(requestMemoryPools), waitFor,
324 durationNs, measureTiming, fencedExecutionResult);
325 }
326
executeFenced(const::aidl::android::hardware::neuralnetworks::Request & request,const std::vector<::ndk::ScopedFileDescriptor> & waitFor,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,int64_t durationNs,FencedExecutionResult * fencedExecutionResult)327 ::ndk::ScopedAStatus ShimPreparedModel::executeFenced(
328 const ::aidl::android::hardware::neuralnetworks::Request& request,
329 const std::vector<::ndk::ScopedFileDescriptor>& waitFor, bool measureTiming,
330 int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
331 FencedExecutionResult* fencedExecutionResult) {
332 return executeFencedCommon(request, waitFor, measureTiming, deadlineNs, loopTimeoutDurationNs,
333 durationNs, /*executionHints=*/{}, /*extensionNameToPrefix=*/{},
334 fencedExecutionResult);
335 }
336
executeSynchronouslyInternal(const std::shared_ptr<::android::nn::sl_wrapper::Execution> & execution,bool measureTiming,int numOutputs,ExecutionResult * executionResult)337 static ndk::ScopedAStatus executeSynchronouslyInternal(
338 const std::shared_ptr<::android::nn::sl_wrapper::Execution>& execution, bool measureTiming,
339 int numOutputs, ExecutionResult* executionResult) {
340 CHECK(execution != nullptr);
341 CHECK(executionResult != nullptr);
342
343 auto result = execution->compute();
344 auto errorStatus = convertResultToErrorStatus(result);
345
346 std::vector<OutputShape> outputShapes;
347 outputShapes.reserve(numOutputs);
348 bool sufficientSize = true;
349 for (int i = 0; i < numOutputs; ++i) {
350 OutputShape outputShape;
351 std::vector<uint32_t> outputDims;
352 auto result = execution->getOutputOperandDimensions(i, &outputDims);
353 if (result == Result::NO_ERROR) {
354 outputShape.isSufficient = true;
355 outputShape.dimensions.assign(outputDims.begin(), outputDims.end());
356 } else if (result == Result::OUTPUT_INSUFFICIENT_SIZE) {
357 sufficientSize = false;
358 outputShape.isSufficient = false;
359 outputShape.dimensions.assign(outputDims.begin(), outputDims.end());
360 } else {
361 if (errorStatus == ErrorStatus::NONE) {
362 errorStatus = ErrorStatus::GENERAL_FAILURE;
363 }
364 }
365 outputShapes.push_back(std::move(outputShape));
366 }
367
368 int64_t timeOnDeviceNs = -1;
369 int64_t timeInDriverNs = -1;
370 if (measureTiming && errorStatus == ErrorStatus::NONE) {
371 uint64_t duration;
372 constexpr int64_t int64cap = std::numeric_limits<int64_t>::max();
373 // Special value used for "no measurements"
374 constexpr uint64_t uint64cap = std::numeric_limits<uint64_t>::max();
375 auto result = execution->getDuration(Duration::ON_HARDWARE, &duration);
376 SLW2SAS_RETURN_IF_ERROR(result);
377 timeOnDeviceNs = (duration == uint64cap) ? -1
378 : (duration > int64cap) ? int64cap
379 : static_cast<int64_t>(duration);
380
381 result = execution->getDuration(Duration::IN_DRIVER, &duration);
382 SLW2SAS_RETURN_IF_ERROR(result);
383 timeInDriverNs = (duration == uint64cap) ? -1
384 : (duration > int64cap) ? int64cap
385 : static_cast<int64_t>(duration);
386 }
387
388 *executionResult =
389 ExecutionResult{sufficientSize,
390 std::move(outputShapes),
391 {.timeOnDeviceNs = timeOnDeviceNs, .timeInDriverNs = timeInDriverNs}};
392 if (errorStatus == ErrorStatus::NONE || errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
393 return ndk::ScopedAStatus::ok();
394 }
395 return toAStatus(errorStatus);
396 }
397
executeSynchronouslyCommon(const Request & request,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,const std::vector<TokenValuePair> & executionHints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix,ExecutionResult * executionResult)398 ::ndk::ScopedAStatus ShimPreparedModel::executeSynchronouslyCommon(
399 const Request& request, bool measureTiming, int64_t deadlineNs,
400 int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& executionHints,
401 const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
402 ExecutionResult* executionResult) {
403 CHECK(executionResult != nullptr);
404
405 if (deadlineNs < -1) {
406 LOG(ERROR) << "Invalid deadline value, must be >= -1";
407 return ndk::ScopedAStatus::fromServiceSpecificError(
408 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
409 }
410
411 auto execution =
412 std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
413 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
414 auto errorStatus =
415 parseInputs(request, measureTiming, deadlineNs, loopTimeoutDurationNs, execution.get(),
416 &requestMemoryPools, executionHints, extensionNameToPrefix);
417 if (errorStatus != ErrorStatus::NONE) {
418 return toAStatus(errorStatus);
419 }
420 return executeSynchronouslyInternal(execution, measureTiming, request.outputs.size(),
421 executionResult);
422 }
423
executeSynchronously(const Request & request,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,::aidl::android::hardware::neuralnetworks::ExecutionResult * executionResult)424 ::ndk::ScopedAStatus ShimPreparedModel::executeSynchronously(
425 const Request& request, bool measureTiming, int64_t deadlineNs,
426 int64_t loopTimeoutDurationNs,
427 ::aidl::android::hardware::neuralnetworks::ExecutionResult* executionResult) {
428 return executeSynchronouslyCommon(request, measureTiming, deadlineNs, loopTimeoutDurationNs,
429 /*executionHints=*/{}, /*extensionNameToPrefix=*/{},
430 executionResult);
431 }
432
executeSynchronouslyWithConfig(const Request & request,const ExecutionConfig & config,int64_t deadlineNs,ExecutionResult * executionResult)433 ::ndk::ScopedAStatus ShimPreparedModel::executeSynchronouslyWithConfig(
434 const Request& request, const ExecutionConfig& config, int64_t deadlineNs,
435 ExecutionResult* executionResult) {
436 return executeSynchronouslyCommon(request, config.measureTiming, deadlineNs,
437 config.loopTimeoutDurationNs, config.executionHints,
438 config.extensionNameToPrefix, executionResult);
439 }
440
executeFencedWithConfig(const Request & request,const std::vector<ndk::ScopedFileDescriptor> & waitFor,const ExecutionConfig & config,int64_t deadlineNs,int64_t durationNs,FencedExecutionResult * executionResult)441 ::ndk::ScopedAStatus ShimPreparedModel::executeFencedWithConfig(
442 const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
443 const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
444 FencedExecutionResult* executionResult) {
445 return executeFencedCommon(request, waitFor, config.measureTiming, deadlineNs,
446 config.loopTimeoutDurationNs, durationNs, config.executionHints,
447 config.extensionNameToPrefix, executionResult);
448 }
449
450 // TODO(183397380): make it use ANNBurst object
451 class ShimBurst : public BnBurst {
452 public:
453 // Precondition: preparedModel != nullptr
454 explicit ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel);
455
456 ndk::ScopedAStatus executeSynchronously(const Request& request,
457 const std::vector<int64_t>& memoryIdentifierTokens,
458 bool measureTiming, int64_t deadlineNs,
459 int64_t loopTimeoutDurationNs,
460 ExecutionResult* executionResult) override;
461 ndk::ScopedAStatus executeSynchronouslyWithConfig(
462 const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
463 const ExecutionConfig& config, int64_t deadlineNs,
464 ExecutionResult* executionResult) override;
465 ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override;
466
467 protected:
468 std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
469 const std::shared_ptr<ShimPreparedModel> kPreparedModel;
470 };
471
configureExecutionBurst(std::shared_ptr<IBurst> * burst)472 ndk::ScopedAStatus ShimPreparedModel::configureExecutionBurst(std::shared_ptr<IBurst>* burst) {
473 std::shared_ptr<ShimPreparedModel> self = this->template ref<ShimPreparedModel>();
474 *burst = ndk::SharedRefBase::make<ShimBurst>(std::move(self));
475 return ndk::ScopedAStatus::ok();
476 }
477
ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel)478 ShimBurst::ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel)
479 : kPreparedModel(std::move(preparedModel)) {
480 CHECK(kPreparedModel != nullptr);
481 }
482
executeSynchronously(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,ExecutionResult * executionResult)483 ndk::ScopedAStatus ShimBurst::executeSynchronously(
484 const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
485 bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
486 ExecutionResult* executionResult) {
487 if (request.pools.size() != memoryIdentifierTokens.size()) {
488 return toAStatus(ErrorStatus::INVALID_ARGUMENT,
489 "request.pools.size() != memoryIdentifierTokens.size()");
490 }
491 if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
492 [](int64_t token) { return token >= -1; })) {
493 return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierTokens");
494 }
495
496 // Ensure at most one execution is in flight at a time.
497 const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
498 if (executionAlreadyInFlight) {
499 return toAStatus(ErrorStatus::GENERAL_FAILURE,
500 "Burst object supports at most one execution at a time");
501 }
502 const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
503
504 return kPreparedModel->executeSynchronously(request, measureTiming, deadlineNs,
505 loopTimeoutDurationNs, executionResult);
506 }
507
executeSynchronouslyWithConfig(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,const ExecutionConfig & config,int64_t deadlineNs,ExecutionResult * executionResult)508 ndk::ScopedAStatus ShimBurst::executeSynchronouslyWithConfig(
509 const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
510 const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) {
511 if (request.pools.size() != memoryIdentifierTokens.size()) {
512 return toAStatus(ErrorStatus::INVALID_ARGUMENT,
513 "request.pools.size() != memoryIdentifierTokens.size()");
514 }
515 if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
516 [](int64_t token) { return token >= -1; })) {
517 return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierTokens");
518 }
519
520 // Ensure at most one execution is in flight at a time.
521 const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
522 if (executionAlreadyInFlight) {
523 return toAStatus(ErrorStatus::GENERAL_FAILURE,
524 "Burst object supports at most one execution at a time");
525 }
526 const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
527
528 return kPreparedModel->executeSynchronouslyWithConfig(request, config, deadlineNs,
529 executionResult);
530 }
531
releaseMemoryResource(int64_t memoryIdentifierToken)532 ndk::ScopedAStatus ShimBurst::releaseMemoryResource(int64_t memoryIdentifierToken) {
533 if (memoryIdentifierToken < -1) {
534 return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierToken");
535 }
536 return ndk::ScopedAStatus::ok();
537 }
538
539 class ShimExecution : public BnExecution {
540 public:
541 explicit ShimExecution(
542 std::shared_ptr<const NnApiSupportLibrary> nnapi,
543 std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,
544 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
545 bool measureTiming, int numberOfOutputs);
546
547 ndk::ScopedAStatus executeSynchronously(int64_t deadlineNs,
548 ExecutionResult* executionResult) override;
549 ndk::ScopedAStatus executeFenced(const std::vector<ndk::ScopedFileDescriptor>& waitFor,
550 int64_t deadlineNs, int64_t durationNs,
551 FencedExecutionResult* fencedExecutionResult) override;
552
553 protected:
554 std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
555 std::shared_ptr<const NnApiSupportLibrary> mNnapi;
556 std::shared_ptr<::android::nn::sl_wrapper::Execution> mExecution;
557 const std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> kRequestMemoryPools;
558 const bool kMeasureTiming;
559 const int kNumberOfOutputs;
560 };
561
createReusableExecution(const Request & request,const ExecutionConfig & config,std::shared_ptr<IExecution> * execution)562 ndk::ScopedAStatus ShimPreparedModel::createReusableExecution(
563 const Request& request, const ExecutionConfig& config,
564 std::shared_ptr<IExecution>* execution) {
565 auto wrapperExecution =
566 std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
567 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
568 auto errorStatus =
569 parseInputs(request, config.measureTiming, kNoDeadline, config.loopTimeoutDurationNs,
570 wrapperExecution.get(), &requestMemoryPools, config.executionHints,
571 config.extensionNameToPrefix);
572 if (errorStatus != ErrorStatus::NONE) {
573 return toAStatus(errorStatus);
574 }
575 auto result = wrapperExecution->setReusable(true);
576 SLW2SAS_RETURN_IF_ERROR(result);
577
578 *execution = ndk::SharedRefBase::make<ShimExecution>(
579 mNnapi, std::move(wrapperExecution), std::move(requestMemoryPools),
580 config.measureTiming, request.outputs.size());
581 return ndk::ScopedAStatus::ok();
582 }
583
ShimExecution(std::shared_ptr<const NnApiSupportLibrary> nnapi,std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,bool measureTiming,int numberOfOutputs)584 ShimExecution::ShimExecution(
585 std::shared_ptr<const NnApiSupportLibrary> nnapi,
586 std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,
587 std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
588 bool measureTiming, int numberOfOutputs)
589 : mNnapi(std::move(nnapi)),
590 mExecution(std::move(execution)),
591 kRequestMemoryPools(std::move(requestMemoryPools)),
592 kMeasureTiming(measureTiming),
593 kNumberOfOutputs(numberOfOutputs) {}
594
executeSynchronously(int64_t deadlineNs,ExecutionResult * executionResult)595 ndk::ScopedAStatus ShimExecution::executeSynchronously(int64_t deadlineNs,
596 ExecutionResult* executionResult) {
597 if (deadlineNs < -1) {
598 LOG(ERROR) << "Invalid deadline value, must be >= -1";
599 return ndk::ScopedAStatus::fromServiceSpecificError(
600 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
601 }
602
603 // Ensure at most one execution is in flight at a time.
604 const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
605 if (executionAlreadyInFlight) {
606 return toAStatus(ErrorStatus::GENERAL_FAILURE,
607 "Execution object supports at most one execution at a time");
608 }
609 const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
610
611 return executeSynchronouslyInternal(mExecution, kMeasureTiming, kNumberOfOutputs,
612 executionResult);
613 }
614
executeFenced(const std::vector<ndk::ScopedFileDescriptor> & waitFor,int64_t deadlineNs,int64_t durationNs,FencedExecutionResult * fencedExecutionResult)615 ndk::ScopedAStatus ShimExecution::executeFenced(
616 const std::vector<ndk::ScopedFileDescriptor>& waitFor, int64_t deadlineNs,
617 int64_t durationNs, FencedExecutionResult* fencedExecutionResult) {
618 if (deadlineNs < -1) {
619 LOG(ERROR) << "Invalid deadline value, must be >= -1";
620 return ndk::ScopedAStatus::fromServiceSpecificError(
621 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
622 }
623
624 // Ensure at most one execution is in flight at a time.
625 const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
626 if (executionAlreadyInFlight) {
627 return toAStatus(ErrorStatus::GENERAL_FAILURE,
628 "Execution object supports at most one execution at a time");
629 }
630 const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
631
632 return executeFencedInternal(mNnapi, mExecution, kRequestMemoryPools, waitFor, durationNs,
633 kMeasureTiming, fencedExecutionResult);
634 }
635
636 } // namespace aidl::android::hardware::neuralnetworks
637