1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "UnidirectionalSequenceRNN.h"
20 
21 #include <algorithm>
22 #include <utility>
23 #include <vector>
24 
25 #include "OperationResolver.h"
26 #include "RNN.h"
27 #include "nnapi/TypeUtils.h"
28 
29 namespace android {
30 namespace nn {
31 namespace unidirectional_sequence_rnn {
32 
33 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
34 namespace {
35 
36 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)37 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
38     const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
39     const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
40     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
41     for (uint32_t f = 0; f < firstDimSize; ++f) {
42         for (uint32_t s = 0; s < secondDimSize; ++s) {
43             for (uint32_t i = 0; i < inputSize; ++i) {
44                 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
45                 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
46                 output[outputIndex] = input[inputIndex];
47             }
48         }
49     }
50 }
51 
52 template <typename T>
executeTyped(IOperationExecutionContext * context)53 bool executeTyped(IOperationExecutionContext* context) {
54     const T* input = context->getInputBuffer<T>(kInputTensor);
55     Shape inputShape = context->getInputShape(kInputTensor);
56     const T* weights = context->getInputBuffer<T>(kWeightsTensor);
57     Shape weightsShape = context->getInputShape(kWeightsTensor);
58     const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor);
59     Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor);
60     const T* bias = context->getInputBuffer<T>(kBiasTensor);
61     const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor);
62     int32_t activation = context->getInputValue<int32_t>(kActivationParam);
63 
64     T* output = context->getOutputBuffer<T>(kOutputTensor);
65     Shape outputShape = context->getOutputShape(kOutputTensor);
66 
67     int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
68     // If the input tensors are not in time major format, we transpose the first
69     // two dimensions, and set input and output pointers to temporary vectors
70     // which are transposed back after the RNN is applied.
71     std::vector<T> inputTransposed;
72     std::vector<T> outputTransposed;
73     if (!timeMajor) {
74         // Convert input and output to time major format.
75         inputTransposed.resize(getNumberOfElements(inputShape));
76         outputTransposed.resize(getNumberOfElements(outputShape));
77         transposeFirstTwoDims(input, inputShape, inputTransposed.data());
78         input = inputTransposed.data();
79         output = outputTransposed.data();
80         std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
81         std::swap(outputShape.dimensions[0], outputShape.dimensions[1]);
82     }
83 
84     const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
85     const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
86     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
87     const uint32_t numUnits = getSizeOfDimension(weightsShape, 0);
88 
89     // A shape at a fixed step (removed time dimension).
90     Shape fixedTimeInputShape = inputShape;
91     fixedTimeInputShape.dimensions.resize(2);
92     fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1];
93     fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2];
94 
95     for (uint32_t i = 0; i < maxTime; ++i) {
96         RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape,
97                         recurrentWeights, recurrentWeightsShape, activation, output);
98         input += batchSize * inputSize;
99         hiddenState = output;
100         output += batchSize * numUnits;
101     }
102 
103     if (!timeMajor) {
104         transposeFirstTwoDims(outputTransposed.data(), outputShape,
105                               context->getOutputBuffer<T>(kOutputTensor));
106     }
107 
108     if (context->getNumOutputs() == kNumOutputsWithState) {
109         // We checked that the state output is not omitted during preparation.
110         T* stateOutput = context->getOutputBuffer<T>(kStateOutputTensor);
111         std::copy(hiddenState, hiddenState + batchSize * numUnits, stateOutput);
112     }
113     return true;
114 }
115 
116 }  // namespace
117 
prepare(IOperationExecutionContext * context)118 bool prepare(IOperationExecutionContext* context) {
119     Shape input = context->getInputShape(kInputTensor);
120     Shape weights = context->getInputShape(kWeightsTensor);
121     Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor);
122     Shape bias = context->getInputShape(kBiasTensor);
123     Shape hiddenState = context->getInputShape(kHiddenStateTensor);
124 
125     int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
126     NN_RET_CHECK(timeMajor == 0 || timeMajor == 1);
127     const uint32_t batchSize =
128             timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
129     const uint32_t maxTime =
130             timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
131     const uint32_t numUnits = getSizeOfDimension(weights, 0);
132     const uint32_t inputSize = getSizeOfDimension(input, 2);
133 
134     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3u);
135     NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2u);
136     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2u);
137     NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1u);
138     NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2u);
139 
140     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1));
141     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0));
142     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0));
143     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1));
144     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0));
145     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1));
146 
147     Shape output = context->getOutputShape(kOutputTensor);
148     output.dimensions.resize(3);
149     output.dimensions[0] = timeMajor ? maxTime : batchSize;
150     output.dimensions[1] = timeMajor ? batchSize : maxTime;
151     output.dimensions[2] = numUnits;
152 
153     if (context->getNumOutputs() == kNumOutputsWithState) {
154         NN_RET_CHECK(!context->isOmittedOutput(kStateOutputTensor));
155         Shape outputStateShape = context->getInputShape(kHiddenStateTensor);
156         outputStateShape.dimensions.resize(2);
157         outputStateShape.dimensions[0] = batchSize;
158         outputStateShape.dimensions[1] = numUnits;
159         NN_RET_CHECK(context->setOutputShape(kStateOutputTensor, outputStateShape));
160     }
161 
162     return context->setOutputShape(kOutputTensor, output);
163 }
164 
execute(IOperationExecutionContext * context)165 bool execute(IOperationExecutionContext* context) {
166     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
167         executeTyped<_Float16>(context);
168     } else {
169         executeTyped<float>(context);
170     }
171     return true;
172 }
173 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
174 
175 }  // namespace unidirectional_sequence_rnn
176 
177 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(UNIDIRECTIONAL_SEQUENCE_RNN,
178                                          unidirectional_sequence_rnn::prepare,
179                                          unidirectional_sequence_rnn::execute);
180 
181 }  // namespace nn
182 }  // namespace android
183