1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "LSTM.h"
20
21 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
22
23 #include <vector>
24
25 #include "CpuExecutor.h"
26 #include "CpuOperationUtils.h"
27 #include "LegacyUtils.h"
28 #include "OperationsExecutionUtils.h"
29 #include "Tracing.h"
30 #include "nnapi/Types.h"
31
32 namespace android {
33 namespace nn {
34
35 namespace {
36
37 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)38 inline T* GetBuffer(RunTimeOperandInfo* operand) {
39 return reinterpret_cast<T*>(operand->buffer);
40 }
41
42 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)43 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
44 return reinterpret_cast<const T*>(operand->buffer);
45 }
46
47 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)48 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
49 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
50 }
51
52 } // anonymous namespace
53
LSTMCell(const Operation & operation,RunTimeOperandInfo * operands)54 LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) {
55 input_ = GetInput(operation, operands, kInputTensor);
56
57 input_to_input_weights_ =
58 GetInput(operation, operands, kInputToInputWeightsTensor); // optional
59 input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
60 input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
61 input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
62
63 recurrent_to_input_weights_ =
64 GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional
65 recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
66 recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
67 recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
68
69 cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional
70 cell_to_forget_weights_ =
71 GetInput(operation, operands, kCellToForgetWeightsTensor); // optional
72 cell_to_output_weights_ =
73 GetInput(operation, operands, kCellToOutputWeightsTensor); // optional
74
75 input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
76 forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
77 cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
78 output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
79
80 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional
81 projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional
82
83 output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
84 cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
85
86 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
87 params_.activation = static_cast<ActivationFn>(getScalarDataWithDefault<int32_t>(
88 activationOperand, TfLiteFusedActivation::kTfLiteActNone));
89
90 const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
91 const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
92 if (input_->type == OperandType::TENSOR_FLOAT32) {
93 params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f);
94 params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f);
95 } else {
96 params_.cell_clip =
97 static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f));
98 params_.proj_clip =
99 static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f));
100 }
101
102 // We check the version of LSTM by checking the number of the inputs to the
103 // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
104 if (operation.inputs.size() == 27) {
105 input_layer_norm_weights_ =
106 GetInput(operation, operands, kInputLayerNormWeightsTensor); // optional
107 forget_layer_norm_weights_ =
108 GetInput(operation, operands, kForgetLayerNormWeightsTensor); // optional
109 cell_layer_norm_weights_ =
110 GetInput(operation, operands, kCellLayerNormWeightsTensor); // optional
111 output_layer_norm_weights_ =
112 GetInput(operation, operands, kOutputLayerNormWeightsTensor); // optional
113 } else {
114 // For LSTM from HAL v1.0 assign operands with no values
115 static RunTimeOperandInfo no_value;
116 no_value.lifetime = Operand::LifeTime::NO_VALUE;
117
118 input_layer_norm_weights_ = &no_value;
119 forget_layer_norm_weights_ = &no_value;
120 cell_layer_norm_weights_ = &no_value;
121 output_layer_norm_weights_ = &no_value;
122 }
123
124 output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
125 cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
126 output_ = GetOutput(operation, operands, kOutputTensor);
127
128 scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
129 }
130
131 // static
CheckInputTensorDimensions(const RunTimeOperandInfo *,const RunTimeOperandInfo * input_to_input_weights,const RunTimeOperandInfo * input_to_forget_weights,const RunTimeOperandInfo * input_to_cell_weights,const RunTimeOperandInfo *,const RunTimeOperandInfo * recurrent_to_input_weights,const RunTimeOperandInfo * recurrent_to_forget_weights,const RunTimeOperandInfo * recurrent_to_cell_weights,const RunTimeOperandInfo *,const RunTimeOperandInfo * cell_to_input_weights,const RunTimeOperandInfo * cell_to_forget_weights,const RunTimeOperandInfo * cell_to_output_weights,const RunTimeOperandInfo * input_gate_bias,const RunTimeOperandInfo * forget_gate_bias,const RunTimeOperandInfo * cell_bias,const RunTimeOperandInfo * output_gate_bias,const RunTimeOperandInfo * projection_weights,const RunTimeOperandInfo * projection_bias,const RunTimeOperandInfo * input_layer_norm_weights,const RunTimeOperandInfo * forget_layer_norm_weights,const RunTimeOperandInfo * cell_layer_norm_weights,const RunTimeOperandInfo * output_layer_norm_weights,uint32_t n_input,uint32_t n_output,uint32_t n_cell,LSTMParams * params)132 bool LSTMCell::CheckInputTensorDimensions(
133 const RunTimeOperandInfo* /*input_*/, const RunTimeOperandInfo* input_to_input_weights,
134 const RunTimeOperandInfo* input_to_forget_weights,
135 const RunTimeOperandInfo* input_to_cell_weights,
136 const RunTimeOperandInfo* /*input_to_output_weights*/,
137 const RunTimeOperandInfo* recurrent_to_input_weights,
138 const RunTimeOperandInfo* recurrent_to_forget_weights,
139 const RunTimeOperandInfo* recurrent_to_cell_weights,
140 const RunTimeOperandInfo* /*recurrent_to_output_weights*/,
141 const RunTimeOperandInfo* cell_to_input_weights,
142 const RunTimeOperandInfo* cell_to_forget_weights,
143 const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
144 const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
145 const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
146 const RunTimeOperandInfo* projection_bias,
147 const RunTimeOperandInfo* input_layer_norm_weights,
148 const RunTimeOperandInfo* forget_layer_norm_weights,
149 const RunTimeOperandInfo* cell_layer_norm_weights,
150 const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
151 uint32_t n_cell, LSTMParams* params) {
152 // Making sure clipping parameters have valid values.
153 // == 0 means no clipping
154 // > 0 means clipping
155 NN_CHECK(params->cell_clip >= 0);
156 NN_CHECK(params->proj_clip >= 0);
157
158 if (!IsNullInput(input_to_input_weights)) {
159 NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2u);
160 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
161 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
162 }
163
164 NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2u);
165 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
166 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
167
168 NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2u);
169 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
170 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
171
172 if (!IsNullInput(recurrent_to_input_weights)) {
173 NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2u);
174 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
175 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
176 }
177
178 NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2u);
179 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
180 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
181
182 NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2u);
183 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
184 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
185
186 // We make sure the input-gate's parameters are either both present (regular
187 // LSTM) or not at all (CIFG-LSTM).
188 const bool cifg_weights_all_or_none =
189 (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
190 (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
191 NN_CHECK(cifg_weights_all_or_none);
192
193 if (!IsNullInput(cell_to_input_weights)) {
194 NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1u);
195 NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
196 }
197
198 if (!IsNullInput(cell_to_forget_weights)) {
199 NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1u);
200 NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
201 }
202
203 if (!IsNullInput(cell_to_output_weights)) {
204 NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1u);
205 NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
206 }
207
208 // Making sure the peephole weights are there all or none.
209 params->use_cifg = IsNullInput(input_to_input_weights);
210 const bool peephole_weights_all_or_none =
211 ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
212 !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
213 (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
214 IsNullInput(cell_to_output_weights));
215 NN_CHECK(peephole_weights_all_or_none);
216
217 // Since we have already checked that weights are all there or none, we can
218 // check the existence of only one to the get the condition.
219 params->use_peephole = !IsNullInput(cell_to_output_weights);
220 // Checking output instead of input layer norm weights because input can be
221 // omitted ones can be omited in case CIFG LSTM is used.
222 params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
223
224 params->use_projection_weight = (projection_weights->lifetime != Operand::LifeTime::NO_VALUE);
225 params->use_projection_bias = (projection_bias->lifetime != Operand::LifeTime::NO_VALUE);
226
227 // Make sure the input gate bias is present only when not a CIFG-LSTM.
228 if (params->use_cifg) {
229 NN_CHECK(IsNullInput(input_gate_bias));
230 } else {
231 NN_CHECK_EQ(NumDimensions(input_gate_bias), 1u);
232 NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
233 }
234
235 NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1u);
236 NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
237
238 NN_CHECK_EQ(NumDimensions(cell_bias), 1u);
239 NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
240
241 NN_CHECK_EQ(NumDimensions(output_gate_bias), 1u);
242 NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
243
244 if (!IsNullInput(projection_weights)) {
245 NN_CHECK_EQ(NumDimensions(projection_weights), 2u);
246 NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
247 NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
248 }
249
250 if (!IsNullInput(projection_bias)) {
251 NN_CHECK_EQ(NumDimensions(projection_bias), 1u);
252 NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
253 }
254
255 // Making sure the projection tensors are consistent:
256 // 1) If projection weight is not present, then projection bias should not be
257 // present.
258 // 2) If projection weight is present, then projection bias is optional.
259 // TODO: make sure this is correct.
260 const bool projecton_tensors_consistent =
261 (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
262 NN_CHECK(projecton_tensors_consistent == true);
263
264 if (!IsNullInput(input_layer_norm_weights)) {
265 NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1u);
266 NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
267 }
268 if (!IsNullInput(forget_layer_norm_weights)) {
269 NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1u);
270 NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
271 }
272 if (!IsNullInput(cell_layer_norm_weights)) {
273 NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1u);
274 NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
275 }
276 if (!IsNullInput(output_layer_norm_weights)) {
277 NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1u);
278 NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
279 }
280
281 if (params->use_cifg) {
282 NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
283 << "input_layer_norm_weights are provided while CIFG is used";
284 const bool layer_norm_weights_all_or_none_cifg =
285 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
286 IsNullInput(output_layer_norm_weights)) ||
287 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
288 !IsNullInput(output_layer_norm_weights));
289 NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
290 } else {
291 const bool layer_norm_weights_all_or_none =
292 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
293 IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
294 (!IsNullInput(input_layer_norm_weights) &&
295 !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
296 !IsNullInput(output_layer_norm_weights));
297 NN_RET_CHECK(layer_norm_weights_all_or_none);
298 }
299
300 return true;
301 }
302
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)303 bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
304 Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
305 Shape* outputShape) {
306 // Check we have all the inputs and outputs we need.
307 NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
308 NumInputsWithValues(operation, operands) <= 27);
309 constexpr int requiredInputs[] = {
310 kInputTensor,
311 kInputToForgetWeightsTensor,
312 kInputToCellWeightsTensor,
313 kInputToOutputWeightsTensor,
314 kRecurrentToForgetWeightsTensor,
315 kRecurrentToCellWeightsTensor,
316 kRecurrentToOutputWeightsTensor,
317 kForgetGateBiasTensor,
318 kCellGateBiasTensor,
319 kOutputGateBiasTensor,
320 kOutputStateInTensor,
321 kCellStateInTensor,
322 kActivationParam,
323 kCellClipParam,
324 kProjClipParam,
325 };
326 for (const int requiredInput : requiredInputs) {
327 NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
328 << "required input " << requiredInput << " is omitted";
329 }
330 NN_CHECK_EQ(NumOutputs(operation), 4);
331
332 // Check that the scalar operands' buffers are large enough.
333 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
334 NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
335 const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
336 const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
337 if (input_->type == OperandType::TENSOR_FLOAT32) {
338 NN_RET_CHECK(cellClipOperand.length >= sizeof(float));
339 NN_RET_CHECK(projClipOperand.length >= sizeof(float));
340 } else {
341 NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16));
342 NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16));
343 }
344
345 // Inferring batch size, number of outputs and number of cells from the
346 // input tensors.
347 NN_CHECK(NumDimensions(input_) > 1);
348 const uint32_t n_batch = SizeOfDimension(input_, 0);
349 const uint32_t n_input = SizeOfDimension(input_, 1);
350
351 const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
352 NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2u);
353 NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
354
355 NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2u);
356 NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
357 const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
358
359 // Check that input tensor dimensions matches with each other.
360 if (!CheckInputTensorDimensions(
361 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
362 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
363 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
364 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
365 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
366 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
367 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
368 ¶ms_)) {
369 return false;
370 }
371
372 // Resize the output and output_state tensors.
373 const Shape& inputShape = input_->shape();
374
375 outputShape->type = inputShape.type;
376 outputShape->dimensions = {n_batch, n_output};
377 outputShape->offset = inputShape.offset;
378 outputShape->scale = inputShape.scale;
379
380 outputStateShape->type = inputShape.type;
381 outputStateShape->dimensions = {n_batch, n_output};
382 outputStateShape->offset = inputShape.offset;
383 outputStateShape->scale = inputShape.scale;
384
385 cellStateShape->type = inputShape.type;
386 cellStateShape->dimensions = {n_batch, n_cell};
387 cellStateShape->offset = inputShape.offset;
388 cellStateShape->scale = inputShape.scale;
389
390 if (params_.use_cifg) {
391 // Reserving space for Cell, Forget, Output gates
392 scratchShape->dimensions = {n_batch, n_cell * 3};
393 } else {
394 // Reserving space for Input, Cell, Forget, Output gates
395 scratchShape->dimensions = {n_batch, n_cell * 4};
396 }
397 scratchShape->type = inputShape.type;
398 scratchShape->offset = inputShape.offset;
399 scratchShape->scale = inputShape.scale;
400
401 return true;
402 }
403
404 // static
LSTMEvalFloat32(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)405 bool LSTMCell::LSTMEvalFloat32(
406 const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
407 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
408 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
409 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
410 const float* recurrent_to_forget_weights_buffer,
411 const float* recurrent_to_cell_weights_buffer,
412 const float* recurrent_to_output_weights_buffer,
413 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
414 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
415 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
416 const float* aux_input_to_forget_weights_buffer,
417 const float* aux_input_to_cell_weights_buffer,
418 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
419 const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
420 const float* output_gate_bias_buffer, const float* projection_weights_buffer,
421 const float* projection_bias_buffer, const float* output_state_in_buffer,
422 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
423 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
424 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
425 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
426 bool timeMajor, bool forwardSequence) {
427 NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
428
429 const uint32_t inputRank = getNumberOfDimensions(input_shape);
430 NN_CHECK(inputRank == 2 || inputRank == 3);
431
432 const uint32_t maxTime =
433 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
434 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
435 : getSizeOfDimension(input_shape, 0);
436 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
437 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
438 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
439
440 Shape batchInputShape = input_shape;
441 batchInputShape.dimensions = {batchSize, inputSize};
442 const uint32_t batchInputSize = batchSize * inputSize;
443 const uint32_t batchOutputSize = batchSize * outputSize;
444
445 std::vector<float> transposedInput;
446 const bool hasAuxInput = (aux_input_buffer != nullptr);
447 std::vector<float> transposedAuxInput;
448 std::vector<float> transposedOutput;
449 Shape transposedInputShape;
450 Shape transposedOutputShape;
451 if (!timeMajor) {
452 transposedInput.resize(maxTime * batchInputSize);
453 transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
454 if (hasAuxInput) {
455 transposedAuxInput.resize(maxTime * batchInputSize);
456 transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
457 transposedAuxInput.data());
458 }
459 transposeFirstTwoDimensions(input_shape, &transposedInputShape);
460 transposedOutput.resize(maxTime * batchOutputSize);
461 transposedOutputShape = transposedInputShape;
462 transposedOutputShape.dimensions[2] = outputSize;
463 }
464 const float* inputData = timeMajor ? input_buffer : transposedInput.data();
465 const float* auxInputData =
466 hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
467 float* outputData = timeMajor ? output_buffer : transposedOutput.data();
468
469 std::vector<float> outputStateInCurrentTimeStep(
470 output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
471 std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
472 cell_state_in_buffer + batchSize * numCells);
473 const float* inputCurrentTimeStep =
474 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
475 const float* auxInputCurrentTimeStep =
476 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
477 : nullptr;
478 float* outputCurrentTimeStep =
479 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
480 const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
481 const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
482
483 for (uint32_t t = 0; t < maxTime; ++t) {
484 LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
485 input_to_forget_weights_buffer, input_to_cell_weights_buffer,
486 input_to_output_weights_buffer, input_to_output_weights_shape,
487 recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
488 recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
489 recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
490 cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
491 auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
492 aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
493 aux_input_to_output_weights_buffer, input_gate_bias_buffer,
494 forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
495 projection_weights_buffer, projection_bias_buffer,
496 outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
497 input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
498 cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
499 output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
500 scratch_buffer_buffer);
501 inputCurrentTimeStep += batchInputDelta;
502 if (hasAuxInput) {
503 auxInputCurrentTimeStep += batchInputDelta;
504 }
505 outputCurrentTimeStep += batchOutputDelta;
506 outputStateInCurrentTimeStep.assign(output_state_out_buffer,
507 output_state_out_buffer + batchSize * outputSize);
508 cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
509 cell_state_out_buffer + batchSize * numCells);
510 }
511
512 if (!timeMajor) {
513 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
514 output_buffer);
515 }
516
517 return true;
518 }
519
520 // static
LSTMEvalFloat16(const LSTMParams & params,const _Float16 * input_buffer,const Shape & input_shape,const _Float16 * input_to_input_weights_buffer,const _Float16 * input_to_forget_weights_buffer,const _Float16 * input_to_cell_weights_buffer,const _Float16 * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const _Float16 * recurrent_to_input_weights_buffer,const _Float16 * recurrent_to_forget_weights_buffer,const _Float16 * recurrent_to_cell_weights_buffer,const _Float16 * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const _Float16 * cell_to_input_weights_buffer,const _Float16 * cell_to_forget_weights_buffer,const _Float16 * cell_to_output_weights_buffer,const _Float16 * aux_input_buffer,const _Float16 * aux_input_to_input_weights_buffer,const _Float16 * aux_input_to_forget_weights_buffer,const _Float16 * aux_input_to_cell_weights_buffer,const _Float16 * aux_input_to_output_weights_buffer,const _Float16 * input_gate_bias_buffer,const _Float16 * forget_gate_bias_buffer,const _Float16 * cell_bias_buffer,const _Float16 * output_gate_bias_buffer,const _Float16 * projection_weights_buffer,const _Float16 * projection_bias_buffer,const _Float16 * output_state_in_buffer,const _Float16 * cell_state_in_buffer,const _Float16 * input_layer_norm_weights_buffer,const _Float16 * forget_layer_norm_weights_buffer,const _Float16 * cell_layer_norm_weights_buffer,const _Float16 * output_layer_norm_weights_buffer,_Float16 * output_state_out_buffer,_Float16 * cell_state_out_buffer,_Float16 * output_buffer,_Float16 * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)521 bool LSTMCell::LSTMEvalFloat16(
522 const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
523 const _Float16* input_to_input_weights_buffer,
524 const _Float16* input_to_forget_weights_buffer,
525 const _Float16* input_to_cell_weights_buffer,
526 const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
527 const _Float16* recurrent_to_input_weights_buffer,
528 const _Float16* recurrent_to_forget_weights_buffer,
529 const _Float16* recurrent_to_cell_weights_buffer,
530 const _Float16* recurrent_to_output_weights_buffer,
531 const Shape& recurrent_to_output_weights_shape,
532 const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
533 const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
534 const _Float16* aux_input_to_input_weights_buffer,
535 const _Float16* aux_input_to_forget_weights_buffer,
536 const _Float16* aux_input_to_cell_weights_buffer,
537 const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
538 const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
539 const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
540 const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
541 const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
542 const _Float16* forget_layer_norm_weights_buffer,
543 const _Float16* cell_layer_norm_weights_buffer,
544 const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
545 _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
546 bool timeMajor, bool forwardSequence) {
547 NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
548
549 const uint32_t inputRank = getNumberOfDimensions(input_shape);
550 NN_CHECK(inputRank == 2 || inputRank == 3);
551
552 const uint32_t maxTime =
553 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
554 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
555 : getSizeOfDimension(input_shape, 0);
556 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
557 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
558 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
559
560 Shape batchInputShape = input_shape;
561 batchInputShape.dimensions = {batchSize, inputSize};
562 const uint32_t batchInputSize = batchSize * inputSize;
563 const uint32_t batchOutputSize = batchSize * outputSize;
564
565 std::vector<float> input_float32(maxTime * batchInputSize);
566 convertFloat16ToFloat32(input_buffer, &input_float32);
567 std::vector<float> input_to_input_weights_float32(numCells * inputSize);
568 if (input_to_input_weights_buffer != nullptr) {
569 convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
570 }
571 std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
572 convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
573 std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
574 convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
575 std::vector<float> input_to_output_weights_float32(numCells * inputSize);
576 convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
577
578 std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
579 if (recurrent_to_input_weights_buffer != nullptr) {
580 convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
581 &recurrent_to_input_weights_float32);
582 }
583 std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
584 convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
585 &recurrent_to_forget_weights_float32);
586 std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
587 convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
588 std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
589 convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
590 &recurrent_to_output_weights_float32);
591
592 std::vector<float> cell_to_input_weights_float32(numCells);
593 if (cell_to_input_weights_buffer != nullptr) {
594 convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
595 }
596 std::vector<float> cell_to_forget_weights_float32(numCells);
597 if (cell_to_forget_weights_buffer != nullptr) {
598 convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
599 }
600 std::vector<float> cell_to_output_weights_float32(numCells);
601 if (cell_to_output_weights_buffer != nullptr) {
602 convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
603 }
604
605 std::vector<float> aux_input_float32(maxTime * batchInputSize);
606 if (aux_input_buffer != nullptr) {
607 convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
608 }
609 std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
610 if (aux_input_to_input_weights_buffer != nullptr) {
611 convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
612 &aux_input_to_input_weights_float32);
613 }
614 std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
615 if (aux_input_to_forget_weights_buffer != nullptr) {
616 convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
617 &aux_input_to_forget_weights_float32);
618 }
619 std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
620 if (aux_input_to_cell_weights_buffer != nullptr) {
621 convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
622 &aux_input_to_cell_weights_float32);
623 }
624 std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
625 if (aux_input_to_output_weights_buffer != nullptr) {
626 convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
627 &aux_input_to_output_weights_float32);
628 }
629
630 std::vector<float> input_gate_bias_float32(numCells);
631 if (input_gate_bias_buffer != nullptr) {
632 convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
633 }
634 std::vector<float> forget_gate_bias_float32(numCells);
635 convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
636 std::vector<float> cell_bias_float32(numCells);
637 convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
638 std::vector<float> output_gate_bias_float32(numCells);
639 convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
640
641 std::vector<float> projection_weights_float32(numCells * outputSize);
642 if (projection_weights_buffer != nullptr) {
643 convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
644 }
645 std::vector<float> projection_bias_float32(outputSize);
646 if (projection_bias_buffer != nullptr) {
647 convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
648 }
649
650 std::vector<float> input_layer_norm_weights_float32(numCells);
651 if (input_layer_norm_weights_buffer != nullptr) {
652 convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
653 }
654 std::vector<float> forget_layer_norm_weights_float32(numCells);
655 if (forget_layer_norm_weights_buffer != nullptr) {
656 convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
657 &forget_layer_norm_weights_float32);
658 }
659 std::vector<float> cell_layer_norm_weights_float32(numCells);
660 if (cell_layer_norm_weights_buffer != nullptr) {
661 convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
662 }
663 std::vector<float> output_layer_norm_weights_float32(numCells);
664 if (output_layer_norm_weights_buffer != nullptr) {
665 convertFloat16ToFloat32(output_layer_norm_weights_buffer,
666 &output_layer_norm_weights_float32);
667 }
668
669 std::vector<float> output_state_out_float32(batchOutputSize);
670 convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
671 std::vector<float> cell_state_out_float32(batchSize * numCells);
672 convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
673
674 std::vector<float> output_float32(maxTime * batchOutputSize);
675 convertFloat16ToFloat32(output_buffer, &output_float32);
676 std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
677 : 4 * batchSize * numCells);
678 convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
679
680 std::vector<float> transposedInput;
681 const bool hasAuxInput = (aux_input_buffer != nullptr);
682 std::vector<float> transposedAuxInput;
683 std::vector<float> transposedOutput;
684 Shape transposedInputShape;
685 Shape transposedOutputShape;
686 if (!timeMajor) {
687 transposedInput.resize(maxTime * batchInputSize);
688 transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
689 transposedInput.data());
690 if (hasAuxInput) {
691 transposedAuxInput.resize(maxTime * batchInputSize);
692 transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
693 transposedAuxInput.data());
694 }
695 transposeFirstTwoDimensions(input_shape, &transposedInputShape);
696 transposedOutput.resize(maxTime * batchOutputSize);
697 transposedOutputShape = transposedInputShape;
698 transposedOutputShape.dimensions[2] = outputSize;
699 }
700 const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
701 const float* auxInputData =
702 hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
703 : nullptr;
704 float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
705
706 std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
707 convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
708 std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
709 convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
710
711 const float* inputCurrentTimeStep =
712 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
713 const float* auxInputCurrentTimeStep =
714 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
715 : nullptr;
716 float* outputCurrentTimeStep =
717 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
718 const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
719 const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
720
721 for (uint32_t t = 0; t < maxTime; ++t) {
722 LSTMStep(params, inputCurrentTimeStep, batchInputShape,
723 input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
724 input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
725 input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
726 recurrent_to_forget_weights_float32.data(),
727 recurrent_to_cell_weights_float32.data(),
728 recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
729 cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
730 cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
731 aux_input_to_input_weights_float32.data(),
732 aux_input_to_forget_weights_float32.data(),
733 aux_input_to_cell_weights_float32.data(),
734 aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
735 forget_gate_bias_float32.data(), cell_bias_float32.data(),
736 output_gate_bias_float32.data(), projection_weights_float32.data(),
737 projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
738 cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
739 forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
740 output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
741 cell_state_out_float32.data(), outputCurrentTimeStep,
742 scratch_buffer_float32.data());
743 inputCurrentTimeStep += batchInputDelta;
744 if (hasAuxInput) {
745 auxInputCurrentTimeStep += batchInputDelta;
746 }
747 outputCurrentTimeStep += batchOutputDelta;
748 outputStateInCurrentTimeStep = output_state_out_float32;
749 cellStateInCurrentTimeStep = cell_state_out_float32;
750 }
751
752 if (!timeMajor) {
753 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
754 output_float32.data());
755 }
756
757 convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
758 convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
759 convertFloat32ToFloat16(output_float32, output_buffer);
760 convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
761 return true;
762 }
763
764 // static
LSTMStep(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer)765 bool LSTMCell::LSTMStep(
766 const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
767 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
768 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
769 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
770 const float* recurrent_to_forget_weights_buffer,
771 const float* recurrent_to_cell_weights_buffer,
772 const float* recurrent_to_output_weights_buffer,
773 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
774 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
775 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
776 const float* aux_input_to_forget_weights_buffer,
777 const float* aux_input_to_cell_weights_buffer,
778 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
779 const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
780 const float* output_gate_bias_buffer, const float* projection_weights_buffer,
781 const float* projection_bias_buffer, const float* output_state_in_buffer,
782 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
783 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
784 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
785 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
786 NNTRACE_COMP("LSTMCell::LSTMStep");
787
788 const uint32_t n_batch = input_shape.dimensions[0];
789 const uint32_t n_input = input_shape.dimensions[1];
790 // n_cell and n_output will be the same size when there is no projection.
791 const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
792 const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
793 const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
794
795 // Index the scratch buffers pointers to the global scratch buffer.
796 float* input_gate_scratch = nullptr;
797 float* cell_scratch = nullptr;
798 float* forget_gate_scratch = nullptr;
799 float* output_gate_scratch = nullptr;
800 if (params.use_cifg) {
801 cell_scratch = scratch_buffer_buffer;
802 forget_gate_scratch = cell_scratch + n_cell * n_batch;
803 output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
804 } else {
805 input_gate_scratch = scratch_buffer_buffer;
806 cell_scratch = input_gate_scratch + n_cell * n_batch;
807 forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
808 output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
809 }
810
811 if (!params.use_layer_norm) {
812 // Initialize scratch buffers with bias.
813 if (!params.use_cifg) {
814 tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
815 input_gate_scratch);
816 }
817 tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
818 forget_gate_scratch);
819 tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
820 cell_scratch);
821 tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
822 output_gate_scratch);
823 } else {
824 // Initialize scratch buffers with zeroes.
825 if (!params.use_cifg) {
826 std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
827 }
828 std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
829 std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
830 std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
831 }
832
833 // For each batch and cell: compute input_weight * input.
834 if (!params.use_cifg) {
835 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_buffer,
836 n_cell, n_input, input_buffer,
837 n_batch, input_gate_scratch);
838 }
839 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_buffer,
840 n_cell, n_input, input_buffer,
841 n_batch, forget_gate_scratch);
842 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
843 input_to_cell_weights_buffer, n_cell, n_input, input_buffer, n_batch, cell_scratch);
844 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_buffer,
845 n_cell, n_input, input_buffer,
846 n_batch, output_gate_scratch);
847
848 // If auxiliary input is available then compute aux_input_weight * aux_input
849 if (aux_input_buffer != nullptr) {
850 if (!params.use_cifg) {
851 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
852 aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
853 n_batch, input_gate_scratch);
854 }
855
856 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
857 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
858 forget_gate_scratch);
859 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
860 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
861 cell_scratch);
862 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
863 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
864 output_gate_scratch);
865 }
866
867 // For each batch and cell: compute recurrent_weight * output_state.
868 if (!params.use_cifg) {
869 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
870 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
871 n_batch, input_gate_scratch);
872 }
873 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
874 recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
875 forget_gate_scratch);
876 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
877 recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
878 cell_scratch);
879 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
880 recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
881 output_gate_scratch);
882
883 // For each batch and cell: update input gate.
884 if (!params.use_cifg) {
885 if (params.use_peephole) {
886 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
887 cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
888 input_gate_scratch);
889 }
890 if (params.use_layer_norm) {
891 tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
892 n_cell, n_batch);
893 tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
894 n_cell, input_gate_scratch, n_batch,
895 input_gate_scratch);
896 tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
897 input_gate_scratch);
898 }
899 tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
900 input_gate_scratch);
901 }
902
903 // For each batch and cell: update forget gate.
904 if (params.use_peephole) {
905 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
906 n_cell, cell_state_in_buffer,
907 n_batch, forget_gate_scratch);
908 }
909 if (params.use_layer_norm) {
910 tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
911 n_cell, n_batch);
912 tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
913 n_cell, forget_gate_scratch, n_batch,
914 forget_gate_scratch);
915 tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
916 forget_gate_scratch);
917 }
918 tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
919 forget_gate_scratch);
920
921 // For each batch and cell: update the cell.
922 if (params.use_layer_norm) {
923 tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch);
924 tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
925 cell_scratch, n_batch, cell_scratch);
926 tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
927 }
928 tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
929 n_batch * n_cell, cell_state_out_buffer);
930 tflite::tensor_utils::ApplyActivationToVector(
931 cell_scratch, n_batch * n_cell, static_cast<TfLiteFusedActivation>(params.activation),
932 cell_scratch);
933 if (params.use_cifg) {
934 tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
935 forget_gate_scratch);
936 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
937 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
938 } else {
939 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
940 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
941 }
942 if (params.cell_clip > 0.0) {
943 tflite::tensor_utils::CwiseClipping(cell_state_out_buffer, n_batch * n_cell,
944 params.cell_clip);
945 }
946
947 // For each batch and cell: update the output gate.
948 if (params.use_peephole) {
949 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
950 n_cell, cell_state_out_buffer,
951 n_batch, output_gate_scratch);
952 }
953 if (params.use_layer_norm) {
954 tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
955 n_cell, n_batch);
956 tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
957 n_cell, output_gate_scratch, n_batch,
958 output_gate_scratch);
959 tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
960 output_gate_scratch);
961 }
962 tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
963 output_gate_scratch);
964 tflite::tensor_utils::ApplyActivationToVector(
965 cell_state_out_buffer, n_batch * n_cell,
966 static_cast<TfLiteFusedActivation>(params.activation), cell_scratch);
967 tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
968 n_batch * n_cell, output_gate_scratch);
969
970 // For each batch: update the projection and output_state.
971 if (params.use_projection_weight) {
972 if (params.use_projection_bias) {
973 tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
974 output_buffer);
975 } else {
976 std::fill_n(output_buffer, n_batch * n_output, 0.0f);
977 }
978 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
979 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
980 output_buffer);
981 if (params.proj_clip > 0.0) {
982 tflite::tensor_utils::CwiseClipping(output_buffer, n_batch * n_output,
983 params.proj_clip);
984 }
985 } else {
986 std::copy_n(output_gate_scratch, n_batch * n_output, output_buffer);
987 }
988 std::copy_n(output_buffer, n_batch * n_output, output_state_out_buffer);
989 return true;
990 }
991
Eval()992 bool LSTMCell::Eval() {
993 switch (input_->type) {
994 case OperandType::TENSOR_FLOAT32: {
995 LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
996 GetBuffer<const float>(input_to_input_weights_),
997 GetBuffer<const float>(input_to_forget_weights_),
998 GetBuffer<const float>(input_to_cell_weights_),
999 GetBuffer<const float>(input_to_output_weights_),
1000 input_to_output_weights_->shape(),
1001 GetBuffer<const float>(recurrent_to_input_weights_),
1002 GetBuffer<const float>(recurrent_to_forget_weights_),
1003 GetBuffer<const float>(recurrent_to_cell_weights_),
1004 GetBuffer<const float>(recurrent_to_output_weights_),
1005 recurrent_to_output_weights_->shape(),
1006 GetBuffer<const float>(cell_to_input_weights_),
1007 GetBuffer<const float>(cell_to_forget_weights_),
1008 GetBuffer<const float>(cell_to_output_weights_),
1009 /*aux_input_buffer=*/nullptr,
1010 /*aux_input_to_input_weights_buffer=*/nullptr,
1011 /*aux_input_to_forget_weights_buffer=*/nullptr,
1012 /*aux_input_to_cell_weights_buffer=*/nullptr,
1013 /*aux_input_to_output_weights_buffer=*/nullptr,
1014 GetBuffer<const float>(input_gate_bias_),
1015 GetBuffer<const float>(forget_gate_bias_),
1016 GetBuffer<const float>(cell_bias_),
1017 GetBuffer<const float>(output_gate_bias_),
1018 GetBuffer<const float>(projection_weights_),
1019 GetBuffer<const float>(projection_bias_),
1020 GetBuffer<const float>(output_state_in_),
1021 GetBuffer<const float>(cell_state_in_),
1022 GetBuffer<const float>(input_layer_norm_weights_),
1023 GetBuffer<const float>(forget_layer_norm_weights_),
1024 GetBuffer<const float>(cell_layer_norm_weights_),
1025 GetBuffer<const float>(output_layer_norm_weights_),
1026 GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
1027 GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
1028 } break;
1029 case OperandType::TENSOR_FLOAT16: {
1030 LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
1031 GetOptionalBuffer<const _Float16>(input_to_input_weights_),
1032 GetBuffer<const _Float16>(input_to_forget_weights_),
1033 GetBuffer<const _Float16>(input_to_cell_weights_),
1034 GetBuffer<const _Float16>(input_to_output_weights_),
1035 input_to_output_weights_->shape(),
1036 GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
1037 GetBuffer<const _Float16>(recurrent_to_forget_weights_),
1038 GetBuffer<const _Float16>(recurrent_to_cell_weights_),
1039 GetBuffer<const _Float16>(recurrent_to_output_weights_),
1040 recurrent_to_output_weights_->shape(),
1041 GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
1042 GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
1043 GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
1044 /*aux_input_buffer=*/nullptr,
1045 /*aux_input_to_input_weights_buffer=*/nullptr,
1046 /*aux_input_to_forget_weights_buffer=*/nullptr,
1047 /*aux_input_to_cell_weights_buffer=*/nullptr,
1048 /*aux_input_to_output_weights_buffer=*/nullptr,
1049 GetOptionalBuffer<const _Float16>(input_gate_bias_),
1050 GetBuffer<const _Float16>(forget_gate_bias_),
1051 GetBuffer<const _Float16>(cell_bias_),
1052 GetBuffer<const _Float16>(output_gate_bias_),
1053 GetOptionalBuffer<const _Float16>(projection_weights_),
1054 GetOptionalBuffer<const _Float16>(projection_bias_),
1055 GetBuffer<const _Float16>(output_state_in_),
1056 GetBuffer<const _Float16>(cell_state_in_),
1057 GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
1058 GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
1059 GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
1060 GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
1061 GetBuffer<_Float16>(output_state_out_),
1062 GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
1063 GetBuffer<_Float16>(scratch_buffer_));
1064 } break;
1065 default: {
1066 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
1067 return false;
1068 }
1069 }
1070 return true;
1071 }
1072
1073 } // namespace nn
1074 } // namespace android
1075