1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "MockExecution.h"
18 #include "MockFencedExecutionCallback.h"
19
20 #include <aidl/android/hardware/neuralnetworks/IFencedExecutionCallback.h>
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include <nnapi/IExecution.h>
24 #include <nnapi/TypeUtils.h>
25 #include <nnapi/Types.h>
26 #include <nnapi/hal/aidl/Execution.h>
27
28 #include <functional>
29 #include <memory>
30
31 namespace aidl::android::hardware::neuralnetworks::utils {
32 namespace {
33
34 using ::testing::_;
35 using ::testing::DoAll;
36 using ::testing::Invoke;
37 using ::testing::InvokeWithoutArgs;
38 using ::testing::SetArgPointee;
39
40 const std::shared_ptr<IExecution> kInvalidExecution;
41 constexpr auto kNoTiming = Timing{.timeOnDeviceNs = -1, .timeInDriverNs = -1};
42
__anon4245b6ff0202null43 constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
44
__anon4245b6ff0302null45 constexpr auto makeGeneralFailure = [] {
46 return ndk::ScopedAStatus::fromServiceSpecificError(
47 static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
48 };
__anon4245b6ff0402null49 constexpr auto makeGeneralTransportFailure = [] {
50 return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
51 };
__anon4245b6ff0502null52 constexpr auto makeDeadObjectFailure = [] {
53 return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
54 };
55
makeFencedExecutionResult(const std::shared_ptr<MockFencedExecutionCallback> & callback)56 auto makeFencedExecutionResult(const std::shared_ptr<MockFencedExecutionCallback>& callback) {
57 return [callback](const std::vector<ndk::ScopedFileDescriptor>& /*waitFor*/,
58 int64_t /*deadline*/, int64_t /*duration*/,
59 FencedExecutionResult* fencedExecutionResult) {
60 *fencedExecutionResult = FencedExecutionResult{.callback = callback,
61 .syncFence = ndk::ScopedFileDescriptor(-1)};
62 return ndk::ScopedAStatus::ok();
63 };
64 }
65
66 } // namespace
67
TEST(ExecutionTest,invalidExecution)68 TEST(ExecutionTest, invalidExecution) {
69 // run test
70 const auto result = Execution::create(kInvalidExecution, {});
71
72 // verify result
73 ASSERT_FALSE(result.has_value());
74 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
75 }
76
TEST(ExecutionTest,executeSync)77 TEST(ExecutionTest, executeSync) {
78 // setup call
79 const auto mockExecution = MockExecution::create();
80 const auto execution = Execution::create(mockExecution, {}).value();
81 const auto mockExecutionResult = ExecutionResult{
82 .outputSufficientSize = true,
83 .outputShapes = {},
84 .timing = kNoTiming,
85 };
86 EXPECT_CALL(*mockExecution, executeSynchronously(_, _))
87 .Times(1)
88 .WillOnce(
89 DoAll(SetArgPointee<1>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
90
91 // run test
92 const auto result = execution->compute({});
93
94 // verify result
95 EXPECT_TRUE(result.has_value())
96 << "Failed with " << result.error().code << ": " << result.error().message;
97 }
98
TEST(ExecutionTest,executeSyncError)99 TEST(ExecutionTest, executeSyncError) {
100 // setup test
101 const auto mockExecution = MockExecution::create();
102 const auto execution = Execution::create(mockExecution, {}).value();
103 EXPECT_CALL(*mockExecution, executeSynchronously(_, _))
104 .Times(1)
105 .WillOnce(Invoke(makeGeneralFailure));
106
107 // run test
108 const auto result = execution->compute({});
109
110 // verify result
111 ASSERT_FALSE(result.has_value());
112 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
113 }
114
TEST(ExecutionTest,executeSyncTransportFailure)115 TEST(ExecutionTest, executeSyncTransportFailure) {
116 // setup test
117 const auto mockExecution = MockExecution::create();
118 const auto execution = Execution::create(mockExecution, {}).value();
119 EXPECT_CALL(*mockExecution, executeSynchronously(_, _))
120 .Times(1)
121 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
122
123 // run test
124 const auto result = execution->compute({});
125
126 // verify result
127 ASSERT_FALSE(result.has_value());
128 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
129 }
130
TEST(ExecutionTest,executeSyncDeadObject)131 TEST(ExecutionTest, executeSyncDeadObject) {
132 // setup test
133 const auto mockExecution = MockExecution::create();
134 const auto execution = Execution::create(mockExecution, {}).value();
135 EXPECT_CALL(*mockExecution, executeSynchronously(_, _))
136 .Times(1)
137 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
138
139 // run test
140 const auto result = execution->compute({});
141
142 // verify result
143 ASSERT_FALSE(result.has_value());
144 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
145 }
146
TEST(ExecutionTest,executeFenced)147 TEST(ExecutionTest, executeFenced) {
148 // setup call
149 const auto mockExecution = MockExecution::create();
150 const auto execution = Execution::create(mockExecution, {}).value();
151 const auto mockCallback = MockFencedExecutionCallback::create();
152 EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
153 .Times(1)
154 .WillOnce(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
155 SetArgPointee<2>(ErrorStatus::NONE), Invoke(makeStatusOk)));
156 EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
157 .Times(1)
158 .WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
159
160 // run test
161 const auto result = execution->computeFenced({}, {}, {});
162
163 // verify result
164 ASSERT_TRUE(result.has_value())
165 << "Failed with " << result.error().code << ": " << result.error().message;
166 const auto& [syncFence, callback] = result.value();
167 EXPECT_EQ(syncFence.syncWait({}), nn::SyncFence::FenceState::SIGNALED);
168 ASSERT_NE(callback, nullptr);
169
170 // get results from callback
171 const auto callbackResult = callback();
172 ASSERT_TRUE(callbackResult.has_value()) << "Failed with " << callbackResult.error().code << ": "
173 << callbackResult.error().message;
174 }
175
TEST(ExecutionTest,executeFencedCallbackError)176 TEST(ExecutionTest, executeFencedCallbackError) {
177 // setup call
178 const auto mockExecution = MockExecution::create();
179 const auto execution = Execution::create(mockExecution, {}).value();
180 const auto mockCallback = MockFencedExecutionCallback::create();
181 EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
182 .Times(1)
183 .WillOnce(Invoke(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
184 SetArgPointee<2>(ErrorStatus::GENERAL_FAILURE),
185 Invoke(makeStatusOk))));
186 EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
187 .Times(1)
188 .WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
189
190 // run test
191 const auto result = execution->computeFenced({}, {}, {});
192
193 // verify result
194 ASSERT_TRUE(result.has_value())
195 << "Failed with " << result.error().code << ": " << result.error().message;
196 const auto& [syncFence, callback] = result.value();
197 EXPECT_NE(syncFence.syncWait({}), nn::SyncFence::FenceState::ACTIVE);
198 ASSERT_NE(callback, nullptr);
199
200 // verify callback failure
201 const auto callbackResult = callback();
202 ASSERT_FALSE(callbackResult.has_value());
203 EXPECT_EQ(callbackResult.error().code, nn::ErrorStatus::GENERAL_FAILURE);
204 }
205
TEST(ExecutionTest,executeFencedError)206 TEST(ExecutionTest, executeFencedError) {
207 // setup test
208 const auto mockExecution = MockExecution::create();
209 const auto execution = Execution::create(mockExecution, {}).value();
210 EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
211 .Times(1)
212 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
213
214 // run test
215 const auto result = execution->computeFenced({}, {}, {});
216
217 // verify result
218 ASSERT_FALSE(result.has_value());
219 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
220 }
221
TEST(ExecutionTest,executeFencedTransportFailure)222 TEST(ExecutionTest, executeFencedTransportFailure) {
223 // setup test
224 const auto mockExecution = MockExecution::create();
225 const auto execution = Execution::create(mockExecution, {}).value();
226 EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
227 .Times(1)
228 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
229
230 // run test
231 const auto result = execution->computeFenced({}, {}, {});
232
233 // verify result
234 ASSERT_FALSE(result.has_value());
235 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
236 }
237
TEST(ExecutionTest,executeFencedDeadObject)238 TEST(ExecutionTest, executeFencedDeadObject) {
239 // setup test
240 const auto mockExecution = MockExecution::create();
241 const auto execution = Execution::create(mockExecution, {}).value();
242 EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
243 .Times(1)
244 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
245
246 // run test
247 const auto result = execution->computeFenced({}, {}, {});
248
249 // verify result
250 ASSERT_FALSE(result.has_value());
251 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
252 }
253
254 } // namespace aidl::android::hardware::neuralnetworks::utils
255