1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "BidirectionalSequenceLSTM.h"
20
21 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
22
23 #include <algorithm>
24 #include <vector>
25
26 #include "CpuExecutor.h"
27 #include "CpuOperationUtils.h"
28 #include "OperationsExecutionUtils.h"
29 #include "Tracing.h"
30
31 namespace android {
32 namespace nn {
33
34 namespace {
35
36 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)37 inline T* GetBuffer(RunTimeOperandInfo* operand) {
38 return reinterpret_cast<T*>(operand->buffer);
39 }
40
41 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)42 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
43 return reinterpret_cast<const T*>(operand->buffer);
44 }
45
46 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)47 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
48 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
49 }
50
51 enum class LinkingMode {
52 NO_LINKING,
53 PARALLEL_LINKING,
54 CROSS_LINKING,
55 };
56
getLinkingMode(bool hasAuxInput,bool hasAuxWeights,LinkingMode * linkingMode)57 bool getLinkingMode(bool hasAuxInput, bool hasAuxWeights, LinkingMode* linkingMode) {
58 // Three possible configurations for three possible linking modes:
59 // 1) NO_LINKING -- no auxiliary tensors at all
60 // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
61 // input to the backward network, so the auxiliary weights are omitted.
62 // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
63 // auxiliary weights.
64 if (!hasAuxInput && !hasAuxWeights) {
65 *linkingMode = LinkingMode::NO_LINKING;
66 } else if (hasAuxInput && !hasAuxWeights) {
67 *linkingMode = LinkingMode::PARALLEL_LINKING;
68 } else if (hasAuxInput && hasAuxWeights) {
69 *linkingMode = LinkingMode::CROSS_LINKING;
70 } else {
71 NN_RET_CHECK_FAIL()
72 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
73 }
74
75 return true;
76 }
77
78 } // anonymous namespace
79
BidirectionalSequenceLSTM(const Operation & operation,RunTimeOperandInfo * operands)80 BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
81 RunTimeOperandInfo* operands) {
82 input_ = GetInput(operation, operands, kInputTensor);
83
84 fw_input_to_input_weights_ =
85 GetInput(operation, operands, kFwInputToInputWeightsTensor); // optional
86 fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
87 fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
88 fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
89
90 fw_recurrent_to_input_weights_ =
91 GetInput(operation, operands, kFwRecurrentToInputWeightsTensor); // optional
92 fw_recurrent_to_forget_weights_ =
93 GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
94 fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
95 fw_recurrent_to_output_weights_ =
96 GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
97
98 fw_cell_to_input_weights_ =
99 GetInput(operation, operands, kFwCellToInputWeightsTensor); // optional
100 fw_cell_to_forget_weights_ =
101 GetInput(operation, operands, kFwCellToForgetWeightsTensor); // optional
102 fw_cell_to_output_weights_ =
103 GetInput(operation, operands, kFwCellToOutputWeightsTensor); // optional
104
105 fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
106 fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
107 fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
108 fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
109
110 fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor); // optional
111 fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor); // optional
112
113 fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
114 fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
115
116 bw_input_to_input_weights_ =
117 GetInput(operation, operands, kBwInputToInputWeightsTensor); // optional
118 bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
119 bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
120 bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
121
122 bw_recurrent_to_input_weights_ =
123 GetInput(operation, operands, kBwRecurrentToInputWeightsTensor); // optional
124 bw_recurrent_to_forget_weights_ =
125 GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
126 bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
127 bw_recurrent_to_output_weights_ =
128 GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
129
130 bw_cell_to_input_weights_ =
131 GetInput(operation, operands, kBwCellToInputWeightsTensor); // optional
132 bw_cell_to_forget_weights_ =
133 GetInput(operation, operands, kBwCellToForgetWeightsTensor); // optional
134 bw_cell_to_output_weights_ =
135 GetInput(operation, operands, kBwCellToOutputWeightsTensor); // optional
136
137 bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
138 bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
139 bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
140 bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
141
142 bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor); // optional
143 bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor); // optional
144
145 bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
146 bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
147
148 aux_input_ = GetInput(operation, operands, kAuxInputTensor);
149 fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
150 fw_aux_input_to_forget_weights_ =
151 GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
152 fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
153 fw_aux_input_to_output_weights_ =
154 GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
155 bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
156 bw_aux_input_to_forget_weights_ =
157 GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
158 bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
159 bw_aux_input_to_output_weights_ =
160 GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
161
162 fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
163 fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
164 fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
165 fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
166 bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
167 bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
168 bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
169 bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
170
171 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
172 params_.activation = static_cast<ActivationFn>(getScalarDataWithDefault<int32_t>(
173 activationOperand, TfLiteFusedActivation::kTfLiteActNone));
174 const auto& clipOperand = *GetInput(operation, operands, kCellClipParam);
175 const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
176 if (input_->type == OperandType::TENSOR_FLOAT32) {
177 params_.cell_clip = getScalarDataWithDefault<float>(clipOperand, 0.0f);
178 params_.proj_clip = getScalarDataWithDefault<float>(projOperand, 0.0f);
179 } else {
180 params_.cell_clip =
181 static_cast<float>(getScalarDataWithDefault<_Float16>(clipOperand, 0.0f));
182 params_.proj_clip =
183 static_cast<float>(getScalarDataWithDefault<_Float16>(projOperand, 0.0f));
184 }
185 const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
186 params_.merge_outputs = getScalarDataWithDefault<bool>(mergeOutputsOperand, false);
187 const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
188 params_.time_major = getScalarDataWithDefault<bool>(timeMajorOperand, false);
189 params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
190
191 fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
192 if (!params_.merge_outputs) {
193 bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
194 }
195
196 params_.output_state = (operation.outputs.size() == 5 || operation.outputs.size() == 6);
197 if (params_.output_state) {
198 uint32_t delta = params_.merge_outputs ? 1 : 0;
199 fw_output_activation_state_ =
200 GetOutput(operation, operands, kFwOutputActivationStateTensor - delta);
201 fw_output_cell_state_ = GetOutput(operation, operands, kFwOutputCellStateTensor - delta);
202 bw_output_activation_state_ =
203 GetOutput(operation, operands, kBwOutputActivationStateTensor - delta);
204 bw_output_cell_state_ = GetOutput(operation, operands, kBwOutputCellStateTensor - delta);
205 }
206 }
207
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * fwOutputShape,Shape * bwOutputShape,Shape * fwOutputActivationState,Shape * fwOutputCellState,Shape * bwOutputActivationState,Shape * bwOutputCellState)208 bool BidirectionalSequenceLSTM::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
209 Shape* fwOutputShape, Shape* bwOutputShape,
210 Shape* fwOutputActivationState, Shape* fwOutputCellState,
211 Shape* bwOutputActivationState, Shape* bwOutputCellState) {
212 // Check we have all the inputs and outputs we need.
213 constexpr int requiredInputs[] = {
214 kInputTensor,
215 kFwInputToForgetWeightsTensor,
216 kFwInputToCellWeightsTensor,
217 kFwInputToOutputWeightsTensor,
218 kFwRecurrentToForgetWeightsTensor,
219 kFwRecurrentToCellWeightsTensor,
220 kFwRecurrentToOutputWeightsTensor,
221 kFwForgetGateBiasTensor,
222 kFwCellGateBiasTensor,
223 kFwOutputGateBiasTensor,
224 kBwInputToForgetWeightsTensor,
225 kBwInputToCellWeightsTensor,
226 kBwInputToOutputWeightsTensor,
227 kBwRecurrentToForgetWeightsTensor,
228 kBwRecurrentToCellWeightsTensor,
229 kBwRecurrentToOutputWeightsTensor,
230 kBwForgetGateBiasTensor,
231 kBwCellGateBiasTensor,
232 kBwOutputGateBiasTensor,
233 kFwInputActivationStateTensor,
234 kFwInputCellStateTensor,
235 kBwInputActivationStateTensor,
236 kBwInputCellStateTensor,
237 kActivationParam,
238 kCellClipParam,
239 kProjClipParam,
240 kMergeOutputsParam,
241 kTimeMajorParam,
242 };
243 for (const int requiredInput : requiredInputs) {
244 NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
245 << "required input " << requiredInput << " is omitted";
246 }
247
248 // Check that the scalar operands' buffers are large enough.
249 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
250 NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
251 const auto& cellOperand = *GetInput(operation, operands, kCellClipParam);
252 const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
253 if (input_->type == OperandType::TENSOR_FLOAT32) {
254 NN_RET_CHECK(cellOperand.length >= sizeof(float));
255 NN_RET_CHECK(projOperand.length >= sizeof(float));
256 } else {
257 NN_RET_CHECK(cellOperand.length >= sizeof(_Float16));
258 NN_RET_CHECK(projOperand.length >= sizeof(_Float16));
259 }
260 const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
261 NN_RET_CHECK(mergeOutputsOperand.length >= sizeof(bool));
262 const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
263 NN_RET_CHECK(timeMajorOperand.length >= sizeof(bool));
264
265 // Inferring batch size, number of outputs and number of cells from the
266 // input tensors.
267 NN_CHECK(NumDimensions(input_) == 3);
268 const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
269 const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
270 const uint32_t n_fw_input = SizeOfDimension(input_, 2);
271
272 const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
273 NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2u);
274 NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_fw_input);
275
276 NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2u);
277 NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
278 const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
279
280 const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
281
282 NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2u);
283 NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
284 const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
285
286 // Check that input tensor dimensions matches with each other.
287 if (!LSTMCell::CheckInputTensorDimensions(
288 input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
289 fw_input_to_cell_weights_, fw_input_to_output_weights_,
290 fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
291 fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
292 fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
293 fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
294 fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
295 fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
296 fw_output_layer_norm_weights_, n_fw_input, n_fw_output, n_fw_cell, ¶ms_)) {
297 return false;
298 }
299
300 if (params_.use_cifg) {
301 NN_RET_CHECK(IsNullInput(fw_aux_input_to_input_weights_) &&
302 IsNullInput(bw_aux_input_to_input_weights_));
303 }
304
305 const bool aux_fw_weights_all_or_none =
306 ((params_.use_cifg || !IsNullInput(fw_aux_input_to_input_weights_)) &&
307 !IsNullInput(fw_aux_input_to_forget_weights_) &&
308 !IsNullInput(fw_aux_input_to_cell_weights_) &&
309 !IsNullInput(fw_aux_input_to_output_weights_)) ||
310 (IsNullInput(fw_aux_input_to_input_weights_) &&
311 IsNullInput(fw_aux_input_to_forget_weights_) &&
312 IsNullInput(fw_aux_input_to_cell_weights_) &&
313 IsNullInput(fw_aux_input_to_output_weights_));
314 const bool aux_bw_weights_all_or_none =
315 ((params_.use_cifg || !IsNullInput(bw_aux_input_to_input_weights_)) &&
316 !IsNullInput(bw_aux_input_to_forget_weights_) &&
317 !IsNullInput(bw_aux_input_to_cell_weights_) &&
318 !IsNullInput(bw_aux_input_to_output_weights_)) ||
319 (IsNullInput(bw_aux_input_to_input_weights_) &&
320 IsNullInput(bw_aux_input_to_forget_weights_) &&
321 IsNullInput(bw_aux_input_to_cell_weights_) &&
322 IsNullInput(bw_aux_input_to_output_weights_));
323
324 NN_RET_CHECK(aux_fw_weights_all_or_none && aux_bw_weights_all_or_none);
325 const bool has_aux_input = !IsNullInput(aux_input_);
326 const bool has_fw_aux_weights = !IsNullInput(fw_aux_input_to_forget_weights_);
327 const bool has_bw_aux_weights = !IsNullInput(bw_aux_input_to_forget_weights_);
328
329 NN_RET_CHECK(has_fw_aux_weights == has_bw_aux_weights);
330
331 LinkingMode linkingMode;
332 NN_RET_CHECK(getLinkingMode(has_aux_input, has_fw_aux_weights, &linkingMode));
333
334 if (has_aux_input) {
335 // Check that aux_input has the same dimensions (except last) as the input.
336 NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
337 NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
338 }
339
340 if (has_fw_aux_weights) {
341 uint32_t n_aux_input = SizeOfDimension(input_, 2);
342
343 // Check forward auxiliary input shapes
344 {
345 NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_input_weights_), 2u);
346 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_input_weights_, 0), n_fw_cell);
347 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_input_weights_, 1), n_aux_input);
348
349 NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_forget_weights_), 2u);
350 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_forget_weights_, 0), n_fw_cell);
351 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_forget_weights_, 1), n_aux_input);
352
353 NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_cell_weights_), 2u);
354 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_cell_weights_, 0), n_fw_cell);
355 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_cell_weights_, 1), n_aux_input);
356
357 NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_output_weights_), 2u);
358 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_output_weights_, 0), n_fw_cell);
359 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_output_weights_, 1), n_aux_input);
360 }
361
362 // Check backward auxiliary input shapes
363 {
364 NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_input_weights_), 2u);
365 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_input_weights_, 0), n_bw_cell);
366 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_input_weights_, 1), n_aux_input);
367
368 NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_forget_weights_), 2u);
369 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_forget_weights_, 0), n_bw_cell);
370 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_forget_weights_, 1), n_aux_input);
371
372 NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_cell_weights_), 2u);
373 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_cell_weights_, 0), n_bw_cell);
374 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_cell_weights_, 1), n_aux_input);
375
376 NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_output_weights_), 2u);
377 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_output_weights_, 0), n_bw_cell);
378 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_output_weights_, 1), n_aux_input);
379 }
380 }
381
382 const Shape& inputShape = input_->shape();
383 fwOutputShape->type = inputShape.type;
384 fwOutputShape->offset = inputShape.offset;
385 fwOutputShape->scale = inputShape.scale;
386 fwOutputShape->dimensions.resize(3);
387 fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
388 fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
389 fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
390
391 const RunTimeOperandInfo* bw_input =
392 linkingMode == LinkingMode::PARALLEL_LINKING ? aux_input_ : input_;
393 const uint32_t n_bw_input = SizeOfDimension(bw_input, 2);
394 // Check that input tensor dimensions matches with each other.
395 if (!LSTMCell::CheckInputTensorDimensions(
396 bw_input, bw_input_to_input_weights_, bw_input_to_forget_weights_,
397 bw_input_to_cell_weights_, bw_input_to_output_weights_,
398 bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
399 bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
400 bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
401 bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
402 bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
403 bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
404 bw_output_layer_norm_weights_, n_bw_input, n_bw_output, n_bw_cell, ¶ms_)) {
405 return false;
406 }
407
408 if (!params_.merge_outputs) {
409 bwOutputShape->type = inputShape.type;
410 bwOutputShape->offset = inputShape.offset;
411 bwOutputShape->scale = inputShape.scale;
412 bwOutputShape->dimensions.resize(3);
413 bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
414 bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
415 bwOutputShape->dimensions[2] = n_bw_output;
416 }
417
418 if (params_.output_state) {
419 *fwOutputActivationState = fw_activation_state_->shape();
420 *fwOutputCellState = fw_cell_state_->shape();
421 *bwOutputActivationState = bw_activation_state_->shape();
422 *bwOutputCellState = bw_cell_state_->shape();
423 }
424
425 if (params_.use_cifg) {
426 fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
427 bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
428 } else {
429 fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
430 bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
431 }
432 fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
433 fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
434 fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
435
436 return true;
437 }
438
Eval()439 bool BidirectionalSequenceLSTM::Eval() {
440 const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
441 const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
442 std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
443 fw_output_dims[2] = n_fw_output;
444 std::vector<uint32_t> bw_output_dims = fw_output_dims;
445 bw_output_dims[2] = n_bw_output;
446 const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
447 const uint32_t n_output_elements =
448 fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
449
450 const bool has_aux_input = !IsNullInput(aux_input_);
451 const bool has_aux_weights = !IsNullInput(fw_aux_input_to_forget_weights_);
452
453 LinkingMode linkingMode;
454 NN_RET_CHECK(getLinkingMode(has_aux_input, has_aux_weights, &linkingMode));
455
456 switch (input_->type) {
457 case OperandType::TENSOR_FLOAT32: {
458 const float* bwInput = GetBuffer<const float>(input_);
459 Shape bwInputShape = input_->shape();
460 const float* auxInput = GetOptionalBuffer<const float>(aux_input_);
461 if (linkingMode == LinkingMode::PARALLEL_LINKING) {
462 bwInput = GetBuffer<const float>(aux_input_);
463 bwInputShape = aux_input_->shape();
464 auxInput = nullptr;
465 }
466
467 float* fw_output_activation_state_buffer = nullptr;
468 float* fw_output_cell_state_buffer = nullptr;
469 std::vector<float> fw_output_activation_state;
470 std::vector<float> fw_output_cell_state;
471 if (params_.output_state) {
472 fw_output_activation_state_buffer = GetBuffer<float>(fw_output_activation_state_);
473 fw_output_cell_state_buffer = GetBuffer<float>(fw_output_cell_state_);
474 } else {
475 fw_output_activation_state.resize(
476 getNumberOfElements(fw_activation_state_->shape()));
477 fw_output_cell_state.resize(getNumberOfElements(fw_cell_state_->shape()));
478
479 fw_output_activation_state_buffer = fw_output_activation_state.data();
480 fw_output_cell_state_buffer = fw_output_cell_state.data();
481 }
482 std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
483 const bool kForwardSequence = true;
484 LSTMCell::LSTMEvalFloat32(
485 params_, GetBuffer<const float>(input_), input_->shape(),
486 GetBuffer<const float>(fw_input_to_input_weights_),
487 GetBuffer<const float>(fw_input_to_forget_weights_),
488 GetBuffer<const float>(fw_input_to_cell_weights_),
489 GetBuffer<const float>(fw_input_to_output_weights_),
490 fw_input_to_output_weights_->shape(),
491 GetBuffer<const float>(fw_recurrent_to_input_weights_),
492 GetBuffer<const float>(fw_recurrent_to_forget_weights_),
493 GetBuffer<const float>(fw_recurrent_to_cell_weights_),
494 GetBuffer<const float>(fw_recurrent_to_output_weights_),
495 fw_recurrent_to_output_weights_->shape(),
496 GetBuffer<const float>(fw_cell_to_input_weights_),
497 GetBuffer<const float>(fw_cell_to_forget_weights_),
498 GetBuffer<const float>(fw_cell_to_output_weights_), auxInput,
499 GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
500 GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
501 GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
502 GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
503 GetBuffer<const float>(fw_input_gate_bias_),
504 GetBuffer<const float>(fw_forget_gate_bias_),
505 GetBuffer<const float>(fw_cell_bias_),
506 GetBuffer<const float>(fw_output_gate_bias_),
507 GetBuffer<const float>(fw_projection_weights_),
508 GetBuffer<const float>(fw_projection_bias_),
509 GetBuffer<const float>(fw_activation_state_),
510 GetBuffer<const float>(fw_cell_state_),
511 GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
512 GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
513 GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
514 GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
515 fw_output_activation_state_buffer, fw_output_cell_state_buffer,
516 GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
517 kForwardSequence);
518
519 float* bw_output_activation_state_buffer;
520 float* bw_output_cell_state_buffer;
521 std::vector<float> bw_output_activation_state;
522 std::vector<float> bw_output_cell_state;
523 if (params_.output_state) {
524 bw_output_activation_state_buffer = GetBuffer<float>(bw_output_activation_state_);
525 bw_output_cell_state_buffer = GetBuffer<float>(bw_output_cell_state_);
526 } else {
527 bw_output_activation_state.resize(
528 getNumberOfElements(bw_activation_state_->shape()));
529 bw_output_cell_state.resize(getNumberOfElements(bw_cell_state_->shape()));
530
531 bw_output_activation_state_buffer = bw_output_activation_state.data();
532 bw_output_cell_state_buffer = bw_output_cell_state.data();
533 }
534 std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
535 const bool kBackwardSequence = false;
536 LSTMCell::LSTMEvalFloat32(
537 params_, bwInput, bwInputShape,
538 GetBuffer<const float>(bw_input_to_input_weights_),
539 GetBuffer<const float>(bw_input_to_forget_weights_),
540 GetBuffer<const float>(bw_input_to_cell_weights_),
541 GetBuffer<const float>(bw_input_to_output_weights_),
542 bw_input_to_output_weights_->shape(),
543 GetBuffer<const float>(bw_recurrent_to_input_weights_),
544 GetBuffer<const float>(bw_recurrent_to_forget_weights_),
545 GetBuffer<const float>(bw_recurrent_to_cell_weights_),
546 GetBuffer<const float>(bw_recurrent_to_output_weights_),
547 bw_recurrent_to_output_weights_->shape(),
548 GetBuffer<const float>(bw_cell_to_input_weights_),
549 GetBuffer<const float>(bw_cell_to_forget_weights_),
550 GetBuffer<const float>(bw_cell_to_output_weights_), auxInput,
551 GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
552 GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
553 GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
554 GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
555 GetBuffer<const float>(bw_input_gate_bias_),
556 GetBuffer<const float>(bw_forget_gate_bias_),
557 GetBuffer<const float>(bw_cell_bias_),
558 GetBuffer<const float>(bw_output_gate_bias_),
559 GetBuffer<const float>(bw_projection_weights_),
560 GetBuffer<const float>(bw_projection_bias_),
561 GetBuffer<const float>(bw_activation_state_),
562 GetBuffer<const float>(bw_cell_state_),
563 GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
564 GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
565 GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
566 GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
567 bw_output_activation_state_buffer, bw_output_cell_state_buffer,
568 params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
569 : GetBuffer<float>(bw_output_),
570 bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
571 if (params_.merge_outputs) {
572 std::vector<float> temp(n_output_elements);
573 mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
574 GetBuffer<float>(fw_output_) + n_fw_output_elements,
575 bw_output_dims, temp.data());
576 std::copy(temp.data(), temp.data() + n_output_elements,
577 GetBuffer<float>(fw_output_));
578 }
579 } break;
580 case OperandType::TENSOR_FLOAT16: {
581 const _Float16* bwInput = GetBuffer<const _Float16>(input_);
582 Shape bwInputShape = input_->shape();
583 const _Float16* auxInput = GetOptionalBuffer<const _Float16>(aux_input_);
584 if (linkingMode == LinkingMode::PARALLEL_LINKING) {
585 bwInput = GetBuffer<const _Float16>(aux_input_);
586 bwInputShape = aux_input_->shape();
587 auxInput = nullptr;
588 }
589
590 _Float16* fw_output_activation_state_buffer;
591 _Float16* fw_output_cell_state_buffer;
592 std::vector<_Float16> fw_output_activation_state;
593 std::vector<_Float16> fw_output_cell_state;
594 if (params_.output_state) {
595 fw_output_activation_state_buffer =
596 GetBuffer<_Float16>(fw_output_activation_state_);
597 fw_output_cell_state_buffer = GetBuffer<_Float16>(fw_output_cell_state_);
598 } else {
599 fw_output_activation_state.resize(
600 getNumberOfElements(fw_activation_state_->shape()));
601 fw_output_cell_state.resize(getNumberOfElements(fw_cell_state_->shape()));
602
603 fw_output_activation_state_buffer = fw_output_activation_state.data();
604 fw_output_cell_state_buffer = fw_output_cell_state.data();
605 }
606 std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
607 const bool kForwardSequence = true;
608 LSTMCell::LSTMEvalFloat16(
609 params_, GetBuffer<const _Float16>(input_), input_->shape(),
610 GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
611 GetBuffer<const _Float16>(fw_input_to_forget_weights_),
612 GetBuffer<const _Float16>(fw_input_to_cell_weights_),
613 GetBuffer<const _Float16>(fw_input_to_output_weights_),
614 fw_input_to_output_weights_->shape(),
615 GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
616 GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
617 GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
618 GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
619 fw_recurrent_to_output_weights_->shape(),
620 GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
621 GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
622 GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_), auxInput,
623 GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
624 GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
625 GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
626 GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
627 GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
628 GetBuffer<const _Float16>(fw_forget_gate_bias_),
629 GetBuffer<const _Float16>(fw_cell_bias_),
630 GetBuffer<const _Float16>(fw_output_gate_bias_),
631 GetOptionalBuffer<const _Float16>(fw_projection_weights_),
632 GetOptionalBuffer<const _Float16>(fw_projection_bias_),
633 GetBuffer<const _Float16>(fw_activation_state_),
634 GetBuffer<const _Float16>(fw_cell_state_),
635 GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
636 GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
637 GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
638 GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
639 fw_output_activation_state_buffer, fw_output_cell_state_buffer,
640 GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
641 kForwardSequence);
642
643 _Float16* bw_output_activation_state_buffer;
644 _Float16* bw_output_cell_state_buffer;
645 std::vector<_Float16> bw_output_activation_state;
646 std::vector<_Float16> bw_output_cell_state;
647 if (params_.output_state) {
648 bw_output_activation_state_buffer =
649 GetBuffer<_Float16>(bw_output_activation_state_);
650 bw_output_cell_state_buffer = GetBuffer<_Float16>(bw_output_cell_state_);
651 } else {
652 bw_output_activation_state.resize(
653 getNumberOfElements(bw_activation_state_->shape()));
654 bw_output_cell_state.resize(getNumberOfElements(bw_cell_state_->shape()));
655
656 bw_output_activation_state_buffer = bw_output_activation_state.data();
657 bw_output_cell_state_buffer = bw_output_cell_state.data();
658 }
659 std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
660 const bool kBackwardSequence = false;
661 LSTMCell::LSTMEvalFloat16(
662 params_, bwInput, bwInputShape,
663 GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
664 GetBuffer<const _Float16>(bw_input_to_forget_weights_),
665 GetBuffer<const _Float16>(bw_input_to_cell_weights_),
666 GetBuffer<const _Float16>(bw_input_to_output_weights_),
667 bw_input_to_output_weights_->shape(),
668 GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
669 GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
670 GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
671 GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
672 bw_recurrent_to_output_weights_->shape(),
673 GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
674 GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
675 GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_), auxInput,
676 GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
677 GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
678 GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
679 GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
680 GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
681 GetBuffer<const _Float16>(bw_forget_gate_bias_),
682 GetBuffer<const _Float16>(bw_cell_bias_),
683 GetBuffer<const _Float16>(bw_output_gate_bias_),
684 GetOptionalBuffer<const _Float16>(bw_projection_weights_),
685 GetOptionalBuffer<const _Float16>(bw_projection_bias_),
686 GetBuffer<const _Float16>(bw_activation_state_),
687 GetBuffer<const _Float16>(bw_cell_state_),
688 GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
689 GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
690 GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
691 GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
692 bw_output_activation_state_buffer, bw_output_cell_state_buffer,
693 params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
694 : GetBuffer<_Float16>(bw_output_),
695 bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
696 if (params_.merge_outputs) {
697 std::vector<_Float16> temp(n_output_elements);
698 mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
699 GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
700 bw_output_dims, temp.data());
701 std::copy(temp.data(), temp.data() + n_output_elements,
702 GetBuffer<_Float16>(fw_output_));
703 }
704 } break;
705 default: {
706 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
707 return false;
708 }
709 }
710 return true;
711 }
712
713 } // namespace nn
714 } // namespace android
715