1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_LSTM_H 18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_LSTM_H 19 20 #include <algorithm> 21 #include <cmath> 22 #include <vector> 23 24 #include "ActivationFunctor.h" 25 #include "OperationsValidationUtils.h" 26 #include "nnapi/Types.h" 27 28 namespace android { 29 namespace nn { 30 31 struct LSTMParams { 32 ActivationFn activation; 33 float cell_clip; 34 float proj_clip; 35 bool use_cifg; 36 bool use_peephole; 37 bool use_layer_norm; 38 bool use_projection_weight; 39 bool use_projection_bias; 40 bool merge_outputs; 41 bool time_major; 42 bool output_state; 43 }; 44 45 struct RunTimeOperandInfo; 46 struct Shape; 47 48 class LSTMCell { 49 public: 50 LSTMCell(const Operation& operation, RunTimeOperandInfo* operands); 51 52 bool Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* scratchShape, 53 Shape* outputStateShape, Shape* cellStateShape, Shape* outputShape); 54 bool Eval(); 55 56 // Input Tensors of size {n_batch, n_input} 57 static constexpr int kInputTensor = 0; 58 59 // Input weight tensors of size: {n_cell, n_input} 60 static constexpr int kInputToInputWeightsTensor = 1; // Optional 61 static constexpr int kInputToForgetWeightsTensor = 2; 62 static constexpr int kInputToCellWeightsTensor = 3; 63 static constexpr int kInputToOutputWeightsTensor = 4; 64 65 // Recurrent weight tensors of size {n_cell, n_output} 66 static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional 67 static constexpr int kRecurrentToForgetWeightsTensor = 6; 68 static constexpr int kRecurrentToCellWeightsTensor = 7; 69 static constexpr int kRecurrentToOutputWeightsTensor = 8; 70 71 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 72 static constexpr int kCellToInputWeightsTensor = 9; // Optional 73 static constexpr int kCellToForgetWeightsTensor = 10; // Optional 74 static constexpr int kCellToOutputWeightsTensor = 11; // Optional 75 76 // Gates bias tensors of size {n_cell} 77 static constexpr int kInputGateBiasTensor = 12; // Optional 78 static constexpr int kForgetGateBiasTensor = 13; 79 static constexpr int kCellGateBiasTensor = 14; 80 static constexpr int kOutputGateBiasTensor = 15; 81 82 // Projection weight tensor of size {n_output, n_cell} 83 static constexpr int kProjectionWeightsTensor = 16; // Optional 84 // Projection bias tensor of size {n_output} 85 static constexpr int kProjectionBiasTensor = 17; // Optional 86 87 static constexpr int kOutputStateInTensor = 18; 88 static constexpr int kCellStateInTensor = 19; 89 90 static constexpr int kActivationParam = 20; 91 static constexpr int kCellClipParam = 21; 92 static constexpr int kProjClipParam = 22; 93 94 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. 95 static constexpr int kInputLayerNormWeightsTensor = 23; 96 static constexpr int kForgetLayerNormWeightsTensor = 24; 97 static constexpr int kCellLayerNormWeightsTensor = 25; 98 static constexpr int kOutputLayerNormWeightsTensor = 26; 99 100 // Output tensors. 101 static constexpr int kScratchBufferTensor = 0; 102 static constexpr int kOutputStateOutTensor = 1; 103 static constexpr int kCellStateOutTensor = 2; 104 static constexpr int kOutputTensor = 3; 105 106 static bool LSTMEvalFloat32( 107 const LSTMParams& params, const float* input_buffer, const Shape& input_shape, 108 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, 109 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, 110 const Shape& input_to_output_weights_shape, 111 const float* recurrent_to_input_weights_buffer, 112 const float* recurrent_to_forget_weights_buffer, 113 const float* recurrent_to_cell_weights_buffer, 114 const float* recurrent_to_output_weights_buffer, 115 const Shape& recurrent_to_output_weights_shape, 116 const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, 117 const float* cell_to_output_weights_buffer, const float* aux_input_buffer, 118 const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights, 119 const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights, 120 const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, 121 const float* cell_bias_buffer, const float* output_gate_bias_buffer, 122 const float* projection_weights_buffer, const float* projection_bias_buffer, 123 const float* output_state_in_buffer, const float* cell_state_in_buffer, 124 const float* input_layer_norm_weights_buffer, 125 const float* forget_layer_norm_weights_buffer, 126 const float* cell_layer_norm_weights_buffer, 127 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, 128 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer, 129 bool timeMajor = true, bool forwardSequence = true); 130 131 static bool LSTMEvalFloat16( 132 const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape, 133 const _Float16* input_to_input_weights_buffer, 134 const _Float16* input_to_forget_weights_buffer, 135 const _Float16* input_to_cell_weights_buffer, 136 const _Float16* input_to_output_weights_buffer, 137 const Shape& input_to_output_weights_shape, 138 const _Float16* recurrent_to_input_weights_buffer, 139 const _Float16* recurrent_to_forget_weights_buffer, 140 const _Float16* recurrent_to_cell_weights_buffer, 141 const _Float16* recurrent_to_output_weights_buffer, 142 const Shape& recurrent_to_output_weights_shape, 143 const _Float16* cell_to_input_weights_buffer, 144 const _Float16* cell_to_forget_weights_buffer, 145 const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer, 146 const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights, 147 const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights, 148 const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer, 149 const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer, 150 const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer, 151 const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer, 152 const _Float16* input_layer_norm_weights_buffer, 153 const _Float16* forget_layer_norm_weights_buffer, 154 const _Float16* cell_layer_norm_weights_buffer, 155 const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer, 156 _Float16* cell_state_out_buffer, _Float16* output_buffer, 157 _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true); 158 159 static bool LSTMStep( 160 const LSTMParams& params, const float* input_buffer, const Shape& input_shape, 161 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, 162 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, 163 const Shape& input_to_output_weights_shape, 164 const float* recurrent_to_input_weights_buffer, 165 const float* recurrent_to_forget_weights_buffer, 166 const float* recurrent_to_cell_weights_buffer, 167 const float* recurrent_to_output_weights_buffer, 168 const Shape& recurrent_to_output_weights_shape, 169 const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, 170 const float* cell_to_output_weights_buffer, const float* aux_input_buffer, 171 const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights, 172 const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights, 173 const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, 174 const float* cell_bias_buffer, const float* output_gate_bias_buffer, 175 const float* projection_weights_buffer, const float* projection_bias_buffer, 176 const float* output_state_in_buffer, const float* cell_state_in_buffer, 177 const float* input_layer_norm_weights_buffer, 178 const float* forget_layer_norm_weights_buffer, 179 const float* cell_layer_norm_weights_buffer, 180 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, 181 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer); 182 183 static bool CheckInputTensorDimensions( 184 const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights, 185 const RunTimeOperandInfo* input_to_forget_weights, 186 const RunTimeOperandInfo* input_to_cell_weights, 187 const RunTimeOperandInfo* input_to_output_weights, 188 const RunTimeOperandInfo* recurrent_to_input_weights, 189 const RunTimeOperandInfo* recurrent_to_forget_weights, 190 const RunTimeOperandInfo* recurrent_to_cell_weights, 191 const RunTimeOperandInfo* recurrent_to_output_weights, 192 const RunTimeOperandInfo* cell_to_input_weights, 193 const RunTimeOperandInfo* cell_to_forget_weights, 194 const RunTimeOperandInfo* cell_to_output_weights, 195 const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias, 196 const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias, 197 const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias, 198 const RunTimeOperandInfo* input_layer_norm_weights, 199 const RunTimeOperandInfo* forget_layer_norm_weights, 200 const RunTimeOperandInfo* cell_layer_norm_weights, 201 const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, 202 uint32_t n_output, uint32_t n_cell, LSTMParams* params); 203 204 private: 205 LSTMParams params_; 206 const RunTimeOperandInfo* input_; 207 208 const RunTimeOperandInfo* input_to_input_weights_; 209 const RunTimeOperandInfo* input_to_forget_weights_; 210 const RunTimeOperandInfo* input_to_cell_weights_; 211 const RunTimeOperandInfo* input_to_output_weights_; 212 213 const RunTimeOperandInfo* recurrent_to_input_weights_; 214 const RunTimeOperandInfo* recurrent_to_forget_weights_; 215 const RunTimeOperandInfo* recurrent_to_cell_weights_; 216 const RunTimeOperandInfo* recurrent_to_output_weights_; 217 218 const RunTimeOperandInfo* cell_to_input_weights_; 219 const RunTimeOperandInfo* cell_to_forget_weights_; 220 const RunTimeOperandInfo* cell_to_output_weights_; 221 222 const RunTimeOperandInfo* input_gate_bias_; 223 const RunTimeOperandInfo* forget_gate_bias_; 224 const RunTimeOperandInfo* cell_bias_; 225 const RunTimeOperandInfo* output_gate_bias_; 226 227 const RunTimeOperandInfo* projection_weights_; 228 const RunTimeOperandInfo* projection_bias_; 229 230 const RunTimeOperandInfo* output_state_in_; 231 const RunTimeOperandInfo* cell_state_in_; 232 233 const RunTimeOperandInfo* input_layer_norm_weights_; 234 const RunTimeOperandInfo* forget_layer_norm_weights_; 235 const RunTimeOperandInfo* cell_layer_norm_weights_; 236 const RunTimeOperandInfo* output_layer_norm_weights_; 237 238 RunTimeOperandInfo* output_state_out_; 239 RunTimeOperandInfo* cell_state_out_; 240 RunTimeOperandInfo* output_; 241 242 RunTimeOperandInfo* scratch_buffer_; 243 }; 244 245 } // namespace nn 246 } // namespace android 247 248 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_LSTM_H 249