1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "Fill.h" 20 21 #include "OperationResolver.h" 22 #include "OperationsExecutionUtils.h" 23 24 namespace android { 25 namespace nn { 26 namespace fill_op { 27 namespace { 28 29 template <typename T> executeTyped(IOperationExecutionContext * context)30bool executeTyped(IOperationExecutionContext* context) { 31 T* output = context->getOutputBuffer<T>(kOutputTensor); 32 const int numElements = getNumberOfElements(context->getOutputShape(kOutputTensor)); 33 const T value = context->getInputValue<T>(kValueScalar); 34 for (int i = 0; i < numElements; ++i) { 35 output[i] = value; 36 } 37 return true; 38 } 39 40 } // namespace 41 prepare(IOperationExecutionContext * context)42bool prepare(IOperationExecutionContext* context) { 43 Shape dimsShape = context->getInputShape(kDimsTensor); 44 NN_RET_CHECK_EQ(getNumberOfDimensions(dimsShape), 1u); 45 46 Shape outputShape = context->getOutputShape(kOutputTensor); 47 outputShape.dimensions.resize(dimsShape.dimensions[0]); 48 const int32_t* dims = context->getInputBuffer<int32_t>(kDimsTensor); 49 for (uint32_t i = 0; i < dimsShape.dimensions[0]; ++i) { 50 outputShape.dimensions[i] = dims[i]; 51 } 52 return context->setOutputShape(kOutputTensor, outputShape); 53 } 54 execute(IOperationExecutionContext * context)55bool execute(IOperationExecutionContext* context) { 56 switch (context->getInputType(kValueScalar)) { 57 case OperandType::FLOAT16: 58 return executeTyped<_Float16>(context); 59 case OperandType::FLOAT32: 60 return executeTyped<float>(context); 61 case OperandType::INT32: 62 return executeTyped<int32_t>(context); 63 default: 64 NN_RET_CHECK_FAIL() << "Unsupported value type for fill op."; 65 } 66 } 67 68 } // namespace fill_op 69 70 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(FILL, fill_op::prepare, fill_op::execute); 71 72 } // namespace nn 73 } // namespace android 74