1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ExecutionBurstController"
18
19 #include "ExecutionBurstController.h"
20
21 #include <android-base/logging.h>
22
23 #include <algorithm>
24 #include <cstring>
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <string>
29 #include <thread>
30 #include <tuple>
31 #include <utility>
32 #include <vector>
33
34 #include "HalInterfaces.h"
35 #include "Tracing.h"
36 #include "Utils.h"
37
38 namespace android::nn {
39 namespace {
40
41 using V1_2::FmqRequestDatum;
42 using V1_2::FmqResultDatum;
43 using V1_2::IBurstCallback;
44 using V1_2::IBurstContext;
45 using FmqRequestDescriptor = hardware::MQDescriptorSync<FmqRequestDatum>;
46 using FmqResultDescriptor = hardware::MQDescriptorSync<FmqResultDatum>;
47
48 constexpr V1_2::Timing kNoTiming12 = {std::numeric_limits<uint64_t>::max(),
49 std::numeric_limits<uint64_t>::max()};
50
51 class BurstContextDeathHandler : public hardware::hidl_death_recipient {
52 public:
53 using Callback = std::function<void()>;
54
BurstContextDeathHandler(const Callback & onDeathCallback)55 BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
56 CHECK(onDeathCallback != nullptr);
57 }
58
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)59 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
60 LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
61 mOnDeathCallback();
62 }
63
64 private:
65 const Callback mOnDeathCallback;
66 };
67
68 } // anonymous namespace
69
70 // serialize a request into a packet
serialize(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)71 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, V1_2::MeasureTiming measure,
72 const std::vector<int32_t>& slots) {
73 // count how many elements need to be sent for a request
74 size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
75 for (const auto& input : request.inputs) {
76 count += input.dimensions.size();
77 }
78 for (const auto& output : request.outputs) {
79 count += output.dimensions.size();
80 }
81
82 // create buffer to temporarily store elements
83 std::vector<FmqRequestDatum> data;
84 data.reserve(count);
85
86 // package packetInfo
87 {
88 FmqRequestDatum datum;
89 datum.packetInformation(
90 {/*.packetSize=*/static_cast<uint32_t>(count),
91 /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
92 /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
93 /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
94 data.push_back(datum);
95 }
96
97 // package input data
98 for (const auto& input : request.inputs) {
99 // package operand information
100 FmqRequestDatum datum;
101 datum.inputOperandInformation(
102 {/*.hasNoValue=*/input.hasNoValue,
103 /*.location=*/input.location,
104 /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
105 data.push_back(datum);
106
107 // package operand dimensions
108 for (uint32_t dimension : input.dimensions) {
109 FmqRequestDatum datum;
110 datum.inputOperandDimensionValue(dimension);
111 data.push_back(datum);
112 }
113 }
114
115 // package output data
116 for (const auto& output : request.outputs) {
117 // package operand information
118 FmqRequestDatum datum;
119 datum.outputOperandInformation(
120 {/*.hasNoValue=*/output.hasNoValue,
121 /*.location=*/output.location,
122 /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
123 data.push_back(datum);
124
125 // package operand dimensions
126 for (uint32_t dimension : output.dimensions) {
127 FmqRequestDatum datum;
128 datum.outputOperandDimensionValue(dimension);
129 data.push_back(datum);
130 }
131 }
132
133 // package pool identifier
134 for (int32_t slot : slots) {
135 FmqRequestDatum datum;
136 datum.poolIdentifier(slot);
137 data.push_back(datum);
138 }
139
140 // package measureTiming
141 {
142 FmqRequestDatum datum;
143 datum.measureTiming(measure);
144 data.push_back(datum);
145 }
146
147 // return packet
148 return data;
149 }
150
151 // deserialize a packet into the result
152 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
deserialize(const std::vector<FmqResultDatum> & data)153 deserialize(const std::vector<FmqResultDatum>& data) {
154 using discriminator = FmqResultDatum::hidl_discriminator;
155
156 std::vector<V1_2::OutputShape> outputShapes;
157 size_t index = 0;
158
159 // validate packet information
160 if (index >= data.size() ||
161 data.at(index).getDiscriminator() != discriminator::packetInformation) {
162 LOG(ERROR) << "FMQ Result packet ill-formed";
163 return std::nullopt;
164 }
165
166 // unpackage packet information
167 const FmqResultDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
168 index++;
169 const uint32_t packetSize = packetInfo.packetSize;
170 const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
171 const uint32_t numberOfOperands = packetInfo.numberOfOperands;
172
173 // verify packet size
174 if (data.size() != packetSize) {
175 LOG(ERROR) << "FMQ Result packet ill-formed";
176 return std::nullopt;
177 }
178
179 // unpackage operands
180 for (size_t operand = 0; operand < numberOfOperands; ++operand) {
181 // validate operand information
182 if (index >= data.size() ||
183 data.at(index).getDiscriminator() != discriminator::operandInformation) {
184 LOG(ERROR) << "FMQ Result packet ill-formed";
185 return std::nullopt;
186 }
187
188 // unpackage operand information
189 const FmqResultDatum::OperandInformation& operandInfo = data.at(index).operandInformation();
190 index++;
191 const bool isSufficient = operandInfo.isSufficient;
192 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
193
194 // unpackage operand dimensions
195 std::vector<uint32_t> dimensions;
196 dimensions.reserve(numberOfDimensions);
197 for (size_t i = 0; i < numberOfDimensions; ++i) {
198 // validate dimension
199 if (index >= data.size() ||
200 data.at(index).getDiscriminator() != discriminator::operandDimensionValue) {
201 LOG(ERROR) << "FMQ Result packet ill-formed";
202 return std::nullopt;
203 }
204
205 // unpackage dimension
206 const uint32_t dimension = data.at(index).operandDimensionValue();
207 index++;
208
209 // store result
210 dimensions.push_back(dimension);
211 }
212
213 // store result
214 outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
215 }
216
217 // validate execution timing
218 if (index >= data.size() ||
219 data.at(index).getDiscriminator() != discriminator::executionTiming) {
220 LOG(ERROR) << "FMQ Result packet ill-formed";
221 return std::nullopt;
222 }
223
224 // unpackage execution timing
225 const V1_2::Timing timing = data.at(index).executionTiming();
226 index++;
227
228 // validate packet information
229 if (index != packetSize) {
230 LOG(ERROR) << "FMQ Result packet ill-formed";
231 return std::nullopt;
232 }
233
234 // return result
235 return std::make_tuple(errorStatus, std::move(outputShapes), timing);
236 }
237
legacyConvertResultCodeToErrorStatus(int resultCode)238 V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
239 return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
240 }
241
242 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)243 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
244 std::unique_ptr<FmqResultChannel> fmqResultChannel =
245 std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
246 if (!fmqResultChannel->isValid()) {
247 LOG(ERROR) << "Unable to create ResultChannelReceiver";
248 return {nullptr, nullptr};
249 }
250
251 const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
252 return std::make_pair(
253 std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
254 descriptor);
255 }
256
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,std::chrono::microseconds pollingTimeWindow)257 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
258 std::chrono::microseconds pollingTimeWindow)
259 : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
260
261 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
getBlocking()262 ResultChannelReceiver::getBlocking() {
263 const auto packet = getPacketBlocking();
264 if (!packet) {
265 return std::nullopt;
266 }
267
268 return deserialize(*packet);
269 }
270
invalidate()271 void ResultChannelReceiver::invalidate() {
272 mValid = false;
273
274 // force unblock
275 // ExecutionBurstController waits on a result packet after sending a
276 // request. If the driver containing ExecutionBurstServer crashes, the
277 // controller may be waiting on the futex. This force unblock wakes up any
278 // thread waiting on the futex.
279 // TODO: look for a different/better way to signal/notify the futex to
280 // wake up any thread waiting on it
281 FmqResultDatum datum;
282 datum.packetInformation({/*.packetSize=*/0,
283 /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
284 /*.numberOfOperands=*/0});
285 mFmqResultChannel->writeBlocking(&datum, 1);
286 }
287
getPacketBlocking()288 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
289 if (!mValid) {
290 return std::nullopt;
291 }
292
293 // First spend time polling if results are available in FMQ instead of
294 // waiting on the futex. Polling is more responsive (yielding lower
295 // latencies), but can take up more power, so only poll for a limited period
296 // of time.
297
298 auto& getCurrentTime = std::chrono::high_resolution_clock::now;
299 const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
300
301 while (getCurrentTime() < timeToStopPolling) {
302 // if class is being torn down, immediately return
303 if (!mValid.load(std::memory_order_relaxed)) {
304 return std::nullopt;
305 }
306
307 // Check if data is available. If it is, immediately retrieve it and
308 // return.
309 const size_t available = mFmqResultChannel->availableToRead();
310 if (available > 0) {
311 std::vector<FmqResultDatum> packet(available);
312 const bool success = mFmqResultChannel->read(packet.data(), available);
313 if (!success) {
314 LOG(ERROR) << "Error receiving packet";
315 return std::nullopt;
316 }
317 return std::make_optional(std::move(packet));
318 }
319
320 std::this_thread::yield();
321 }
322
323 // If we get to this point, we either stopped polling because it was taking
324 // too long or polling was not allowed. Instead, perform a blocking call
325 // which uses a futex to save power.
326
327 // wait for result packet and read first element of result packet
328 FmqResultDatum datum;
329 bool success = mFmqResultChannel->readBlocking(&datum, 1);
330
331 // retrieve remaining elements
332 // NOTE: all of the data is already available at this point, so there's no
333 // need to do a blocking wait to wait for more data. This is known because
334 // in FMQ, all writes are published (made available) atomically. Currently,
335 // the producer always publishes the entire packet in one function call, so
336 // if the first element of the packet is available, the remaining elements
337 // are also available.
338 const size_t count = mFmqResultChannel->availableToRead();
339 std::vector<FmqResultDatum> packet(count + 1);
340 std::memcpy(&packet.front(), &datum, sizeof(datum));
341 success &= mFmqResultChannel->read(packet.data() + 1, count);
342
343 if (!mValid) {
344 return std::nullopt;
345 }
346
347 // ensure packet was successfully received
348 if (!success) {
349 LOG(ERROR) << "Error receiving packet";
350 return std::nullopt;
351 }
352
353 return std::make_optional(std::move(packet));
354 }
355
356 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength)357 RequestChannelSender::create(size_t channelLength) {
358 std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
359 std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
360 if (!fmqRequestChannel->isValid()) {
361 LOG(ERROR) << "Unable to create RequestChannelSender";
362 return {nullptr, nullptr};
363 }
364
365 const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
366 return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
367 descriptor);
368 }
369
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)370 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
371 : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
372
send(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)373 bool RequestChannelSender::send(const V1_0::Request& request, V1_2::MeasureTiming measure,
374 const std::vector<int32_t>& slots) {
375 const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
376 return sendPacket(serialized);
377 }
378
sendPacket(const std::vector<FmqRequestDatum> & packet)379 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
380 if (!mValid) {
381 return false;
382 }
383
384 if (packet.size() > mFmqRequestChannel->availableToWrite()) {
385 LOG(ERROR)
386 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
387 return false;
388 }
389
390 // Always send the packet with "blocking" because this signals the futex and
391 // unblocks the consumer if it is waiting on the futex.
392 return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
393 }
394
invalidate()395 void RequestChannelSender::invalidate() {
396 mValid = false;
397 }
398
getMemories(const hardware::hidl_vec<int32_t> & slots,getMemories_cb cb)399 hardware::Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
400 const hardware::hidl_vec<int32_t>& slots, getMemories_cb cb) {
401 std::lock_guard<std::mutex> guard(mMutex);
402
403 // get all memories
404 hardware::hidl_vec<hardware::hidl_memory> memories(slots.size());
405 std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
406 if (slot < 0 || static_cast<size_t>(slot) >= mMemoryCache.size()) {
407 return hardware::hidl_memory{};
408 }
409 return mMemoryCache[slot];
410 });
411
412 // ensure all memories are valid
413 if (!std::all_of(memories.begin(), memories.end(),
414 [](const hardware::hidl_memory& memory) { return memory.valid(); })) {
415 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
416 return hardware::Void();
417 }
418
419 // return successful
420 cb(V1_0::ErrorStatus::NONE, std::move(memories));
421 return hardware::Void();
422 }
423
getSlots(const hardware::hidl_vec<hardware::hidl_memory> & memories,const std::vector<intptr_t> & keys)424 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
425 const hardware::hidl_vec<hardware::hidl_memory>& memories,
426 const std::vector<intptr_t>& keys) {
427 std::lock_guard<std::mutex> guard(mMutex);
428
429 // retrieve (or bind) all slots corresponding to memories
430 std::vector<int32_t> slots;
431 slots.reserve(memories.size());
432 for (size_t i = 0; i < memories.size(); ++i) {
433 slots.push_back(getSlotLocked(memories[i], keys[i]));
434 }
435 return slots;
436 }
437
freeMemory(intptr_t key)438 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
439 intptr_t key) {
440 std::lock_guard<std::mutex> guard(mMutex);
441
442 auto iter = mMemoryIdToSlot.find(key);
443 if (iter == mMemoryIdToSlot.end()) {
444 return {false, 0};
445 }
446 const int32_t slot = iter->second;
447 mMemoryIdToSlot.erase(key);
448 mMemoryCache[slot] = {};
449 mFreeSlots.push(slot);
450 return {true, slot};
451 }
452
getSlotLocked(const hardware::hidl_memory & memory,intptr_t key)453 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(
454 const hardware::hidl_memory& memory, intptr_t key) {
455 auto iter = mMemoryIdToSlot.find(key);
456 if (iter == mMemoryIdToSlot.end()) {
457 const int32_t slot = allocateSlotLocked();
458 mMemoryIdToSlot[key] = slot;
459 mMemoryCache[slot] = memory;
460 return slot;
461 } else {
462 const int32_t slot = iter->second;
463 return slot;
464 }
465 }
466
allocateSlotLocked()467 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
468 constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
469
470 // if there is a free slot, use it
471 if (mFreeSlots.size() > 0) {
472 const int32_t slot = mFreeSlots.top();
473 mFreeSlots.pop();
474 return slot;
475 }
476
477 // otherwise use a slot for the first time
478 CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
479 const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
480 mMemoryCache.emplace_back();
481
482 return slot;
483 }
484
create(const sp<V1_2::IPreparedModel> & preparedModel,std::chrono::microseconds pollingTimeWindow)485 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
486 const sp<V1_2::IPreparedModel>& preparedModel,
487 std::chrono::microseconds pollingTimeWindow) {
488 // check inputs
489 if (preparedModel == nullptr) {
490 LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
491 return nullptr;
492 }
493
494 // create callback object
495 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
496
497 // create FMQ objects
498 auto [requestChannelSenderTemp, requestChannelDescriptor] =
499 RequestChannelSender::create(kExecutionBurstChannelLength);
500 auto [resultChannelReceiverTemp, resultChannelDescriptor] =
501 ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
502 std::shared_ptr<RequestChannelSender> requestChannelSender =
503 std::move(requestChannelSenderTemp);
504 std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
505 std::move(resultChannelReceiverTemp);
506
507 // check FMQ objects
508 if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
509 !resultChannelDescriptor) {
510 LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
511 return nullptr;
512 }
513
514 // configure burst
515 V1_0::ErrorStatus errorStatus;
516 sp<IBurstContext> burstContext;
517 const hardware::Return<void> ret = preparedModel->configureExecutionBurst(
518 callback, *requestChannelDescriptor, *resultChannelDescriptor,
519 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
520 const sp<IBurstContext>& context) {
521 errorStatus = status;
522 burstContext = context;
523 });
524
525 // check burst
526 if (!ret.isOk()) {
527 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
528 << ret.description();
529 return nullptr;
530 }
531 if (errorStatus != V1_0::ErrorStatus::NONE) {
532 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
533 << toString(errorStatus);
534 return nullptr;
535 }
536 if (burstContext == nullptr) {
537 LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
538 return nullptr;
539 }
540
541 // create death handler object
542 BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
543 resultChannelReceiver] {
544 requestChannelSender->invalidate();
545 resultChannelReceiver->invalidate();
546 };
547 const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
548
549 // linkToDeath registers a callback that will be invoked on service death to
550 // proactively handle service crashes. If the linkToDeath call fails,
551 // asynchronous calls are susceptible to hangs if the service crashes before
552 // providing the response.
553 const hardware::Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
554 if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
555 LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
556 "for the IBurstContext object.";
557 return nullptr;
558 }
559
560 // make and return controller
561 return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
562 burstContext, callback, deathHandler);
563 }
564
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hardware::hidl_death_recipient> & deathHandler)565 ExecutionBurstController::ExecutionBurstController(
566 const std::shared_ptr<RequestChannelSender>& requestChannelSender,
567 const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
568 const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
569 const sp<hardware::hidl_death_recipient>& deathHandler)
570 : mRequestChannelSender(requestChannelSender),
571 mResultChannelReceiver(resultChannelReceiver),
572 mBurstContext(burstContext),
573 mMemoryCache(callback),
574 mDeathHandler(deathHandler) {}
575
~ExecutionBurstController()576 ExecutionBurstController::~ExecutionBurstController() {
577 // It is safe to ignore any errors resulting from this unlinkToDeath call
578 // because the ExecutionBurstController object is already being destroyed
579 // and its underlying IBurstContext object is no longer being used by the NN
580 // runtime.
581 if (mDeathHandler) {
582 mBurstContext->unlinkToDeath(mDeathHandler).isOk();
583 }
584 }
585
getExecutionResult(V1_0::ErrorStatus status,std::vector<V1_2::OutputShape> outputShapes,V1_2::Timing timing,bool fallback)586 static std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool> getExecutionResult(
587 V1_0::ErrorStatus status, std::vector<V1_2::OutputShape> outputShapes, V1_2::Timing timing,
588 bool fallback) {
589 auto [n, checkedOutputShapes, checkedTiming] =
590 getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
591 return {n, convertToV1_2(checkedOutputShapes), convertToV1_2(checkedTiming), fallback};
592 }
593
594 std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool>
compute(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<intptr_t> & memoryIds)595 ExecutionBurstController::compute(const V1_0::Request& request, V1_2::MeasureTiming measure,
596 const std::vector<intptr_t>& memoryIds) {
597 // This is the first point when we know an execution is occurring, so begin
598 // to collect systraces. Note that the first point we can begin collecting
599 // systraces in ExecutionBurstServer is when the RequestChannelReceiver
600 // realizes there is data in the FMQ, so ExecutionBurstServer collects
601 // systraces at different points in the code.
602 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
603
604 std::lock_guard<std::mutex> guard(mMutex);
605
606 // send request packet
607 const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
608 const bool success = mRequestChannelSender->send(request, measure, slots);
609 if (!success) {
610 LOG(ERROR) << "Error sending FMQ packet";
611 // only use fallback execution path if the packet could not be sent
612 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
613 /*fallback=*/true);
614 }
615
616 // get result packet
617 const auto result = mResultChannelReceiver->getBlocking();
618 if (!result) {
619 LOG(ERROR) << "Error retrieving FMQ packet";
620 // only use fallback execution path if the packet could not be sent
621 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
622 /*fallback=*/false);
623 }
624
625 // unpack results and return (only use fallback execution path if the
626 // packet could not be sent)
627 auto [status, outputShapes, timing] = std::move(*result);
628 return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
629 }
630
freeMemory(intptr_t key)631 void ExecutionBurstController::freeMemory(intptr_t key) {
632 std::lock_guard<std::mutex> guard(mMutex);
633
634 bool valid;
635 int32_t slot;
636 std::tie(valid, slot) = mMemoryCache->freeMemory(key);
637 if (valid) {
638 mBurstContext->freeMemory(slot).isOk();
639 }
640 }
641
642 } // namespace android::nn
643