1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "SharedMemory.h"
18 
19 #include <android-base/logging.h>
20 
21 #include <algorithm>
22 #include <limits>
23 #include <optional>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27 
28 #include "Result.h"
29 #include "TypeUtils.h"
30 #include "Types.h"
31 
32 namespace android::nn {
33 namespace {
34 
35 bool hasNoPointerData(const Operand& operand);
36 bool hasNoPointerData(const Model::Subgraph& subgraph);
37 bool hasNoPointerData(const Request::Argument& argument);
38 
39 template <typename Type>
hasNoPointerData(const std::vector<Type> & objects)40 bool hasNoPointerData(const std::vector<Type>& objects) {
41     return std::all_of(objects.begin(), objects.end(),
42                        [](const auto& object) { return hasNoPointerData(object); });
43 }
44 
hasNoPointerData(const DataLocation & location)45 bool hasNoPointerData(const DataLocation& location) {
46     return std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer);
47 }
48 
hasNoPointerData(const Operand & operand)49 bool hasNoPointerData(const Operand& operand) {
50     return hasNoPointerData(operand.location);
51 }
52 
hasNoPointerData(const Model::Subgraph & subgraph)53 bool hasNoPointerData(const Model::Subgraph& subgraph) {
54     return hasNoPointerData(subgraph.operands);
55 }
56 
hasNoPointerData(const Request::Argument & argument)57 bool hasNoPointerData(const Request::Argument& argument) {
58     return hasNoPointerData(argument.location);
59 }
60 
copyPointersToSharedMemory(Operand * operand,ConstantMemoryBuilder * memoryBuilder)61 void copyPointersToSharedMemory(Operand* operand, ConstantMemoryBuilder* memoryBuilder) {
62     CHECK(operand != nullptr);
63     CHECK(memoryBuilder != nullptr);
64 
65     if (operand->lifetime != Operand::LifeTime::POINTER) {
66         return;
67     }
68 
69     const void* data = std::visit([](auto ptr) { return static_cast<const void*>(ptr); },
70                                   operand->location.pointer);
71     CHECK(data != nullptr);
72     operand->lifetime = Operand::LifeTime::CONSTANT_REFERENCE;
73     operand->location = memoryBuilder->append(data, operand->location.length);
74 }
75 
copyPointersToSharedMemory(Model::Subgraph * subgraph,ConstantMemoryBuilder * memoryBuilder)76 void copyPointersToSharedMemory(Model::Subgraph* subgraph, ConstantMemoryBuilder* memoryBuilder) {
77     CHECK(subgraph != nullptr);
78     std::for_each(subgraph->operands.begin(), subgraph->operands.end(),
79                   [memoryBuilder](auto& operand) {
80                       copyPointersToSharedMemory(&operand, memoryBuilder);
81                   });
82 }
83 
84 }  // anonymous namespace
85 
MutableMemoryBuilder(uint32_t poolIndex)86 MutableMemoryBuilder::MutableMemoryBuilder(uint32_t poolIndex) : mPoolIndex(poolIndex) {}
87 
append(size_t length,size_t alignment,size_t padding)88 DataLocation MutableMemoryBuilder::append(size_t length, size_t alignment, size_t padding) {
89     CHECK_GT(length, 0u);
90     mSize = roundUp(mSize, alignment);
91     const size_t offset = mSize;
92     const size_t paddedLength = roundUp(length, padding);
93     CHECK_LE(offset, std::numeric_limits<uint32_t>::max());
94     CHECK_LE(paddedLength, std::numeric_limits<uint32_t>::max());
95     mSize += paddedLength;
96     return {.poolIndex = mPoolIndex,
97             .offset = static_cast<uint32_t>(offset),
98             .length = static_cast<uint32_t>(length),
99             .padding = static_cast<uint32_t>(paddedLength - length)};
100 }
101 
empty() const102 bool MutableMemoryBuilder::empty() const {
103     return mSize == 0;
104 }
105 
finish()106 GeneralResult<SharedMemory> MutableMemoryBuilder::finish() {
107     return createSharedMemory(mSize);
108 }
109 
ConstantMemoryBuilder(uint32_t poolIndex)110 ConstantMemoryBuilder::ConstantMemoryBuilder(uint32_t poolIndex) : mBuilder(poolIndex) {}
111 
append(const void * data,size_t length)112 DataLocation ConstantMemoryBuilder::append(const void* data, size_t length) {
113     const auto location = mBuilder.append(length);
114     CHECK_EQ(location.length, length);
115     mSlices.push_back({.data = data, .length = length, .offset = location.offset});
116     return location;
117 }
118 
empty() const119 bool ConstantMemoryBuilder::empty() const {
120     return mBuilder.empty();
121 }
122 
finish()123 GeneralResult<SharedMemory> ConstantMemoryBuilder::finish() {
124     // Allocate the memory.
125     auto memory = NN_TRY(mBuilder.finish());
126 
127     // Map the memory.
128     const auto [pointer, size, context] = NN_TRY(map(memory););
129 
130     // Get mutable pointer.
131     uint8_t* mutablePointer = static_cast<uint8_t*>(std::get<void*>(pointer));
132 
133     // Copy data to the memory pool.
134     std::for_each(mSlices.begin(), mSlices.end(), [mutablePointer](const auto& slice) {
135         std::memcpy(mutablePointer + slice.offset, slice.data, slice.length);
136     });
137 
138     return memory;
139 }
140 
hasNoPointerData(const Model & model)141 bool hasNoPointerData(const Model& model) {
142     return hasNoPointerData(model.main) && hasNoPointerData(model.referenced);
143 }
144 
hasNoPointerData(const Request & request)145 bool hasNoPointerData(const Request& request) {
146     return hasNoPointerData(request.inputs) && hasNoPointerData(request.outputs);
147 }
148 
flushDataFromPointerToShared(const Model * model,std::optional<Model> * maybeModelInSharedOut)149 GeneralResult<std::reference_wrapper<const Model>> flushDataFromPointerToShared(
150         const Model* model, std::optional<Model>* maybeModelInSharedOut) {
151     CHECK(model != nullptr);
152     CHECK(maybeModelInSharedOut != nullptr);
153 
154     if (hasNoPointerData(*model)) {
155         return *model;
156     }
157 
158     // Make a copy of the model in order to make modifications. The modified model is returned to
159     // the caller through `maybeModelInSharedOut` if the function succeeds.
160     Model modelInShared = *model;
161 
162     ConstantMemoryBuilder memoryBuilder(modelInShared.pools.size());
163     copyPointersToSharedMemory(&modelInShared.main, &memoryBuilder);
164     std::for_each(modelInShared.referenced.begin(), modelInShared.referenced.end(),
165                   [&memoryBuilder](auto& subgraph) {
166                       copyPointersToSharedMemory(&subgraph, &memoryBuilder);
167                   });
168 
169     if (!memoryBuilder.empty()) {
170         auto memory = NN_TRY(memoryBuilder.finish());
171         modelInShared.pools.push_back(std::move(memory));
172     }
173 
174     *maybeModelInSharedOut = modelInShared;
175     return **maybeModelInSharedOut;
176 }
177 
178 template <>
flush() const179 void InputRelocationTracker::flush() const {
180     // Copy from pointers to shared memory.
181     uint8_t* memoryPtr = static_cast<uint8_t*>(std::get<void*>(kMapping.pointer));
182     for (const auto& [data, length, offset] : kRelocationInfos) {
183         std::memcpy(memoryPtr + offset, data, length);
184     }
185 }
186 
187 template <>
flush() const188 void OutputRelocationTracker::flush() const {
189     // Copy from shared memory to pointers.
190     const uint8_t* memoryPtr = static_cast<const uint8_t*>(
191             std::visit([](auto ptr) { return static_cast<const void*>(ptr); }, kMapping.pointer));
192     for (const auto& [data, length, offset] : kRelocationInfos) {
193         std::memcpy(data, memoryPtr + offset, length);
194     }
195 }
196 
convertRequestFromPointerToShared(const Request * request,uint32_t alignment,uint32_t padding,std::optional<Request> * maybeRequestInSharedOut,RequestRelocation * relocationOut)197 GeneralResult<std::reference_wrapper<const Request>> convertRequestFromPointerToShared(
198         const Request* request, uint32_t alignment, uint32_t padding,
199         std::optional<Request>* maybeRequestInSharedOut, RequestRelocation* relocationOut) {
200     CHECK(request != nullptr);
201     CHECK(maybeRequestInSharedOut != nullptr);
202     CHECK(relocationOut != nullptr);
203 
204     if (hasNoPointerData(*request)) {
205         return *request;
206     }
207 
208     // Make a copy of the request in order to make modifications. The modified request is returned
209     // to the caller through `maybeRequestInSharedOut` if the function succeeds.
210     Request requestInShared = *request;
211 
212     RequestRelocation relocation;
213 
214     // Change input pointers to shared memory.
215     MutableMemoryBuilder inputBuilder(requestInShared.pools.size());
216     std::vector<InputRelocationInfo> inputRelocationInfos;
217     for (auto& input : requestInShared.inputs) {
218         const auto& location = input.location;
219         if (input.lifetime != Request::Argument::LifeTime::POINTER) {
220             continue;
221         }
222 
223         input.lifetime = Request::Argument::LifeTime::POOL;
224         const void* data = std::visit([](auto ptr) { return static_cast<const void*>(ptr); },
225                                       location.pointer);
226         CHECK(data != nullptr);
227         input.location = inputBuilder.append(location.length, alignment, padding);
228         inputRelocationInfos.push_back({data, input.location.length, input.location.offset});
229     }
230 
231     // Allocate input memory.
232     if (!inputBuilder.empty()) {
233         auto memory = NN_TRY(inputBuilder.finish());
234         requestInShared.pools.push_back(memory);
235         relocation.input = NN_TRY(
236                 InputRelocationTracker::create(std::move(inputRelocationInfos), std::move(memory)));
237     }
238 
239     // Change output pointers to shared memory.
240     MutableMemoryBuilder outputBuilder(requestInShared.pools.size());
241     std::vector<OutputRelocationInfo> outputRelocationInfos;
242     for (auto& output : requestInShared.outputs) {
243         const auto& location = output.location;
244         if (output.lifetime != Request::Argument::LifeTime::POINTER) {
245             continue;
246         }
247 
248         output.lifetime = Request::Argument::LifeTime::POOL;
249         void* data = std::get<void*>(location.pointer);
250         CHECK(data != nullptr);
251         output.location = outputBuilder.append(location.length, alignment, padding);
252         outputRelocationInfos.push_back({data, output.location.length, output.location.offset});
253     }
254 
255     // Allocate output memory.
256     if (!outputBuilder.empty()) {
257         auto memory = NN_TRY(outputBuilder.finish());
258         requestInShared.pools.push_back(memory);
259         relocation.output = NN_TRY(OutputRelocationTracker::create(std::move(outputRelocationInfos),
260                                                                    std::move(memory)));
261     }
262 
263     *maybeRequestInSharedOut = requestInShared;
264     *relocationOut = std::move(relocation);
265     return **maybeRequestInSharedOut;
266 }
267 
268 }  // namespace android::nn
269