1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "Fill.h"
20 
21 #include "OperationResolver.h"
22 #include "OperationsExecutionUtils.h"
23 
24 namespace android {
25 namespace nn {
26 namespace fill_op {
27 namespace {
28 
29 template <typename T>
executeTyped(IOperationExecutionContext * context)30 bool executeTyped(IOperationExecutionContext* context) {
31     T* output = context->getOutputBuffer<T>(kOutputTensor);
32     const int numElements = getNumberOfElements(context->getOutputShape(kOutputTensor));
33     const T value = context->getInputValue<T>(kValueScalar);
34     for (int i = 0; i < numElements; ++i) {
35         output[i] = value;
36     }
37     return true;
38 }
39 
40 }  // namespace
41 
prepare(IOperationExecutionContext * context)42 bool prepare(IOperationExecutionContext* context) {
43     Shape dimsShape = context->getInputShape(kDimsTensor);
44     NN_RET_CHECK_EQ(getNumberOfDimensions(dimsShape), 1u);
45 
46     Shape outputShape = context->getOutputShape(kOutputTensor);
47     outputShape.dimensions.resize(dimsShape.dimensions[0]);
48     const int32_t* dims = context->getInputBuffer<int32_t>(kDimsTensor);
49     for (uint32_t i = 0; i < dimsShape.dimensions[0]; ++i) {
50         outputShape.dimensions[i] = dims[i];
51     }
52     return context->setOutputShape(kOutputTensor, outputShape);
53 }
54 
execute(IOperationExecutionContext * context)55 bool execute(IOperationExecutionContext* context) {
56     switch (context->getInputType(kValueScalar)) {
57         case OperandType::FLOAT16:
58             return executeTyped<_Float16>(context);
59         case OperandType::FLOAT32:
60             return executeTyped<float>(context);
61         case OperandType::INT32:
62             return executeTyped<int32_t>(context);
63         default:
64             NN_RET_CHECK_FAIL() << "Unsupported value type for fill op.";
65     }
66 }
67 
68 }  // namespace fill_op
69 
70 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(FILL, fill_op::prepare, fill_op::execute);
71 
72 }  // namespace nn
73 }  // namespace android
74