1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "LSTM.h"
20 
21 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
22 
23 #include <vector>
24 
25 #include "CpuExecutor.h"
26 #include "CpuOperationUtils.h"
27 #include "LegacyUtils.h"
28 #include "OperationsExecutionUtils.h"
29 #include "Tracing.h"
30 #include "nnapi/Types.h"
31 
32 namespace android {
33 namespace nn {
34 
35 namespace {
36 
37 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)38 inline T* GetBuffer(RunTimeOperandInfo* operand) {
39     return reinterpret_cast<T*>(operand->buffer);
40 }
41 
42 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)43 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
44     return reinterpret_cast<const T*>(operand->buffer);
45 }
46 
47 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)48 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
49     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
50 }
51 
52 }  // anonymous namespace
53 
LSTMCell(const Operation & operation,RunTimeOperandInfo * operands)54 LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) {
55     input_ = GetInput(operation, operands, kInputTensor);
56 
57     input_to_input_weights_ =
58             GetInput(operation, operands, kInputToInputWeightsTensor);  // optional
59     input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
60     input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
61     input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
62 
63     recurrent_to_input_weights_ =
64             GetInput(operation, operands, kRecurrentToInputWeightsTensor);  // optional
65     recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
66     recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
67     recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
68 
69     cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor);  // optional
70     cell_to_forget_weights_ =
71             GetInput(operation, operands, kCellToForgetWeightsTensor);  // optional
72     cell_to_output_weights_ =
73             GetInput(operation, operands, kCellToOutputWeightsTensor);  // optional
74 
75     input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
76     forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
77     cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
78     output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
79 
80     projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
81     projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
82 
83     output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
84     cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
85 
86     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
87     params_.activation = static_cast<ActivationFn>(getScalarDataWithDefault<int32_t>(
88             activationOperand, TfLiteFusedActivation::kTfLiteActNone));
89 
90     const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
91     const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
92     if (input_->type == OperandType::TENSOR_FLOAT32) {
93         params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f);
94         params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f);
95     } else {
96         params_.cell_clip =
97                 static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f));
98         params_.proj_clip =
99                 static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f));
100     }
101 
102     // We check the version of LSTM by checking the number of the inputs to the
103     // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
104     if (operation.inputs.size() == 27) {
105         input_layer_norm_weights_ =
106                 GetInput(operation, operands, kInputLayerNormWeightsTensor);  // optional
107         forget_layer_norm_weights_ =
108                 GetInput(operation, operands, kForgetLayerNormWeightsTensor);  // optional
109         cell_layer_norm_weights_ =
110                 GetInput(operation, operands, kCellLayerNormWeightsTensor);  // optional
111         output_layer_norm_weights_ =
112                 GetInput(operation, operands, kOutputLayerNormWeightsTensor);  // optional
113     } else {
114         // For LSTM from HAL v1.0 assign operands with no values
115         static RunTimeOperandInfo no_value;
116         no_value.lifetime = Operand::LifeTime::NO_VALUE;
117 
118         input_layer_norm_weights_ = &no_value;
119         forget_layer_norm_weights_ = &no_value;
120         cell_layer_norm_weights_ = &no_value;
121         output_layer_norm_weights_ = &no_value;
122     }
123 
124     output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
125     cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
126     output_ = GetOutput(operation, operands, kOutputTensor);
127 
128     scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
129 }
130 
131 // static
CheckInputTensorDimensions(const RunTimeOperandInfo *,const RunTimeOperandInfo * input_to_input_weights,const RunTimeOperandInfo * input_to_forget_weights,const RunTimeOperandInfo * input_to_cell_weights,const RunTimeOperandInfo *,const RunTimeOperandInfo * recurrent_to_input_weights,const RunTimeOperandInfo * recurrent_to_forget_weights,const RunTimeOperandInfo * recurrent_to_cell_weights,const RunTimeOperandInfo *,const RunTimeOperandInfo * cell_to_input_weights,const RunTimeOperandInfo * cell_to_forget_weights,const RunTimeOperandInfo * cell_to_output_weights,const RunTimeOperandInfo * input_gate_bias,const RunTimeOperandInfo * forget_gate_bias,const RunTimeOperandInfo * cell_bias,const RunTimeOperandInfo * output_gate_bias,const RunTimeOperandInfo * projection_weights,const RunTimeOperandInfo * projection_bias,const RunTimeOperandInfo * input_layer_norm_weights,const RunTimeOperandInfo * forget_layer_norm_weights,const RunTimeOperandInfo * cell_layer_norm_weights,const RunTimeOperandInfo * output_layer_norm_weights,uint32_t n_input,uint32_t n_output,uint32_t n_cell,LSTMParams * params)132 bool LSTMCell::CheckInputTensorDimensions(
133         const RunTimeOperandInfo* /*input_*/, const RunTimeOperandInfo* input_to_input_weights,
134         const RunTimeOperandInfo* input_to_forget_weights,
135         const RunTimeOperandInfo* input_to_cell_weights,
136         const RunTimeOperandInfo* /*input_to_output_weights*/,
137         const RunTimeOperandInfo* recurrent_to_input_weights,
138         const RunTimeOperandInfo* recurrent_to_forget_weights,
139         const RunTimeOperandInfo* recurrent_to_cell_weights,
140         const RunTimeOperandInfo* /*recurrent_to_output_weights*/,
141         const RunTimeOperandInfo* cell_to_input_weights,
142         const RunTimeOperandInfo* cell_to_forget_weights,
143         const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
144         const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
145         const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
146         const RunTimeOperandInfo* projection_bias,
147         const RunTimeOperandInfo* input_layer_norm_weights,
148         const RunTimeOperandInfo* forget_layer_norm_weights,
149         const RunTimeOperandInfo* cell_layer_norm_weights,
150         const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
151         uint32_t n_cell, LSTMParams* params) {
152     // Making sure clipping parameters have valid values.
153     // == 0 means no clipping
154     //  > 0 means clipping
155     NN_CHECK(params->cell_clip >= 0);
156     NN_CHECK(params->proj_clip >= 0);
157 
158     if (!IsNullInput(input_to_input_weights)) {
159         NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2u);
160         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
161         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
162     }
163 
164     NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2u);
165     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
166     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
167 
168     NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2u);
169     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
170     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
171 
172     if (!IsNullInput(recurrent_to_input_weights)) {
173         NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2u);
174         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
175         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
176     }
177 
178     NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2u);
179     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
180     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
181 
182     NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2u);
183     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
184     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
185 
186     // We make sure the input-gate's parameters are either both present (regular
187     // LSTM) or not at all (CIFG-LSTM).
188     const bool cifg_weights_all_or_none =
189             (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
190             (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
191     NN_CHECK(cifg_weights_all_or_none);
192 
193     if (!IsNullInput(cell_to_input_weights)) {
194         NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1u);
195         NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
196     }
197 
198     if (!IsNullInput(cell_to_forget_weights)) {
199         NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1u);
200         NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
201     }
202 
203     if (!IsNullInput(cell_to_output_weights)) {
204         NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1u);
205         NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
206     }
207 
208     // Making sure the peephole weights are there all or none.
209     params->use_cifg = IsNullInput(input_to_input_weights);
210     const bool peephole_weights_all_or_none =
211             ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
212              !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
213             (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
214              IsNullInput(cell_to_output_weights));
215     NN_CHECK(peephole_weights_all_or_none);
216 
217     // Since we have already checked that weights are all there or none, we can
218     // check the existence of only one to the get the condition.
219     params->use_peephole = !IsNullInput(cell_to_output_weights);
220     // Checking output instead of input layer norm weights because input can be
221     // omitted ones can be omited in case CIFG LSTM is used.
222     params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
223 
224     params->use_projection_weight = (projection_weights->lifetime != Operand::LifeTime::NO_VALUE);
225     params->use_projection_bias = (projection_bias->lifetime != Operand::LifeTime::NO_VALUE);
226 
227     // Make sure the input gate bias is present only when not a CIFG-LSTM.
228     if (params->use_cifg) {
229         NN_CHECK(IsNullInput(input_gate_bias));
230     } else {
231         NN_CHECK_EQ(NumDimensions(input_gate_bias), 1u);
232         NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
233     }
234 
235     NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1u);
236     NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
237 
238     NN_CHECK_EQ(NumDimensions(cell_bias), 1u);
239     NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
240 
241     NN_CHECK_EQ(NumDimensions(output_gate_bias), 1u);
242     NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
243 
244     if (!IsNullInput(projection_weights)) {
245         NN_CHECK_EQ(NumDimensions(projection_weights), 2u);
246         NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
247         NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
248     }
249 
250     if (!IsNullInput(projection_bias)) {
251         NN_CHECK_EQ(NumDimensions(projection_bias), 1u);
252         NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
253     }
254 
255     // Making sure the projection tensors are consistent:
256     // 1) If projection weight is not present, then projection bias should not be
257     // present.
258     // 2) If projection weight is present, then projection bias is optional.
259     // TODO: make sure this is correct.
260     const bool projecton_tensors_consistent =
261             (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
262     NN_CHECK(projecton_tensors_consistent == true);
263 
264     if (!IsNullInput(input_layer_norm_weights)) {
265         NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1u);
266         NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
267     }
268     if (!IsNullInput(forget_layer_norm_weights)) {
269         NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1u);
270         NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
271     }
272     if (!IsNullInput(cell_layer_norm_weights)) {
273         NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1u);
274         NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
275     }
276     if (!IsNullInput(output_layer_norm_weights)) {
277         NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1u);
278         NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
279     }
280 
281     if (params->use_cifg) {
282         NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
283                 << "input_layer_norm_weights are provided while CIFG is used";
284         const bool layer_norm_weights_all_or_none_cifg =
285                 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
286                  IsNullInput(output_layer_norm_weights)) ||
287                 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
288                  !IsNullInput(output_layer_norm_weights));
289         NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
290     } else {
291         const bool layer_norm_weights_all_or_none =
292                 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
293                  IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
294                 (!IsNullInput(input_layer_norm_weights) &&
295                  !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
296                  !IsNullInput(output_layer_norm_weights));
297         NN_RET_CHECK(layer_norm_weights_all_or_none);
298     }
299 
300     return true;
301 }
302 
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)303 bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
304                        Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
305                        Shape* outputShape) {
306     // Check we have all the inputs and outputs we need.
307     NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
308              NumInputsWithValues(operation, operands) <= 27);
309     constexpr int requiredInputs[] = {
310             kInputTensor,
311             kInputToForgetWeightsTensor,
312             kInputToCellWeightsTensor,
313             kInputToOutputWeightsTensor,
314             kRecurrentToForgetWeightsTensor,
315             kRecurrentToCellWeightsTensor,
316             kRecurrentToOutputWeightsTensor,
317             kForgetGateBiasTensor,
318             kCellGateBiasTensor,
319             kOutputGateBiasTensor,
320             kOutputStateInTensor,
321             kCellStateInTensor,
322             kActivationParam,
323             kCellClipParam,
324             kProjClipParam,
325     };
326     for (const int requiredInput : requiredInputs) {
327         NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
328                 << "required input " << requiredInput << " is omitted";
329     }
330     NN_CHECK_EQ(NumOutputs(operation), 4);
331 
332     // Check that the scalar operands' buffers are large enough.
333     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
334     NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
335     const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
336     const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
337     if (input_->type == OperandType::TENSOR_FLOAT32) {
338         NN_RET_CHECK(cellClipOperand.length >= sizeof(float));
339         NN_RET_CHECK(projClipOperand.length >= sizeof(float));
340     } else {
341         NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16));
342         NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16));
343     }
344 
345     // Inferring batch size, number of outputs and number of cells from the
346     // input tensors.
347     NN_CHECK(NumDimensions(input_) > 1);
348     const uint32_t n_batch = SizeOfDimension(input_, 0);
349     const uint32_t n_input = SizeOfDimension(input_, 1);
350 
351     const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
352     NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2u);
353     NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
354 
355     NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2u);
356     NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
357     const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
358 
359     // Check that input tensor dimensions matches with each other.
360     if (!CheckInputTensorDimensions(
361                 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
362                 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
363                 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
364                 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
365                 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
366                 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
367                 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
368                 &params_)) {
369         return false;
370     }
371 
372     // Resize the output and output_state tensors.
373     const Shape& inputShape = input_->shape();
374 
375     outputShape->type = inputShape.type;
376     outputShape->dimensions = {n_batch, n_output};
377     outputShape->offset = inputShape.offset;
378     outputShape->scale = inputShape.scale;
379 
380     outputStateShape->type = inputShape.type;
381     outputStateShape->dimensions = {n_batch, n_output};
382     outputStateShape->offset = inputShape.offset;
383     outputStateShape->scale = inputShape.scale;
384 
385     cellStateShape->type = inputShape.type;
386     cellStateShape->dimensions = {n_batch, n_cell};
387     cellStateShape->offset = inputShape.offset;
388     cellStateShape->scale = inputShape.scale;
389 
390     if (params_.use_cifg) {
391         // Reserving space for Cell, Forget, Output gates
392         scratchShape->dimensions = {n_batch, n_cell * 3};
393     } else {
394         // Reserving space for Input, Cell, Forget, Output gates
395         scratchShape->dimensions = {n_batch, n_cell * 4};
396     }
397     scratchShape->type = inputShape.type;
398     scratchShape->offset = inputShape.offset;
399     scratchShape->scale = inputShape.scale;
400 
401     return true;
402 }
403 
404 // static
LSTMEvalFloat32(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)405 bool LSTMCell::LSTMEvalFloat32(
406         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
407         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
408         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
409         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
410         const float* recurrent_to_forget_weights_buffer,
411         const float* recurrent_to_cell_weights_buffer,
412         const float* recurrent_to_output_weights_buffer,
413         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
414         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
415         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
416         const float* aux_input_to_forget_weights_buffer,
417         const float* aux_input_to_cell_weights_buffer,
418         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
419         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
420         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
421         const float* projection_bias_buffer, const float* output_state_in_buffer,
422         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
423         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
424         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
425         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
426         bool timeMajor, bool forwardSequence) {
427     NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
428 
429     const uint32_t inputRank = getNumberOfDimensions(input_shape);
430     NN_CHECK(inputRank == 2 || inputRank == 3);
431 
432     const uint32_t maxTime =
433             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
434     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
435                                                 : getSizeOfDimension(input_shape, 0);
436     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
437     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
438     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
439 
440     Shape batchInputShape = input_shape;
441     batchInputShape.dimensions = {batchSize, inputSize};
442     const uint32_t batchInputSize = batchSize * inputSize;
443     const uint32_t batchOutputSize = batchSize * outputSize;
444 
445     std::vector<float> transposedInput;
446     const bool hasAuxInput = (aux_input_buffer != nullptr);
447     std::vector<float> transposedAuxInput;
448     std::vector<float> transposedOutput;
449     Shape transposedInputShape;
450     Shape transposedOutputShape;
451     if (!timeMajor) {
452         transposedInput.resize(maxTime * batchInputSize);
453         transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
454         if (hasAuxInput) {
455             transposedAuxInput.resize(maxTime * batchInputSize);
456             transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
457                                                transposedAuxInput.data());
458         }
459         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
460         transposedOutput.resize(maxTime * batchOutputSize);
461         transposedOutputShape = transposedInputShape;
462         transposedOutputShape.dimensions[2] = outputSize;
463     }
464     const float* inputData = timeMajor ? input_buffer : transposedInput.data();
465     const float* auxInputData =
466             hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
467     float* outputData = timeMajor ? output_buffer : transposedOutput.data();
468 
469     std::vector<float> outputStateInCurrentTimeStep(
470             output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
471     std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
472                                                   cell_state_in_buffer + batchSize * numCells);
473     const float* inputCurrentTimeStep =
474             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
475     const float* auxInputCurrentTimeStep =
476             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
477                         : nullptr;
478     float* outputCurrentTimeStep =
479             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
480     const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
481     const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
482 
483     for (uint32_t t = 0; t < maxTime; ++t) {
484         LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
485                  input_to_forget_weights_buffer, input_to_cell_weights_buffer,
486                  input_to_output_weights_buffer, input_to_output_weights_shape,
487                  recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
488                  recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
489                  recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
490                  cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
491                  auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
492                  aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
493                  aux_input_to_output_weights_buffer, input_gate_bias_buffer,
494                  forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
495                  projection_weights_buffer, projection_bias_buffer,
496                  outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
497                  input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
498                  cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
499                  output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
500                  scratch_buffer_buffer);
501         inputCurrentTimeStep += batchInputDelta;
502         if (hasAuxInput) {
503             auxInputCurrentTimeStep += batchInputDelta;
504         }
505         outputCurrentTimeStep += batchOutputDelta;
506         outputStateInCurrentTimeStep.assign(output_state_out_buffer,
507                                             output_state_out_buffer + batchSize * outputSize);
508         cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
509                                           cell_state_out_buffer + batchSize * numCells);
510     }
511 
512     if (!timeMajor) {
513         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
514                                            output_buffer);
515     }
516 
517     return true;
518 }
519 
520 // static
LSTMEvalFloat16(const LSTMParams & params,const _Float16 * input_buffer,const Shape & input_shape,const _Float16 * input_to_input_weights_buffer,const _Float16 * input_to_forget_weights_buffer,const _Float16 * input_to_cell_weights_buffer,const _Float16 * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const _Float16 * recurrent_to_input_weights_buffer,const _Float16 * recurrent_to_forget_weights_buffer,const _Float16 * recurrent_to_cell_weights_buffer,const _Float16 * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const _Float16 * cell_to_input_weights_buffer,const _Float16 * cell_to_forget_weights_buffer,const _Float16 * cell_to_output_weights_buffer,const _Float16 * aux_input_buffer,const _Float16 * aux_input_to_input_weights_buffer,const _Float16 * aux_input_to_forget_weights_buffer,const _Float16 * aux_input_to_cell_weights_buffer,const _Float16 * aux_input_to_output_weights_buffer,const _Float16 * input_gate_bias_buffer,const _Float16 * forget_gate_bias_buffer,const _Float16 * cell_bias_buffer,const _Float16 * output_gate_bias_buffer,const _Float16 * projection_weights_buffer,const _Float16 * projection_bias_buffer,const _Float16 * output_state_in_buffer,const _Float16 * cell_state_in_buffer,const _Float16 * input_layer_norm_weights_buffer,const _Float16 * forget_layer_norm_weights_buffer,const _Float16 * cell_layer_norm_weights_buffer,const _Float16 * output_layer_norm_weights_buffer,_Float16 * output_state_out_buffer,_Float16 * cell_state_out_buffer,_Float16 * output_buffer,_Float16 * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)521 bool LSTMCell::LSTMEvalFloat16(
522         const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
523         const _Float16* input_to_input_weights_buffer,
524         const _Float16* input_to_forget_weights_buffer,
525         const _Float16* input_to_cell_weights_buffer,
526         const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
527         const _Float16* recurrent_to_input_weights_buffer,
528         const _Float16* recurrent_to_forget_weights_buffer,
529         const _Float16* recurrent_to_cell_weights_buffer,
530         const _Float16* recurrent_to_output_weights_buffer,
531         const Shape& recurrent_to_output_weights_shape,
532         const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
533         const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
534         const _Float16* aux_input_to_input_weights_buffer,
535         const _Float16* aux_input_to_forget_weights_buffer,
536         const _Float16* aux_input_to_cell_weights_buffer,
537         const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
538         const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
539         const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
540         const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
541         const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
542         const _Float16* forget_layer_norm_weights_buffer,
543         const _Float16* cell_layer_norm_weights_buffer,
544         const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
545         _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
546         bool timeMajor, bool forwardSequence) {
547     NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
548 
549     const uint32_t inputRank = getNumberOfDimensions(input_shape);
550     NN_CHECK(inputRank == 2 || inputRank == 3);
551 
552     const uint32_t maxTime =
553             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
554     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
555                                                 : getSizeOfDimension(input_shape, 0);
556     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
557     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
558     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
559 
560     Shape batchInputShape = input_shape;
561     batchInputShape.dimensions = {batchSize, inputSize};
562     const uint32_t batchInputSize = batchSize * inputSize;
563     const uint32_t batchOutputSize = batchSize * outputSize;
564 
565     std::vector<float> input_float32(maxTime * batchInputSize);
566     convertFloat16ToFloat32(input_buffer, &input_float32);
567     std::vector<float> input_to_input_weights_float32(numCells * inputSize);
568     if (input_to_input_weights_buffer != nullptr) {
569         convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
570     }
571     std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
572     convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
573     std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
574     convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
575     std::vector<float> input_to_output_weights_float32(numCells * inputSize);
576     convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
577 
578     std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
579     if (recurrent_to_input_weights_buffer != nullptr) {
580         convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
581                                 &recurrent_to_input_weights_float32);
582     }
583     std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
584     convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
585                             &recurrent_to_forget_weights_float32);
586     std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
587     convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
588     std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
589     convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
590                             &recurrent_to_output_weights_float32);
591 
592     std::vector<float> cell_to_input_weights_float32(numCells);
593     if (cell_to_input_weights_buffer != nullptr) {
594         convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
595     }
596     std::vector<float> cell_to_forget_weights_float32(numCells);
597     if (cell_to_forget_weights_buffer != nullptr) {
598         convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
599     }
600     std::vector<float> cell_to_output_weights_float32(numCells);
601     if (cell_to_output_weights_buffer != nullptr) {
602         convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
603     }
604 
605     std::vector<float> aux_input_float32(maxTime * batchInputSize);
606     if (aux_input_buffer != nullptr) {
607         convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
608     }
609     std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
610     if (aux_input_to_input_weights_buffer != nullptr) {
611         convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
612                                 &aux_input_to_input_weights_float32);
613     }
614     std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
615     if (aux_input_to_forget_weights_buffer != nullptr) {
616         convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
617                                 &aux_input_to_forget_weights_float32);
618     }
619     std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
620     if (aux_input_to_cell_weights_buffer != nullptr) {
621         convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
622                                 &aux_input_to_cell_weights_float32);
623     }
624     std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
625     if (aux_input_to_output_weights_buffer != nullptr) {
626         convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
627                                 &aux_input_to_output_weights_float32);
628     }
629 
630     std::vector<float> input_gate_bias_float32(numCells);
631     if (input_gate_bias_buffer != nullptr) {
632         convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
633     }
634     std::vector<float> forget_gate_bias_float32(numCells);
635     convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
636     std::vector<float> cell_bias_float32(numCells);
637     convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
638     std::vector<float> output_gate_bias_float32(numCells);
639     convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
640 
641     std::vector<float> projection_weights_float32(numCells * outputSize);
642     if (projection_weights_buffer != nullptr) {
643         convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
644     }
645     std::vector<float> projection_bias_float32(outputSize);
646     if (projection_bias_buffer != nullptr) {
647         convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
648     }
649 
650     std::vector<float> input_layer_norm_weights_float32(numCells);
651     if (input_layer_norm_weights_buffer != nullptr) {
652         convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
653     }
654     std::vector<float> forget_layer_norm_weights_float32(numCells);
655     if (forget_layer_norm_weights_buffer != nullptr) {
656         convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
657                                 &forget_layer_norm_weights_float32);
658     }
659     std::vector<float> cell_layer_norm_weights_float32(numCells);
660     if (cell_layer_norm_weights_buffer != nullptr) {
661         convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
662     }
663     std::vector<float> output_layer_norm_weights_float32(numCells);
664     if (output_layer_norm_weights_buffer != nullptr) {
665         convertFloat16ToFloat32(output_layer_norm_weights_buffer,
666                                 &output_layer_norm_weights_float32);
667     }
668 
669     std::vector<float> output_state_out_float32(batchOutputSize);
670     convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
671     std::vector<float> cell_state_out_float32(batchSize * numCells);
672     convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
673 
674     std::vector<float> output_float32(maxTime * batchOutputSize);
675     convertFloat16ToFloat32(output_buffer, &output_float32);
676     std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
677                                                               : 4 * batchSize * numCells);
678     convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
679 
680     std::vector<float> transposedInput;
681     const bool hasAuxInput = (aux_input_buffer != nullptr);
682     std::vector<float> transposedAuxInput;
683     std::vector<float> transposedOutput;
684     Shape transposedInputShape;
685     Shape transposedOutputShape;
686     if (!timeMajor) {
687         transposedInput.resize(maxTime * batchInputSize);
688         transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
689                                            transposedInput.data());
690         if (hasAuxInput) {
691             transposedAuxInput.resize(maxTime * batchInputSize);
692             transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
693                                                transposedAuxInput.data());
694         }
695         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
696         transposedOutput.resize(maxTime * batchOutputSize);
697         transposedOutputShape = transposedInputShape;
698         transposedOutputShape.dimensions[2] = outputSize;
699     }
700     const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
701     const float* auxInputData =
702             hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
703                         : nullptr;
704     float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
705 
706     std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
707     convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
708     std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
709     convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
710 
711     const float* inputCurrentTimeStep =
712             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
713     const float* auxInputCurrentTimeStep =
714             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
715                         : nullptr;
716     float* outputCurrentTimeStep =
717             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
718     const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
719     const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
720 
721     for (uint32_t t = 0; t < maxTime; ++t) {
722         LSTMStep(params, inputCurrentTimeStep, batchInputShape,
723                  input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
724                  input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
725                  input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
726                  recurrent_to_forget_weights_float32.data(),
727                  recurrent_to_cell_weights_float32.data(),
728                  recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
729                  cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
730                  cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
731                  aux_input_to_input_weights_float32.data(),
732                  aux_input_to_forget_weights_float32.data(),
733                  aux_input_to_cell_weights_float32.data(),
734                  aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
735                  forget_gate_bias_float32.data(), cell_bias_float32.data(),
736                  output_gate_bias_float32.data(), projection_weights_float32.data(),
737                  projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
738                  cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
739                  forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
740                  output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
741                  cell_state_out_float32.data(), outputCurrentTimeStep,
742                  scratch_buffer_float32.data());
743         inputCurrentTimeStep += batchInputDelta;
744         if (hasAuxInput) {
745             auxInputCurrentTimeStep += batchInputDelta;
746         }
747         outputCurrentTimeStep += batchOutputDelta;
748         outputStateInCurrentTimeStep = output_state_out_float32;
749         cellStateInCurrentTimeStep = cell_state_out_float32;
750     }
751 
752     if (!timeMajor) {
753         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
754                                            output_float32.data());
755     }
756 
757     convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
758     convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
759     convertFloat32ToFloat16(output_float32, output_buffer);
760     convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
761     return true;
762 }
763 
764 // static
LSTMStep(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer)765 bool LSTMCell::LSTMStep(
766         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
767         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
768         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
769         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
770         const float* recurrent_to_forget_weights_buffer,
771         const float* recurrent_to_cell_weights_buffer,
772         const float* recurrent_to_output_weights_buffer,
773         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
774         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
775         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
776         const float* aux_input_to_forget_weights_buffer,
777         const float* aux_input_to_cell_weights_buffer,
778         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
779         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
780         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
781         const float* projection_bias_buffer, const float* output_state_in_buffer,
782         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
783         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
784         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
785         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
786     NNTRACE_COMP("LSTMCell::LSTMStep");
787 
788     const uint32_t n_batch = input_shape.dimensions[0];
789     const uint32_t n_input = input_shape.dimensions[1];
790     // n_cell and n_output will be the same size when there is no projection.
791     const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
792     const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
793     const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
794 
795     // Index the scratch buffers pointers to the global scratch buffer.
796     float* input_gate_scratch = nullptr;
797     float* cell_scratch = nullptr;
798     float* forget_gate_scratch = nullptr;
799     float* output_gate_scratch = nullptr;
800     if (params.use_cifg) {
801         cell_scratch = scratch_buffer_buffer;
802         forget_gate_scratch = cell_scratch + n_cell * n_batch;
803         output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
804     } else {
805         input_gate_scratch = scratch_buffer_buffer;
806         cell_scratch = input_gate_scratch + n_cell * n_batch;
807         forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
808         output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
809     }
810 
811     if (!params.use_layer_norm) {
812         // Initialize scratch buffers with bias.
813         if (!params.use_cifg) {
814             tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
815                                                           input_gate_scratch);
816         }
817         tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
818                                                       forget_gate_scratch);
819         tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
820                                                       cell_scratch);
821         tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
822                                                       output_gate_scratch);
823     } else {
824         // Initialize scratch buffers with zeroes.
825         if (!params.use_cifg) {
826             std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
827         }
828         std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
829         std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
830         std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
831     }
832 
833     // For each batch and cell: compute input_weight * input.
834     if (!params.use_cifg) {
835         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_buffer,
836                                                                   n_cell, n_input, input_buffer,
837                                                                   n_batch, input_gate_scratch);
838     }
839     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_buffer,
840                                                               n_cell, n_input, input_buffer,
841                                                               n_batch, forget_gate_scratch);
842     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
843             input_to_cell_weights_buffer, n_cell, n_input, input_buffer, n_batch, cell_scratch);
844     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_buffer,
845                                                               n_cell, n_input, input_buffer,
846                                                               n_batch, output_gate_scratch);
847 
848     // If auxiliary input is available then compute aux_input_weight * aux_input
849     if (aux_input_buffer != nullptr) {
850         if (!params.use_cifg) {
851             tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
852                     aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
853                     n_batch, input_gate_scratch);
854         }
855 
856         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
857                 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
858                 forget_gate_scratch);
859         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
860                 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
861                 cell_scratch);
862         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
863                 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
864                 output_gate_scratch);
865     }
866 
867     // For each batch and cell: compute recurrent_weight * output_state.
868     if (!params.use_cifg) {
869         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
870                 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
871                 n_batch, input_gate_scratch);
872     }
873     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
874             recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
875             forget_gate_scratch);
876     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
877             recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
878             cell_scratch);
879     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
880             recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
881             output_gate_scratch);
882 
883     // For each batch and cell: update input gate.
884     if (!params.use_cifg) {
885         if (params.use_peephole) {
886             tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
887                     cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
888                     input_gate_scratch);
889         }
890         if (params.use_layer_norm) {
891             tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
892                                                           n_cell, n_batch);
893             tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
894                                                                 n_cell, input_gate_scratch, n_batch,
895                                                                 input_gate_scratch);
896             tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
897                                                        input_gate_scratch);
898         }
899         tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
900                                                    input_gate_scratch);
901     }
902 
903     // For each batch and cell: update forget gate.
904     if (params.use_peephole) {
905         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
906                                                                       n_cell, cell_state_in_buffer,
907                                                                       n_batch, forget_gate_scratch);
908     }
909     if (params.use_layer_norm) {
910         tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
911                                                       n_cell, n_batch);
912         tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
913                                                             n_cell, forget_gate_scratch, n_batch,
914                                                             forget_gate_scratch);
915         tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
916                                                    forget_gate_scratch);
917     }
918     tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
919                                                forget_gate_scratch);
920 
921     // For each batch and cell: update the cell.
922     if (params.use_layer_norm) {
923         tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch);
924         tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
925                                                             cell_scratch, n_batch, cell_scratch);
926         tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
927     }
928     tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
929                                                    n_batch * n_cell, cell_state_out_buffer);
930     tflite::tensor_utils::ApplyActivationToVector(
931             cell_scratch, n_batch * n_cell, static_cast<TfLiteFusedActivation>(params.activation),
932             cell_scratch);
933     if (params.use_cifg) {
934         tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
935                                          forget_gate_scratch);
936         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
937                 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
938     } else {
939         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
940                 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
941     }
942     if (params.cell_clip > 0.0) {
943         tflite::tensor_utils::CwiseClipping(cell_state_out_buffer, n_batch * n_cell,
944                                             params.cell_clip);
945     }
946 
947     // For each batch and cell: update the output gate.
948     if (params.use_peephole) {
949         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
950                                                                       n_cell, cell_state_out_buffer,
951                                                                       n_batch, output_gate_scratch);
952     }
953     if (params.use_layer_norm) {
954         tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
955                                                       n_cell, n_batch);
956         tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
957                                                             n_cell, output_gate_scratch, n_batch,
958                                                             output_gate_scratch);
959         tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
960                                                    output_gate_scratch);
961     }
962     tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
963                                                output_gate_scratch);
964     tflite::tensor_utils::ApplyActivationToVector(
965             cell_state_out_buffer, n_batch * n_cell,
966             static_cast<TfLiteFusedActivation>(params.activation), cell_scratch);
967     tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
968                                                    n_batch * n_cell, output_gate_scratch);
969 
970     // For each batch: update the projection and output_state.
971     if (params.use_projection_weight) {
972         if (params.use_projection_bias) {
973             tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
974                                                           output_buffer);
975         } else {
976             std::fill_n(output_buffer, n_batch * n_output, 0.0f);
977         }
978         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
979                 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
980                 output_buffer);
981         if (params.proj_clip > 0.0) {
982             tflite::tensor_utils::CwiseClipping(output_buffer, n_batch * n_output,
983                                                 params.proj_clip);
984         }
985     } else {
986         std::copy_n(output_gate_scratch, n_batch * n_output, output_buffer);
987     }
988     std::copy_n(output_buffer, n_batch * n_output, output_state_out_buffer);
989     return true;
990 }
991 
Eval()992 bool LSTMCell::Eval() {
993     switch (input_->type) {
994         case OperandType::TENSOR_FLOAT32: {
995             LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
996                             GetBuffer<const float>(input_to_input_weights_),
997                             GetBuffer<const float>(input_to_forget_weights_),
998                             GetBuffer<const float>(input_to_cell_weights_),
999                             GetBuffer<const float>(input_to_output_weights_),
1000                             input_to_output_weights_->shape(),
1001                             GetBuffer<const float>(recurrent_to_input_weights_),
1002                             GetBuffer<const float>(recurrent_to_forget_weights_),
1003                             GetBuffer<const float>(recurrent_to_cell_weights_),
1004                             GetBuffer<const float>(recurrent_to_output_weights_),
1005                             recurrent_to_output_weights_->shape(),
1006                             GetBuffer<const float>(cell_to_input_weights_),
1007                             GetBuffer<const float>(cell_to_forget_weights_),
1008                             GetBuffer<const float>(cell_to_output_weights_),
1009                             /*aux_input_buffer=*/nullptr,
1010                             /*aux_input_to_input_weights_buffer=*/nullptr,
1011                             /*aux_input_to_forget_weights_buffer=*/nullptr,
1012                             /*aux_input_to_cell_weights_buffer=*/nullptr,
1013                             /*aux_input_to_output_weights_buffer=*/nullptr,
1014                             GetBuffer<const float>(input_gate_bias_),
1015                             GetBuffer<const float>(forget_gate_bias_),
1016                             GetBuffer<const float>(cell_bias_),
1017                             GetBuffer<const float>(output_gate_bias_),
1018                             GetBuffer<const float>(projection_weights_),
1019                             GetBuffer<const float>(projection_bias_),
1020                             GetBuffer<const float>(output_state_in_),
1021                             GetBuffer<const float>(cell_state_in_),
1022                             GetBuffer<const float>(input_layer_norm_weights_),
1023                             GetBuffer<const float>(forget_layer_norm_weights_),
1024                             GetBuffer<const float>(cell_layer_norm_weights_),
1025                             GetBuffer<const float>(output_layer_norm_weights_),
1026                             GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
1027                             GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
1028         } break;
1029         case OperandType::TENSOR_FLOAT16: {
1030             LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
1031                             GetOptionalBuffer<const _Float16>(input_to_input_weights_),
1032                             GetBuffer<const _Float16>(input_to_forget_weights_),
1033                             GetBuffer<const _Float16>(input_to_cell_weights_),
1034                             GetBuffer<const _Float16>(input_to_output_weights_),
1035                             input_to_output_weights_->shape(),
1036                             GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
1037                             GetBuffer<const _Float16>(recurrent_to_forget_weights_),
1038                             GetBuffer<const _Float16>(recurrent_to_cell_weights_),
1039                             GetBuffer<const _Float16>(recurrent_to_output_weights_),
1040                             recurrent_to_output_weights_->shape(),
1041                             GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
1042                             GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
1043                             GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
1044                             /*aux_input_buffer=*/nullptr,
1045                             /*aux_input_to_input_weights_buffer=*/nullptr,
1046                             /*aux_input_to_forget_weights_buffer=*/nullptr,
1047                             /*aux_input_to_cell_weights_buffer=*/nullptr,
1048                             /*aux_input_to_output_weights_buffer=*/nullptr,
1049                             GetOptionalBuffer<const _Float16>(input_gate_bias_),
1050                             GetBuffer<const _Float16>(forget_gate_bias_),
1051                             GetBuffer<const _Float16>(cell_bias_),
1052                             GetBuffer<const _Float16>(output_gate_bias_),
1053                             GetOptionalBuffer<const _Float16>(projection_weights_),
1054                             GetOptionalBuffer<const _Float16>(projection_bias_),
1055                             GetBuffer<const _Float16>(output_state_in_),
1056                             GetBuffer<const _Float16>(cell_state_in_),
1057                             GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
1058                             GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
1059                             GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
1060                             GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
1061                             GetBuffer<_Float16>(output_state_out_),
1062                             GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
1063                             GetBuffer<_Float16>(scratch_buffer_));
1064         } break;
1065         default: {
1066             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
1067             return false;
1068         }
1069     }
1070     return true;
1071 }
1072 
1073 }  // namespace nn
1074 }  // namespace android
1075