1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ExecutionBurstController"
18 
19 #include "ExecutionBurstController.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <algorithm>
24 #include <cstring>
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <string>
29 #include <thread>
30 #include <tuple>
31 #include <utility>
32 #include <vector>
33 
34 #include "HalInterfaces.h"
35 #include "Tracing.h"
36 #include "Utils.h"
37 
38 namespace android::nn {
39 namespace {
40 
41 using V1_2::FmqRequestDatum;
42 using V1_2::FmqResultDatum;
43 using V1_2::IBurstCallback;
44 using V1_2::IBurstContext;
45 using FmqRequestDescriptor = hardware::MQDescriptorSync<FmqRequestDatum>;
46 using FmqResultDescriptor = hardware::MQDescriptorSync<FmqResultDatum>;
47 
48 constexpr V1_2::Timing kNoTiming12 = {std::numeric_limits<uint64_t>::max(),
49                                       std::numeric_limits<uint64_t>::max()};
50 
51 class BurstContextDeathHandler : public hardware::hidl_death_recipient {
52    public:
53     using Callback = std::function<void()>;
54 
BurstContextDeathHandler(const Callback & onDeathCallback)55     BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
56         CHECK(onDeathCallback != nullptr);
57     }
58 
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)59     void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
60         LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
61         mOnDeathCallback();
62     }
63 
64    private:
65     const Callback mOnDeathCallback;
66 };
67 
68 }  // anonymous namespace
69 
70 // serialize a request into a packet
serialize(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)71 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, V1_2::MeasureTiming measure,
72                                        const std::vector<int32_t>& slots) {
73     // count how many elements need to be sent for a request
74     size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
75     for (const auto& input : request.inputs) {
76         count += input.dimensions.size();
77     }
78     for (const auto& output : request.outputs) {
79         count += output.dimensions.size();
80     }
81 
82     // create buffer to temporarily store elements
83     std::vector<FmqRequestDatum> data;
84     data.reserve(count);
85 
86     // package packetInfo
87     {
88         FmqRequestDatum datum;
89         datum.packetInformation(
90                 {/*.packetSize=*/static_cast<uint32_t>(count),
91                  /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
92                  /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
93                  /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
94         data.push_back(datum);
95     }
96 
97     // package input data
98     for (const auto& input : request.inputs) {
99         // package operand information
100         FmqRequestDatum datum;
101         datum.inputOperandInformation(
102                 {/*.hasNoValue=*/input.hasNoValue,
103                  /*.location=*/input.location,
104                  /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
105         data.push_back(datum);
106 
107         // package operand dimensions
108         for (uint32_t dimension : input.dimensions) {
109             FmqRequestDatum datum;
110             datum.inputOperandDimensionValue(dimension);
111             data.push_back(datum);
112         }
113     }
114 
115     // package output data
116     for (const auto& output : request.outputs) {
117         // package operand information
118         FmqRequestDatum datum;
119         datum.outputOperandInformation(
120                 {/*.hasNoValue=*/output.hasNoValue,
121                  /*.location=*/output.location,
122                  /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
123         data.push_back(datum);
124 
125         // package operand dimensions
126         for (uint32_t dimension : output.dimensions) {
127             FmqRequestDatum datum;
128             datum.outputOperandDimensionValue(dimension);
129             data.push_back(datum);
130         }
131     }
132 
133     // package pool identifier
134     for (int32_t slot : slots) {
135         FmqRequestDatum datum;
136         datum.poolIdentifier(slot);
137         data.push_back(datum);
138     }
139 
140     // package measureTiming
141     {
142         FmqRequestDatum datum;
143         datum.measureTiming(measure);
144         data.push_back(datum);
145     }
146 
147     // return packet
148     return data;
149 }
150 
151 // deserialize a packet into the result
152 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
deserialize(const std::vector<FmqResultDatum> & data)153 deserialize(const std::vector<FmqResultDatum>& data) {
154     using discriminator = FmqResultDatum::hidl_discriminator;
155 
156     std::vector<V1_2::OutputShape> outputShapes;
157     size_t index = 0;
158 
159     // validate packet information
160     if (index >= data.size() ||
161         data.at(index).getDiscriminator() != discriminator::packetInformation) {
162         LOG(ERROR) << "FMQ Result packet ill-formed";
163         return std::nullopt;
164     }
165 
166     // unpackage packet information
167     const FmqResultDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
168     index++;
169     const uint32_t packetSize = packetInfo.packetSize;
170     const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
171     const uint32_t numberOfOperands = packetInfo.numberOfOperands;
172 
173     // verify packet size
174     if (data.size() != packetSize) {
175         LOG(ERROR) << "FMQ Result packet ill-formed";
176         return std::nullopt;
177     }
178 
179     // unpackage operands
180     for (size_t operand = 0; operand < numberOfOperands; ++operand) {
181         // validate operand information
182         if (index >= data.size() ||
183             data.at(index).getDiscriminator() != discriminator::operandInformation) {
184             LOG(ERROR) << "FMQ Result packet ill-formed";
185             return std::nullopt;
186         }
187 
188         // unpackage operand information
189         const FmqResultDatum::OperandInformation& operandInfo = data.at(index).operandInformation();
190         index++;
191         const bool isSufficient = operandInfo.isSufficient;
192         const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
193 
194         // unpackage operand dimensions
195         std::vector<uint32_t> dimensions;
196         dimensions.reserve(numberOfDimensions);
197         for (size_t i = 0; i < numberOfDimensions; ++i) {
198             // validate dimension
199             if (index >= data.size() ||
200                 data.at(index).getDiscriminator() != discriminator::operandDimensionValue) {
201                 LOG(ERROR) << "FMQ Result packet ill-formed";
202                 return std::nullopt;
203             }
204 
205             // unpackage dimension
206             const uint32_t dimension = data.at(index).operandDimensionValue();
207             index++;
208 
209             // store result
210             dimensions.push_back(dimension);
211         }
212 
213         // store result
214         outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
215     }
216 
217     // validate execution timing
218     if (index >= data.size() ||
219         data.at(index).getDiscriminator() != discriminator::executionTiming) {
220         LOG(ERROR) << "FMQ Result packet ill-formed";
221         return std::nullopt;
222     }
223 
224     // unpackage execution timing
225     const V1_2::Timing timing = data.at(index).executionTiming();
226     index++;
227 
228     // validate packet information
229     if (index != packetSize) {
230         LOG(ERROR) << "FMQ Result packet ill-formed";
231         return std::nullopt;
232     }
233 
234     // return result
235     return std::make_tuple(errorStatus, std::move(outputShapes), timing);
236 }
237 
legacyConvertResultCodeToErrorStatus(int resultCode)238 V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
239     return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
240 }
241 
242 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)243 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
244     std::unique_ptr<FmqResultChannel> fmqResultChannel =
245             std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
246     if (!fmqResultChannel->isValid()) {
247         LOG(ERROR) << "Unable to create ResultChannelReceiver";
248         return {nullptr, nullptr};
249     }
250 
251     const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
252     return std::make_pair(
253             std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
254             descriptor);
255 }
256 
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,std::chrono::microseconds pollingTimeWindow)257 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
258                                              std::chrono::microseconds pollingTimeWindow)
259     : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
260 
261 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
getBlocking()262 ResultChannelReceiver::getBlocking() {
263     const auto packet = getPacketBlocking();
264     if (!packet) {
265         return std::nullopt;
266     }
267 
268     return deserialize(*packet);
269 }
270 
invalidate()271 void ResultChannelReceiver::invalidate() {
272     mValid = false;
273 
274     // force unblock
275     // ExecutionBurstController waits on a result packet after sending a
276     // request. If the driver containing ExecutionBurstServer crashes, the
277     // controller may be waiting on the futex. This force unblock wakes up any
278     // thread waiting on the futex.
279     // TODO: look for a different/better way to signal/notify the futex to
280     // wake up any thread waiting on it
281     FmqResultDatum datum;
282     datum.packetInformation({/*.packetSize=*/0,
283                              /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
284                              /*.numberOfOperands=*/0});
285     mFmqResultChannel->writeBlocking(&datum, 1);
286 }
287 
getPacketBlocking()288 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
289     if (!mValid) {
290         return std::nullopt;
291     }
292 
293     // First spend time polling if results are available in FMQ instead of
294     // waiting on the futex. Polling is more responsive (yielding lower
295     // latencies), but can take up more power, so only poll for a limited period
296     // of time.
297 
298     auto& getCurrentTime = std::chrono::high_resolution_clock::now;
299     const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
300 
301     while (getCurrentTime() < timeToStopPolling) {
302         // if class is being torn down, immediately return
303         if (!mValid.load(std::memory_order_relaxed)) {
304             return std::nullopt;
305         }
306 
307         // Check if data is available. If it is, immediately retrieve it and
308         // return.
309         const size_t available = mFmqResultChannel->availableToRead();
310         if (available > 0) {
311             std::vector<FmqResultDatum> packet(available);
312             const bool success = mFmqResultChannel->read(packet.data(), available);
313             if (!success) {
314                 LOG(ERROR) << "Error receiving packet";
315                 return std::nullopt;
316             }
317             return std::make_optional(std::move(packet));
318         }
319 
320         std::this_thread::yield();
321     }
322 
323     // If we get to this point, we either stopped polling because it was taking
324     // too long or polling was not allowed. Instead, perform a blocking call
325     // which uses a futex to save power.
326 
327     // wait for result packet and read first element of result packet
328     FmqResultDatum datum;
329     bool success = mFmqResultChannel->readBlocking(&datum, 1);
330 
331     // retrieve remaining elements
332     // NOTE: all of the data is already available at this point, so there's no
333     // need to do a blocking wait to wait for more data. This is known because
334     // in FMQ, all writes are published (made available) atomically. Currently,
335     // the producer always publishes the entire packet in one function call, so
336     // if the first element of the packet is available, the remaining elements
337     // are also available.
338     const size_t count = mFmqResultChannel->availableToRead();
339     std::vector<FmqResultDatum> packet(count + 1);
340     std::memcpy(&packet.front(), &datum, sizeof(datum));
341     success &= mFmqResultChannel->read(packet.data() + 1, count);
342 
343     if (!mValid) {
344         return std::nullopt;
345     }
346 
347     // ensure packet was successfully received
348     if (!success) {
349         LOG(ERROR) << "Error receiving packet";
350         return std::nullopt;
351     }
352 
353     return std::make_optional(std::move(packet));
354 }
355 
356 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength)357 RequestChannelSender::create(size_t channelLength) {
358     std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
359             std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
360     if (!fmqRequestChannel->isValid()) {
361         LOG(ERROR) << "Unable to create RequestChannelSender";
362         return {nullptr, nullptr};
363     }
364 
365     const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
366     return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
367                           descriptor);
368 }
369 
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)370 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
371     : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
372 
send(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)373 bool RequestChannelSender::send(const V1_0::Request& request, V1_2::MeasureTiming measure,
374                                 const std::vector<int32_t>& slots) {
375     const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
376     return sendPacket(serialized);
377 }
378 
sendPacket(const std::vector<FmqRequestDatum> & packet)379 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
380     if (!mValid) {
381         return false;
382     }
383 
384     if (packet.size() > mFmqRequestChannel->availableToWrite()) {
385         LOG(ERROR)
386                 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
387         return false;
388     }
389 
390     // Always send the packet with "blocking" because this signals the futex and
391     // unblocks the consumer if it is waiting on the futex.
392     return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
393 }
394 
invalidate()395 void RequestChannelSender::invalidate() {
396     mValid = false;
397 }
398 
getMemories(const hardware::hidl_vec<int32_t> & slots,getMemories_cb cb)399 hardware::Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
400         const hardware::hidl_vec<int32_t>& slots, getMemories_cb cb) {
401     std::lock_guard<std::mutex> guard(mMutex);
402 
403     // get all memories
404     hardware::hidl_vec<hardware::hidl_memory> memories(slots.size());
405     std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
406         if (slot < 0 || static_cast<size_t>(slot) >= mMemoryCache.size()) {
407             return hardware::hidl_memory{};
408         }
409         return mMemoryCache[slot];
410     });
411 
412     // ensure all memories are valid
413     if (!std::all_of(memories.begin(), memories.end(),
414                      [](const hardware::hidl_memory& memory) { return memory.valid(); })) {
415         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
416         return hardware::Void();
417     }
418 
419     // return successful
420     cb(V1_0::ErrorStatus::NONE, std::move(memories));
421     return hardware::Void();
422 }
423 
getSlots(const hardware::hidl_vec<hardware::hidl_memory> & memories,const std::vector<intptr_t> & keys)424 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
425         const hardware::hidl_vec<hardware::hidl_memory>& memories,
426         const std::vector<intptr_t>& keys) {
427     std::lock_guard<std::mutex> guard(mMutex);
428 
429     // retrieve (or bind) all slots corresponding to memories
430     std::vector<int32_t> slots;
431     slots.reserve(memories.size());
432     for (size_t i = 0; i < memories.size(); ++i) {
433         slots.push_back(getSlotLocked(memories[i], keys[i]));
434     }
435     return slots;
436 }
437 
freeMemory(intptr_t key)438 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
439         intptr_t key) {
440     std::lock_guard<std::mutex> guard(mMutex);
441 
442     auto iter = mMemoryIdToSlot.find(key);
443     if (iter == mMemoryIdToSlot.end()) {
444         return {false, 0};
445     }
446     const int32_t slot = iter->second;
447     mMemoryIdToSlot.erase(key);
448     mMemoryCache[slot] = {};
449     mFreeSlots.push(slot);
450     return {true, slot};
451 }
452 
getSlotLocked(const hardware::hidl_memory & memory,intptr_t key)453 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(
454         const hardware::hidl_memory& memory, intptr_t key) {
455     auto iter = mMemoryIdToSlot.find(key);
456     if (iter == mMemoryIdToSlot.end()) {
457         const int32_t slot = allocateSlotLocked();
458         mMemoryIdToSlot[key] = slot;
459         mMemoryCache[slot] = memory;
460         return slot;
461     } else {
462         const int32_t slot = iter->second;
463         return slot;
464     }
465 }
466 
allocateSlotLocked()467 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
468     constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
469 
470     // if there is a free slot, use it
471     if (mFreeSlots.size() > 0) {
472         const int32_t slot = mFreeSlots.top();
473         mFreeSlots.pop();
474         return slot;
475     }
476 
477     // otherwise use a slot for the first time
478     CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
479     const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
480     mMemoryCache.emplace_back();
481 
482     return slot;
483 }
484 
create(const sp<V1_2::IPreparedModel> & preparedModel,std::chrono::microseconds pollingTimeWindow)485 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
486         const sp<V1_2::IPreparedModel>& preparedModel,
487         std::chrono::microseconds pollingTimeWindow) {
488     // check inputs
489     if (preparedModel == nullptr) {
490         LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
491         return nullptr;
492     }
493 
494     // create callback object
495     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
496 
497     // create FMQ objects
498     auto [requestChannelSenderTemp, requestChannelDescriptor] =
499             RequestChannelSender::create(kExecutionBurstChannelLength);
500     auto [resultChannelReceiverTemp, resultChannelDescriptor] =
501             ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
502     std::shared_ptr<RequestChannelSender> requestChannelSender =
503             std::move(requestChannelSenderTemp);
504     std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
505             std::move(resultChannelReceiverTemp);
506 
507     // check FMQ objects
508     if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
509         !resultChannelDescriptor) {
510         LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
511         return nullptr;
512     }
513 
514     // configure burst
515     V1_0::ErrorStatus errorStatus;
516     sp<IBurstContext> burstContext;
517     const hardware::Return<void> ret = preparedModel->configureExecutionBurst(
518             callback, *requestChannelDescriptor, *resultChannelDescriptor,
519             [&errorStatus, &burstContext](V1_0::ErrorStatus status,
520                                           const sp<IBurstContext>& context) {
521                 errorStatus = status;
522                 burstContext = context;
523             });
524 
525     // check burst
526     if (!ret.isOk()) {
527         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
528                    << ret.description();
529         return nullptr;
530     }
531     if (errorStatus != V1_0::ErrorStatus::NONE) {
532         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
533                    << toString(errorStatus);
534         return nullptr;
535     }
536     if (burstContext == nullptr) {
537         LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
538         return nullptr;
539     }
540 
541     // create death handler object
542     BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
543                                                           resultChannelReceiver] {
544         requestChannelSender->invalidate();
545         resultChannelReceiver->invalidate();
546     };
547     const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
548 
549     // linkToDeath registers a callback that will be invoked on service death to
550     // proactively handle service crashes. If the linkToDeath call fails,
551     // asynchronous calls are susceptible to hangs if the service crashes before
552     // providing the response.
553     const hardware::Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
554     if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
555         LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
556                       "for the IBurstContext object.";
557         return nullptr;
558     }
559 
560     // make and return controller
561     return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
562                                                       burstContext, callback, deathHandler);
563 }
564 
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hardware::hidl_death_recipient> & deathHandler)565 ExecutionBurstController::ExecutionBurstController(
566         const std::shared_ptr<RequestChannelSender>& requestChannelSender,
567         const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
568         const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
569         const sp<hardware::hidl_death_recipient>& deathHandler)
570     : mRequestChannelSender(requestChannelSender),
571       mResultChannelReceiver(resultChannelReceiver),
572       mBurstContext(burstContext),
573       mMemoryCache(callback),
574       mDeathHandler(deathHandler) {}
575 
~ExecutionBurstController()576 ExecutionBurstController::~ExecutionBurstController() {
577     // It is safe to ignore any errors resulting from this unlinkToDeath call
578     // because the ExecutionBurstController object is already being destroyed
579     // and its underlying IBurstContext object is no longer being used by the NN
580     // runtime.
581     if (mDeathHandler) {
582         mBurstContext->unlinkToDeath(mDeathHandler).isOk();
583     }
584 }
585 
getExecutionResult(V1_0::ErrorStatus status,std::vector<V1_2::OutputShape> outputShapes,V1_2::Timing timing,bool fallback)586 static std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool> getExecutionResult(
587         V1_0::ErrorStatus status, std::vector<V1_2::OutputShape> outputShapes, V1_2::Timing timing,
588         bool fallback) {
589     auto [n, checkedOutputShapes, checkedTiming] =
590             getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
591     return {n, convertToV1_2(checkedOutputShapes), convertToV1_2(checkedTiming), fallback};
592 }
593 
594 std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool>
compute(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<intptr_t> & memoryIds)595 ExecutionBurstController::compute(const V1_0::Request& request, V1_2::MeasureTiming measure,
596                                   const std::vector<intptr_t>& memoryIds) {
597     // This is the first point when we know an execution is occurring, so begin
598     // to collect systraces. Note that the first point we can begin collecting
599     // systraces in ExecutionBurstServer is when the RequestChannelReceiver
600     // realizes there is data in the FMQ, so ExecutionBurstServer collects
601     // systraces at different points in the code.
602     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
603 
604     std::lock_guard<std::mutex> guard(mMutex);
605 
606     // send request packet
607     const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
608     const bool success = mRequestChannelSender->send(request, measure, slots);
609     if (!success) {
610         LOG(ERROR) << "Error sending FMQ packet";
611         // only use fallback execution path if the packet could not be sent
612         return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
613                                   /*fallback=*/true);
614     }
615 
616     // get result packet
617     const auto result = mResultChannelReceiver->getBlocking();
618     if (!result) {
619         LOG(ERROR) << "Error retrieving FMQ packet";
620         // only use fallback execution path if the packet could not be sent
621         return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
622                                   /*fallback=*/false);
623     }
624 
625     // unpack results and return (only use fallback execution path if the
626     // packet could not be sent)
627     auto [status, outputShapes, timing] = std::move(*result);
628     return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
629 }
630 
freeMemory(intptr_t key)631 void ExecutionBurstController::freeMemory(intptr_t key) {
632     std::lock_guard<std::mutex> guard(mMutex);
633 
634     bool valid;
635     int32_t slot;
636     std::tie(valid, slot) = mMemoryCache->freeMemory(key);
637     if (valid) {
638         mBurstContext->freeMemory(slot).isOk();
639     }
640 }
641 
642 }  // namespace android::nn
643