1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "QuantizedLSTM.h"
20 
21 #pragma clang diagnostic push
22 #pragma clang diagnostic ignored "-Wunused-parameter"
23 #pragma clang diagnostic ignored "-Wsign-compare"
24 #include <public/gemmlowp.h>
25 #include <tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h>
26 #pragma clang diagnostic pop
27 
28 #include <algorithm>
29 #include <vector>
30 
31 #include "CpuExecutor.h"
32 #include "CpuOperationUtils.h"
33 #include "Tracing.h"
34 
35 namespace android {
36 namespace nn {
37 
38 namespace {
39 
40 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)41 inline T* GetBuffer(RunTimeOperandInfo* operand) {
42     return reinterpret_cast<T*>(operand->buffer);
43 }
44 
45 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)46 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
47     return reinterpret_cast<const T*>(operand->buffer);
48 }
49 
50 using tflite::Dims;
51 
52 // The function below is taken from TF Lite implementation in order to decouple
53 // NN API from TF Lite dependency. Original function, with a description of its
54 // parameters and types can be found by this link:
55 // https://github.com/tensorflow/tensorflow/blob/0d697e5fc4c05c699eea0764364104ea500ccc68/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h#L1926
56 //
57 // clang-format off
58 template <int StateIntegerBits>
quantizedLstmStep(const uint8_t * input_data_uint8,const Dims<4> & input_dims,const uint8_t * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8_t * weights_data_uint8,const Dims<4> & weights_dims,const int32_t * bias_data_int32,const Dims<4> & bias_dims,const int16_t * prevCellState_data_int16,const Dims<4> & prevCellState_dims,int16_t * output_state_data_int16,const Dims<4> & output_state_dims,uint8_t * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8_t * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16_t * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32_t weights_zero_point,int32_t accum_multiplier,int accum_shift)59 void quantizedLstmStep(const uint8_t* input_data_uint8, const Dims<4>& input_dims,
60                        const uint8_t* prev_activ_data_uint8,
61                        const Dims<4>& prev_activ_dims, const uint8_t* weights_data_uint8,
62                        const Dims<4>& weights_dims, const int32_t* bias_data_int32,
63                        const Dims<4>& bias_dims, const int16_t* prevCellState_data_int16,
64                        const Dims<4>& prevCellState_dims, int16_t* output_state_data_int16,
65                        const Dims<4>& output_state_dims, uint8_t* output_activ_data_uint8,
66                        const Dims<4>& output_activ_dims, uint8_t* concat_temp_data_uint8,
67                        const Dims<4>& concat_temp_dims, int16_t* activ_temp_data_int16,
68                        const Dims<4>& activ_temp_dims, int32_t weights_zero_point,
69                        int32_t accum_multiplier, int accum_shift) {
70   // Gather dimensions information, and perform consistency checks.
71   const int outer_size =
72       MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prevCellState_dims,
73                               output_state_dims, output_activ_dims);
74   TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
75   TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
76   const int input_depth = ArraySize(input_dims, 0);
77   const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
78   const int total_input_depth = prev_activ_depth + input_depth;
79   TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
80   TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
81                   1);
82   const int intern_activ_depth =
83       MatchingArraySize(weights_dims, 1, bias_dims, 0);
84   TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
85   const int output_depth =
86       MatchingArraySize(prevCellState_dims, 0, prev_activ_dims, 0,
87                         output_state_dims, 0, output_activ_dims, 0);
88   TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
89   const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
90   const int fc_output_depth =
91       MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
92   const int fc_accum_depth = ArraySize(weights_dims, 0);
93   TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
94 
95   // Depth-concatenate prev_activ and input data together.
96   uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
97                                                 prev_activ_data_uint8};
98   Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
99   tflite::reference_ops::Concatenation<tflite::FusedActivationFunctionType::kNone, uint8_t>(
100       0, concat_input_arrays_data, concat_input_arrays_dims, 2,
101       concat_temp_data_uint8, concat_temp_dims);
102 
103   // Implementation of the fully connected node inside the LSTM cell.
104   // The operands are 8-bit integers, the accumulators are internally 32bit
105   // integers, and the output is 16-bit fixed-point with 3 integer bits so
106   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
107   // is explained in the function comment above.
108   for (int b = 0; b < fc_batches; ++b) {
109     for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
110       // Internal accumulation.
111       // Initialize accumulator with the bias-value.
112       int32_t accum = bias_data_int32[out_c];
113       // Accumulation loop.
114       for (int d = 0; d < fc_accum_depth; ++d) {
115         int16_t input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
116         int16_t weights_val =
117             weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
118         accum += input_val * weights_val;
119       }
120       // Down-scale the final int32 accumulator to the scale used by our
121       // (16-bit, using 3 integer bits) fixed-point format. The quantized
122       // multiplier and shift here have been pre-computed offline
123       // (e.g. by toco).
124       accum =
125           tflite::MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
126       // Saturate, cast to int16, and store to the temporary activations array.
127       accum = std::max(-32768, std::min(32767, accum));
128       activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
129     }
130   }
131 
132   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
133   // and muls, all done in 16-bit fixed-point.
134   for (int b = 0; b < outer_size; ++b) {
135     for (int c = 0; c < output_depth; ++c) {
136       // Define the fixed-point data types that we will use here. All use
137       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
138       // They only differ by the number of integral vs. fractional bits,
139       // determining the range of values that they can represent.
140       //
141       // F0 uses 0 integer bits, range [-1, 1].
142       // This is the return type of math functions such as tanh, logistic,
143       // whose range is in [-1, 1].
144       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
145       // F3 uses 3 integer bits, range [-8, 8].
146       // This is the range of the previous fully-connected node's output,
147       // which is our input here.
148       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
149       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
150       // 2^StateIntegerBits]. It's used to represent the internal state, whose
151       // number of integer bits is currently dictated by the model. See comment
152       // on the StateIntegerBits template parameter above.
153       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
154       // Implementation of input gate, using fixed-point logistic function.
155       F3 input_gate_input = F3::FromRaw(
156           activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
157       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
158       // Implementation of input modulation gate, using fixed-point tanh
159       // function.
160       F3 input_modulation_gate_input = F3::FromRaw(
161           activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
162       F0 input_modulation_gate_output =
163           gemmlowp::tanh(input_modulation_gate_input);
164       // Implementation of forget gate, using fixed-point logistic function.
165       F3 forget_gate_input = F3::FromRaw(
166           activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
167       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
168       // Implementation of output gate, using fixed-point logistic function.
169       F3 output_gate_input = F3::FromRaw(
170           activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
171       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
172       // Implementation of internal multiplication nodes, still in fixed-point.
173       F0 input_times_input_modulation =
174           input_gate_output * input_modulation_gate_output;
175       FS prevCellState = FS::FromRaw(prevCellState_data_int16[b * output_depth + c]);
176       FS prevCellState_times_forget_state = forget_gate_output * prevCellState;
177       // Implementation of internal addition node, saturating.
178       FS new_state = gemmlowp::SaturatingAdd(
179           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
180           prevCellState_times_forget_state);
181       // Implementation of last internal Tanh node, still in fixed-point.
182       // Since a Tanh fixed-point implementation is specialized for a given
183       // number or integer bits, and each specialization can have a substantial
184       // code size, and we already used above a Tanh on an input with 3 integer
185       // bits, and per the table in the above function comment there is no
186       // significant accuracy to be lost by clamping to [-8, +8] for a
187       // 3-integer-bits representation, let us just do that. This helps people
188       // porting this to targets where code footprint must be minimized.
189       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
190       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
191       // Store the new internal state back to memory, as 16-bit integers.
192       // Note: here we store the original value with StateIntegerBits, not
193       // the rescaled 3-integer-bits value fed to tanh.
194       output_state_data_int16[b * output_depth + c] = new_state.raw();
195       // Down-scale the output activations to 8-bit integers, saturating,
196       // and store back to memory.
197       int16_t rescaled_output_activ =
198           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
199       int16_t clamped_output_activ =
200           std::max<int16_t>(-128, std::min<int16_t>(127, rescaled_output_activ));
201       output_activ_data_uint8[b * output_depth + c] =
202           128 + clamped_output_activ;
203     }
204   }
205 }
206 // clang-format on
207 
208 // The function assigns a 2D matrix to a submatrix of the weights at a given row
209 // and column offsets.
assignWeightsSubmatrix(const RunTimeOperandInfo * submatrix,const int32_t offset_row,const int32_t offset_column,const std::vector<uint32_t> & weightsDims,uint8_t * weights)210 void assignWeightsSubmatrix(const RunTimeOperandInfo* submatrix, const int32_t offset_row,
211                             const int32_t offset_column, const std::vector<uint32_t>& weightsDims,
212                             uint8_t* weights) {
213     const uint8_t* submatrixValues = GetBuffer<uint8_t>(submatrix);
214     const std::vector<uint32_t> submatrixDims = submatrix->shape().dimensions;
215     for (uint32_t i = 0; i < submatrixDims[0] * submatrixDims[1]; ++i) {
216         const uint32_t row = i / submatrixDims[1];
217         const uint32_t column = i % submatrixDims[1];
218         weights[(row + offset_row) * weightsDims[1] + column + offset_column] = submatrixValues[i];
219     }
220 }
221 
222 }  // namespace
223 
QuantizedLSTMCell(const Operation & operation,RunTimeOperandInfo * operands)224 QuantizedLSTMCell::QuantizedLSTMCell(const Operation& operation, RunTimeOperandInfo* operands) {
225     input_ = GetInput(operation, operands, kInputTensor);
226 
227     inputToInputWeights_ = GetInput(operation, operands, kInputToInputWeightsTensor);
228     inputToForgetWeights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
229     inputToCellWeights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
230     inputToOutputWeights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
231 
232     recurrentToInputWeights_ = GetInput(operation, operands, kRecurrentToInputWeightsTensor);
233     recurrentToForgetWeights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
234     recurrentToCellWeights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
235     recurrentToOutputWeights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
236 
237     inputGateBias_ = GetInput(operation, operands, kInputGateBiasTensor);
238     forgetGateBias_ = GetInput(operation, operands, kForgetGateBiasTensor);
239     cellGateBias_ = GetInput(operation, operands, kCellGateBiasTensor);
240     outputGateBias_ = GetInput(operation, operands, kOutputGateBiasTensor);
241 
242     prevCellState_ = GetInput(operation, operands, kPrevCellStateTensor);
243     prevOutput_ = GetInput(operation, operands, kPrevOutputTensor);
244 
245     cellStateOut_ = GetOutput(operation, operands, kCellStateOutTensor);
246     output_ = GetOutput(operation, operands, kOutputTensor);
247 }
248 
prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * cellStateOutShape,Shape * outputShape)249 bool QuantizedLSTMCell::prepare(const Operation& operation, RunTimeOperandInfo* operands,
250                                 Shape* cellStateOutShape, Shape* outputShape) {
251     auto input = GetInput(operation, operands, kInputTensor);
252     NN_RET_CHECK_EQ(NumDimensions(input), 2u);
253     NN_RET_CHECK_EQ(input->scale, 1. / 128.0);
254     NN_RET_CHECK_EQ(input->zeroPoint, 128);
255     const uint32_t numBatches = SizeOfDimension(input, 0);
256     const uint32_t inputSize = SizeOfDimension(input, 1);
257 
258     auto prevOutput = GetInput(operation, operands, kPrevOutputTensor);
259     NN_RET_CHECK_EQ(NumDimensions(prevOutput), 2u);
260     NN_RET_CHECK_EQ(SizeOfDimension(prevOutput, 0), numBatches);
261     NN_RET_CHECK_EQ(prevOutput->scale, 1. / 128.0);
262     NN_RET_CHECK_EQ(prevOutput->zeroPoint, 128);
263     const uint32_t outputSize = SizeOfDimension(prevOutput, 1);
264 
265     auto inputToInputWeights = GetInput(operation, operands, kInputToInputWeightsTensor);
266     const float weightsScale = inputToInputWeights->scale;
267     NN_RET_CHECK(weightsScale != 0);
268     const float weightsZeroPoint = inputToInputWeights->zeroPoint;
269 
270     auto checkWeightsShape = [&](const RunTimeOperandInfo* weights, uint32_t columns) -> bool {
271         NN_RET_CHECK_EQ(NumDimensions(weights), 2u);
272         NN_RET_CHECK_EQ(SizeOfDimension(weights, 0), outputSize);
273         NN_RET_CHECK_EQ(SizeOfDimension(weights, 1), columns);
274         NN_RET_CHECK_EQ(weights->scale, weightsScale);
275         NN_RET_CHECK_EQ(weights->zeroPoint, weightsZeroPoint);
276         return true;
277     };
278 
279     auto inputToForgetWeights = GetInput(operation, operands, kInputToForgetWeightsTensor);
280     auto inputToCellWeights = GetInput(operation, operands, kInputToCellWeightsTensor);
281     auto inputToOutputWeights = GetInput(operation, operands, kInputToOutputWeightsTensor);
282     NN_RET_CHECK(checkWeightsShape(inputToInputWeights, inputSize));
283     NN_RET_CHECK(checkWeightsShape(inputToForgetWeights, inputSize));
284     NN_RET_CHECK(checkWeightsShape(inputToCellWeights, inputSize));
285     NN_RET_CHECK(checkWeightsShape(inputToOutputWeights, inputSize));
286 
287     auto recurrentToInputWeights = GetInput(operation, operands, kRecurrentToInputWeightsTensor);
288     auto recurrentToForgetWeights = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
289     auto recurrentToCellWeights = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
290     auto recurrentToOutputWeights = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
291     NN_RET_CHECK(checkWeightsShape(recurrentToInputWeights, outputSize));
292     NN_RET_CHECK(checkWeightsShape(recurrentToForgetWeights, outputSize));
293     NN_RET_CHECK(checkWeightsShape(recurrentToCellWeights, outputSize));
294     NN_RET_CHECK(checkWeightsShape(recurrentToOutputWeights, outputSize));
295 
296     auto inputGateBias = GetInput(operation, operands, kInputGateBiasTensor);
297     const float biasScale = inputGateBias->scale;
298     NN_RET_CHECK_EQ(biasScale, weightsScale / 128.0);
299     const float biasZeroPoint = inputGateBias->zeroPoint;
300     NN_RET_CHECK_EQ(biasZeroPoint, 0);
301 
302     auto checkBiasShape = [&](const RunTimeOperandInfo* bias) -> bool {
303         NN_RET_CHECK_EQ(NumDimensions(bias), 1u);
304         NN_RET_CHECK_EQ(SizeOfDimension(bias, 0), outputSize);
305         NN_RET_CHECK_EQ(bias->scale, biasScale);
306         NN_RET_CHECK_EQ(bias->zeroPoint, biasZeroPoint);
307         return true;
308     };
309 
310     auto forgetGateBias = GetInput(operation, operands, kForgetGateBiasTensor);
311     auto cellGateBias = GetInput(operation, operands, kCellGateBiasTensor);
312     auto outputGateBias = GetInput(operation, operands, kOutputGateBiasTensor);
313     NN_RET_CHECK(checkBiasShape(inputGateBias));
314     NN_RET_CHECK(checkBiasShape(forgetGateBias));
315     NN_RET_CHECK(checkBiasShape(cellGateBias));
316     NN_RET_CHECK(checkBiasShape(outputGateBias));
317 
318     auto prevCellState = GetInput(operation, operands, kPrevCellStateTensor);
319     NN_CHECK_EQ(NumDimensions(prevCellState), 2u);
320     NN_CHECK_EQ(SizeOfDimension(prevCellState, 0), numBatches);
321     NN_CHECK_EQ(SizeOfDimension(prevCellState, 1), outputSize);
322     NN_CHECK_EQ(prevCellState->zeroPoint, 0);
323     // Cell state range for quantized LSTM is a function of StateIntegerBits and
324     // can be calculated as:
325     // [-2^StateIntegerBits, 2^StateIntegerBits * 32767/32768].
326     // Therefore, for a fixed StateIntegerBits parameter, cell state scale is
327     // equal to 2^StateIntegerBits * 2^(-15) = 2^(StateIntegerBits - 15) and
328     // therefore:
329     // StateIntegerBits = log2(cell state scale) + 15
330     int stateScaleLog2Rounded;
331     NN_CHECK(tflite::CheckedLog2(prevCellState->scale, &stateScaleLog2Rounded));
332     const int stateIntegerBits = 15 + stateScaleLog2Rounded;
333     // We only support StateIntegerBits == 4
334     NN_CHECK(stateIntegerBits == 4);
335 
336     *cellStateOutShape = prevCellState->shape();
337     *outputShape = prevOutput->shape();
338     return true;
339 }
340 
341 // The function contatenates 8 input weight matrices into one. Resulting matrix
342 // has a shape [4 * outputSize, outputSize + inputSize]. The matrix is
343 // constructed as follows:
344 // +-----------------------------------+
345 // | recurrentToInput  | inputToInput  |
346 // |-------------------+---------------|
347 // | recurrentToCell   | inputToCell   |
348 // |-------------------+---------------|
349 // | recurrentToForget | inputToForget |
350 // |-------------------+---------------|
351 // | recurrentToOutput | inputToOutput |
352 // +-----------------------------------+
concatenateWeights(const std::vector<uint32_t> & weightsDims,uint8_t * weights)353 void QuantizedLSTMCell::concatenateWeights(const std::vector<uint32_t>& weightsDims,
354                                            uint8_t* weights) {
355     const int outputSize = SizeOfDimension(inputToInputWeights_, 0);
356 
357     assignWeightsSubmatrix(inputToInputWeights_, 0 * outputSize, outputSize, weightsDims, weights);
358     assignWeightsSubmatrix(inputToCellWeights_, 1 * outputSize, outputSize, weightsDims, weights);
359     assignWeightsSubmatrix(inputToForgetWeights_, 2 * outputSize, outputSize, weightsDims, weights);
360     assignWeightsSubmatrix(inputToOutputWeights_, 3 * outputSize, outputSize, weightsDims, weights);
361     assignWeightsSubmatrix(recurrentToInputWeights_, 0 * outputSize, 0, weightsDims, weights);
362     assignWeightsSubmatrix(recurrentToCellWeights_, 1 * outputSize, 0, weightsDims, weights);
363     assignWeightsSubmatrix(recurrentToForgetWeights_, 2 * outputSize, 0, weightsDims, weights);
364     assignWeightsSubmatrix(recurrentToOutputWeights_, 3 * outputSize, 0, weightsDims, weights);
365 }
366 
367 // The function concatenate four bias vectors of shape [outputSize] into one
368 // vector of shape [4 * outputSize].
concatenateBiases(uint32_t outputSize,int32_t * bias)369 void QuantizedLSTMCell::concatenateBiases(uint32_t outputSize, int32_t* bias) {
370     memcpy(bias + 0 * outputSize, GetBuffer<int32_t>(inputGateBias_), sizeof(int32_t) * outputSize);
371     memcpy(bias + 1 * outputSize, GetBuffer<int32_t>(cellGateBias_), sizeof(int32_t) * outputSize);
372     memcpy(bias + 2 * outputSize, GetBuffer<int32_t>(forgetGateBias_),
373            sizeof(int32_t) * outputSize);
374     memcpy(bias + 3 * outputSize, GetBuffer<int32_t>(outputGateBias_),
375            sizeof(int32_t) * outputSize);
376 }
377 
eval()378 bool QuantizedLSTMCell::eval() {
379     NNTRACE_COMP("QuantizedLSTM::eval");
380 
381     Shape weightsShape;
382     weightsShape.dimensions = {4 * SizeOfDimension(prevOutput_, 1),
383                                SizeOfDimension(input_, 1) + SizeOfDimension(prevOutput_, 1)};
384     std::vector<uint8_t> weights(getNumberOfElements(weightsShape));
385     concatenateWeights(weightsShape.dimensions, weights.data());
386 
387     Shape biasShape;
388     biasShape.dimensions = {getSizeOfDimension(weightsShape, 0)};
389     std::vector<int32_t> bias(getNumberOfElements(biasShape));
390     concatenateBiases(SizeOfDimension(prevOutput_, 1), bias.data());
391 
392     Shape concatTempShape;
393     concatTempShape.dimensions = {SizeOfDimension(input_, 0), getSizeOfDimension(weightsShape, 1)};
394 
395     Shape activationTempShape;
396     activationTempShape.dimensions = {SizeOfDimension(input_, 0),
397                                       getSizeOfDimension(weightsShape, 0)};
398 
399     std::vector<uint8_t> concatTemp(getNumberOfElements(concatTempShape));
400     std::vector<int16_t> activationTemp(getNumberOfElements(activationTempShape));
401 
402     // From https://arxiv.org/pdf/1712.05877, for a fully-connected layer,
403     // accumulator multiplier is equal to:
404     // (input scale) * (weights scale) / (fully-connected output scale)
405     // In our case fully-connected output scale is fixed and equal to
406     // 2^(-12) (See LSTMCell definition in TF Lite for more details on that).
407     // But bias scale is set to (input scale) * (weights scale) (also from the
408     // paper), so we can multiply it to an inverse of the fc-output scale to get
409     // the multiplier value:
410     double realAccumMultiplier = 4096 * inputGateBias_->scale;
411     int32_t accumMultiplier;
412     int accumShift;
413     tflite::QuantizeMultiplier(realAccumMultiplier, &accumMultiplier, &accumShift);
414     quantizedLstmStep<4>(
415             // Inputs.
416             GetBuffer<const uint8_t>(input_), convertShapeToDims(input_->shape()),
417             GetBuffer<const uint8_t>(prevOutput_), convertShapeToDims(prevOutput_->shape()),
418             weights.data(), convertShapeToDims(weightsShape), bias.data(),
419             convertShapeToDims(biasShape), GetBuffer<const int16_t>(prevCellState_),
420             convertShapeToDims(prevCellState_->shape()),
421             // Outputs.
422             GetBuffer<int16_t>(cellStateOut_), convertShapeToDims(cellStateOut_->shape()),
423             GetBuffer<uint8_t>(output_), convertShapeToDims(output_->shape()), concatTemp.data(),
424             convertShapeToDims(concatTempShape), activationTemp.data(),
425             convertShapeToDims(activationTempShape), inputToInputWeights_->zeroPoint,
426             accumMultiplier, accumShift);
427     return true;
428 }
429 
430 }  // namespace nn
431 }  // namespace android
432