1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Burst.h"
18
19 #include <android-base/logging.h>
20 #include <android-base/thread_annotations.h>
21 #include <android/binder_auto_utils.h>
22 #include <nnapi/IBurst.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/Types.h>
25 #include <nnapi/Validation.h>
26 #include <nnapi/hal/aidl/Conversions.h>
27 #include <nnapi/hal/aidl/Utils.h>
28
29 #include <algorithm>
30 #include <chrono>
31 #include <memory>
32 #include <mutex>
33 #include <unordered_map>
34 #include <utility>
35 #include <variant>
36
37 namespace aidl::android::hardware::neuralnetworks::adapter {
38 namespace {
39
40 using Value = Burst::ThreadSafeMemoryCache::Value;
41
42 template <typename Type>
convertInput(const Type & object)43 auto convertInput(const Type& object) -> decltype(nn::convert(std::declval<Type>())) {
44 auto result = nn::convert(object);
45 if (!result.has_value()) {
46 result.error().code = nn::ErrorStatus::INVALID_ARGUMENT;
47 }
48 return result;
49 }
50
makeDuration(int64_t durationNs)51 nn::Duration makeDuration(int64_t durationNs) {
52 return nn::Duration(std::chrono::nanoseconds(durationNs));
53 }
54
makeOptionalDuration(int64_t durationNs)55 nn::GeneralResult<nn::OptionalDuration> makeOptionalDuration(int64_t durationNs) {
56 if (durationNs < -1) {
57 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid duration " << durationNs;
58 }
59 return durationNs < 0 ? nn::OptionalDuration{} : makeDuration(durationNs);
60 }
61
makeOptionalTimePoint(int64_t durationNs)62 nn::GeneralResult<nn::OptionalTimePoint> makeOptionalTimePoint(int64_t durationNs) {
63 if (durationNs < -1) {
64 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid time point " << durationNs;
65 }
66 return durationNs < 0 ? nn::OptionalTimePoint{} : nn::TimePoint(makeDuration(durationNs));
67 }
68
ensureAllMemoriesAreCached(nn::Request * request,const std::vector<int64_t> & memoryIdentifierTokens,const nn::IBurst & burst,const Burst::ThreadSafeMemoryCache & cache)69 std::vector<nn::IBurst::OptionalCacheHold> ensureAllMemoriesAreCached(
70 nn::Request* request, const std::vector<int64_t>& memoryIdentifierTokens,
71 const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache) {
72 std::vector<nn::IBurst::OptionalCacheHold> holds;
73 holds.reserve(memoryIdentifierTokens.size());
74
75 for (size_t i = 0; i < memoryIdentifierTokens.size(); ++i) {
76 const auto& pool = request->pools[i];
77 const auto token = memoryIdentifierTokens[i];
78 constexpr int64_t kNoToken = -1;
79 if (token == kNoToken || !std::holds_alternative<nn::SharedMemory>(pool)) {
80 continue;
81 }
82
83 const auto& memory = std::get<nn::SharedMemory>(pool);
84 auto [storedMemory, hold] = cache.add(token, memory, burst);
85
86 request->pools[i] = std::move(storedMemory);
87 holds.push_back(std::move(hold));
88 }
89
90 return holds;
91 }
92
executeSynchronously(const nn::IBurst & burst,const Burst::ThreadSafeMemoryCache & cache,const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,const std::vector<TokenValuePair> & hints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix)93 nn::ExecutionResult<ExecutionResult> executeSynchronously(
94 const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache, const Request& request,
95 const std::vector<int64_t>& memoryIdentifierTokens, bool measureTiming, int64_t deadlineNs,
96 int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
97 const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
98 if (request.pools.size() != memoryIdentifierTokens.size()) {
99 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
100 << "request.pools.size() != memoryIdentifierTokens.size()";
101 }
102 if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
103 [](int64_t token) { return token >= -1; })) {
104 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid memoryIdentifierTokens";
105 }
106
107 auto nnRequest = NN_TRY(convertInput(request));
108 const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
109 const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
110 const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
111 auto nnHints = NN_TRY(convertInput(hints));
112 auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
113
114 const auto hold = ensureAllMemoriesAreCached(&nnRequest, memoryIdentifierTokens, burst, cache);
115
116 const auto result = burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration,
117 nnHints, nnExtensionNameToPrefix);
118
119 if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
120 const auto& [message, code, outputShapes] = result.error();
121 return ExecutionResult{.outputSufficientSize = false,
122 .outputShapes = utils::convert(outputShapes).value(),
123 .timing = {.timeInDriverNs = -1, .timeOnDeviceNs = -1}};
124 }
125
126 const auto& [outputShapes, timing] = NN_TRY(result);
127 return ExecutionResult{.outputSufficientSize = true,
128 .outputShapes = utils::convert(outputShapes).value(),
129 .timing = utils::convert(timing).value()};
130 }
131
132 } // namespace
133
add(int64_t token,const nn::SharedMemory & memory,const nn::IBurst & burst) const134 Value Burst::ThreadSafeMemoryCache::add(int64_t token, const nn::SharedMemory& memory,
135 const nn::IBurst& burst) const {
136 std::lock_guard guard(mMutex);
137 if (const auto it = mCache.find(token); it != mCache.end()) {
138 return it->second;
139 }
140 auto hold = burst.cacheMemory(memory);
141 auto [it, _] = mCache.emplace(token, std::make_pair(memory, std::move(hold)));
142 return it->second;
143 }
144
remove(int64_t token) const145 void Burst::ThreadSafeMemoryCache::remove(int64_t token) const {
146 std::lock_guard guard(mMutex);
147 mCache.erase(token);
148 }
149
Burst(nn::SharedBurst burst)150 Burst::Burst(nn::SharedBurst burst) : kBurst(std::move(burst)) {
151 CHECK(kBurst != nullptr);
152 }
153
executeSynchronously(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,ExecutionResult * executionResult)154 ndk::ScopedAStatus Burst::executeSynchronously(const Request& request,
155 const std::vector<int64_t>& memoryIdentifierTokens,
156 bool measureTiming, int64_t deadlineNs,
157 int64_t loopTimeoutDurationNs,
158 ExecutionResult* executionResult) {
159 auto result =
160 adapter::executeSynchronously(*kBurst, kMemoryCache, request, memoryIdentifierTokens,
161 measureTiming, deadlineNs, loopTimeoutDurationNs, {}, {});
162 if (!result.has_value()) {
163 auto [message, code, _] = std::move(result).error();
164 const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
165 return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
166 static_cast<int32_t>(aidlCode), message.c_str());
167 }
168 *executionResult = std::move(result).value();
169 return ndk::ScopedAStatus::ok();
170 }
171
executeSynchronouslyWithConfig(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,const ExecutionConfig & config,int64_t deadlineNs,ExecutionResult * executionResult)172 ndk::ScopedAStatus Burst::executeSynchronouslyWithConfig(
173 const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
174 const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) {
175 auto result = adapter::executeSynchronously(
176 *kBurst, kMemoryCache, request, memoryIdentifierTokens, config.measureTiming,
177 deadlineNs, config.loopTimeoutDurationNs, config.executionHints,
178 config.extensionNameToPrefix);
179 if (!result.has_value()) {
180 auto [message, code, _] = std::move(result).error();
181 const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
182 return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
183 static_cast<int32_t>(aidlCode), message.c_str());
184 }
185 *executionResult = std::move(result).value();
186 return ndk::ScopedAStatus::ok();
187 }
188
releaseMemoryResource(int64_t memoryIdentifierToken)189 ndk::ScopedAStatus Burst::releaseMemoryResource(int64_t memoryIdentifierToken) {
190 if (memoryIdentifierToken < -1) {
191 return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
192 static_cast<int32_t>(ErrorStatus::INVALID_ARGUMENT),
193 "Invalid memoryIdentifierToken");
194 }
195 kMemoryCache.remove(memoryIdentifierToken);
196 return ndk::ScopedAStatus::ok();
197 }
198
199 } // namespace aidl::android::hardware::neuralnetworks::adapter
200