1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "BidirectionalSequenceLSTM.h"
20 
21 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
22 
23 #include <algorithm>
24 #include <vector>
25 
26 #include "CpuExecutor.h"
27 #include "CpuOperationUtils.h"
28 #include "OperationsExecutionUtils.h"
29 #include "Tracing.h"
30 
31 namespace android {
32 namespace nn {
33 
34 namespace {
35 
36 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)37 inline T* GetBuffer(RunTimeOperandInfo* operand) {
38     return reinterpret_cast<T*>(operand->buffer);
39 }
40 
41 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)42 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
43     return reinterpret_cast<const T*>(operand->buffer);
44 }
45 
46 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)47 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
48     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
49 }
50 
51 enum class LinkingMode {
52     NO_LINKING,
53     PARALLEL_LINKING,
54     CROSS_LINKING,
55 };
56 
getLinkingMode(bool hasAuxInput,bool hasAuxWeights,LinkingMode * linkingMode)57 bool getLinkingMode(bool hasAuxInput, bool hasAuxWeights, LinkingMode* linkingMode) {
58     // Three possible configurations for three possible linking modes:
59     // 1) NO_LINKING -- no auxiliary tensors at all
60     // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
61     //    input to the backward network, so the auxiliary weights are omitted.
62     // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
63     //    auxiliary weights.
64     if (!hasAuxInput && !hasAuxWeights) {
65         *linkingMode = LinkingMode::NO_LINKING;
66     } else if (hasAuxInput && !hasAuxWeights) {
67         *linkingMode = LinkingMode::PARALLEL_LINKING;
68     } else if (hasAuxInput && hasAuxWeights) {
69         *linkingMode = LinkingMode::CROSS_LINKING;
70     } else {
71         NN_RET_CHECK_FAIL()
72                 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
73     }
74 
75     return true;
76 }
77 
78 }  // anonymous namespace
79 
BidirectionalSequenceLSTM(const Operation & operation,RunTimeOperandInfo * operands)80 BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
81                                                      RunTimeOperandInfo* operands) {
82     input_ = GetInput(operation, operands, kInputTensor);
83 
84     fw_input_to_input_weights_ =
85             GetInput(operation, operands, kFwInputToInputWeightsTensor);  // optional
86     fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
87     fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
88     fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
89 
90     fw_recurrent_to_input_weights_ =
91             GetInput(operation, operands, kFwRecurrentToInputWeightsTensor);  // optional
92     fw_recurrent_to_forget_weights_ =
93             GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
94     fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
95     fw_recurrent_to_output_weights_ =
96             GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
97 
98     fw_cell_to_input_weights_ =
99             GetInput(operation, operands, kFwCellToInputWeightsTensor);  // optional
100     fw_cell_to_forget_weights_ =
101             GetInput(operation, operands, kFwCellToForgetWeightsTensor);  // optional
102     fw_cell_to_output_weights_ =
103             GetInput(operation, operands, kFwCellToOutputWeightsTensor);  // optional
104 
105     fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
106     fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
107     fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
108     fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
109 
110     fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor);  // optional
111     fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor);        // optional
112 
113     fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
114     fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
115 
116     bw_input_to_input_weights_ =
117             GetInput(operation, operands, kBwInputToInputWeightsTensor);  // optional
118     bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
119     bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
120     bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
121 
122     bw_recurrent_to_input_weights_ =
123             GetInput(operation, operands, kBwRecurrentToInputWeightsTensor);  // optional
124     bw_recurrent_to_forget_weights_ =
125             GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
126     bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
127     bw_recurrent_to_output_weights_ =
128             GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
129 
130     bw_cell_to_input_weights_ =
131             GetInput(operation, operands, kBwCellToInputWeightsTensor);  // optional
132     bw_cell_to_forget_weights_ =
133             GetInput(operation, operands, kBwCellToForgetWeightsTensor);  // optional
134     bw_cell_to_output_weights_ =
135             GetInput(operation, operands, kBwCellToOutputWeightsTensor);  // optional
136 
137     bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
138     bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
139     bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
140     bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
141 
142     bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor);  // optional
143     bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor);        // optional
144 
145     bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
146     bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
147 
148     aux_input_ = GetInput(operation, operands, kAuxInputTensor);
149     fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
150     fw_aux_input_to_forget_weights_ =
151             GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
152     fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
153     fw_aux_input_to_output_weights_ =
154             GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
155     bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
156     bw_aux_input_to_forget_weights_ =
157             GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
158     bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
159     bw_aux_input_to_output_weights_ =
160             GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
161 
162     fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
163     fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
164     fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
165     fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
166     bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
167     bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
168     bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
169     bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
170 
171     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
172     params_.activation = static_cast<ActivationFn>(getScalarDataWithDefault<int32_t>(
173             activationOperand, TfLiteFusedActivation::kTfLiteActNone));
174     const auto& clipOperand = *GetInput(operation, operands, kCellClipParam);
175     const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
176     if (input_->type == OperandType::TENSOR_FLOAT32) {
177         params_.cell_clip = getScalarDataWithDefault<float>(clipOperand, 0.0f);
178         params_.proj_clip = getScalarDataWithDefault<float>(projOperand, 0.0f);
179     } else {
180         params_.cell_clip =
181                 static_cast<float>(getScalarDataWithDefault<_Float16>(clipOperand, 0.0f));
182         params_.proj_clip =
183                 static_cast<float>(getScalarDataWithDefault<_Float16>(projOperand, 0.0f));
184     }
185     const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
186     params_.merge_outputs = getScalarDataWithDefault<bool>(mergeOutputsOperand, false);
187     const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
188     params_.time_major = getScalarDataWithDefault<bool>(timeMajorOperand, false);
189     params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
190 
191     fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
192     if (!params_.merge_outputs) {
193         bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
194     }
195 
196     params_.output_state = (operation.outputs.size() == 5 || operation.outputs.size() == 6);
197     if (params_.output_state) {
198         uint32_t delta = params_.merge_outputs ? 1 : 0;
199         fw_output_activation_state_ =
200                 GetOutput(operation, operands, kFwOutputActivationStateTensor - delta);
201         fw_output_cell_state_ = GetOutput(operation, operands, kFwOutputCellStateTensor - delta);
202         bw_output_activation_state_ =
203                 GetOutput(operation, operands, kBwOutputActivationStateTensor - delta);
204         bw_output_cell_state_ = GetOutput(operation, operands, kBwOutputCellStateTensor - delta);
205     }
206 }
207 
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * fwOutputShape,Shape * bwOutputShape,Shape * fwOutputActivationState,Shape * fwOutputCellState,Shape * bwOutputActivationState,Shape * bwOutputCellState)208 bool BidirectionalSequenceLSTM::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
209                                         Shape* fwOutputShape, Shape* bwOutputShape,
210                                         Shape* fwOutputActivationState, Shape* fwOutputCellState,
211                                         Shape* bwOutputActivationState, Shape* bwOutputCellState) {
212     // Check we have all the inputs and outputs we need.
213     constexpr int requiredInputs[] = {
214             kInputTensor,
215             kFwInputToForgetWeightsTensor,
216             kFwInputToCellWeightsTensor,
217             kFwInputToOutputWeightsTensor,
218             kFwRecurrentToForgetWeightsTensor,
219             kFwRecurrentToCellWeightsTensor,
220             kFwRecurrentToOutputWeightsTensor,
221             kFwForgetGateBiasTensor,
222             kFwCellGateBiasTensor,
223             kFwOutputGateBiasTensor,
224             kBwInputToForgetWeightsTensor,
225             kBwInputToCellWeightsTensor,
226             kBwInputToOutputWeightsTensor,
227             kBwRecurrentToForgetWeightsTensor,
228             kBwRecurrentToCellWeightsTensor,
229             kBwRecurrentToOutputWeightsTensor,
230             kBwForgetGateBiasTensor,
231             kBwCellGateBiasTensor,
232             kBwOutputGateBiasTensor,
233             kFwInputActivationStateTensor,
234             kFwInputCellStateTensor,
235             kBwInputActivationStateTensor,
236             kBwInputCellStateTensor,
237             kActivationParam,
238             kCellClipParam,
239             kProjClipParam,
240             kMergeOutputsParam,
241             kTimeMajorParam,
242     };
243     for (const int requiredInput : requiredInputs) {
244         NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
245                 << "required input " << requiredInput << " is omitted";
246     }
247 
248     // Check that the scalar operands' buffers are large enough.
249     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
250     NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
251     const auto& cellOperand = *GetInput(operation, operands, kCellClipParam);
252     const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
253     if (input_->type == OperandType::TENSOR_FLOAT32) {
254         NN_RET_CHECK(cellOperand.length >= sizeof(float));
255         NN_RET_CHECK(projOperand.length >= sizeof(float));
256     } else {
257         NN_RET_CHECK(cellOperand.length >= sizeof(_Float16));
258         NN_RET_CHECK(projOperand.length >= sizeof(_Float16));
259     }
260     const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
261     NN_RET_CHECK(mergeOutputsOperand.length >= sizeof(bool));
262     const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
263     NN_RET_CHECK(timeMajorOperand.length >= sizeof(bool));
264 
265     // Inferring batch size, number of outputs and number of cells from the
266     // input tensors.
267     NN_CHECK(NumDimensions(input_) == 3);
268     const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
269     const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
270     const uint32_t n_fw_input = SizeOfDimension(input_, 2);
271 
272     const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
273     NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2u);
274     NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_fw_input);
275 
276     NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2u);
277     NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
278     const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
279 
280     const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
281 
282     NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2u);
283     NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
284     const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
285 
286     // Check that input tensor dimensions matches with each other.
287     if (!LSTMCell::CheckInputTensorDimensions(
288                 input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
289                 fw_input_to_cell_weights_, fw_input_to_output_weights_,
290                 fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
291                 fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
292                 fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
293                 fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
294                 fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
295                 fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
296                 fw_output_layer_norm_weights_, n_fw_input, n_fw_output, n_fw_cell, &params_)) {
297         return false;
298     }
299 
300     if (params_.use_cifg) {
301         NN_RET_CHECK(IsNullInput(fw_aux_input_to_input_weights_) &&
302                      IsNullInput(bw_aux_input_to_input_weights_));
303     }
304 
305     const bool aux_fw_weights_all_or_none =
306             ((params_.use_cifg || !IsNullInput(fw_aux_input_to_input_weights_)) &&
307              !IsNullInput(fw_aux_input_to_forget_weights_) &&
308              !IsNullInput(fw_aux_input_to_cell_weights_) &&
309              !IsNullInput(fw_aux_input_to_output_weights_)) ||
310             (IsNullInput(fw_aux_input_to_input_weights_) &&
311              IsNullInput(fw_aux_input_to_forget_weights_) &&
312              IsNullInput(fw_aux_input_to_cell_weights_) &&
313              IsNullInput(fw_aux_input_to_output_weights_));
314     const bool aux_bw_weights_all_or_none =
315             ((params_.use_cifg || !IsNullInput(bw_aux_input_to_input_weights_)) &&
316              !IsNullInput(bw_aux_input_to_forget_weights_) &&
317              !IsNullInput(bw_aux_input_to_cell_weights_) &&
318              !IsNullInput(bw_aux_input_to_output_weights_)) ||
319             (IsNullInput(bw_aux_input_to_input_weights_) &&
320              IsNullInput(bw_aux_input_to_forget_weights_) &&
321              IsNullInput(bw_aux_input_to_cell_weights_) &&
322              IsNullInput(bw_aux_input_to_output_weights_));
323 
324     NN_RET_CHECK(aux_fw_weights_all_or_none && aux_bw_weights_all_or_none);
325     const bool has_aux_input = !IsNullInput(aux_input_);
326     const bool has_fw_aux_weights = !IsNullInput(fw_aux_input_to_forget_weights_);
327     const bool has_bw_aux_weights = !IsNullInput(bw_aux_input_to_forget_weights_);
328 
329     NN_RET_CHECK(has_fw_aux_weights == has_bw_aux_weights);
330 
331     LinkingMode linkingMode;
332     NN_RET_CHECK(getLinkingMode(has_aux_input, has_fw_aux_weights, &linkingMode));
333 
334     if (has_aux_input) {
335         // Check that aux_input has the same dimensions (except last) as the input.
336         NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
337         NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
338     }
339 
340     if (has_fw_aux_weights) {
341         uint32_t n_aux_input = SizeOfDimension(input_, 2);
342 
343         // Check forward auxiliary input shapes
344         {
345             NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_input_weights_), 2u);
346             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_input_weights_, 0), n_fw_cell);
347             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_input_weights_, 1), n_aux_input);
348 
349             NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_forget_weights_), 2u);
350             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_forget_weights_, 0), n_fw_cell);
351             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_forget_weights_, 1), n_aux_input);
352 
353             NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_cell_weights_), 2u);
354             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_cell_weights_, 0), n_fw_cell);
355             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_cell_weights_, 1), n_aux_input);
356 
357             NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_output_weights_), 2u);
358             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_output_weights_, 0), n_fw_cell);
359             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_output_weights_, 1), n_aux_input);
360         }
361 
362         // Check backward auxiliary input shapes
363         {
364             NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_input_weights_), 2u);
365             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_input_weights_, 0), n_bw_cell);
366             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_input_weights_, 1), n_aux_input);
367 
368             NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_forget_weights_), 2u);
369             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_forget_weights_, 0), n_bw_cell);
370             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_forget_weights_, 1), n_aux_input);
371 
372             NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_cell_weights_), 2u);
373             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_cell_weights_, 0), n_bw_cell);
374             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_cell_weights_, 1), n_aux_input);
375 
376             NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_output_weights_), 2u);
377             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_output_weights_, 0), n_bw_cell);
378             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_output_weights_, 1), n_aux_input);
379         }
380     }
381 
382     const Shape& inputShape = input_->shape();
383     fwOutputShape->type = inputShape.type;
384     fwOutputShape->offset = inputShape.offset;
385     fwOutputShape->scale = inputShape.scale;
386     fwOutputShape->dimensions.resize(3);
387     fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
388     fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
389     fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
390 
391     const RunTimeOperandInfo* bw_input =
392             linkingMode == LinkingMode::PARALLEL_LINKING ? aux_input_ : input_;
393     const uint32_t n_bw_input = SizeOfDimension(bw_input, 2);
394     // Check that input tensor dimensions matches with each other.
395     if (!LSTMCell::CheckInputTensorDimensions(
396                 bw_input, bw_input_to_input_weights_, bw_input_to_forget_weights_,
397                 bw_input_to_cell_weights_, bw_input_to_output_weights_,
398                 bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
399                 bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
400                 bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
401                 bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
402                 bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
403                 bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
404                 bw_output_layer_norm_weights_, n_bw_input, n_bw_output, n_bw_cell, &params_)) {
405         return false;
406     }
407 
408     if (!params_.merge_outputs) {
409         bwOutputShape->type = inputShape.type;
410         bwOutputShape->offset = inputShape.offset;
411         bwOutputShape->scale = inputShape.scale;
412         bwOutputShape->dimensions.resize(3);
413         bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
414         bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
415         bwOutputShape->dimensions[2] = n_bw_output;
416     }
417 
418     if (params_.output_state) {
419         *fwOutputActivationState = fw_activation_state_->shape();
420         *fwOutputCellState = fw_cell_state_->shape();
421         *bwOutputActivationState = bw_activation_state_->shape();
422         *bwOutputCellState = bw_cell_state_->shape();
423     }
424 
425     if (params_.use_cifg) {
426         fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
427         bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
428     } else {
429         fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
430         bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
431     }
432     fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
433     fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
434     fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
435 
436     return true;
437 }
438 
Eval()439 bool BidirectionalSequenceLSTM::Eval() {
440     const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
441     const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
442     std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
443     fw_output_dims[2] = n_fw_output;
444     std::vector<uint32_t> bw_output_dims = fw_output_dims;
445     bw_output_dims[2] = n_bw_output;
446     const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
447     const uint32_t n_output_elements =
448             fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
449 
450     const bool has_aux_input = !IsNullInput(aux_input_);
451     const bool has_aux_weights = !IsNullInput(fw_aux_input_to_forget_weights_);
452 
453     LinkingMode linkingMode;
454     NN_RET_CHECK(getLinkingMode(has_aux_input, has_aux_weights, &linkingMode));
455 
456     switch (input_->type) {
457         case OperandType::TENSOR_FLOAT32: {
458             const float* bwInput = GetBuffer<const float>(input_);
459             Shape bwInputShape = input_->shape();
460             const float* auxInput = GetOptionalBuffer<const float>(aux_input_);
461             if (linkingMode == LinkingMode::PARALLEL_LINKING) {
462                 bwInput = GetBuffer<const float>(aux_input_);
463                 bwInputShape = aux_input_->shape();
464                 auxInput = nullptr;
465             }
466 
467             float* fw_output_activation_state_buffer = nullptr;
468             float* fw_output_cell_state_buffer = nullptr;
469             std::vector<float> fw_output_activation_state;
470             std::vector<float> fw_output_cell_state;
471             if (params_.output_state) {
472                 fw_output_activation_state_buffer = GetBuffer<float>(fw_output_activation_state_);
473                 fw_output_cell_state_buffer = GetBuffer<float>(fw_output_cell_state_);
474             } else {
475                 fw_output_activation_state.resize(
476                         getNumberOfElements(fw_activation_state_->shape()));
477                 fw_output_cell_state.resize(getNumberOfElements(fw_cell_state_->shape()));
478 
479                 fw_output_activation_state_buffer = fw_output_activation_state.data();
480                 fw_output_cell_state_buffer = fw_output_cell_state.data();
481             }
482             std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
483             const bool kForwardSequence = true;
484             LSTMCell::LSTMEvalFloat32(
485                     params_, GetBuffer<const float>(input_), input_->shape(),
486                     GetBuffer<const float>(fw_input_to_input_weights_),
487                     GetBuffer<const float>(fw_input_to_forget_weights_),
488                     GetBuffer<const float>(fw_input_to_cell_weights_),
489                     GetBuffer<const float>(fw_input_to_output_weights_),
490                     fw_input_to_output_weights_->shape(),
491                     GetBuffer<const float>(fw_recurrent_to_input_weights_),
492                     GetBuffer<const float>(fw_recurrent_to_forget_weights_),
493                     GetBuffer<const float>(fw_recurrent_to_cell_weights_),
494                     GetBuffer<const float>(fw_recurrent_to_output_weights_),
495                     fw_recurrent_to_output_weights_->shape(),
496                     GetBuffer<const float>(fw_cell_to_input_weights_),
497                     GetBuffer<const float>(fw_cell_to_forget_weights_),
498                     GetBuffer<const float>(fw_cell_to_output_weights_), auxInput,
499                     GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
500                     GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
501                     GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
502                     GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
503                     GetBuffer<const float>(fw_input_gate_bias_),
504                     GetBuffer<const float>(fw_forget_gate_bias_),
505                     GetBuffer<const float>(fw_cell_bias_),
506                     GetBuffer<const float>(fw_output_gate_bias_),
507                     GetBuffer<const float>(fw_projection_weights_),
508                     GetBuffer<const float>(fw_projection_bias_),
509                     GetBuffer<const float>(fw_activation_state_),
510                     GetBuffer<const float>(fw_cell_state_),
511                     GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
512                     GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
513                     GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
514                     GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
515                     fw_output_activation_state_buffer, fw_output_cell_state_buffer,
516                     GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
517                     kForwardSequence);
518 
519             float* bw_output_activation_state_buffer;
520             float* bw_output_cell_state_buffer;
521             std::vector<float> bw_output_activation_state;
522             std::vector<float> bw_output_cell_state;
523             if (params_.output_state) {
524                 bw_output_activation_state_buffer = GetBuffer<float>(bw_output_activation_state_);
525                 bw_output_cell_state_buffer = GetBuffer<float>(bw_output_cell_state_);
526             } else {
527                 bw_output_activation_state.resize(
528                         getNumberOfElements(bw_activation_state_->shape()));
529                 bw_output_cell_state.resize(getNumberOfElements(bw_cell_state_->shape()));
530 
531                 bw_output_activation_state_buffer = bw_output_activation_state.data();
532                 bw_output_cell_state_buffer = bw_output_cell_state.data();
533             }
534             std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
535             const bool kBackwardSequence = false;
536             LSTMCell::LSTMEvalFloat32(
537                     params_, bwInput, bwInputShape,
538                     GetBuffer<const float>(bw_input_to_input_weights_),
539                     GetBuffer<const float>(bw_input_to_forget_weights_),
540                     GetBuffer<const float>(bw_input_to_cell_weights_),
541                     GetBuffer<const float>(bw_input_to_output_weights_),
542                     bw_input_to_output_weights_->shape(),
543                     GetBuffer<const float>(bw_recurrent_to_input_weights_),
544                     GetBuffer<const float>(bw_recurrent_to_forget_weights_),
545                     GetBuffer<const float>(bw_recurrent_to_cell_weights_),
546                     GetBuffer<const float>(bw_recurrent_to_output_weights_),
547                     bw_recurrent_to_output_weights_->shape(),
548                     GetBuffer<const float>(bw_cell_to_input_weights_),
549                     GetBuffer<const float>(bw_cell_to_forget_weights_),
550                     GetBuffer<const float>(bw_cell_to_output_weights_), auxInput,
551                     GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
552                     GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
553                     GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
554                     GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
555                     GetBuffer<const float>(bw_input_gate_bias_),
556                     GetBuffer<const float>(bw_forget_gate_bias_),
557                     GetBuffer<const float>(bw_cell_bias_),
558                     GetBuffer<const float>(bw_output_gate_bias_),
559                     GetBuffer<const float>(bw_projection_weights_),
560                     GetBuffer<const float>(bw_projection_bias_),
561                     GetBuffer<const float>(bw_activation_state_),
562                     GetBuffer<const float>(bw_cell_state_),
563                     GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
564                     GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
565                     GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
566                     GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
567                     bw_output_activation_state_buffer, bw_output_cell_state_buffer,
568                     params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
569                                           : GetBuffer<float>(bw_output_),
570                     bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
571             if (params_.merge_outputs) {
572                 std::vector<float> temp(n_output_elements);
573                 mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
574                                     GetBuffer<float>(fw_output_) + n_fw_output_elements,
575                                     bw_output_dims, temp.data());
576                 std::copy(temp.data(), temp.data() + n_output_elements,
577                           GetBuffer<float>(fw_output_));
578             }
579         } break;
580         case OperandType::TENSOR_FLOAT16: {
581             const _Float16* bwInput = GetBuffer<const _Float16>(input_);
582             Shape bwInputShape = input_->shape();
583             const _Float16* auxInput = GetOptionalBuffer<const _Float16>(aux_input_);
584             if (linkingMode == LinkingMode::PARALLEL_LINKING) {
585                 bwInput = GetBuffer<const _Float16>(aux_input_);
586                 bwInputShape = aux_input_->shape();
587                 auxInput = nullptr;
588             }
589 
590             _Float16* fw_output_activation_state_buffer;
591             _Float16* fw_output_cell_state_buffer;
592             std::vector<_Float16> fw_output_activation_state;
593             std::vector<_Float16> fw_output_cell_state;
594             if (params_.output_state) {
595                 fw_output_activation_state_buffer =
596                         GetBuffer<_Float16>(fw_output_activation_state_);
597                 fw_output_cell_state_buffer = GetBuffer<_Float16>(fw_output_cell_state_);
598             } else {
599                 fw_output_activation_state.resize(
600                         getNumberOfElements(fw_activation_state_->shape()));
601                 fw_output_cell_state.resize(getNumberOfElements(fw_cell_state_->shape()));
602 
603                 fw_output_activation_state_buffer = fw_output_activation_state.data();
604                 fw_output_cell_state_buffer = fw_output_cell_state.data();
605             }
606             std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
607             const bool kForwardSequence = true;
608             LSTMCell::LSTMEvalFloat16(
609                     params_, GetBuffer<const _Float16>(input_), input_->shape(),
610                     GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
611                     GetBuffer<const _Float16>(fw_input_to_forget_weights_),
612                     GetBuffer<const _Float16>(fw_input_to_cell_weights_),
613                     GetBuffer<const _Float16>(fw_input_to_output_weights_),
614                     fw_input_to_output_weights_->shape(),
615                     GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
616                     GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
617                     GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
618                     GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
619                     fw_recurrent_to_output_weights_->shape(),
620                     GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
621                     GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
622                     GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_), auxInput,
623                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
624                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
625                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
626                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
627                     GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
628                     GetBuffer<const _Float16>(fw_forget_gate_bias_),
629                     GetBuffer<const _Float16>(fw_cell_bias_),
630                     GetBuffer<const _Float16>(fw_output_gate_bias_),
631                     GetOptionalBuffer<const _Float16>(fw_projection_weights_),
632                     GetOptionalBuffer<const _Float16>(fw_projection_bias_),
633                     GetBuffer<const _Float16>(fw_activation_state_),
634                     GetBuffer<const _Float16>(fw_cell_state_),
635                     GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
636                     GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
637                     GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
638                     GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
639                     fw_output_activation_state_buffer, fw_output_cell_state_buffer,
640                     GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
641                     kForwardSequence);
642 
643             _Float16* bw_output_activation_state_buffer;
644             _Float16* bw_output_cell_state_buffer;
645             std::vector<_Float16> bw_output_activation_state;
646             std::vector<_Float16> bw_output_cell_state;
647             if (params_.output_state) {
648                 bw_output_activation_state_buffer =
649                         GetBuffer<_Float16>(bw_output_activation_state_);
650                 bw_output_cell_state_buffer = GetBuffer<_Float16>(bw_output_cell_state_);
651             } else {
652                 bw_output_activation_state.resize(
653                         getNumberOfElements(bw_activation_state_->shape()));
654                 bw_output_cell_state.resize(getNumberOfElements(bw_cell_state_->shape()));
655 
656                 bw_output_activation_state_buffer = bw_output_activation_state.data();
657                 bw_output_cell_state_buffer = bw_output_cell_state.data();
658             }
659             std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
660             const bool kBackwardSequence = false;
661             LSTMCell::LSTMEvalFloat16(
662                     params_, bwInput, bwInputShape,
663                     GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
664                     GetBuffer<const _Float16>(bw_input_to_forget_weights_),
665                     GetBuffer<const _Float16>(bw_input_to_cell_weights_),
666                     GetBuffer<const _Float16>(bw_input_to_output_weights_),
667                     bw_input_to_output_weights_->shape(),
668                     GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
669                     GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
670                     GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
671                     GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
672                     bw_recurrent_to_output_weights_->shape(),
673                     GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
674                     GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
675                     GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_), auxInput,
676                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
677                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
678                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
679                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
680                     GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
681                     GetBuffer<const _Float16>(bw_forget_gate_bias_),
682                     GetBuffer<const _Float16>(bw_cell_bias_),
683                     GetBuffer<const _Float16>(bw_output_gate_bias_),
684                     GetOptionalBuffer<const _Float16>(bw_projection_weights_),
685                     GetOptionalBuffer<const _Float16>(bw_projection_bias_),
686                     GetBuffer<const _Float16>(bw_activation_state_),
687                     GetBuffer<const _Float16>(bw_cell_state_),
688                     GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
689                     GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
690                     GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
691                     GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
692                     bw_output_activation_state_buffer, bw_output_cell_state_buffer,
693                     params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
694                                           : GetBuffer<_Float16>(bw_output_),
695                     bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
696             if (params_.merge_outputs) {
697                 std::vector<_Float16> temp(n_output_elements);
698                 mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
699                                     GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
700                                     bw_output_dims, temp.data());
701                 std::copy(temp.data(), temp.data() + n_output_elements,
702                           GetBuffer<_Float16>(fw_output_));
703             }
704         } break;
705         default: {
706             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
707             return false;
708         }
709     }
710     return true;
711 }
712 
713 }  // namespace nn
714 }  // namespace android
715