1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_LSTM_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_LSTM_H
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <vector>
23 
24 #include "ActivationFunctor.h"
25 #include "OperationsValidationUtils.h"
26 #include "nnapi/Types.h"
27 
28 namespace android {
29 namespace nn {
30 
31 struct LSTMParams {
32     ActivationFn activation;
33     float cell_clip;
34     float proj_clip;
35     bool use_cifg;
36     bool use_peephole;
37     bool use_layer_norm;
38     bool use_projection_weight;
39     bool use_projection_bias;
40     bool merge_outputs;
41     bool time_major;
42     bool output_state;
43 };
44 
45 struct RunTimeOperandInfo;
46 struct Shape;
47 
48 class LSTMCell {
49    public:
50     LSTMCell(const Operation& operation, RunTimeOperandInfo* operands);
51 
52     bool Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* scratchShape,
53                  Shape* outputStateShape, Shape* cellStateShape, Shape* outputShape);
54     bool Eval();
55 
56     // Input Tensors of size {n_batch, n_input}
57     static constexpr int kInputTensor = 0;
58 
59     // Input weight tensors of size: {n_cell, n_input}
60     static constexpr int kInputToInputWeightsTensor = 1;  // Optional
61     static constexpr int kInputToForgetWeightsTensor = 2;
62     static constexpr int kInputToCellWeightsTensor = 3;
63     static constexpr int kInputToOutputWeightsTensor = 4;
64 
65     // Recurrent weight tensors of size {n_cell, n_output}
66     static constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
67     static constexpr int kRecurrentToForgetWeightsTensor = 6;
68     static constexpr int kRecurrentToCellWeightsTensor = 7;
69     static constexpr int kRecurrentToOutputWeightsTensor = 8;
70 
71     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
72     static constexpr int kCellToInputWeightsTensor = 9;    // Optional
73     static constexpr int kCellToForgetWeightsTensor = 10;  // Optional
74     static constexpr int kCellToOutputWeightsTensor = 11;  // Optional
75 
76     // Gates bias tensors of size {n_cell}
77     static constexpr int kInputGateBiasTensor = 12;  // Optional
78     static constexpr int kForgetGateBiasTensor = 13;
79     static constexpr int kCellGateBiasTensor = 14;
80     static constexpr int kOutputGateBiasTensor = 15;
81 
82     // Projection weight tensor of size {n_output, n_cell}
83     static constexpr int kProjectionWeightsTensor = 16;  // Optional
84     // Projection bias tensor of size {n_output}
85     static constexpr int kProjectionBiasTensor = 17;  // Optional
86 
87     static constexpr int kOutputStateInTensor = 18;
88     static constexpr int kCellStateInTensor = 19;
89 
90     static constexpr int kActivationParam = 20;
91     static constexpr int kCellClipParam = 21;
92     static constexpr int kProjClipParam = 22;
93 
94     // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
95     static constexpr int kInputLayerNormWeightsTensor = 23;
96     static constexpr int kForgetLayerNormWeightsTensor = 24;
97     static constexpr int kCellLayerNormWeightsTensor = 25;
98     static constexpr int kOutputLayerNormWeightsTensor = 26;
99 
100     // Output tensors.
101     static constexpr int kScratchBufferTensor = 0;
102     static constexpr int kOutputStateOutTensor = 1;
103     static constexpr int kCellStateOutTensor = 2;
104     static constexpr int kOutputTensor = 3;
105 
106     static bool LSTMEvalFloat32(
107             const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
108             const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
109             const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
110             const Shape& input_to_output_weights_shape,
111             const float* recurrent_to_input_weights_buffer,
112             const float* recurrent_to_forget_weights_buffer,
113             const float* recurrent_to_cell_weights_buffer,
114             const float* recurrent_to_output_weights_buffer,
115             const Shape& recurrent_to_output_weights_shape,
116             const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
117             const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
118             const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
119             const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
120             const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
121             const float* cell_bias_buffer, const float* output_gate_bias_buffer,
122             const float* projection_weights_buffer, const float* projection_bias_buffer,
123             const float* output_state_in_buffer, const float* cell_state_in_buffer,
124             const float* input_layer_norm_weights_buffer,
125             const float* forget_layer_norm_weights_buffer,
126             const float* cell_layer_norm_weights_buffer,
127             const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
128             float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
129             bool timeMajor = true, bool forwardSequence = true);
130 
131     static bool LSTMEvalFloat16(
132             const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
133             const _Float16* input_to_input_weights_buffer,
134             const _Float16* input_to_forget_weights_buffer,
135             const _Float16* input_to_cell_weights_buffer,
136             const _Float16* input_to_output_weights_buffer,
137             const Shape& input_to_output_weights_shape,
138             const _Float16* recurrent_to_input_weights_buffer,
139             const _Float16* recurrent_to_forget_weights_buffer,
140             const _Float16* recurrent_to_cell_weights_buffer,
141             const _Float16* recurrent_to_output_weights_buffer,
142             const Shape& recurrent_to_output_weights_shape,
143             const _Float16* cell_to_input_weights_buffer,
144             const _Float16* cell_to_forget_weights_buffer,
145             const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
146             const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights,
147             const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights,
148             const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer,
149             const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer,
150             const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer,
151             const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer,
152             const _Float16* input_layer_norm_weights_buffer,
153             const _Float16* forget_layer_norm_weights_buffer,
154             const _Float16* cell_layer_norm_weights_buffer,
155             const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
156             _Float16* cell_state_out_buffer, _Float16* output_buffer,
157             _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true);
158 
159     static bool LSTMStep(
160             const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
161             const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
162             const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
163             const Shape& input_to_output_weights_shape,
164             const float* recurrent_to_input_weights_buffer,
165             const float* recurrent_to_forget_weights_buffer,
166             const float* recurrent_to_cell_weights_buffer,
167             const float* recurrent_to_output_weights_buffer,
168             const Shape& recurrent_to_output_weights_shape,
169             const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
170             const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
171             const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
172             const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
173             const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
174             const float* cell_bias_buffer, const float* output_gate_bias_buffer,
175             const float* projection_weights_buffer, const float* projection_bias_buffer,
176             const float* output_state_in_buffer, const float* cell_state_in_buffer,
177             const float* input_layer_norm_weights_buffer,
178             const float* forget_layer_norm_weights_buffer,
179             const float* cell_layer_norm_weights_buffer,
180             const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
181             float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer);
182 
183     static bool CheckInputTensorDimensions(
184             const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
185             const RunTimeOperandInfo* input_to_forget_weights,
186             const RunTimeOperandInfo* input_to_cell_weights,
187             const RunTimeOperandInfo* input_to_output_weights,
188             const RunTimeOperandInfo* recurrent_to_input_weights,
189             const RunTimeOperandInfo* recurrent_to_forget_weights,
190             const RunTimeOperandInfo* recurrent_to_cell_weights,
191             const RunTimeOperandInfo* recurrent_to_output_weights,
192             const RunTimeOperandInfo* cell_to_input_weights,
193             const RunTimeOperandInfo* cell_to_forget_weights,
194             const RunTimeOperandInfo* cell_to_output_weights,
195             const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias,
196             const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias,
197             const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias,
198             const RunTimeOperandInfo* input_layer_norm_weights,
199             const RunTimeOperandInfo* forget_layer_norm_weights,
200             const RunTimeOperandInfo* cell_layer_norm_weights,
201             const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input,
202             uint32_t n_output, uint32_t n_cell, LSTMParams* params);
203 
204    private:
205     LSTMParams params_;
206     const RunTimeOperandInfo* input_;
207 
208     const RunTimeOperandInfo* input_to_input_weights_;
209     const RunTimeOperandInfo* input_to_forget_weights_;
210     const RunTimeOperandInfo* input_to_cell_weights_;
211     const RunTimeOperandInfo* input_to_output_weights_;
212 
213     const RunTimeOperandInfo* recurrent_to_input_weights_;
214     const RunTimeOperandInfo* recurrent_to_forget_weights_;
215     const RunTimeOperandInfo* recurrent_to_cell_weights_;
216     const RunTimeOperandInfo* recurrent_to_output_weights_;
217 
218     const RunTimeOperandInfo* cell_to_input_weights_;
219     const RunTimeOperandInfo* cell_to_forget_weights_;
220     const RunTimeOperandInfo* cell_to_output_weights_;
221 
222     const RunTimeOperandInfo* input_gate_bias_;
223     const RunTimeOperandInfo* forget_gate_bias_;
224     const RunTimeOperandInfo* cell_bias_;
225     const RunTimeOperandInfo* output_gate_bias_;
226 
227     const RunTimeOperandInfo* projection_weights_;
228     const RunTimeOperandInfo* projection_bias_;
229 
230     const RunTimeOperandInfo* output_state_in_;
231     const RunTimeOperandInfo* cell_state_in_;
232 
233     const RunTimeOperandInfo* input_layer_norm_weights_;
234     const RunTimeOperandInfo* forget_layer_norm_weights_;
235     const RunTimeOperandInfo* cell_layer_norm_weights_;
236     const RunTimeOperandInfo* output_layer_norm_weights_;
237 
238     RunTimeOperandInfo* output_state_out_;
239     RunTimeOperandInfo* cell_state_out_;
240     RunTimeOperandInfo* output_;
241 
242     RunTimeOperandInfo* scratch_buffer_;
243 };
244 
245 }  // namespace nn
246 }  // namespace android
247 
248 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_LSTM_H
249