1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Burst.h"
18 
19 #include <android-base/logging.h>
20 #include <android-base/thread_annotations.h>
21 #include <android/binder_auto_utils.h>
22 #include <nnapi/IBurst.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/Types.h>
25 #include <nnapi/Validation.h>
26 #include <nnapi/hal/aidl/Conversions.h>
27 #include <nnapi/hal/aidl/Utils.h>
28 
29 #include <algorithm>
30 #include <chrono>
31 #include <memory>
32 #include <mutex>
33 #include <unordered_map>
34 #include <utility>
35 #include <variant>
36 
37 namespace aidl::android::hardware::neuralnetworks::adapter {
38 namespace {
39 
40 using Value = Burst::ThreadSafeMemoryCache::Value;
41 
42 template <typename Type>
convertInput(const Type & object)43 auto convertInput(const Type& object) -> decltype(nn::convert(std::declval<Type>())) {
44     auto result = nn::convert(object);
45     if (!result.has_value()) {
46         result.error().code = nn::ErrorStatus::INVALID_ARGUMENT;
47     }
48     return result;
49 }
50 
makeDuration(int64_t durationNs)51 nn::Duration makeDuration(int64_t durationNs) {
52     return nn::Duration(std::chrono::nanoseconds(durationNs));
53 }
54 
makeOptionalDuration(int64_t durationNs)55 nn::GeneralResult<nn::OptionalDuration> makeOptionalDuration(int64_t durationNs) {
56     if (durationNs < -1) {
57         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid duration " << durationNs;
58     }
59     return durationNs < 0 ? nn::OptionalDuration{} : makeDuration(durationNs);
60 }
61 
makeOptionalTimePoint(int64_t durationNs)62 nn::GeneralResult<nn::OptionalTimePoint> makeOptionalTimePoint(int64_t durationNs) {
63     if (durationNs < -1) {
64         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid time point " << durationNs;
65     }
66     return durationNs < 0 ? nn::OptionalTimePoint{} : nn::TimePoint(makeDuration(durationNs));
67 }
68 
ensureAllMemoriesAreCached(nn::Request * request,const std::vector<int64_t> & memoryIdentifierTokens,const nn::IBurst & burst,const Burst::ThreadSafeMemoryCache & cache)69 std::vector<nn::IBurst::OptionalCacheHold> ensureAllMemoriesAreCached(
70         nn::Request* request, const std::vector<int64_t>& memoryIdentifierTokens,
71         const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache) {
72     std::vector<nn::IBurst::OptionalCacheHold> holds;
73     holds.reserve(memoryIdentifierTokens.size());
74 
75     for (size_t i = 0; i < memoryIdentifierTokens.size(); ++i) {
76         const auto& pool = request->pools[i];
77         const auto token = memoryIdentifierTokens[i];
78         constexpr int64_t kNoToken = -1;
79         if (token == kNoToken || !std::holds_alternative<nn::SharedMemory>(pool)) {
80             continue;
81         }
82 
83         const auto& memory = std::get<nn::SharedMemory>(pool);
84         auto [storedMemory, hold] = cache.add(token, memory, burst);
85 
86         request->pools[i] = std::move(storedMemory);
87         holds.push_back(std::move(hold));
88     }
89 
90     return holds;
91 }
92 
executeSynchronously(const nn::IBurst & burst,const Burst::ThreadSafeMemoryCache & cache,const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,const std::vector<TokenValuePair> & hints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix)93 nn::ExecutionResult<ExecutionResult> executeSynchronously(
94         const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache, const Request& request,
95         const std::vector<int64_t>& memoryIdentifierTokens, bool measureTiming, int64_t deadlineNs,
96         int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
97         const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
98     if (request.pools.size() != memoryIdentifierTokens.size()) {
99         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
100                << "request.pools.size() != memoryIdentifierTokens.size()";
101     }
102     if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
103                      [](int64_t token) { return token >= -1; })) {
104         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid memoryIdentifierTokens";
105     }
106 
107     auto nnRequest = NN_TRY(convertInput(request));
108     const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
109     const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
110     const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
111     auto nnHints = NN_TRY(convertInput(hints));
112     auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
113 
114     const auto hold = ensureAllMemoriesAreCached(&nnRequest, memoryIdentifierTokens, burst, cache);
115 
116     const auto result = burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration,
117                                       nnHints, nnExtensionNameToPrefix);
118 
119     if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
120         const auto& [message, code, outputShapes] = result.error();
121         return ExecutionResult{.outputSufficientSize = false,
122                                .outputShapes = utils::convert(outputShapes).value(),
123                                .timing = {.timeInDriverNs = -1, .timeOnDeviceNs = -1}};
124     }
125 
126     const auto& [outputShapes, timing] = NN_TRY(result);
127     return ExecutionResult{.outputSufficientSize = true,
128                            .outputShapes = utils::convert(outputShapes).value(),
129                            .timing = utils::convert(timing).value()};
130 }
131 
132 }  // namespace
133 
add(int64_t token,const nn::SharedMemory & memory,const nn::IBurst & burst) const134 Value Burst::ThreadSafeMemoryCache::add(int64_t token, const nn::SharedMemory& memory,
135                                         const nn::IBurst& burst) const {
136     std::lock_guard guard(mMutex);
137     if (const auto it = mCache.find(token); it != mCache.end()) {
138         return it->second;
139     }
140     auto hold = burst.cacheMemory(memory);
141     auto [it, _] = mCache.emplace(token, std::make_pair(memory, std::move(hold)));
142     return it->second;
143 }
144 
remove(int64_t token) const145 void Burst::ThreadSafeMemoryCache::remove(int64_t token) const {
146     std::lock_guard guard(mMutex);
147     mCache.erase(token);
148 }
149 
Burst(nn::SharedBurst burst)150 Burst::Burst(nn::SharedBurst burst) : kBurst(std::move(burst)) {
151     CHECK(kBurst != nullptr);
152 }
153 
executeSynchronously(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,ExecutionResult * executionResult)154 ndk::ScopedAStatus Burst::executeSynchronously(const Request& request,
155                                                const std::vector<int64_t>& memoryIdentifierTokens,
156                                                bool measureTiming, int64_t deadlineNs,
157                                                int64_t loopTimeoutDurationNs,
158                                                ExecutionResult* executionResult) {
159     auto result =
160             adapter::executeSynchronously(*kBurst, kMemoryCache, request, memoryIdentifierTokens,
161                                           measureTiming, deadlineNs, loopTimeoutDurationNs, {}, {});
162     if (!result.has_value()) {
163         auto [message, code, _] = std::move(result).error();
164         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
165         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
166                 static_cast<int32_t>(aidlCode), message.c_str());
167     }
168     *executionResult = std::move(result).value();
169     return ndk::ScopedAStatus::ok();
170 }
171 
executeSynchronouslyWithConfig(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,const ExecutionConfig & config,int64_t deadlineNs,ExecutionResult * executionResult)172 ndk::ScopedAStatus Burst::executeSynchronouslyWithConfig(
173         const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
174         const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) {
175     auto result = adapter::executeSynchronously(
176             *kBurst, kMemoryCache, request, memoryIdentifierTokens, config.measureTiming,
177             deadlineNs, config.loopTimeoutDurationNs, config.executionHints,
178             config.extensionNameToPrefix);
179     if (!result.has_value()) {
180         auto [message, code, _] = std::move(result).error();
181         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
182         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
183                 static_cast<int32_t>(aidlCode), message.c_str());
184     }
185     *executionResult = std::move(result).value();
186     return ndk::ScopedAStatus::ok();
187 }
188 
releaseMemoryResource(int64_t memoryIdentifierToken)189 ndk::ScopedAStatus Burst::releaseMemoryResource(int64_t memoryIdentifierToken) {
190     if (memoryIdentifierToken < -1) {
191         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
192                 static_cast<int32_t>(ErrorStatus::INVALID_ARGUMENT),
193                 "Invalid memoryIdentifierToken");
194     }
195     kMemoryCache.remove(memoryIdentifierToken);
196     return ndk::ScopedAStatus::ok();
197 }
198 
199 }  // namespace aidl::android::hardware::neuralnetworks::adapter
200