1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "MockExecution.h"
18 #include "MockFencedExecutionCallback.h"
19 
20 #include <aidl/android/hardware/neuralnetworks/IFencedExecutionCallback.h>
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include <nnapi/IExecution.h>
24 #include <nnapi/TypeUtils.h>
25 #include <nnapi/Types.h>
26 #include <nnapi/hal/aidl/Execution.h>
27 
28 #include <functional>
29 #include <memory>
30 
31 namespace aidl::android::hardware::neuralnetworks::utils {
32 namespace {
33 
34 using ::testing::_;
35 using ::testing::DoAll;
36 using ::testing::Invoke;
37 using ::testing::InvokeWithoutArgs;
38 using ::testing::SetArgPointee;
39 
40 const std::shared_ptr<IExecution> kInvalidExecution;
41 constexpr auto kNoTiming = Timing{.timeOnDeviceNs = -1, .timeInDriverNs = -1};
42 
__anon4245b6ff0202null43 constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
44 
__anon4245b6ff0302null45 constexpr auto makeGeneralFailure = [] {
46     return ndk::ScopedAStatus::fromServiceSpecificError(
47             static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
48 };
__anon4245b6ff0402null49 constexpr auto makeGeneralTransportFailure = [] {
50     return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
51 };
__anon4245b6ff0502null52 constexpr auto makeDeadObjectFailure = [] {
53     return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
54 };
55 
makeFencedExecutionResult(const std::shared_ptr<MockFencedExecutionCallback> & callback)56 auto makeFencedExecutionResult(const std::shared_ptr<MockFencedExecutionCallback>& callback) {
57     return [callback](const std::vector<ndk::ScopedFileDescriptor>& /*waitFor*/,
58                       int64_t /*deadline*/, int64_t /*duration*/,
59                       FencedExecutionResult* fencedExecutionResult) {
60         *fencedExecutionResult = FencedExecutionResult{.callback = callback,
61                                                        .syncFence = ndk::ScopedFileDescriptor(-1)};
62         return ndk::ScopedAStatus::ok();
63     };
64 }
65 
66 }  // namespace
67 
TEST(ExecutionTest,invalidExecution)68 TEST(ExecutionTest, invalidExecution) {
69     // run test
70     const auto result = Execution::create(kInvalidExecution, {});
71 
72     // verify result
73     ASSERT_FALSE(result.has_value());
74     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
75 }
76 
TEST(ExecutionTest,executeSync)77 TEST(ExecutionTest, executeSync) {
78     // setup call
79     const auto mockExecution = MockExecution::create();
80     const auto execution = Execution::create(mockExecution, {}).value();
81     const auto mockExecutionResult = ExecutionResult{
82             .outputSufficientSize = true,
83             .outputShapes = {},
84             .timing = kNoTiming,
85     };
86     EXPECT_CALL(*mockExecution, executeSynchronously(_, _))
87             .Times(1)
88             .WillOnce(
89                     DoAll(SetArgPointee<1>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
90 
91     // run test
92     const auto result = execution->compute({});
93 
94     // verify result
95     EXPECT_TRUE(result.has_value())
96             << "Failed with " << result.error().code << ": " << result.error().message;
97 }
98 
TEST(ExecutionTest,executeSyncError)99 TEST(ExecutionTest, executeSyncError) {
100     // setup test
101     const auto mockExecution = MockExecution::create();
102     const auto execution = Execution::create(mockExecution, {}).value();
103     EXPECT_CALL(*mockExecution, executeSynchronously(_, _))
104             .Times(1)
105             .WillOnce(Invoke(makeGeneralFailure));
106 
107     // run test
108     const auto result = execution->compute({});
109 
110     // verify result
111     ASSERT_FALSE(result.has_value());
112     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
113 }
114 
TEST(ExecutionTest,executeSyncTransportFailure)115 TEST(ExecutionTest, executeSyncTransportFailure) {
116     // setup test
117     const auto mockExecution = MockExecution::create();
118     const auto execution = Execution::create(mockExecution, {}).value();
119     EXPECT_CALL(*mockExecution, executeSynchronously(_, _))
120             .Times(1)
121             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
122 
123     // run test
124     const auto result = execution->compute({});
125 
126     // verify result
127     ASSERT_FALSE(result.has_value());
128     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
129 }
130 
TEST(ExecutionTest,executeSyncDeadObject)131 TEST(ExecutionTest, executeSyncDeadObject) {
132     // setup test
133     const auto mockExecution = MockExecution::create();
134     const auto execution = Execution::create(mockExecution, {}).value();
135     EXPECT_CALL(*mockExecution, executeSynchronously(_, _))
136             .Times(1)
137             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
138 
139     // run test
140     const auto result = execution->compute({});
141 
142     // verify result
143     ASSERT_FALSE(result.has_value());
144     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
145 }
146 
TEST(ExecutionTest,executeFenced)147 TEST(ExecutionTest, executeFenced) {
148     // setup call
149     const auto mockExecution = MockExecution::create();
150     const auto execution = Execution::create(mockExecution, {}).value();
151     const auto mockCallback = MockFencedExecutionCallback::create();
152     EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
153             .Times(1)
154             .WillOnce(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
155                             SetArgPointee<2>(ErrorStatus::NONE), Invoke(makeStatusOk)));
156     EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
157             .Times(1)
158             .WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
159 
160     // run test
161     const auto result = execution->computeFenced({}, {}, {});
162 
163     // verify result
164     ASSERT_TRUE(result.has_value())
165             << "Failed with " << result.error().code << ": " << result.error().message;
166     const auto& [syncFence, callback] = result.value();
167     EXPECT_EQ(syncFence.syncWait({}), nn::SyncFence::FenceState::SIGNALED);
168     ASSERT_NE(callback, nullptr);
169 
170     // get results from callback
171     const auto callbackResult = callback();
172     ASSERT_TRUE(callbackResult.has_value()) << "Failed with " << callbackResult.error().code << ": "
173                                             << callbackResult.error().message;
174 }
175 
TEST(ExecutionTest,executeFencedCallbackError)176 TEST(ExecutionTest, executeFencedCallbackError) {
177     // setup call
178     const auto mockExecution = MockExecution::create();
179     const auto execution = Execution::create(mockExecution, {}).value();
180     const auto mockCallback = MockFencedExecutionCallback::create();
181     EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
182             .Times(1)
183             .WillOnce(Invoke(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
184                                    SetArgPointee<2>(ErrorStatus::GENERAL_FAILURE),
185                                    Invoke(makeStatusOk))));
186     EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
187             .Times(1)
188             .WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
189 
190     // run test
191     const auto result = execution->computeFenced({}, {}, {});
192 
193     // verify result
194     ASSERT_TRUE(result.has_value())
195             << "Failed with " << result.error().code << ": " << result.error().message;
196     const auto& [syncFence, callback] = result.value();
197     EXPECT_NE(syncFence.syncWait({}), nn::SyncFence::FenceState::ACTIVE);
198     ASSERT_NE(callback, nullptr);
199 
200     // verify callback failure
201     const auto callbackResult = callback();
202     ASSERT_FALSE(callbackResult.has_value());
203     EXPECT_EQ(callbackResult.error().code, nn::ErrorStatus::GENERAL_FAILURE);
204 }
205 
TEST(ExecutionTest,executeFencedError)206 TEST(ExecutionTest, executeFencedError) {
207     // setup test
208     const auto mockExecution = MockExecution::create();
209     const auto execution = Execution::create(mockExecution, {}).value();
210     EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
211             .Times(1)
212             .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
213 
214     // run test
215     const auto result = execution->computeFenced({}, {}, {});
216 
217     // verify result
218     ASSERT_FALSE(result.has_value());
219     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
220 }
221 
TEST(ExecutionTest,executeFencedTransportFailure)222 TEST(ExecutionTest, executeFencedTransportFailure) {
223     // setup test
224     const auto mockExecution = MockExecution::create();
225     const auto execution = Execution::create(mockExecution, {}).value();
226     EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
227             .Times(1)
228             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
229 
230     // run test
231     const auto result = execution->computeFenced({}, {}, {});
232 
233     // verify result
234     ASSERT_FALSE(result.has_value());
235     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
236 }
237 
TEST(ExecutionTest,executeFencedDeadObject)238 TEST(ExecutionTest, executeFencedDeadObject) {
239     // setup test
240     const auto mockExecution = MockExecution::create();
241     const auto execution = Execution::create(mockExecution, {}).value();
242     EXPECT_CALL(*mockExecution, executeFenced(_, _, _, _))
243             .Times(1)
244             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
245 
246     // run test
247     const auto result = execution->computeFenced({}, {}, {});
248 
249     // verify result
250     ASSERT_FALSE(result.has_value());
251     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
252 }
253 
254 }  // namespace aidl::android::hardware::neuralnetworks::utils
255