1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "RNN.h"
20 
21 #include <vector>
22 
23 #include "CpuExecutor.h"
24 #include "CpuOperationUtils.h"
25 #include "Tracing.h"
26 
27 namespace android {
28 namespace nn {
29 
RNN(const Operation & operation,RunTimeOperandInfo * operands)30 RNN::RNN(const Operation& operation, RunTimeOperandInfo* operands) {
31     NNTRACE_TRANS("RNN::RNN");
32     input_ = GetInput(operation, operands, kInputTensor);
33     weights_ = GetInput(operation, operands, kWeightsTensor);
34     recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
35     hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
36     bias_ = GetInput(operation, operands, kBiasTensor);
37 
38     activation_ = static_cast<ActivationFn>(
39             getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
40 
41     hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
42     output_ = GetOutput(operation, operands, kOutputTensor);
43 }
44 
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * hiddenStateShape,Shape * outputShape)45 bool RNN::Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* hiddenStateShape,
46                   Shape* outputShape) {
47     NNTRACE_TRANS("RNN::Prepare");
48     // Check we have all the inputs and outputs we need.
49     const int num_inputs = NumInputsWithValues(operation, operands);
50     NN_CHECK(num_inputs == 6);
51     NN_CHECK_EQ(NumOutputs(operation), 2);
52 
53     const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor);
54     const RunTimeOperandInfo* input_weights = GetInput(operation, operands, kWeightsTensor);
55     const RunTimeOperandInfo* recurrent_weights =
56             GetInput(operation, operands, kRecurrentWeightsTensor);
57     const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor);
58 
59     // Check all the parameters of tensor match within themselves and match the
60     // input configuration.
61     const uint32_t batch_size = SizeOfDimension(input, 0);
62     const uint32_t num_units = SizeOfDimension(input_weights, 0);
63     NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
64     NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
65     NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
66     NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
67 
68     const Shape& inputShape = input->shape();
69 
70     // Resize state.
71     hiddenStateShape->type = inputShape.type;
72     hiddenStateShape->dimensions = {batch_size, num_units};
73 
74     // Resize output.
75     outputShape->type = inputShape.type;
76     outputShape->dimensions = {batch_size, num_units};
77 
78     return true;
79 }
80 
Eval()81 bool RNN::Eval() {
82     switch (input_->type) {
83         case OperandType::TENSOR_FLOAT16: {
84             RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(),
85                               reinterpret_cast<_Float16*>(hidden_state_in_->buffer),
86                               reinterpret_cast<_Float16*>(bias_->buffer),
87                               reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(),
88                               reinterpret_cast<_Float16*>(recurrent_weights_->buffer),
89                               recurrent_weights_->shape(), activation_,
90                               reinterpret_cast<_Float16*>(output_->buffer));
91             memcpy(hidden_state_out_->buffer, output_->buffer,
92                    sizeof(_Float16) * getNumberOfElements(output_->shape()));
93             break;
94         }
95         case OperandType::TENSOR_FLOAT32: {
96             RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(),
97                            reinterpret_cast<float*>(hidden_state_in_->buffer),
98                            reinterpret_cast<float*>(bias_->buffer),
99                            reinterpret_cast<float*>(weights_->buffer), weights_->shape(),
100                            reinterpret_cast<float*>(recurrent_weights_->buffer),
101                            recurrent_weights_->shape(), activation_,
102                            reinterpret_cast<float*>(output_->buffer));
103             memcpy(hidden_state_out_->buffer, output_->buffer,
104                    sizeof(float) * getNumberOfElements(output_->shape()));
105             break;
106         }
107         default: {
108             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
109             return false;
110         }
111     }
112     return true;
113 }
114 
115 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,T * outputData)116 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
117                   const T* biasData, const T* weightsData, const Shape& weightsShape,
118                   const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
119                   const int32_t activation, T* outputData) {
120     NNTRACE_COMP("RNN::Eval");
121 
122     Shape dummyShape;
123     uint32_t numUnits = weightsShape.dimensions[0];
124     return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape,
125                       hiddenStateInputData, biasData, weightsData, weightsShape,
126                       /*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape,
127                       recurrentWeightsData, recurrentWeightsShape, activation,
128                       /*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData);
129 }
130 
131 // A more general version of the RNNStep function.
132 // Auxiliary input is treated as if it was concatenated to a regular input and
133 // the result was multiplied by the weights matrix which was also concatenated
134 // with auxiliary weights.
135 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * auxInputData,const Shape & auxInputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * auxWeightsData,const Shape & auxWeightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,const uint32_t outputBatchStride,const uint32_t outputBatchOffset,T * outputData,T * hiddenStateOutput)136 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
137                   const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData,
138                   const T* weightsData, const Shape& weightsShape, const T* auxWeightsData,
139                   const Shape& auxWeightsShape, const T* recurrentWeightsData,
140                   const Shape& recurrentWeightsShape, const int32_t activation,
141                   const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData,
142                   T* hiddenStateOutput) {
143     NNTRACE_COMP("RNN::Eval");
144 
145     const uint32_t batch_size = inputShape.dimensions[0];
146     const uint32_t num_units = weightsShape.dimensions[0];
147     const uint32_t input_size = inputShape.dimensions[1];
148     const uint32_t input_weights_stride = weightsShape.dimensions[1];
149     const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1];
150 
151     uint32_t aux_input_size = 0;
152     uint32_t aux_input_weights_stride = 0;
153     bool hasAuxInput = (auxInputData != nullptr);
154     if (hasAuxInput) {
155         aux_input_size = auxInputShape.dimensions[1];
156         aux_input_weights_stride = auxWeightsShape.dimensions[1];
157     }
158 
159     // For each batch
160     for (uint32_t b = 0; b < batch_size; b++) {
161         // Initialize the pointer to input, output and bias.
162         const T* input_ptr_batch = inputData + b * input_size;
163         const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units;
164         const T* aux_input_ptr_batch = nullptr;
165         if (hasAuxInput) {
166             aux_input_ptr_batch = auxInputData + b * aux_input_size;
167         }
168         T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset;
169 
170         // Initialize input_weights and recurrent_weights.
171         const T* input_weights_ptr = weightsData;
172         const T* recurrent_weights_ptr = recurrentWeightsData;
173         const T* aux_input_weights_ptr = nullptr;
174         if (hasAuxInput) {
175             aux_input_weights_ptr = auxWeightsData;
176         }
177 
178         // Output = bias
179         for (uint32_t o = 0; o < num_units; o++) {
180             output_ptr_batch[o] = biasData[o];
181         }
182 
183         // Output += input * input_weights
184         for (uint32_t o = 0; o < num_units; o++) {
185             for (uint32_t i = 0; i < input_size; i++) {
186                 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
187             }
188             input_weights_ptr += input_weights_stride;
189         }
190 
191         if (hasAuxInput) {
192             // Output += aux_input * aux_input_weights
193             for (uint32_t o = 0; o < num_units; o++) {
194                 for (uint32_t i = 0; i < input_size; i++) {
195                     output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i];
196                 }
197                 aux_input_weights_ptr += aux_input_weights_stride;
198             }
199         }
200 
201         // Output += recurrent_weights * hidden_state
202         for (uint32_t o = 0; o < num_units; o++) {
203             for (uint32_t h = 0; h < num_units; h++) {
204                 output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
205             }
206             recurrent_weights_ptr += recurrent_weights_stride;
207         }
208 
209         // Output = activation(Output)
210         for (uint32_t o = 0; o < num_units; o++) {
211             output_ptr_batch[o] =
212                     (ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]);
213             if (hiddenStateOutput != nullptr) {
214                 *hiddenStateOutput = output_ptr_batch[o];
215                 ++hiddenStateOutput;
216             }
217         }
218     }
219 
220     return true;
221 }
222 
223 template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
224                                      const _Float16* hiddenStateInputData, const _Float16* biasData,
225                                      const _Float16* weightsData, const Shape& weightsShape,
226                                      const _Float16* recurrentWeightsData,
227                                      const Shape& recurrentWeightsShape, int32_t activation,
228                                      _Float16* outputData);
229 template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
230                                      const _Float16* auxInputData, const Shape& auxInputShape,
231                                      const _Float16* hiddenStateInputData, const _Float16* biasData,
232                                      const _Float16* weightsData, const Shape& weightsShape,
233                                      const _Float16* auxWeightsData, const Shape& auxWeightsShape,
234                                      const _Float16* recurrentWeightsData,
235                                      const Shape& recurrentWeightsShape, const int32_t activation,
236                                      const uint32_t outputBatchStride,
237                                      const uint32_t outputBatchOffset, _Float16* outputData,
238                                      _Float16* hiddenStateOutput);
239 template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
240                                   const float* hiddenStateInputData, const float* biasData,
241                                   const float* weightsData, const Shape& weightsShape,
242                                   const float* recurrentWeightsData,
243                                   const Shape& recurrentWeightsShape, int32_t activation,
244                                   float* outputData);
245 template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
246                                   const float* auxInputData, const Shape& auxInputShape,
247                                   const float* hiddenStateInputData, const float* biasData,
248                                   const float* weightsData, const Shape& weightsShape,
249                                   const float* auxWeightsData, const Shape& auxWeightsShape,
250                                   const float* recurrentWeightsData,
251                                   const Shape& recurrentWeightsShape, int32_t activation,
252                                   uint32_t outputBatchStride, uint32_t outputBatchStep,
253                                   float* outputData, float* hiddenStateOutput);
254 
255 }  // namespace nn
256 }  // namespace android
257