1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_ADAPTER_BURST_H
18 #define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_ADAPTER_BURST_H
19 
20 #include <android-base/thread_annotations.h>
21 #include <android/hardware/neuralnetworks/1.0/types.h>
22 #include <android/hardware/neuralnetworks/1.2/IBurstCallback.h>
23 #include <android/hardware/neuralnetworks/1.2/IBurstContext.h>
24 #include <android/hardware/neuralnetworks/1.2/IPreparedModel.h>
25 #include <android/hardware/neuralnetworks/1.2/types.h>
26 #include <fmq/MessageQueue.h>
27 #include <hidl/MQDescriptor.h>
28 #include <nnapi/IBurst.h>
29 #include <nnapi/Result.h>
30 #include <nnapi/Types.h>
31 #include <nnapi/hal/1.0/ProtectCallback.h>
32 #include <nnapi/hal/1.2/BurstUtils.h>
33 
34 #include <atomic>
35 #include <chrono>
36 #include <memory>
37 #include <optional>
38 #include <thread>
39 #include <tuple>
40 #include <vector>
41 
42 namespace android::hardware::neuralnetworks::adapter {
43 
44 /**
45  * The Burst class is responsible for waiting for and deserializing a request object from a FMQ,
46  * performing the inference, and serializing the result back across another FMQ.
47  */
48 class Burst : public V1_2::IBurstContext {
49     struct PrivateConstructorTag {};
50 
51   public:
52     /**
53      * Class to cache the memory objects for a burst object.
54      *
55      * This class is thread-safe.
56      */
57     class MemoryCache {
58       public:
59         // Precondition: burstExecutor != nullptr
60         // Precondition: burstCallback != nullptr
61         MemoryCache(nn::SharedBurst burstExecutor, sp<V1_2::IBurstCallback> burstCallback);
62 
63         /**
64          * Get the cached memory objects corresponding to provided slot identifiers.
65          *
66          * If the slot entry is not present in the cache, this class will use V1_2::IBurstCallback
67          * to retrieve those entries that are not present in the cache, then cache them.
68          *
69          * @param slots Identifiers of memory objects to be retrieved.
70          * @return A vector where each element is the memory object and a ref-counted cache "hold"
71          *     object to preserve the cache entry of the IBurst object as long as the "hold" object
72          *     is alive, otherwise GeneralError. Each element of the vector corresponds to the
73          *     element of slot.
74          */
75         nn::GeneralResult<std::vector<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>>>
76         getCacheEntries(const std::vector<int32_t>& slots);
77 
78         /**
79          * Remove an entry from the cache.
80          *
81          * @param slot Identifier of the memory object to be removed from the cache.
82          */
83         void removeCacheEntry(int32_t slot);
84 
85       private:
86         nn::GeneralResult<void> ensureCacheEntriesArePresentLocked(
87                 const std::vector<int32_t>& slots) REQUIRES(mMutex);
88         nn::GeneralResult<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>>
89         getCacheEntryLocked(int32_t slot) REQUIRES(mMutex);
90         void addCacheEntryLocked(int32_t slot, nn::SharedMemory memory) REQUIRES(mMutex);
91 
92         std::mutex mMutex;
93         std::map<int32_t, std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>> mCache
94                 GUARDED_BY(mMutex);
95         nn::SharedBurst kBurstExecutor;
96         const sp<V1_2::IBurstCallback> kBurstCallback;
97     };
98 
99     /**
100      * Create automated context to manage FMQ-based executions.
101      *
102      * This function is intended to be used by a service to automatically:
103      * 1) Receive data from a provided FMQ
104      * 2) Execute a model with the given information
105      * 3) Send the result to the created FMQ
106      *
107      * @param callback Callback used to retrieve memories corresponding to unrecognized slots.
108      * @param requestChannel Input FMQ channel through which the client passes the request to the
109      *     service.
110      * @param resultChannel Output FMQ channel from which the client can retrieve the result of the
111      *     execution.
112      * @param burstExecutor Object which maintains a local cache of the memory pools and executes
113      *     using the cached memory pools.
114      * @param pollingTimeWindow How much time (in microseconds) the Burst is allowed to poll the FMQ
115      *     before waiting on the blocking futex. Polling may result in lower latencies at the
116      *     potential cost of more power usage.
117      * @return V1_2::IBurstContext Handle to the burst context.
118      */
119     static nn::GeneralResult<sp<Burst>> create(
120             const sp<V1_2::IBurstCallback>& callback,
121             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
122             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
123             nn::SharedBurst burstExecutor,
124             std::chrono::microseconds pollingTimeWindow = std::chrono::microseconds{0});
125 
126     Burst(PrivateConstructorTag tag, const sp<V1_2::IBurstCallback>& callback,
127           std::unique_ptr<V1_2::utils::RequestChannelReceiver> requestChannel,
128           std::unique_ptr<V1_2::utils::ResultChannelSender> resultChannel,
129           nn::SharedBurst burstExecutor);
130     ~Burst();
131 
132     // Used by the NN runtime to preemptively remove any stored memory. See
133     // V1_2::IBurstContext::freeMemory for more information.
134     Return<void> freeMemory(int32_t slot) override;
135 
136   private:
137     // Work loop that will continue processing execution requests until the Burst object is freed.
138     void task();
139 
140     nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> execute(
141             const V1_0::Request& requestWithoutPools, const std::vector<int32_t>& slotsOfPools,
142             V1_2::MeasureTiming measure);
143 
144     std::thread mWorker;
145     std::atomic<bool> mTeardown{false};
146     const sp<V1_2::IBurstCallback> mCallback;
147     const std::unique_ptr<V1_2::utils::RequestChannelReceiver> mRequestChannelReceiver;
148     const std::unique_ptr<V1_2::utils::ResultChannelSender> mResultChannelSender;
149     const nn::SharedBurst mBurstExecutor;
150     MemoryCache mMemoryCache;
151 };
152 
153 }  // namespace android::hardware::neuralnetworks::adapter
154 
155 #endif  // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_ADAPTER_BURST_H
156