1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ShimPreparedModel.h"
18 
19 #include <aidl/android/hardware/neuralnetworks/BnBurst.h>
20 #include <aidl/android/hardware/neuralnetworks/BnExecution.h>
21 #include <aidl/android/hardware/neuralnetworks/BnFencedExecutionCallback.h>
22 #include <aidl/android/hardware/neuralnetworks/ErrorStatus.h>
23 #include <aidl/android/hardware/neuralnetworks/OutputShape.h>
24 #include <aidl/android/hardware/neuralnetworks/RequestMemoryPool.h>
25 #include <android-base/chrono_utils.h>
26 #include <android-base/logging.h>
27 #include <android-base/scopeguard.h>
28 #include <android/binder_auto_utils.h>
29 #include <nnapi/TypeUtils.h>
30 #include <nnapi/hal/aidl/Conversions.h>
31 #include <nnapi/hal/aidl/Utils.h>
32 
33 #include <algorithm>
34 #include <chrono>
35 #include <limits>
36 #include <memory>
37 #include <thread>
38 #include <unordered_map>
39 #include <utility>
40 #include <vector>
41 
42 #include "ShimConverter.h"
43 #include "ShimUtils.h"
44 
45 namespace aidl::android::hardware::neuralnetworks {
46 
parseInputs(const Request & request,bool measure,int64_t deadlineNs,int64_t loopTimeoutDurationNs,::android::nn::sl_wrapper::Execution * execution,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> * requestMemoryPools,const std::vector<TokenValuePair> & executionHints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix)47 ErrorStatus ShimPreparedModel::parseInputs(
48         const Request& request, bool measure, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
49         ::android::nn::sl_wrapper::Execution* execution,
50         std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>>* requestMemoryPools,
51         const std::vector<TokenValuePair>& executionHints,
52         const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
53     for (const auto& requestPool : request.pools) {
54         switch (requestPool.getTag()) {
55             case RequestMemoryPool::pool: {
56                 const auto& memoryPool = requestPool.get<RequestMemoryPool::pool>();
57                 std::shared_ptr<::android::nn::sl_wrapper::Memory> mem =
58                         convertFromHAL(mNnapi.get(), memoryPool);
59                 if (!mem) {
60                     LOG(ERROR) << "Failed to convert request HAL memory pools into SL memory";
61                     return ErrorStatus::INVALID_ARGUMENT;
62                 }
63 
64                 requestMemoryPools->push_back(mem);
65                 break;
66             }
67             case RequestMemoryPool::token: {
68                 int token = requestPool.get<RequestMemoryPool::token>();
69 
70                 auto memory = mBufferTracker->get(static_cast<uint32_t>(token));
71                 if (memory == nullptr) {
72                     return ErrorStatus::INVALID_ARGUMENT;
73                 }
74 
75                 requestMemoryPools->push_back(memory);
76                 break;
77             }
78         }
79     }
80 
81     // enable input and output padding
82     const auto enablePaddingResult = execution->enableInputAndOutputPadding(true);
83     if (enablePaddingResult != Result::NO_ERROR) {
84         return convertResultToErrorStatus(enablePaddingResult);
85     }
86 
87     const auto& model = mMainAndReferencedModels[0];
88 
89     if (request.inputs.size() > model.getInputs().size()) {
90         return ErrorStatus::INVALID_ARGUMENT;
91     }
92 
93     // set inputs
94     for (int i = 0; i < request.inputs.size(); ++i) {
95         const auto& input = request.inputs[i];
96         ::android::nn::wrapper::OperandType operandType = model.getOperands()[model.getInputs()[i]];
97         if (!input.hasNoValue) {
98             if (input.dimensions.size() > 0) {
99                 operandType.updateDimensions(::android::nn::toUnsigned(input.dimensions).value());
100             }
101             auto result = execution->setInputFromMemory(
102                     i, requestMemoryPools->at(input.location.poolIndex).get(),
103                     input.location.offset, input.location.length, &operandType.operandType);
104             if (result != Result::NO_ERROR) {
105                 return convertResultToErrorStatus(result);
106             }
107         } else {
108             auto result = execution->setInput(i, nullptr, 0);
109             if (result != Result::NO_ERROR) {
110                 return convertResultToErrorStatus(result);
111             }
112         }
113     }
114 
115     if (request.outputs.size() > model.getOutputs().size()) {
116         return ErrorStatus::INVALID_ARGUMENT;
117     }
118     // set outputs
119     for (int i = 0; i < request.outputs.size(); ++i) {
120         const auto& output = request.outputs[i];
121         ::android::nn::wrapper::OperandType operandType =
122                 model.getOperands()[model.getOutputs()[i]];
123 
124         if (!output.hasNoValue) {
125             if (output.dimensions.size() > 0) {
126                 operandType.updateDimensions(::android::nn::toUnsigned(output.dimensions).value());
127             }
128             auto result = execution->setOutputFromMemory(
129                     i, requestMemoryPools->at(output.location.poolIndex).get(),
130                     output.location.offset, output.location.length, &operandType.operandType);
131             if (result != Result::NO_ERROR) {
132                 return convertResultToErrorStatus(result);
133             }
134         } else {
135             auto result = execution->setOutput(i, nullptr, 0);
136             if (result != Result::NO_ERROR) {
137                 return convertResultToErrorStatus(result);
138             }
139         }
140     }
141 
142     if (measure) {
143         execution->setMeasureTiming(true);
144     }
145 
146     if (deadlineNs > -1) {
147         std::chrono::time_point<::android::base::boot_clock> deadlinePoint(
148                 std::chrono::nanoseconds{deadlineNs});
149         const auto currentTime = ::android::base::boot_clock::now();
150         const auto timeoutDuration = std::chrono::nanoseconds(deadlinePoint - currentTime);
151         if (timeoutDuration <= std::chrono::nanoseconds::zero()) {
152             return ErrorStatus::MISSED_DEADLINE_TRANSIENT;
153         } else {
154             auto result = execution->setTimeout(std::max<uint64_t>(1, timeoutDuration.count()));
155             if (result != Result::NO_ERROR) {
156                 return convertResultToErrorStatus(result);
157             }
158         }
159     }
160 
161     if (loopTimeoutDurationNs > 0) {
162         execution->setLoopTimeout(loopTimeoutDurationNs);
163     }
164 
165     if (!executionHints.empty() || !extensionNameToPrefix.empty()) {
166         std::unordered_map<uint16_t, std::string> prefixToName;
167         for (const auto [name, prefix] : extensionNameToPrefix) {
168             prefixToName.emplace(prefix, name);
169         }
170 
171         for (const auto& [token, value] : executionHints) {
172             const auto uToken = static_cast<uint32_t>(token);
173             const auto prefix = ::android::nn::getExtensionPrefix(uToken);
174             const auto attributeCodeWithinExtension = ::android::nn::getTypeWithinExtension(uToken);
175 
176             const auto it = prefixToName.find(prefix);
177             if (it == prefixToName.end()) {
178                 return ErrorStatus::INVALID_ARGUMENT;
179             }
180             const std::string& extensionName = it->second;
181 
182             const auto result = execution->addExtensionAttribute(
183                     extensionName, attributeCodeWithinExtension, value);
184             if (result != Result::NO_ERROR) {
185                 return convertResultToErrorStatus(result);
186             }
187         }
188     }
189 
190     return ErrorStatus::NONE;
191 }
192 
193 class ShimFencedExecutionCallback : public BnFencedExecutionCallback {
194    public:
ShimFencedExecutionCallback(std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,Event e,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,bool measureTiming)195     ShimFencedExecutionCallback(
196             std::shared_ptr<::android::nn::sl_wrapper::Execution> execution, Event e,
197             std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,
198             bool measureTiming)
199         : mMemoryPools(std::move(memoryPools)),
200           mExecution(std::move(execution)),
201           mEvent(std::move(e)),
202           mMeasureTiming(measureTiming) {}
203 
getExecutionInfo(Timing * timingLaunched,Timing * timingFenced,ErrorStatus * errorStatus)204     ndk::ScopedAStatus getExecutionInfo(Timing* timingLaunched, Timing* timingFenced,
205                                         ErrorStatus* errorStatus) override {
206         auto status = mEvent.wait();
207         *errorStatus = convertResultToErrorStatus(status);
208 
209         if (mMeasureTiming) {
210             uint64_t duration;
211             constexpr int64_t int64cap = std::numeric_limits<int64_t>::max();
212             // Special value used for "no measurements"
213             constexpr uint64_t uint64cap = std::numeric_limits<uint64_t>::max();
214             auto result = mExecution->getDuration(Duration::ON_HARDWARE, &duration);
215             SLW2SAS_RETURN_IF_ERROR(result);
216             timingLaunched->timeOnDeviceNs = (duration == uint64cap) ? -1
217                                              : (duration > int64cap)
218                                                      ? int64cap
219                                                      : static_cast<int64_t>(duration);
220 
221             result = mExecution->getDuration(Duration::IN_DRIVER, &duration);
222             SLW2SAS_RETURN_IF_ERROR(result);
223             timingLaunched->timeInDriverNs = (duration == uint64cap) ? -1
224                                              : (duration > int64cap)
225                                                      ? int64cap
226                                                      : static_cast<int64_t>(duration);
227 
228             result = mExecution->getDuration(Duration::FENCED_ON_HARDWARE, &duration);
229             SLW2SAS_RETURN_IF_ERROR(result);
230             timingFenced->timeOnDeviceNs = (duration == uint64cap) ? -1
231                                            : (duration > int64cap) ? int64cap
232                                                                    : static_cast<int64_t>(duration);
233 
234             result = mExecution->getDuration(Duration::FENCED_IN_DRIVER, &duration);
235             SLW2SAS_RETURN_IF_ERROR(result);
236             timingFenced->timeInDriverNs = (duration == uint64cap) ? -1
237                                            : (duration > int64cap) ? int64cap
238                                                                    : static_cast<int64_t>(duration);
239         } else {
240             timingFenced->timeOnDeviceNs = -1;
241             timingFenced->timeInDriverNs = -1;
242             timingLaunched->timeOnDeviceNs = -1;
243             timingLaunched->timeInDriverNs = -1;
244         }
245 
246         return ndk::ScopedAStatus::ok();
247     }
248 
249    private:
250     std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> mMemoryPools;
251     std::shared_ptr<::android::nn::sl_wrapper::Execution> mExecution;
252     ::android::nn::wrapper::Event mEvent;
253     bool mMeasureTiming;
254 };
255 
executeFencedInternal(const std::shared_ptr<const NnApiSupportLibrary> & nnapi,const std::shared_ptr<::android::nn::sl_wrapper::Execution> & execution,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,const std::vector<ndk::ScopedFileDescriptor> & waitFor,int64_t durationNs,bool measureTiming,FencedExecutionResult * fencedExecutionResult)256 static ndk::ScopedAStatus executeFencedInternal(
257         const std::shared_ptr<const NnApiSupportLibrary>& nnapi,
258         const std::shared_ptr<::android::nn::sl_wrapper::Execution>& execution,
259         std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
260         const std::vector<ndk::ScopedFileDescriptor>& waitFor, int64_t durationNs,
261         bool measureTiming, FencedExecutionResult* fencedExecutionResult) {
262     CHECK(execution != nullptr);
263     CHECK(fencedExecutionResult != nullptr);
264 
265     std::vector<const ANeuralNetworksEvent*> deps(waitFor.size());
266     auto createResult = Result::NO_ERROR;
267     std::transform(waitFor.begin(), waitFor.end(), deps.begin(),
268                    [&](const ::ndk::ScopedFileDescriptor& e) {
269                        ANeuralNetworksEvent* r = nullptr;
270                        if (createResult == Result::NO_ERROR) {
271                            createResult = static_cast<Result>(
272                                    nnapi->getFL5()->ANeuralNetworksEvent_createFromSyncFenceFd(
273                                            e.get(), &r));
274                        }
275                        return r;
276                    });
277 
278     const auto guard = ::android::base::make_scope_guard([nnapi, deps] {
279         for (auto& dep : deps) {
280             if (dep != nullptr) {
281                 nnapi->getFL5()->ANeuralNetworksEvent_free(const_cast<ANeuralNetworksEvent*>(dep));
282             }
283         }
284     });
285 
286     SLW2SAS_RETURN_IF_ERROR(createResult);
287 
288     Event e(nnapi.get());
289     auto result = execution->startComputeWithDependencies(deps, durationNs, &e);
290     SLW2SAS_RETURN_IF_ERROR(result);
291 
292     int syncFence = -1;
293     fencedExecutionResult->syncFence = ndk::ScopedFileDescriptor(
294             (e.getSyncFenceFd(&syncFence) == Result::NO_ERROR) ? syncFence : -1);
295     fencedExecutionResult->callback = ndk::SharedRefBase::make<ShimFencedExecutionCallback>(
296             execution, std::move(e), requestMemoryPools, measureTiming);
297 
298     return ndk::ScopedAStatus::ok();
299 }
300 
executeFencedCommon(const Request & request,const std::vector<::ndk::ScopedFileDescriptor> & waitFor,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,int64_t durationNs,const std::vector<TokenValuePair> & executionHints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix,FencedExecutionResult * fencedExecutionResult)301 ::ndk::ScopedAStatus ShimPreparedModel::executeFencedCommon(
302         const Request& request, const std::vector<::ndk::ScopedFileDescriptor>& waitFor,
303         bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
304         const std::vector<TokenValuePair>& executionHints,
305         const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
306         FencedExecutionResult* fencedExecutionResult) {
307     CHECK(fencedExecutionResult != nullptr);
308 
309     if (deadlineNs < -1) {
310         LOG(ERROR) << "Invalid deadline value, must be >= -1";
311         return ndk::ScopedAStatus::fromServiceSpecificError(
312                 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
313     }
314     auto execution =
315             std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
316     std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
317     auto errorStatus =
318             parseInputs(request, measureTiming, deadlineNs, loopTimeoutDurationNs, execution.get(),
319                         &requestMemoryPools, executionHints, extensionNameToPrefix);
320     if (errorStatus != ErrorStatus::NONE) {
321         return toAStatus(errorStatus);
322     }
323     return executeFencedInternal(mNnapi, execution, std::move(requestMemoryPools), waitFor,
324                                  durationNs, measureTiming, fencedExecutionResult);
325 }
326 
executeFenced(const::aidl::android::hardware::neuralnetworks::Request & request,const std::vector<::ndk::ScopedFileDescriptor> & waitFor,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,int64_t durationNs,FencedExecutionResult * fencedExecutionResult)327 ::ndk::ScopedAStatus ShimPreparedModel::executeFenced(
328         const ::aidl::android::hardware::neuralnetworks::Request& request,
329         const std::vector<::ndk::ScopedFileDescriptor>& waitFor, bool measureTiming,
330         int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
331         FencedExecutionResult* fencedExecutionResult) {
332     return executeFencedCommon(request, waitFor, measureTiming, deadlineNs, loopTimeoutDurationNs,
333                                durationNs, /*executionHints=*/{}, /*extensionNameToPrefix=*/{},
334                                fencedExecutionResult);
335 }
336 
executeSynchronouslyInternal(const std::shared_ptr<::android::nn::sl_wrapper::Execution> & execution,bool measureTiming,int numOutputs,ExecutionResult * executionResult)337 static ndk::ScopedAStatus executeSynchronouslyInternal(
338         const std::shared_ptr<::android::nn::sl_wrapper::Execution>& execution, bool measureTiming,
339         int numOutputs, ExecutionResult* executionResult) {
340     CHECK(execution != nullptr);
341     CHECK(executionResult != nullptr);
342 
343     auto result = execution->compute();
344     auto errorStatus = convertResultToErrorStatus(result);
345 
346     std::vector<OutputShape> outputShapes;
347     outputShapes.reserve(numOutputs);
348     bool sufficientSize = true;
349     for (int i = 0; i < numOutputs; ++i) {
350         OutputShape outputShape;
351         std::vector<uint32_t> outputDims;
352         auto result = execution->getOutputOperandDimensions(i, &outputDims);
353         if (result == Result::NO_ERROR) {
354             outputShape.isSufficient = true;
355             outputShape.dimensions.assign(outputDims.begin(), outputDims.end());
356         } else if (result == Result::OUTPUT_INSUFFICIENT_SIZE) {
357             sufficientSize = false;
358             outputShape.isSufficient = false;
359             outputShape.dimensions.assign(outputDims.begin(), outputDims.end());
360         } else {
361             if (errorStatus == ErrorStatus::NONE) {
362                 errorStatus = ErrorStatus::GENERAL_FAILURE;
363             }
364         }
365         outputShapes.push_back(std::move(outputShape));
366     }
367 
368     int64_t timeOnDeviceNs = -1;
369     int64_t timeInDriverNs = -1;
370     if (measureTiming && errorStatus == ErrorStatus::NONE) {
371         uint64_t duration;
372         constexpr int64_t int64cap = std::numeric_limits<int64_t>::max();
373         // Special value used for "no measurements"
374         constexpr uint64_t uint64cap = std::numeric_limits<uint64_t>::max();
375         auto result = execution->getDuration(Duration::ON_HARDWARE, &duration);
376         SLW2SAS_RETURN_IF_ERROR(result);
377         timeOnDeviceNs = (duration == uint64cap) ? -1
378                          : (duration > int64cap) ? int64cap
379                                                  : static_cast<int64_t>(duration);
380 
381         result = execution->getDuration(Duration::IN_DRIVER, &duration);
382         SLW2SAS_RETURN_IF_ERROR(result);
383         timeInDriverNs = (duration == uint64cap) ? -1
384                          : (duration > int64cap) ? int64cap
385                                                  : static_cast<int64_t>(duration);
386     }
387 
388     *executionResult =
389             ExecutionResult{sufficientSize,
390                             std::move(outputShapes),
391                             {.timeOnDeviceNs = timeOnDeviceNs, .timeInDriverNs = timeInDriverNs}};
392     if (errorStatus == ErrorStatus::NONE || errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
393         return ndk::ScopedAStatus::ok();
394     }
395     return toAStatus(errorStatus);
396 }
397 
executeSynchronouslyCommon(const Request & request,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,const std::vector<TokenValuePair> & executionHints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix,ExecutionResult * executionResult)398 ::ndk::ScopedAStatus ShimPreparedModel::executeSynchronouslyCommon(
399         const Request& request, bool measureTiming, int64_t deadlineNs,
400         int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& executionHints,
401         const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
402         ExecutionResult* executionResult) {
403     CHECK(executionResult != nullptr);
404 
405     if (deadlineNs < -1) {
406         LOG(ERROR) << "Invalid deadline value, must be >= -1";
407         return ndk::ScopedAStatus::fromServiceSpecificError(
408                 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
409     }
410 
411     auto execution =
412             std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
413     std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
414     auto errorStatus =
415             parseInputs(request, measureTiming, deadlineNs, loopTimeoutDurationNs, execution.get(),
416                         &requestMemoryPools, executionHints, extensionNameToPrefix);
417     if (errorStatus != ErrorStatus::NONE) {
418         return toAStatus(errorStatus);
419     }
420     return executeSynchronouslyInternal(execution, measureTiming, request.outputs.size(),
421                                         executionResult);
422 }
423 
executeSynchronously(const Request & request,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,::aidl::android::hardware::neuralnetworks::ExecutionResult * executionResult)424 ::ndk::ScopedAStatus ShimPreparedModel::executeSynchronously(
425         const Request& request, bool measureTiming, int64_t deadlineNs,
426         int64_t loopTimeoutDurationNs,
427         ::aidl::android::hardware::neuralnetworks::ExecutionResult* executionResult) {
428     return executeSynchronouslyCommon(request, measureTiming, deadlineNs, loopTimeoutDurationNs,
429                                       /*executionHints=*/{}, /*extensionNameToPrefix=*/{},
430                                       executionResult);
431 }
432 
executeSynchronouslyWithConfig(const Request & request,const ExecutionConfig & config,int64_t deadlineNs,ExecutionResult * executionResult)433 ::ndk::ScopedAStatus ShimPreparedModel::executeSynchronouslyWithConfig(
434         const Request& request, const ExecutionConfig& config, int64_t deadlineNs,
435         ExecutionResult* executionResult) {
436     return executeSynchronouslyCommon(request, config.measureTiming, deadlineNs,
437                                       config.loopTimeoutDurationNs, config.executionHints,
438                                       config.extensionNameToPrefix, executionResult);
439 }
440 
executeFencedWithConfig(const Request & request,const std::vector<ndk::ScopedFileDescriptor> & waitFor,const ExecutionConfig & config,int64_t deadlineNs,int64_t durationNs,FencedExecutionResult * executionResult)441 ::ndk::ScopedAStatus ShimPreparedModel::executeFencedWithConfig(
442         const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
443         const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
444         FencedExecutionResult* executionResult) {
445     return executeFencedCommon(request, waitFor, config.measureTiming, deadlineNs,
446                                config.loopTimeoutDurationNs, durationNs, config.executionHints,
447                                config.extensionNameToPrefix, executionResult);
448 }
449 
450 // TODO(183397380): make it use ANNBurst object
451 class ShimBurst : public BnBurst {
452    public:
453     // Precondition: preparedModel != nullptr
454     explicit ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel);
455 
456     ndk::ScopedAStatus executeSynchronously(const Request& request,
457                                             const std::vector<int64_t>& memoryIdentifierTokens,
458                                             bool measureTiming, int64_t deadlineNs,
459                                             int64_t loopTimeoutDurationNs,
460                                             ExecutionResult* executionResult) override;
461     ndk::ScopedAStatus executeSynchronouslyWithConfig(
462             const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
463             const ExecutionConfig& config, int64_t deadlineNs,
464             ExecutionResult* executionResult) override;
465     ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override;
466 
467    protected:
468     std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
469     const std::shared_ptr<ShimPreparedModel> kPreparedModel;
470 };
471 
configureExecutionBurst(std::shared_ptr<IBurst> * burst)472 ndk::ScopedAStatus ShimPreparedModel::configureExecutionBurst(std::shared_ptr<IBurst>* burst) {
473     std::shared_ptr<ShimPreparedModel> self = this->template ref<ShimPreparedModel>();
474     *burst = ndk::SharedRefBase::make<ShimBurst>(std::move(self));
475     return ndk::ScopedAStatus::ok();
476 }
477 
ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel)478 ShimBurst::ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel)
479     : kPreparedModel(std::move(preparedModel)) {
480     CHECK(kPreparedModel != nullptr);
481 }
482 
executeSynchronously(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,ExecutionResult * executionResult)483 ndk::ScopedAStatus ShimBurst::executeSynchronously(
484         const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
485         bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
486         ExecutionResult* executionResult) {
487     if (request.pools.size() != memoryIdentifierTokens.size()) {
488         return toAStatus(ErrorStatus::INVALID_ARGUMENT,
489                          "request.pools.size() != memoryIdentifierTokens.size()");
490     }
491     if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
492                      [](int64_t token) { return token >= -1; })) {
493         return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierTokens");
494     }
495 
496     // Ensure at most one execution is in flight at a time.
497     const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
498     if (executionAlreadyInFlight) {
499         return toAStatus(ErrorStatus::GENERAL_FAILURE,
500                          "Burst object supports at most one execution at a time");
501     }
502     const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
503 
504     return kPreparedModel->executeSynchronously(request, measureTiming, deadlineNs,
505                                                 loopTimeoutDurationNs, executionResult);
506 }
507 
executeSynchronouslyWithConfig(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,const ExecutionConfig & config,int64_t deadlineNs,ExecutionResult * executionResult)508 ndk::ScopedAStatus ShimBurst::executeSynchronouslyWithConfig(
509         const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
510         const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) {
511     if (request.pools.size() != memoryIdentifierTokens.size()) {
512         return toAStatus(ErrorStatus::INVALID_ARGUMENT,
513                          "request.pools.size() != memoryIdentifierTokens.size()");
514     }
515     if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
516                      [](int64_t token) { return token >= -1; })) {
517         return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierTokens");
518     }
519 
520     // Ensure at most one execution is in flight at a time.
521     const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
522     if (executionAlreadyInFlight) {
523         return toAStatus(ErrorStatus::GENERAL_FAILURE,
524                          "Burst object supports at most one execution at a time");
525     }
526     const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
527 
528     return kPreparedModel->executeSynchronouslyWithConfig(request, config, deadlineNs,
529                                                           executionResult);
530 }
531 
releaseMemoryResource(int64_t memoryIdentifierToken)532 ndk::ScopedAStatus ShimBurst::releaseMemoryResource(int64_t memoryIdentifierToken) {
533     if (memoryIdentifierToken < -1) {
534         return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierToken");
535     }
536     return ndk::ScopedAStatus::ok();
537 }
538 
539 class ShimExecution : public BnExecution {
540    public:
541     explicit ShimExecution(
542             std::shared_ptr<const NnApiSupportLibrary> nnapi,
543             std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,
544             std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
545             bool measureTiming, int numberOfOutputs);
546 
547     ndk::ScopedAStatus executeSynchronously(int64_t deadlineNs,
548                                             ExecutionResult* executionResult) override;
549     ndk::ScopedAStatus executeFenced(const std::vector<ndk::ScopedFileDescriptor>& waitFor,
550                                      int64_t deadlineNs, int64_t durationNs,
551                                      FencedExecutionResult* fencedExecutionResult) override;
552 
553    protected:
554     std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
555     std::shared_ptr<const NnApiSupportLibrary> mNnapi;
556     std::shared_ptr<::android::nn::sl_wrapper::Execution> mExecution;
557     const std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> kRequestMemoryPools;
558     const bool kMeasureTiming;
559     const int kNumberOfOutputs;
560 };
561 
createReusableExecution(const Request & request,const ExecutionConfig & config,std::shared_ptr<IExecution> * execution)562 ndk::ScopedAStatus ShimPreparedModel::createReusableExecution(
563         const Request& request, const ExecutionConfig& config,
564         std::shared_ptr<IExecution>* execution) {
565     auto wrapperExecution =
566             std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
567     std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
568     auto errorStatus =
569             parseInputs(request, config.measureTiming, kNoDeadline, config.loopTimeoutDurationNs,
570                         wrapperExecution.get(), &requestMemoryPools, config.executionHints,
571                         config.extensionNameToPrefix);
572     if (errorStatus != ErrorStatus::NONE) {
573         return toAStatus(errorStatus);
574     }
575     auto result = wrapperExecution->setReusable(true);
576     SLW2SAS_RETURN_IF_ERROR(result);
577 
578     *execution = ndk::SharedRefBase::make<ShimExecution>(
579             mNnapi, std::move(wrapperExecution), std::move(requestMemoryPools),
580             config.measureTiming, request.outputs.size());
581     return ndk::ScopedAStatus::ok();
582 }
583 
ShimExecution(std::shared_ptr<const NnApiSupportLibrary> nnapi,std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,bool measureTiming,int numberOfOutputs)584 ShimExecution::ShimExecution(
585         std::shared_ptr<const NnApiSupportLibrary> nnapi,
586         std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,
587         std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
588         bool measureTiming, int numberOfOutputs)
589     : mNnapi(std::move(nnapi)),
590       mExecution(std::move(execution)),
591       kRequestMemoryPools(std::move(requestMemoryPools)),
592       kMeasureTiming(measureTiming),
593       kNumberOfOutputs(numberOfOutputs) {}
594 
executeSynchronously(int64_t deadlineNs,ExecutionResult * executionResult)595 ndk::ScopedAStatus ShimExecution::executeSynchronously(int64_t deadlineNs,
596                                                        ExecutionResult* executionResult) {
597     if (deadlineNs < -1) {
598         LOG(ERROR) << "Invalid deadline value, must be >= -1";
599         return ndk::ScopedAStatus::fromServiceSpecificError(
600                 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
601     }
602 
603     // Ensure at most one execution is in flight at a time.
604     const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
605     if (executionAlreadyInFlight) {
606         return toAStatus(ErrorStatus::GENERAL_FAILURE,
607                          "Execution object supports at most one execution at a time");
608     }
609     const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
610 
611     return executeSynchronouslyInternal(mExecution, kMeasureTiming, kNumberOfOutputs,
612                                         executionResult);
613 }
614 
executeFenced(const std::vector<ndk::ScopedFileDescriptor> & waitFor,int64_t deadlineNs,int64_t durationNs,FencedExecutionResult * fencedExecutionResult)615 ndk::ScopedAStatus ShimExecution::executeFenced(
616         const std::vector<ndk::ScopedFileDescriptor>& waitFor, int64_t deadlineNs,
617         int64_t durationNs, FencedExecutionResult* fencedExecutionResult) {
618     if (deadlineNs < -1) {
619         LOG(ERROR) << "Invalid deadline value, must be >= -1";
620         return ndk::ScopedAStatus::fromServiceSpecificError(
621                 static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
622     }
623 
624     // Ensure at most one execution is in flight at a time.
625     const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
626     if (executionAlreadyInFlight) {
627         return toAStatus(ErrorStatus::GENERAL_FAILURE,
628                          "Execution object supports at most one execution at a time");
629     }
630     const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
631 
632     return executeFencedInternal(mNnapi, mExecution, kRequestMemoryPools, waitFor, durationNs,
633                                  kMeasureTiming, fencedExecutionResult);
634 }
635 
636 }  // namespace aidl::android::hardware::neuralnetworks
637