1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 
20 #include <vector>
21 
22 #include "NeuralNetworksWrapper.h"
23 #include "RNN.h"
24 
25 namespace android {
26 namespace nn {
27 namespace wrapper {
28 
29 using ::testing::Each;
30 using ::testing::FloatNear;
31 using ::testing::Matcher;
32 
33 namespace {
34 
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-5)35 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
36                                            float max_abs_error = 1.e-5) {
37     std::vector<Matcher<float>> matchers;
38     matchers.reserve(values.size());
39     for (const float& v : values) {
40         matchers.emplace_back(FloatNear(v, max_abs_error));
41     }
42     return matchers;
43 }
44 
45 static float rnn_input[] = {
46         0.23689353,  0.285385,     0.037029743, -0.19858193,  -0.27569133,  0.43773448,
47         0.60379338,  0.35562468,   -0.69424844, -0.93421471,  -0.87287879,  0.37144363,
48         -0.62476718, 0.23791671,   0.40060222,  0.1356622,    -0.99774903,  -0.98858172,
49         -0.38952237, -0.47685933,  0.31073618,  0.71511042,   -0.63767755,  -0.31729108,
50         0.33468103,  0.75801885,   0.30660987,  -0.37354088,  0.77002847,   -0.62747043,
51         -0.68572164, 0.0069220066, 0.65791464,  0.35130811,   0.80834007,   -0.61777675,
52         -0.21095741, 0.41213346,   0.73784804,  0.094794154,  0.47791874,   0.86496925,
53         -0.53376222, 0.85315156,   0.10288584,  0.86684,      -0.011186242, 0.10513687,
54         0.87825835,  0.59929144,   0.62827742,  0.18899453,   0.31440187,   0.99059987,
55         0.87170351,  -0.35091716,  0.74861872,  0.17831337,   0.2755419,    0.51864719,
56         0.55084288,  0.58982027,   -0.47443086, 0.20875752,   -0.058871567, -0.66609079,
57         0.59098077,  0.73017097,   0.74604273,  0.32882881,   -0.17503482,  0.22396147,
58         0.19379807,  0.29120302,   0.077113032, -0.70331609,  0.15804303,   -0.93407321,
59         0.40182066,  0.036301374,  0.66521823,  0.0300982,    -0.7747041,   -0.02038002,
60         0.020698071, -0.90300065,  0.62870288,  -0.23068321,  0.27531278,   -0.095755219,
61         -0.712036,   -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,  0.43519354,
62         0.14744234,  0.62589407,   0.1653645,   -0.10651493,  -0.045277178, 0.99032974,
63         -0.88255352, -0.85147917,  0.28153265,  0.19455957,   -0.55479527,  -0.56042433,
64         0.26048636,  0.84702539,   0.47587705,  -0.074295521, -0.12287641,  0.70117295,
65         0.90532446,  0.89782166,   0.79817224,  0.53402734,   -0.33286154,  0.073485017,
66         -0.56172788, -0.044897556, 0.89964068,  -0.067662835, 0.76863563,   0.93455386,
67         -0.6324693,  -0.083922029};
68 
69 static float rnn_golden_output[] = {
70         0.496726,   0,        0.965996,  0,         0.0584254, 0,          0,         0.12315,
71         0,          0,        0.612266,  0.456601,  0,         0.52286,    1.16099,   0.0291232,
72 
73         0,          0,        0.524901,  0,         0,         0,          0,         1.02116,
74         0,          1.35762,  0,         0.356909,  0.436415,  0.0355727,  0,         0,
75 
76         0,          0,        0,         0.262335,  0,         0,          0,         1.33992,
77         0,          2.9739,   0,         0,         1.31914,   2.66147,    0,         0,
78 
79         0.942568,   0,        0,         0,         0.025507,  0,          0,         0,
80         0.321429,   0.569141, 1.25274,   1.57719,   0.8158,    1.21805,    0.586239,  0.25427,
81 
82         1.04436,    0,        0.630725,  0,         0.133801,  0.210693,   0.363026,  0,
83         0.533426,   0,        1.25926,   0.722707,  0,         1.22031,    1.30117,   0.495867,
84 
85         0.222187,   0,        0.72725,   0,         0.767003,  0,          0,         0.147835,
86         0,          0,        0,         0.608758,  0.469394,  0.00720298, 0.927537,  0,
87 
88         0.856974,   0.424257, 0,         0,         0.937329,  0,          0,         0,
89         0.476425,   0,        0.566017,  0.418462,  0.141911,  0.996214,   1.13063,   0,
90 
91         0.967899,   0,        0,         0,         0.0831304, 0,          0,         1.00378,
92         0,          0,        0,         1.44818,   1.01768,   0.943891,   0.502745,  0,
93 
94         0.940135,   0,        0,         0,         0,         0,          0,         2.13243,
95         0,          0.71208,  0.123918,  1.53907,   1.30225,   1.59644,    0.70222,   0,
96 
97         0.804329,   0,        0.430576,  0,         0.505872,  0.509603,   0.343448,  0,
98         0.107756,   0.614544, 1.44549,   1.52311,   0.0454298, 0.300267,   0.562784,  0.395095,
99 
100         0.228154,   0,        0.675323,  0,         1.70536,   0.766217,   0,         0,
101         0,          0.735363, 0.0759267, 1.91017,   0.941888,  0,          0,         0,
102 
103         0,          0,        1.5909,    0,         0,         0,          0,         0.5755,
104         0,          0.184687, 0,         1.56296,   0.625285,  0,          0,         0,
105 
106         0,          0,        0.0857888, 0,         0,         0,          0,         0.488383,
107         0.252786,   0,        0,         0,         1.02817,   1.85665,    0,         0,
108 
109         0.00981836, 0,        1.06371,   0,         0,         0,          0,         0,
110         0,          0.290445, 0.316406,  0,         0.304161,  1.25079,    0.0707152, 0,
111 
112         0.986264,   0.309201, 0,         0,         0,         0,          0,         1.64896,
113         0.346248,   0,        0.918175,  0.78884,   0.524981,  1.92076,    2.07013,   0.333244,
114 
115         0.415153,   0.210318, 0,         0,         0,         0,          0,         2.02616,
116         0,          0.728256, 0.84183,   0.0907453, 0.628881,  3.58099,    1.49974,   0};
117 
118 }  // anonymous namespace
119 
120 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
121     ACTION(Input)                                \
122     ACTION(Weights)                              \
123     ACTION(RecurrentWeights)                     \
124     ACTION(Bias)                                 \
125     ACTION(HiddenStateIn)
126 
127 // For all output and intermediate states
128 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
129     ACTION(HiddenStateOut)             \
130     ACTION(Output)
131 
132 class BasicRNNOpModel {
133    public:
BasicRNNOpModel(uint32_t batches,uint32_t units,uint32_t size)134     BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
135         : batches_(batches), units_(units), input_size_(size), activation_(kActivationRelu) {
136         std::vector<uint32_t> inputs;
137 
138         OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
139         inputs.push_back(model_.addOperand(&InputTy));
140         OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
141         inputs.push_back(model_.addOperand(&WeightTy));
142         OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
143         inputs.push_back(model_.addOperand(&RecurrentWeightTy));
144         OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
145         inputs.push_back(model_.addOperand(&BiasTy));
146         OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
147         inputs.push_back(model_.addOperand(&HiddenStateTy));
148         OperandType ActionParamTy(Type::INT32, {});
149         inputs.push_back(model_.addOperand(&ActionParamTy));
150 
151         std::vector<uint32_t> outputs;
152 
153         outputs.push_back(model_.addOperand(&HiddenStateTy));
154         OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
155         outputs.push_back(model_.addOperand(&OutputTy));
156 
157         Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
158         HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
159         HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
160         Output_.insert(Output_.end(), batches_ * units_, 0.f);
161 
162         model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
163         model_.identifyInputsAndOutputs(inputs, outputs);
164 
165         model_.finish();
166     }
167 
168 #define DefineSetter(X) \
169     void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
170 
171     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
172 
173 #undef DefineSetter
174 
SetInput(int offset,float * begin,float * end)175     void SetInput(int offset, float* begin, float* end) {
176         for (; begin != end; begin++, offset++) {
177             Input_[offset] = *begin;
178         }
179     }
180 
ResetHiddenState()181     void ResetHiddenState() {
182         std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
183         std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
184     }
185 
GetOutput() const186     const std::vector<float>& GetOutput() const { return Output_; }
187 
input_size() const188     uint32_t input_size() const { return input_size_; }
num_units() const189     uint32_t num_units() const { return units_; }
num_batches() const190     uint32_t num_batches() const { return batches_; }
191 
Invoke()192     void Invoke() {
193         ASSERT_TRUE(model_.isValid());
194 
195         HiddenStateIn_.swap(HiddenStateOut_);
196 
197         Compilation compilation(&model_);
198         compilation.finish();
199         Execution execution(&compilation);
200 #define SetInputOrWeight(X)                                                                    \
201     ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
202               Result::NO_ERROR);
203 
204         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
205 
206 #undef SetInputOrWeight
207 
208 #define SetOutput(X)                                                                            \
209     ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
210               Result::NO_ERROR);
211 
212         FOR_ALL_OUTPUT_TENSORS(SetOutput);
213 
214 #undef SetOutput
215 
216         ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, sizeof(activation_)),
217                   Result::NO_ERROR);
218 
219         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
220     }
221 
222    private:
223     Model model_;
224 
225     const uint32_t batches_;
226     const uint32_t units_;
227     const uint32_t input_size_;
228 
229     const int activation_;
230 
231 #define DefineTensor(X) std::vector<float> X##_;
232 
233     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
234     FOR_ALL_OUTPUT_TENSORS(DefineTensor);
235 
236 #undef DefineTensor
237 };
238 
TEST(RNNOpTest,BlackBoxTest)239 TEST(RNNOpTest, BlackBoxTest) {
240     BasicRNNOpModel rnn(2, 16, 8);
241     rnn.SetWeights(
242             {0.461459,  0.153381,    0.529743,   -0.00371218, 0.676267,    -0.211346, 0.317493,
243              0.969689,  -0.343251,   0.186423,   0.398151,    0.152399,    0.448504,  0.317662,
244              0.523556,  -0.323514,   0.480877,   0.333113,    -0.757714,   -0.674487, -0.643585,
245              0.217766,  -0.0251462,  0.79512,    -0.595574,   -0.422444,   0.371572,  -0.452178,
246              -0.556069, -0.482188,   -0.685456,  -0.727851,   0.841829,    0.551535,  -0.232336,
247              0.729158,  -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,  -0.423241,
248              0.548547,  -0.0152023,  -0.757482,  -0.85491,    0.251331,    -0.989183, 0.306261,
249              -0.340716, 0.886103,    -0.0726757, -0.723523,   -0.784303,   0.0354295, 0.566564,
250              -0.485469, -0.620498,   0.832546,   0.697884,    -0.279115,   0.294415,  -0.584313,
251              0.548772,  0.0648819,   0.968726,   0.723834,    -0.0080452,  -0.350386, -0.272803,
252              0.115121,  -0.412644,   -0.824713,  -0.992843,   -0.592904,   -0.417893, 0.863791,
253              -0.423461, -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,  -0.639158,
254              0.816969,  -0.337228,   0.659878,   0.73107,     0.754768,    -0.337042, 0.0960841,
255              0.368357,  0.244191,    -0.817703,  -0.211223,   0.442012,    0.37225,   -0.623598,
256              -0.405423, 0.455101,    0.673656,   -0.145345,   -0.511346,   -0.901675, -0.81252,
257              -0.127006, 0.809865,    -0.721884,  0.636255,    0.868989,    -0.347973, -0.10179,
258              -0.777449, 0.917274,    0.819286,   0.206218,    -0.00785118, 0.167141,  0.45872,
259              0.972934,  -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057, -0.469077,
260              0.277308,  0.415818});
261 
262     rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
263                  -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, 0.37197268,
264                  0.61957061, 0.3956964, -0.37609905});
265 
266     rnn.SetRecurrentWeights(
267             {0.1, 0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0.1, 0,
268              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0.1, 0,   0,   0,
269              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0.1, 0,   0,   0,   0,   0,
270              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0.1, 0,   0,   0,   0,   0,   0,   0,
271              0,   0,   0, 0,   0, 0,   0, 0,   0,  0.1, 0,   0,   0,   0,   0,   0,   0,   0,   0,
272              0,   0,   0, 0,   0, 0,   0, 0.1, 0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
273              0,   0,   0, 0,   0, 0.1, 0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
274              0,   0,   0, 0.1, 0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
275              0,   0.1, 0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0.1,
276              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0,   0,   0.1, 0,   0,
277              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0,   0,   0.1, 0,   0,   0,   0,
278              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0,   0,   0.1, 0,   0,   0,   0,   0,   0,
279              0,   0,   0, 0,   0, 0,   0, 0,   0,  0,   0.1, 0,   0,   0,   0,   0,   0,   0,   0,
280              0,   0,   0, 0,   0, 0,   0, 0,   0.1});
281 
282     rnn.ResetHiddenState();
283     const int input_sequence_size =
284             sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches());
285 
286     for (int i = 0; i < input_sequence_size; i++) {
287         float* batch_start = rnn_input + i * rnn.input_size();
288         float* batch_end = batch_start + rnn.input_size();
289         rnn.SetInput(0, batch_start, batch_end);
290         rnn.SetInput(rnn.input_size(), batch_start, batch_end);
291 
292         rnn.Invoke();
293 
294         float* golden_start = rnn_golden_output + i * rnn.num_units();
295         float* golden_end = golden_start + rnn.num_units();
296         std::vector<float> expected;
297         expected.insert(expected.end(), golden_start, golden_end);
298         expected.insert(expected.end(), golden_start, golden_end);
299 
300         EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
301     }
302 }
303 
304 }  // namespace wrapper
305 }  // namespace nn
306 }  // namespace android
307