1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "SampleDriverAidlUtils.h"
18 
19 #include <aidl/android/hardware/common/NativeHandle.h>
20 #include <android/binder_auto_utils.h>
21 #include <android/binder_ibinder.h>
22 #include <android/binder_manager.h>
23 #include <nnapi/Validation.h>
24 #include <nnapi/hal/aidl/Conversions.h>
25 #include <nnapi/hal/aidl/Utils.h>
26 #include <utils/NativeHandle.h>
27 
28 #include <memory>
29 #include <string>
30 #include <thread>
31 #include <utility>
32 #include <vector>
33 
34 #include "SampleDriverAidl.h"
35 #include "android/binder_process.h"
36 
37 namespace android {
38 namespace nn {
39 namespace sample_driver_aidl {
40 
run(const std::shared_ptr<aidl_hal::BnDevice> & device,const std::string & name)41 int run(const std::shared_ptr<aidl_hal::BnDevice>& device, const std::string& name) {
42     constexpr size_t kNumberOfThreads = 4;
43     ABinderProcess_setThreadPoolMaxThreadCount(kNumberOfThreads);
44 
45     const std::string fqName = std::string(SampleDriver::descriptor) + "/" + name;
46     const binder_status_t status =
47             AServiceManager_addService(device->asBinder().get(), fqName.c_str());
48     if (status != STATUS_OK) {
49         LOG(ERROR) << "Could not register service " << name;
50         return 1;
51     }
52 
53     ABinderProcess_joinThreadPool();
54     LOG(ERROR) << "Service exited!";
55     return 1;
56 }
57 
notify(const std::shared_ptr<aidl_hal::IPreparedModelCallback> & callback,const aidl_hal::ErrorStatus & status,const std::shared_ptr<aidl_hal::IPreparedModel> & preparedModel)58 void notify(const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback,
59             const aidl_hal::ErrorStatus& status,
60             const std::shared_ptr<aidl_hal::IPreparedModel>& preparedModel) {
61     const auto ret = callback->notify(status, preparedModel);
62     if (!ret.isOk()) {
63         LOG(ERROR) << "Error when calling IPreparedModelCallback::notify: " << ret.getDescription()
64                    << " " << ret.getMessage();
65     }
66 }
67 
toAStatus(aidl_hal::ErrorStatus errorStatus)68 ndk::ScopedAStatus toAStatus(aidl_hal::ErrorStatus errorStatus) {
69     if (errorStatus == aidl_hal::ErrorStatus::NONE) {
70         return ndk::ScopedAStatus::ok();
71     }
72     return ndk::ScopedAStatus::fromServiceSpecificError(static_cast<int32_t>(errorStatus));
73 }
74 
toAStatus(aidl_hal::ErrorStatus errorStatus,const std::string & errorMessage)75 ndk::ScopedAStatus toAStatus(aidl_hal::ErrorStatus errorStatus, const std::string& errorMessage) {
76     if (errorStatus == aidl_hal::ErrorStatus::NONE) {
77         return ndk::ScopedAStatus::ok();
78     }
79     return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
80             static_cast<int32_t>(errorStatus), errorMessage.c_str());
81 }
82 
prepareModelBase(aidl_hal::Model && model,const SampleDriver * driver,aidl_hal::ExecutionPreference preference,aidl_hal::Priority priority,int64_t halDeadline,const std::shared_ptr<aidl_hal::IPreparedModelCallback> & callback,bool isFullModelSupported)83 ndk::ScopedAStatus prepareModelBase(
84         aidl_hal::Model&& model, const SampleDriver* driver,
85         aidl_hal::ExecutionPreference preference, aidl_hal::Priority priority, int64_t halDeadline,
86         const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback,
87         bool isFullModelSupported) {
88     const uid_t userId = AIBinder_getCallingUid();
89     if (callback.get() == nullptr) {
90         LOG(ERROR) << "invalid callback passed to prepareModelBase";
91         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
92                          "invalid callback passed to prepareModelBase");
93     }
94     const auto canonicalModel = convert(model);
95     if (!canonicalModel.has_value()) {
96         VLOG(DRIVER) << "invalid model passed to prepareModelBase";
97         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
98         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
99                          "invalid model passed to prepareModelBase");
100     }
101     if (VLOG_IS_ON(DRIVER)) {
102         VLOG(DRIVER) << "prepareModelBase";
103         logModelToInfo(canonicalModel.value());
104     }
105     if (!aidl_hal::utils::valid(preference)) {
106         const std::string log_message =
107                 "invalid execution preference passed to prepareModelBase: " + toString(preference);
108         VLOG(DRIVER) << log_message;
109         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
110         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT, log_message);
111     }
112     if (!aidl_hal::utils::valid(priority)) {
113         const std::string log_message =
114                 "invalid priority passed to prepareModelBase: " + toString(priority);
115         VLOG(DRIVER) << log_message;
116         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
117         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT, log_message);
118     }
119 
120     if (!isFullModelSupported) {
121         VLOG(DRIVER) << "model is not fully supported";
122         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
123         return ndk::ScopedAStatus::ok();
124     }
125 
126     if (halDeadline < -1) {
127         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
128         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
129                          "Invalid deadline: " + toString(halDeadline));
130     }
131     const auto deadline = makeDeadline(halDeadline);
132     if (hasDeadlinePassed(deadline)) {
133         notify(callback, aidl_hal::ErrorStatus::MISSED_DEADLINE_PERSISTENT, nullptr);
134         return ndk::ScopedAStatus::ok();
135     }
136 
137     // asynchronously prepare the model from a new, detached thread
138     std::thread(
139             [driver, preference, userId, priority, callback](aidl_hal::Model&& model) {
140                 std::shared_ptr<SamplePreparedModel> preparedModel =
141                         ndk::SharedRefBase::make<SamplePreparedModel>(std::move(model), driver,
142                                                                       preference, userId, priority);
143                 if (!preparedModel->initialize()) {
144                     notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
145                     return;
146                 }
147                 notify(callback, aidl_hal::ErrorStatus::NONE, preparedModel);
148             },
149             std::move(model))
150             .detach();
151 
152     return ndk::ScopedAStatus::ok();
153 }
154 
155 }  // namespace sample_driver_aidl
156 }  // namespace nn
157 }  // namespace android
158