1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "RNN.h"
20
21 #include <vector>
22
23 #include "CpuExecutor.h"
24 #include "CpuOperationUtils.h"
25 #include "Tracing.h"
26
27 namespace android {
28 namespace nn {
29
RNN(const Operation & operation,RunTimeOperandInfo * operands)30 RNN::RNN(const Operation& operation, RunTimeOperandInfo* operands) {
31 NNTRACE_TRANS("RNN::RNN");
32 input_ = GetInput(operation, operands, kInputTensor);
33 weights_ = GetInput(operation, operands, kWeightsTensor);
34 recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
35 hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
36 bias_ = GetInput(operation, operands, kBiasTensor);
37
38 activation_ = static_cast<ActivationFn>(
39 getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
40
41 hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
42 output_ = GetOutput(operation, operands, kOutputTensor);
43 }
44
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * hiddenStateShape,Shape * outputShape)45 bool RNN::Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* hiddenStateShape,
46 Shape* outputShape) {
47 NNTRACE_TRANS("RNN::Prepare");
48 // Check we have all the inputs and outputs we need.
49 const int num_inputs = NumInputsWithValues(operation, operands);
50 NN_CHECK(num_inputs == 6);
51 NN_CHECK_EQ(NumOutputs(operation), 2);
52
53 const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor);
54 const RunTimeOperandInfo* input_weights = GetInput(operation, operands, kWeightsTensor);
55 const RunTimeOperandInfo* recurrent_weights =
56 GetInput(operation, operands, kRecurrentWeightsTensor);
57 const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor);
58
59 // Check all the parameters of tensor match within themselves and match the
60 // input configuration.
61 const uint32_t batch_size = SizeOfDimension(input, 0);
62 const uint32_t num_units = SizeOfDimension(input_weights, 0);
63 NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
64 NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
65 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
66 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
67
68 const Shape& inputShape = input->shape();
69
70 // Resize state.
71 hiddenStateShape->type = inputShape.type;
72 hiddenStateShape->dimensions = {batch_size, num_units};
73
74 // Resize output.
75 outputShape->type = inputShape.type;
76 outputShape->dimensions = {batch_size, num_units};
77
78 return true;
79 }
80
Eval()81 bool RNN::Eval() {
82 switch (input_->type) {
83 case OperandType::TENSOR_FLOAT16: {
84 RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(),
85 reinterpret_cast<_Float16*>(hidden_state_in_->buffer),
86 reinterpret_cast<_Float16*>(bias_->buffer),
87 reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(),
88 reinterpret_cast<_Float16*>(recurrent_weights_->buffer),
89 recurrent_weights_->shape(), activation_,
90 reinterpret_cast<_Float16*>(output_->buffer));
91 memcpy(hidden_state_out_->buffer, output_->buffer,
92 sizeof(_Float16) * getNumberOfElements(output_->shape()));
93 break;
94 }
95 case OperandType::TENSOR_FLOAT32: {
96 RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(),
97 reinterpret_cast<float*>(hidden_state_in_->buffer),
98 reinterpret_cast<float*>(bias_->buffer),
99 reinterpret_cast<float*>(weights_->buffer), weights_->shape(),
100 reinterpret_cast<float*>(recurrent_weights_->buffer),
101 recurrent_weights_->shape(), activation_,
102 reinterpret_cast<float*>(output_->buffer));
103 memcpy(hidden_state_out_->buffer, output_->buffer,
104 sizeof(float) * getNumberOfElements(output_->shape()));
105 break;
106 }
107 default: {
108 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
109 return false;
110 }
111 }
112 return true;
113 }
114
115 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,T * outputData)116 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
117 const T* biasData, const T* weightsData, const Shape& weightsShape,
118 const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
119 const int32_t activation, T* outputData) {
120 NNTRACE_COMP("RNN::Eval");
121
122 Shape dummyShape;
123 uint32_t numUnits = weightsShape.dimensions[0];
124 return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape,
125 hiddenStateInputData, biasData, weightsData, weightsShape,
126 /*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape,
127 recurrentWeightsData, recurrentWeightsShape, activation,
128 /*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData);
129 }
130
131 // A more general version of the RNNStep function.
132 // Auxiliary input is treated as if it was concatenated to a regular input and
133 // the result was multiplied by the weights matrix which was also concatenated
134 // with auxiliary weights.
135 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * auxInputData,const Shape & auxInputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * auxWeightsData,const Shape & auxWeightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,const uint32_t outputBatchStride,const uint32_t outputBatchOffset,T * outputData,T * hiddenStateOutput)136 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
137 const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData,
138 const T* weightsData, const Shape& weightsShape, const T* auxWeightsData,
139 const Shape& auxWeightsShape, const T* recurrentWeightsData,
140 const Shape& recurrentWeightsShape, const int32_t activation,
141 const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData,
142 T* hiddenStateOutput) {
143 NNTRACE_COMP("RNN::Eval");
144
145 const uint32_t batch_size = inputShape.dimensions[0];
146 const uint32_t num_units = weightsShape.dimensions[0];
147 const uint32_t input_size = inputShape.dimensions[1];
148 const uint32_t input_weights_stride = weightsShape.dimensions[1];
149 const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1];
150
151 uint32_t aux_input_size = 0;
152 uint32_t aux_input_weights_stride = 0;
153 bool hasAuxInput = (auxInputData != nullptr);
154 if (hasAuxInput) {
155 aux_input_size = auxInputShape.dimensions[1];
156 aux_input_weights_stride = auxWeightsShape.dimensions[1];
157 }
158
159 // For each batch
160 for (uint32_t b = 0; b < batch_size; b++) {
161 // Initialize the pointer to input, output and bias.
162 const T* input_ptr_batch = inputData + b * input_size;
163 const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units;
164 const T* aux_input_ptr_batch = nullptr;
165 if (hasAuxInput) {
166 aux_input_ptr_batch = auxInputData + b * aux_input_size;
167 }
168 T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset;
169
170 // Initialize input_weights and recurrent_weights.
171 const T* input_weights_ptr = weightsData;
172 const T* recurrent_weights_ptr = recurrentWeightsData;
173 const T* aux_input_weights_ptr = nullptr;
174 if (hasAuxInput) {
175 aux_input_weights_ptr = auxWeightsData;
176 }
177
178 // Output = bias
179 for (uint32_t o = 0; o < num_units; o++) {
180 output_ptr_batch[o] = biasData[o];
181 }
182
183 // Output += input * input_weights
184 for (uint32_t o = 0; o < num_units; o++) {
185 for (uint32_t i = 0; i < input_size; i++) {
186 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
187 }
188 input_weights_ptr += input_weights_stride;
189 }
190
191 if (hasAuxInput) {
192 // Output += aux_input * aux_input_weights
193 for (uint32_t o = 0; o < num_units; o++) {
194 for (uint32_t i = 0; i < input_size; i++) {
195 output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i];
196 }
197 aux_input_weights_ptr += aux_input_weights_stride;
198 }
199 }
200
201 // Output += recurrent_weights * hidden_state
202 for (uint32_t o = 0; o < num_units; o++) {
203 for (uint32_t h = 0; h < num_units; h++) {
204 output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
205 }
206 recurrent_weights_ptr += recurrent_weights_stride;
207 }
208
209 // Output = activation(Output)
210 for (uint32_t o = 0; o < num_units; o++) {
211 output_ptr_batch[o] =
212 (ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]);
213 if (hiddenStateOutput != nullptr) {
214 *hiddenStateOutput = output_ptr_batch[o];
215 ++hiddenStateOutput;
216 }
217 }
218 }
219
220 return true;
221 }
222
223 template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
224 const _Float16* hiddenStateInputData, const _Float16* biasData,
225 const _Float16* weightsData, const Shape& weightsShape,
226 const _Float16* recurrentWeightsData,
227 const Shape& recurrentWeightsShape, int32_t activation,
228 _Float16* outputData);
229 template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
230 const _Float16* auxInputData, const Shape& auxInputShape,
231 const _Float16* hiddenStateInputData, const _Float16* biasData,
232 const _Float16* weightsData, const Shape& weightsShape,
233 const _Float16* auxWeightsData, const Shape& auxWeightsShape,
234 const _Float16* recurrentWeightsData,
235 const Shape& recurrentWeightsShape, const int32_t activation,
236 const uint32_t outputBatchStride,
237 const uint32_t outputBatchOffset, _Float16* outputData,
238 _Float16* hiddenStateOutput);
239 template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
240 const float* hiddenStateInputData, const float* biasData,
241 const float* weightsData, const Shape& weightsShape,
242 const float* recurrentWeightsData,
243 const Shape& recurrentWeightsShape, int32_t activation,
244 float* outputData);
245 template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
246 const float* auxInputData, const Shape& auxInputShape,
247 const float* hiddenStateInputData, const float* biasData,
248 const float* weightsData, const Shape& weightsShape,
249 const float* auxWeightsData, const Shape& auxWeightsShape,
250 const float* recurrentWeightsData,
251 const Shape& recurrentWeightsShape, int32_t activation,
252 uint32_t outputBatchStride, uint32_t outputBatchStep,
253 float* outputData, float* hiddenStateOutput);
254
255 } // namespace nn
256 } // namespace android
257