1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19
20 #include <vector>
21
22 #include "NeuralNetworksWrapper.h"
23 #include "RNN.h"
24
25 namespace android {
26 namespace nn {
27 namespace wrapper {
28
29 using ::testing::Each;
30 using ::testing::FloatNear;
31 using ::testing::Matcher;
32
33 namespace {
34
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-5)35 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
36 float max_abs_error = 1.e-5) {
37 std::vector<Matcher<float>> matchers;
38 matchers.reserve(values.size());
39 for (const float& v : values) {
40 matchers.emplace_back(FloatNear(v, max_abs_error));
41 }
42 return matchers;
43 }
44
45 static float rnn_input[] = {
46 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, 0.43773448,
47 0.60379338, 0.35562468, -0.69424844, -0.93421471, -0.87287879, 0.37144363,
48 -0.62476718, 0.23791671, 0.40060222, 0.1356622, -0.99774903, -0.98858172,
49 -0.38952237, -0.47685933, 0.31073618, 0.71511042, -0.63767755, -0.31729108,
50 0.33468103, 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
51 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, -0.61777675,
52 -0.21095741, 0.41213346, 0.73784804, 0.094794154, 0.47791874, 0.86496925,
53 -0.53376222, 0.85315156, 0.10288584, 0.86684, -0.011186242, 0.10513687,
54 0.87825835, 0.59929144, 0.62827742, 0.18899453, 0.31440187, 0.99059987,
55 0.87170351, -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
56 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, -0.66609079,
57 0.59098077, 0.73017097, 0.74604273, 0.32882881, -0.17503482, 0.22396147,
58 0.19379807, 0.29120302, 0.077113032, -0.70331609, 0.15804303, -0.93407321,
59 0.40182066, 0.036301374, 0.66521823, 0.0300982, -0.7747041, -0.02038002,
60 0.020698071, -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
61 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, 0.43519354,
62 0.14744234, 0.62589407, 0.1653645, -0.10651493, -0.045277178, 0.99032974,
63 -0.88255352, -0.85147917, 0.28153265, 0.19455957, -0.55479527, -0.56042433,
64 0.26048636, 0.84702539, 0.47587705, -0.074295521, -0.12287641, 0.70117295,
65 0.90532446, 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
66 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, 0.93455386,
67 -0.6324693, -0.083922029};
68
69 static float rnn_golden_output[] = {
70 0.496726, 0, 0.965996, 0, 0.0584254, 0, 0, 0.12315,
71 0, 0, 0.612266, 0.456601, 0, 0.52286, 1.16099, 0.0291232,
72
73 0, 0, 0.524901, 0, 0, 0, 0, 1.02116,
74 0, 1.35762, 0, 0.356909, 0.436415, 0.0355727, 0, 0,
75
76 0, 0, 0, 0.262335, 0, 0, 0, 1.33992,
77 0, 2.9739, 0, 0, 1.31914, 2.66147, 0, 0,
78
79 0.942568, 0, 0, 0, 0.025507, 0, 0, 0,
80 0.321429, 0.569141, 1.25274, 1.57719, 0.8158, 1.21805, 0.586239, 0.25427,
81
82 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, 0.363026, 0,
83 0.533426, 0, 1.25926, 0.722707, 0, 1.22031, 1.30117, 0.495867,
84
85 0.222187, 0, 0.72725, 0, 0.767003, 0, 0, 0.147835,
86 0, 0, 0, 0.608758, 0.469394, 0.00720298, 0.927537, 0,
87
88 0.856974, 0.424257, 0, 0, 0.937329, 0, 0, 0,
89 0.476425, 0, 0.566017, 0.418462, 0.141911, 0.996214, 1.13063, 0,
90
91 0.967899, 0, 0, 0, 0.0831304, 0, 0, 1.00378,
92 0, 0, 0, 1.44818, 1.01768, 0.943891, 0.502745, 0,
93
94 0.940135, 0, 0, 0, 0, 0, 0, 2.13243,
95 0, 0.71208, 0.123918, 1.53907, 1.30225, 1.59644, 0.70222, 0,
96
97 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, 0.343448, 0,
98 0.107756, 0.614544, 1.44549, 1.52311, 0.0454298, 0.300267, 0.562784, 0.395095,
99
100 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, 0, 0,
101 0, 0.735363, 0.0759267, 1.91017, 0.941888, 0, 0, 0,
102
103 0, 0, 1.5909, 0, 0, 0, 0, 0.5755,
104 0, 0.184687, 0, 1.56296, 0.625285, 0, 0, 0,
105
106 0, 0, 0.0857888, 0, 0, 0, 0, 0.488383,
107 0.252786, 0, 0, 0, 1.02817, 1.85665, 0, 0,
108
109 0.00981836, 0, 1.06371, 0, 0, 0, 0, 0,
110 0, 0.290445, 0.316406, 0, 0.304161, 1.25079, 0.0707152, 0,
111
112 0.986264, 0.309201, 0, 0, 0, 0, 0, 1.64896,
113 0.346248, 0, 0.918175, 0.78884, 0.524981, 1.92076, 2.07013, 0.333244,
114
115 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616,
116 0, 0.728256, 0.84183, 0.0907453, 0.628881, 3.58099, 1.49974, 0};
117
118 } // anonymous namespace
119
120 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
121 ACTION(Input) \
122 ACTION(Weights) \
123 ACTION(RecurrentWeights) \
124 ACTION(Bias) \
125 ACTION(HiddenStateIn)
126
127 // For all output and intermediate states
128 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
129 ACTION(HiddenStateOut) \
130 ACTION(Output)
131
132 class BasicRNNOpModel {
133 public:
BasicRNNOpModel(uint32_t batches,uint32_t units,uint32_t size)134 BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
135 : batches_(batches), units_(units), input_size_(size), activation_(kActivationRelu) {
136 std::vector<uint32_t> inputs;
137
138 OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
139 inputs.push_back(model_.addOperand(&InputTy));
140 OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
141 inputs.push_back(model_.addOperand(&WeightTy));
142 OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
143 inputs.push_back(model_.addOperand(&RecurrentWeightTy));
144 OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
145 inputs.push_back(model_.addOperand(&BiasTy));
146 OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
147 inputs.push_back(model_.addOperand(&HiddenStateTy));
148 OperandType ActionParamTy(Type::INT32, {});
149 inputs.push_back(model_.addOperand(&ActionParamTy));
150
151 std::vector<uint32_t> outputs;
152
153 outputs.push_back(model_.addOperand(&HiddenStateTy));
154 OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
155 outputs.push_back(model_.addOperand(&OutputTy));
156
157 Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
158 HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
159 HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
160 Output_.insert(Output_.end(), batches_ * units_, 0.f);
161
162 model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
163 model_.identifyInputsAndOutputs(inputs, outputs);
164
165 model_.finish();
166 }
167
168 #define DefineSetter(X) \
169 void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
170
171 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
172
173 #undef DefineSetter
174
SetInput(int offset,float * begin,float * end)175 void SetInput(int offset, float* begin, float* end) {
176 for (; begin != end; begin++, offset++) {
177 Input_[offset] = *begin;
178 }
179 }
180
ResetHiddenState()181 void ResetHiddenState() {
182 std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
183 std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
184 }
185
GetOutput() const186 const std::vector<float>& GetOutput() const { return Output_; }
187
input_size() const188 uint32_t input_size() const { return input_size_; }
num_units() const189 uint32_t num_units() const { return units_; }
num_batches() const190 uint32_t num_batches() const { return batches_; }
191
Invoke()192 void Invoke() {
193 ASSERT_TRUE(model_.isValid());
194
195 HiddenStateIn_.swap(HiddenStateOut_);
196
197 Compilation compilation(&model_);
198 compilation.finish();
199 Execution execution(&compilation);
200 #define SetInputOrWeight(X) \
201 ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
202 Result::NO_ERROR);
203
204 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
205
206 #undef SetInputOrWeight
207
208 #define SetOutput(X) \
209 ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
210 Result::NO_ERROR);
211
212 FOR_ALL_OUTPUT_TENSORS(SetOutput);
213
214 #undef SetOutput
215
216 ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, sizeof(activation_)),
217 Result::NO_ERROR);
218
219 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
220 }
221
222 private:
223 Model model_;
224
225 const uint32_t batches_;
226 const uint32_t units_;
227 const uint32_t input_size_;
228
229 const int activation_;
230
231 #define DefineTensor(X) std::vector<float> X##_;
232
233 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
234 FOR_ALL_OUTPUT_TENSORS(DefineTensor);
235
236 #undef DefineTensor
237 };
238
TEST(RNNOpTest,BlackBoxTest)239 TEST(RNNOpTest, BlackBoxTest) {
240 BasicRNNOpModel rnn(2, 16, 8);
241 rnn.SetWeights(
242 {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 0.317493,
243 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 0.448504, 0.317662,
244 0.523556, -0.323514, 0.480877, 0.333113, -0.757714, -0.674487, -0.643585,
245 0.217766, -0.0251462, 0.79512, -0.595574, -0.422444, 0.371572, -0.452178,
246 -0.556069, -0.482188, -0.685456, -0.727851, 0.841829, 0.551535, -0.232336,
247 0.729158, -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
248 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 0.306261,
249 -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 0.0354295, 0.566564,
250 -0.485469, -0.620498, 0.832546, 0.697884, -0.279115, 0.294415, -0.584313,
251 0.548772, 0.0648819, 0.968726, 0.723834, -0.0080452, -0.350386, -0.272803,
252 0.115121, -0.412644, -0.824713, -0.992843, -0.592904, -0.417893, 0.863791,
253 -0.423461, -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
254 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 0.0960841,
255 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 0.37225, -0.623598,
256 -0.405423, 0.455101, 0.673656, -0.145345, -0.511346, -0.901675, -0.81252,
257 -0.127006, 0.809865, -0.721884, 0.636255, 0.868989, -0.347973, -0.10179,
258 -0.777449, 0.917274, 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872,
259 0.972934, -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
260 0.277308, 0.415818});
261
262 rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
263 -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, 0.37197268,
264 0.61957061, 0.3956964, -0.37609905});
265
266 rnn.SetRecurrentWeights(
267 {0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0,
268 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0,
269 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0,
270 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0,
271 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
272 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
273 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
274 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
275 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1,
276 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0,
277 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0,
278 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0,
279 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0, 0, 0,
280 0, 0, 0, 0, 0, 0, 0, 0, 0.1});
281
282 rnn.ResetHiddenState();
283 const int input_sequence_size =
284 sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches());
285
286 for (int i = 0; i < input_sequence_size; i++) {
287 float* batch_start = rnn_input + i * rnn.input_size();
288 float* batch_end = batch_start + rnn.input_size();
289 rnn.SetInput(0, batch_start, batch_end);
290 rnn.SetInput(rnn.input_size(), batch_start, batch_end);
291
292 rnn.Invoke();
293
294 float* golden_start = rnn_golden_output + i * rnn.num_units();
295 float* golden_end = golden_start + rnn.num_units();
296 std::vector<float> expected;
297 expected.insert(expected.end(), golden_start, golden_end);
298 expected.insert(expected.end(), golden_start, golden_end);
299
300 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
301 }
302 }
303
304 } // namespace wrapper
305 } // namespace nn
306 } // namespace android
307