/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "Operations" #include "UnidirectionalSequenceLSTM.h" #include #include "IndexedShapeWrapper.h" #include "OperationResolver.h" #include "OperationsExecutionUtils.h" #ifdef NN_INCLUDE_CPU_IMPLEMENTATION #include #include "LSTM.h" #endif // NN_INCLUDE_CPU_IMPLEMENTATION namespace android { namespace nn { namespace unidirectional_sequence_lstm { #ifdef NN_INCLUDE_CPU_IMPLEMENTATION namespace { inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) { return context->getInputBuffer(tensor) != nullptr; } inline bool isTimeMajor(IOperationExecutionContext* context) { return context->getInputValue(kTimeMajorParam); } template inline LSTMParams getLSTMParams(IOperationExecutionContext* context) { LSTMParams params; params.activation = static_cast(context->getInputValue(kActivationParam)); params.cell_clip = static_cast(context->getInputValue(kCellClipParam)); params.proj_clip = static_cast(context->getInputValue(kProjClipParam)); params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor); params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor); params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor); params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor); params.use_projection_bias = hasTensor(context, kProjectionBiasTensor); return params; } } // namespace bool prepare(IOperationExecutionContext* context) { // Check that none of the required inputs are omitted const std::vector requiredInputs = { kInputTensor, kInputToForgetWeightsTensor, kInputToCellWeightsTensor, kInputToOutputWeightsTensor, kRecurrentToForgetWeightsTensor, kRecurrentToCellWeightsTensor, kRecurrentToOutputWeightsTensor, kForgetGateBiasTensor, kCellGateBiasTensor, kOutputGateBiasTensor, kOutputStateInTensor, kCellStateInTensor, kActivationParam, kCellClipParam, kProjClipParam, kTimeMajorParam, }; for (const int requiredInput : requiredInputs) { NN_RET_CHECK(!context->isOmittedInput(requiredInput)) << "required input " << requiredInput << " is omitted"; } const Shape inputShape = context->getInputShape(kInputTensor); const uint32_t inputRank = getNumberOfDimensions(inputShape); NN_RET_CHECK_EQ(inputRank, 3u) << "Invalid input tensor rank: " << inputRank; [[maybe_unused]] const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1); const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0); const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1); const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize); const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0); const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells); const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1); if (hasTensor(context, kInputToInputWeightsTensor)) { const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize); } const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize); const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize); if (hasTensor(context, kRecurrentToInputWeightsTensor)) { const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize); } const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize); const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize); // We make sure the input-gate's parameters are either both present (regular // LSTM) or not at all (CIFG-LSTM). const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) && hasTensor(context, kRecurrentToInputWeightsTensor)) || (!hasTensor(context, kInputToInputWeightsTensor) && !hasTensor(context, kRecurrentToInputWeightsTensor)); NN_RET_CHECK(cifgWeightsAllOrNone); if (hasTensor(context, kCellToInputWeightsTensor)) { const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells); } if (hasTensor(context, kCellToForgetWeightsTensor)) { const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells); } if (hasTensor(context, kCellToOutputWeightsTensor)) { const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells); } // Making sure the peephole weights are there all or none. const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor); const bool peepholeWeightsAllOrNone = ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) && hasTensor(context, kCellToForgetWeightsTensor) && hasTensor(context, kCellToOutputWeightsTensor)) || (!hasTensor(context, kCellToInputWeightsTensor) && !hasTensor(context, kCellToForgetWeightsTensor) && !hasTensor(context, kCellToOutputWeightsTensor)); NN_RET_CHECK(peepholeWeightsAllOrNone); if (!cifgUsed) { NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor)); const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells); } else { NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor)) << "Input gate bias tensor is present when CIFG is used"; } const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells); const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells); const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells); if (hasTensor(context, kProjectionWeightsTensor)) { const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize); NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells); } if (hasTensor(context, kProjectionBiasTensor)) { const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize); } const Shape outputStateShape = context->getInputShape(kOutputStateInTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize); NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize); const Shape cellStateShape = context->getInputShape(kCellStateInTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2u); NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize); NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells); if (hasTensor(context, kInputLayerNormWeightsTensor)) { const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells); } if (hasTensor(context, kForgetLayerNormWeightsTensor)) { const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells); } if (hasTensor(context, kCellLayerNormWeightsTensor)) { const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells); } if (hasTensor(context, kOutputLayerNormWeightsTensor)) { const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1u); NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells); } if (cifgUsed) { NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor)) << "Input layer norm weights tensor is present when CIFG is used"; const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormWeightsTensor) && hasTensor(context, kCellLayerNormWeightsTensor) && hasTensor(context, kOutputLayerNormWeightsTensor)) || (!hasTensor(context, kForgetLayerNormWeightsTensor) && !hasTensor(context, kCellLayerNormWeightsTensor) && !hasTensor(context, kOutputLayerNormWeightsTensor)); NN_RET_CHECK(layerNormWeightsAllOrNoneCifg); } else { const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormWeightsTensor) && hasTensor(context, kForgetLayerNormWeightsTensor) && hasTensor(context, kCellLayerNormWeightsTensor) && hasTensor(context, kOutputLayerNormWeightsTensor)) || (!hasTensor(context, kInputLayerNormWeightsTensor) && !hasTensor(context, kForgetLayerNormWeightsTensor) && !hasTensor(context, kCellLayerNormWeightsTensor) && !hasTensor(context, kOutputLayerNormWeightsTensor)); NN_RET_CHECK(layerNormWeightsAllOrNone); } Shape outputShape = context->getInputShape(kInputTensor); outputShape.dimensions[2] = outputSize; if (context->getNumOutputs() == kNumOutputsWithState) { NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor)); NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor)); Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor); outputStateOutTensor.dimensions.resize(2); outputStateOutTensor.dimensions[0] = batchSize; outputStateOutTensor.dimensions[1] = outputSize; NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor)); Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor); cellStateOutTensor.dimensions.resize(2); cellStateOutTensor.dimensions[0] = batchSize; cellStateOutTensor.dimensions[1] = numCells; NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor)); } return context->setOutputShape(kOutputTensor, outputShape); } bool execute(IOperationExecutionContext* context) { const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor)); const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor)); const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor); const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize; const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState); const OperandType inputType = context->getInputType(kInputTensor); switch (inputType) { case OperandType::TENSOR_FLOAT32: { // Initialize empty vectors and resize below only if needed std::vector outputStateOutBuffer; std::vector cellStateOutBuffer; float* outputStateOut; float* cellStateOut; if (useStateOutTensors) { outputStateOut = context->getOutputBuffer(kOutputStateOutTensor); cellStateOut = context->getOutputBuffer(kCellStateOutTensor); } else { outputStateOutBuffer.resize(outputStateSize); cellStateOutBuffer.resize(cellStateSize); outputStateOut = outputStateOutBuffer.data(); cellStateOut = cellStateOutBuffer.data(); } std::vector scratchBuffer(scratchSize); LSTMCell::LSTMEvalFloat32( getLSTMParams(context), context->getInputBuffer(kInputTensor), context->getInputShape(kInputTensor), context->getInputBuffer(kInputToInputWeightsTensor), context->getInputBuffer(kInputToForgetWeightsTensor), context->getInputBuffer(kInputToCellWeightsTensor), context->getInputBuffer(kInputToOutputWeightsTensor), context->getInputShape(kInputToOutputWeightsTensor), context->getInputBuffer(kRecurrentToInputWeightsTensor), context->getInputBuffer(kRecurrentToForgetWeightsTensor), context->getInputBuffer(kRecurrentToCellWeightsTensor), context->getInputBuffer(kRecurrentToOutputWeightsTensor), context->getInputShape(kRecurrentToOutputWeightsTensor), context->getInputBuffer(kCellToInputWeightsTensor), context->getInputBuffer(kCellToForgetWeightsTensor), context->getInputBuffer(kCellToOutputWeightsTensor), /*aux_input_buffer=*/nullptr, /*aux_input_to_input_weights_buffer=*/nullptr, /*aux_input_to_forget_weights_buffer=*/nullptr, /*aux_input_to_cell_weights_buffer=*/nullptr, /*aux_input_to_output_weights_buffer=*/nullptr, context->getInputBuffer(kInputGateBiasTensor), context->getInputBuffer(kForgetGateBiasTensor), context->getInputBuffer(kCellGateBiasTensor), context->getInputBuffer(kOutputGateBiasTensor), context->getInputBuffer(kProjectionWeightsTensor), context->getInputBuffer(kProjectionBiasTensor), context->getInputBuffer(kOutputStateInTensor), context->getInputBuffer(kCellStateInTensor), context->getInputBuffer(kInputLayerNormWeightsTensor), context->getInputBuffer(kForgetLayerNormWeightsTensor), context->getInputBuffer(kCellLayerNormWeightsTensor), context->getInputBuffer(kOutputLayerNormWeightsTensor), outputStateOut, cellStateOut, context->getOutputBuffer(kOutputTensor), scratchBuffer.data(), isTimeMajor(context)); } break; case OperandType::TENSOR_FLOAT16: { // Initialize empty vectors and resize below only if needed std::vector<_Float16> outputStateOutBuffer; std::vector<_Float16> cellStateOutBuffer; _Float16* outputStateOut; _Float16* cellStateOut; if (useStateOutTensors) { outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor); cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor); } else { outputStateOutBuffer.resize(outputStateSize); cellStateOutBuffer.resize(cellStateSize); outputStateOut = outputStateOutBuffer.data(); cellStateOut = cellStateOutBuffer.data(); } std::vector<_Float16> scratchBuffer(scratchSize); LSTMCell::LSTMEvalFloat16( getLSTMParams<_Float16>(context), context->getInputBuffer<_Float16>(kInputTensor), context->getInputShape(kInputTensor), context->getInputBuffer<_Float16>(kInputToInputWeightsTensor), context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor), context->getInputBuffer<_Float16>(kInputToCellWeightsTensor), context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor), context->getInputShape(kInputToOutputWeightsTensor), context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor), context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor), context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor), context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor), context->getInputShape(kRecurrentToOutputWeightsTensor), context->getInputBuffer<_Float16>(kCellToInputWeightsTensor), context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor), context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor), /*aux_input_buffer=*/nullptr, /*aux_input_to_input_weights_buffer=*/nullptr, /*aux_input_to_forget_weights_buffer=*/nullptr, /*aux_input_to_cell_weights_buffer=*/nullptr, /*aux_input_to_output_weights_buffer=*/nullptr, context->getInputBuffer<_Float16>(kInputGateBiasTensor), context->getInputBuffer<_Float16>(kForgetGateBiasTensor), context->getInputBuffer<_Float16>(kCellGateBiasTensor), context->getInputBuffer<_Float16>(kOutputGateBiasTensor), context->getInputBuffer<_Float16>(kProjectionWeightsTensor), context->getInputBuffer<_Float16>(kProjectionBiasTensor), context->getInputBuffer<_Float16>(kOutputStateInTensor), context->getInputBuffer<_Float16>(kCellStateInTensor), context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor), context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor), context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor), context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor), outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor), scratchBuffer.data(), isTimeMajor(context)); } break; default: { LOG(ERROR) << "Unsupported data type: " << static_cast(inputType); return false; } } return true; } #endif // NN_INCLUDE_CPU_IMPLEMENTATION } // namespace unidirectional_sequence_lstm NN_REGISTER_OPERATION_DEFAULT_VALIDATION(UNIDIRECTIONAL_SEQUENCE_LSTM, unidirectional_sequence_lstm::prepare, unidirectional_sequence_lstm::execute, .allowOmittedOperand = true); } // namespace nn } // namespace android