1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "UnidirectionalSequenceRNN.h"
20
21 #include <algorithm>
22 #include <utility>
23 #include <vector>
24
25 #include "OperationResolver.h"
26 #include "RNN.h"
27 #include "nnapi/TypeUtils.h"
28
29 namespace android {
30 namespace nn {
31 namespace unidirectional_sequence_rnn {
32
33 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
34 namespace {
35
36 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)37 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
38 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
39 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
40 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
41 for (uint32_t f = 0; f < firstDimSize; ++f) {
42 for (uint32_t s = 0; s < secondDimSize; ++s) {
43 for (uint32_t i = 0; i < inputSize; ++i) {
44 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
45 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
46 output[outputIndex] = input[inputIndex];
47 }
48 }
49 }
50 }
51
52 template <typename T>
executeTyped(IOperationExecutionContext * context)53 bool executeTyped(IOperationExecutionContext* context) {
54 const T* input = context->getInputBuffer<T>(kInputTensor);
55 Shape inputShape = context->getInputShape(kInputTensor);
56 const T* weights = context->getInputBuffer<T>(kWeightsTensor);
57 Shape weightsShape = context->getInputShape(kWeightsTensor);
58 const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor);
59 Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor);
60 const T* bias = context->getInputBuffer<T>(kBiasTensor);
61 const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor);
62 int32_t activation = context->getInputValue<int32_t>(kActivationParam);
63
64 T* output = context->getOutputBuffer<T>(kOutputTensor);
65 Shape outputShape = context->getOutputShape(kOutputTensor);
66
67 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
68 // If the input tensors are not in time major format, we transpose the first
69 // two dimensions, and set input and output pointers to temporary vectors
70 // which are transposed back after the RNN is applied.
71 std::vector<T> inputTransposed;
72 std::vector<T> outputTransposed;
73 if (!timeMajor) {
74 // Convert input and output to time major format.
75 inputTransposed.resize(getNumberOfElements(inputShape));
76 outputTransposed.resize(getNumberOfElements(outputShape));
77 transposeFirstTwoDims(input, inputShape, inputTransposed.data());
78 input = inputTransposed.data();
79 output = outputTransposed.data();
80 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
81 std::swap(outputShape.dimensions[0], outputShape.dimensions[1]);
82 }
83
84 const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
85 const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
86 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
87 const uint32_t numUnits = getSizeOfDimension(weightsShape, 0);
88
89 // A shape at a fixed step (removed time dimension).
90 Shape fixedTimeInputShape = inputShape;
91 fixedTimeInputShape.dimensions.resize(2);
92 fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1];
93 fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2];
94
95 for (uint32_t i = 0; i < maxTime; ++i) {
96 RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape,
97 recurrentWeights, recurrentWeightsShape, activation, output);
98 input += batchSize * inputSize;
99 hiddenState = output;
100 output += batchSize * numUnits;
101 }
102
103 if (!timeMajor) {
104 transposeFirstTwoDims(outputTransposed.data(), outputShape,
105 context->getOutputBuffer<T>(kOutputTensor));
106 }
107
108 if (context->getNumOutputs() == kNumOutputsWithState) {
109 // We checked that the state output is not omitted during preparation.
110 T* stateOutput = context->getOutputBuffer<T>(kStateOutputTensor);
111 std::copy(hiddenState, hiddenState + batchSize * numUnits, stateOutput);
112 }
113 return true;
114 }
115
116 } // namespace
117
prepare(IOperationExecutionContext * context)118 bool prepare(IOperationExecutionContext* context) {
119 Shape input = context->getInputShape(kInputTensor);
120 Shape weights = context->getInputShape(kWeightsTensor);
121 Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor);
122 Shape bias = context->getInputShape(kBiasTensor);
123 Shape hiddenState = context->getInputShape(kHiddenStateTensor);
124
125 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
126 NN_RET_CHECK(timeMajor == 0 || timeMajor == 1);
127 const uint32_t batchSize =
128 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
129 const uint32_t maxTime =
130 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
131 const uint32_t numUnits = getSizeOfDimension(weights, 0);
132 const uint32_t inputSize = getSizeOfDimension(input, 2);
133
134 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3u);
135 NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2u);
136 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2u);
137 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1u);
138 NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2u);
139
140 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1));
141 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0));
142 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0));
143 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1));
144 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0));
145 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1));
146
147 Shape output = context->getOutputShape(kOutputTensor);
148 output.dimensions.resize(3);
149 output.dimensions[0] = timeMajor ? maxTime : batchSize;
150 output.dimensions[1] = timeMajor ? batchSize : maxTime;
151 output.dimensions[2] = numUnits;
152
153 if (context->getNumOutputs() == kNumOutputsWithState) {
154 NN_RET_CHECK(!context->isOmittedOutput(kStateOutputTensor));
155 Shape outputStateShape = context->getInputShape(kHiddenStateTensor);
156 outputStateShape.dimensions.resize(2);
157 outputStateShape.dimensions[0] = batchSize;
158 outputStateShape.dimensions[1] = numUnits;
159 NN_RET_CHECK(context->setOutputShape(kStateOutputTensor, outputStateShape));
160 }
161
162 return context->setOutputShape(kOutputTensor, output);
163 }
164
execute(IOperationExecutionContext * context)165 bool execute(IOperationExecutionContext* context) {
166 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
167 executeTyped<_Float16>(context);
168 } else {
169 executeTyped<float>(context);
170 }
171 return true;
172 }
173 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
174
175 } // namespace unidirectional_sequence_rnn
176
177 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(UNIDIRECTIONAL_SEQUENCE_RNN,
178 unidirectional_sequence_rnn::prepare,
179 unidirectional_sequence_rnn::execute);
180
181 } // namespace nn
182 } // namespace android
183