1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "SharedMemory.h"
18
19 #include <android-base/logging.h>
20
21 #include <algorithm>
22 #include <limits>
23 #include <optional>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27
28 #include "Result.h"
29 #include "TypeUtils.h"
30 #include "Types.h"
31
32 namespace android::nn {
33 namespace {
34
35 bool hasNoPointerData(const Operand& operand);
36 bool hasNoPointerData(const Model::Subgraph& subgraph);
37 bool hasNoPointerData(const Request::Argument& argument);
38
39 template <typename Type>
hasNoPointerData(const std::vector<Type> & objects)40 bool hasNoPointerData(const std::vector<Type>& objects) {
41 return std::all_of(objects.begin(), objects.end(),
42 [](const auto& object) { return hasNoPointerData(object); });
43 }
44
hasNoPointerData(const DataLocation & location)45 bool hasNoPointerData(const DataLocation& location) {
46 return std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer);
47 }
48
hasNoPointerData(const Operand & operand)49 bool hasNoPointerData(const Operand& operand) {
50 return hasNoPointerData(operand.location);
51 }
52
hasNoPointerData(const Model::Subgraph & subgraph)53 bool hasNoPointerData(const Model::Subgraph& subgraph) {
54 return hasNoPointerData(subgraph.operands);
55 }
56
hasNoPointerData(const Request::Argument & argument)57 bool hasNoPointerData(const Request::Argument& argument) {
58 return hasNoPointerData(argument.location);
59 }
60
copyPointersToSharedMemory(Operand * operand,ConstantMemoryBuilder * memoryBuilder)61 void copyPointersToSharedMemory(Operand* operand, ConstantMemoryBuilder* memoryBuilder) {
62 CHECK(operand != nullptr);
63 CHECK(memoryBuilder != nullptr);
64
65 if (operand->lifetime != Operand::LifeTime::POINTER) {
66 return;
67 }
68
69 const void* data = std::visit([](auto ptr) { return static_cast<const void*>(ptr); },
70 operand->location.pointer);
71 CHECK(data != nullptr);
72 operand->lifetime = Operand::LifeTime::CONSTANT_REFERENCE;
73 operand->location = memoryBuilder->append(data, operand->location.length);
74 }
75
copyPointersToSharedMemory(Model::Subgraph * subgraph,ConstantMemoryBuilder * memoryBuilder)76 void copyPointersToSharedMemory(Model::Subgraph* subgraph, ConstantMemoryBuilder* memoryBuilder) {
77 CHECK(subgraph != nullptr);
78 std::for_each(subgraph->operands.begin(), subgraph->operands.end(),
79 [memoryBuilder](auto& operand) {
80 copyPointersToSharedMemory(&operand, memoryBuilder);
81 });
82 }
83
84 } // anonymous namespace
85
MutableMemoryBuilder(uint32_t poolIndex)86 MutableMemoryBuilder::MutableMemoryBuilder(uint32_t poolIndex) : mPoolIndex(poolIndex) {}
87
append(size_t length,size_t alignment,size_t padding)88 DataLocation MutableMemoryBuilder::append(size_t length, size_t alignment, size_t padding) {
89 CHECK_GT(length, 0u);
90 mSize = roundUp(mSize, alignment);
91 const size_t offset = mSize;
92 const size_t paddedLength = roundUp(length, padding);
93 CHECK_LE(offset, std::numeric_limits<uint32_t>::max());
94 CHECK_LE(paddedLength, std::numeric_limits<uint32_t>::max());
95 mSize += paddedLength;
96 return {.poolIndex = mPoolIndex,
97 .offset = static_cast<uint32_t>(offset),
98 .length = static_cast<uint32_t>(length),
99 .padding = static_cast<uint32_t>(paddedLength - length)};
100 }
101
empty() const102 bool MutableMemoryBuilder::empty() const {
103 return mSize == 0;
104 }
105
finish()106 GeneralResult<SharedMemory> MutableMemoryBuilder::finish() {
107 return createSharedMemory(mSize);
108 }
109
ConstantMemoryBuilder(uint32_t poolIndex)110 ConstantMemoryBuilder::ConstantMemoryBuilder(uint32_t poolIndex) : mBuilder(poolIndex) {}
111
append(const void * data,size_t length)112 DataLocation ConstantMemoryBuilder::append(const void* data, size_t length) {
113 const auto location = mBuilder.append(length);
114 CHECK_EQ(location.length, length);
115 mSlices.push_back({.data = data, .length = length, .offset = location.offset});
116 return location;
117 }
118
empty() const119 bool ConstantMemoryBuilder::empty() const {
120 return mBuilder.empty();
121 }
122
finish()123 GeneralResult<SharedMemory> ConstantMemoryBuilder::finish() {
124 // Allocate the memory.
125 auto memory = NN_TRY(mBuilder.finish());
126
127 // Map the memory.
128 const auto [pointer, size, context] = NN_TRY(map(memory););
129
130 // Get mutable pointer.
131 uint8_t* mutablePointer = static_cast<uint8_t*>(std::get<void*>(pointer));
132
133 // Copy data to the memory pool.
134 std::for_each(mSlices.begin(), mSlices.end(), [mutablePointer](const auto& slice) {
135 std::memcpy(mutablePointer + slice.offset, slice.data, slice.length);
136 });
137
138 return memory;
139 }
140
hasNoPointerData(const Model & model)141 bool hasNoPointerData(const Model& model) {
142 return hasNoPointerData(model.main) && hasNoPointerData(model.referenced);
143 }
144
hasNoPointerData(const Request & request)145 bool hasNoPointerData(const Request& request) {
146 return hasNoPointerData(request.inputs) && hasNoPointerData(request.outputs);
147 }
148
flushDataFromPointerToShared(const Model * model,std::optional<Model> * maybeModelInSharedOut)149 GeneralResult<std::reference_wrapper<const Model>> flushDataFromPointerToShared(
150 const Model* model, std::optional<Model>* maybeModelInSharedOut) {
151 CHECK(model != nullptr);
152 CHECK(maybeModelInSharedOut != nullptr);
153
154 if (hasNoPointerData(*model)) {
155 return *model;
156 }
157
158 // Make a copy of the model in order to make modifications. The modified model is returned to
159 // the caller through `maybeModelInSharedOut` if the function succeeds.
160 Model modelInShared = *model;
161
162 ConstantMemoryBuilder memoryBuilder(modelInShared.pools.size());
163 copyPointersToSharedMemory(&modelInShared.main, &memoryBuilder);
164 std::for_each(modelInShared.referenced.begin(), modelInShared.referenced.end(),
165 [&memoryBuilder](auto& subgraph) {
166 copyPointersToSharedMemory(&subgraph, &memoryBuilder);
167 });
168
169 if (!memoryBuilder.empty()) {
170 auto memory = NN_TRY(memoryBuilder.finish());
171 modelInShared.pools.push_back(std::move(memory));
172 }
173
174 *maybeModelInSharedOut = modelInShared;
175 return **maybeModelInSharedOut;
176 }
177
178 template <>
flush() const179 void InputRelocationTracker::flush() const {
180 // Copy from pointers to shared memory.
181 uint8_t* memoryPtr = static_cast<uint8_t*>(std::get<void*>(kMapping.pointer));
182 for (const auto& [data, length, offset] : kRelocationInfos) {
183 std::memcpy(memoryPtr + offset, data, length);
184 }
185 }
186
187 template <>
flush() const188 void OutputRelocationTracker::flush() const {
189 // Copy from shared memory to pointers.
190 const uint8_t* memoryPtr = static_cast<const uint8_t*>(
191 std::visit([](auto ptr) { return static_cast<const void*>(ptr); }, kMapping.pointer));
192 for (const auto& [data, length, offset] : kRelocationInfos) {
193 std::memcpy(data, memoryPtr + offset, length);
194 }
195 }
196
convertRequestFromPointerToShared(const Request * request,uint32_t alignment,uint32_t padding,std::optional<Request> * maybeRequestInSharedOut,RequestRelocation * relocationOut)197 GeneralResult<std::reference_wrapper<const Request>> convertRequestFromPointerToShared(
198 const Request* request, uint32_t alignment, uint32_t padding,
199 std::optional<Request>* maybeRequestInSharedOut, RequestRelocation* relocationOut) {
200 CHECK(request != nullptr);
201 CHECK(maybeRequestInSharedOut != nullptr);
202 CHECK(relocationOut != nullptr);
203
204 if (hasNoPointerData(*request)) {
205 return *request;
206 }
207
208 // Make a copy of the request in order to make modifications. The modified request is returned
209 // to the caller through `maybeRequestInSharedOut` if the function succeeds.
210 Request requestInShared = *request;
211
212 RequestRelocation relocation;
213
214 // Change input pointers to shared memory.
215 MutableMemoryBuilder inputBuilder(requestInShared.pools.size());
216 std::vector<InputRelocationInfo> inputRelocationInfos;
217 for (auto& input : requestInShared.inputs) {
218 const auto& location = input.location;
219 if (input.lifetime != Request::Argument::LifeTime::POINTER) {
220 continue;
221 }
222
223 input.lifetime = Request::Argument::LifeTime::POOL;
224 const void* data = std::visit([](auto ptr) { return static_cast<const void*>(ptr); },
225 location.pointer);
226 CHECK(data != nullptr);
227 input.location = inputBuilder.append(location.length, alignment, padding);
228 inputRelocationInfos.push_back({data, input.location.length, input.location.offset});
229 }
230
231 // Allocate input memory.
232 if (!inputBuilder.empty()) {
233 auto memory = NN_TRY(inputBuilder.finish());
234 requestInShared.pools.push_back(memory);
235 relocation.input = NN_TRY(
236 InputRelocationTracker::create(std::move(inputRelocationInfos), std::move(memory)));
237 }
238
239 // Change output pointers to shared memory.
240 MutableMemoryBuilder outputBuilder(requestInShared.pools.size());
241 std::vector<OutputRelocationInfo> outputRelocationInfos;
242 for (auto& output : requestInShared.outputs) {
243 const auto& location = output.location;
244 if (output.lifetime != Request::Argument::LifeTime::POINTER) {
245 continue;
246 }
247
248 output.lifetime = Request::Argument::LifeTime::POOL;
249 void* data = std::get<void*>(location.pointer);
250 CHECK(data != nullptr);
251 output.location = outputBuilder.append(location.length, alignment, padding);
252 outputRelocationInfos.push_back({data, output.location.length, output.location.offset});
253 }
254
255 // Allocate output memory.
256 if (!outputBuilder.empty()) {
257 auto memory = NN_TRY(outputBuilder.finish());
258 requestInShared.pools.push_back(memory);
259 relocation.output = NN_TRY(OutputRelocationTracker::create(std::move(outputRelocationInfos),
260 std::move(memory)));
261 }
262
263 *maybeRequestInSharedOut = requestInShared;
264 *relocationOut = std::move(relocation);
265 return **maybeRequestInSharedOut;
266 }
267
268 } // namespace android::nn
269