1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "TopK_V2.h"
20
21 #include <algorithm>
22 #include <utility>
23 #include <vector>
24
25 #include "OperationResolver.h"
26 #include "OperationsExecutionUtils.h"
27
28 namespace android {
29 namespace nn {
30 namespace topk_v2 {
31 namespace {
32
33 template <typename T>
evalGeneric(const T * inputData,const Shape & inputShape,const int32_t k,T * valuesData,int32_t * indicesData)34 bool evalGeneric(const T* inputData, const Shape& inputShape, const int32_t k, T* valuesData,
35 int32_t* indicesData) {
36 const int rowSize = inputShape.dimensions.back();
37 const int totalSize = getNumberOfElements(inputShape);
38 std::vector<std::pair<T, int32_t>> values(rowSize);
39 T* curOutputValue = valuesData;
40 int32_t* curOutputIndex = indicesData;
41 for (int rowBegin = 0; rowBegin < totalSize; rowBegin += rowSize) {
42 for (int i = 0; i < rowSize; ++i) {
43 values[i] = std::make_pair(inputData[rowBegin + i], i);
44 }
45 std::nth_element(values.begin(), values.begin() + (rowSize - k), values.end());
46 std::sort(values.begin() + (rowSize - k), values.end());
47 std::reverse(values.begin(), values.end());
48 for (int i = 0; i < k; ++i) {
49 *curOutputValue = values[i].first;
50 *curOutputIndex = values[i].second;
51 curOutputValue++;
52 curOutputIndex++;
53 }
54 }
55 return true;
56 }
57
58 template <typename T>
executeTyped(IOperationExecutionContext * context)59 bool executeTyped(IOperationExecutionContext* context) {
60 return evalGeneric(context->getInputBuffer<T>(kInputTensor),
61 context->getInputShape(kInputTensor),
62 context->getInputValue<int32_t>(kTopKScalar),
63 context->getOutputBuffer<T>(kOutputValuesTensor),
64 context->getOutputBuffer<int32_t>(kOutputIndicesTensor));
65 }
66
67 } // namespace
68
prepare(IOperationExecutionContext * context)69 bool prepare(IOperationExecutionContext* context) {
70 const Shape inputShape = context->getInputShape(kInputTensor);
71 const int32_t k = context->getInputValue<int32_t>(kTopKScalar);
72 NN_RET_CHECK_GT(k, 0);
73 NN_RET_CHECK_LE(static_cast<uint32_t>(k), inputShape.dimensions.back());
74
75 // Copy input shape to ensure that quantization parameters for the output
76 // values are the same as for the input tensor.
77 Shape outputValuesShape = inputShape;
78 outputValuesShape.dimensions.back() = k;
79 Shape outputIndicesShape;
80 outputIndicesShape.type = OperandType::TENSOR_INT32;
81 outputIndicesShape.dimensions = inputShape.dimensions;
82 outputIndicesShape.dimensions.back() = k;
83 return context->setOutputShape(kOutputValuesTensor, outputValuesShape) &&
84 context->setOutputShape(kOutputIndicesTensor, outputIndicesShape);
85 }
86
execute(IOperationExecutionContext * context)87 bool execute(IOperationExecutionContext* context) {
88 const Shape inputShape = context->getInputShape(kInputTensor);
89 switch (inputShape.type) {
90 case OperandType::TENSOR_FLOAT16: {
91 return executeTyped<_Float16>(context);
92 } break;
93 case OperandType::TENSOR_FLOAT32: {
94 return executeTyped<float>(context);
95 } break;
96 case OperandType::TENSOR_INT32: {
97 return executeTyped<int32_t>(context);
98 } break;
99 case OperandType::TENSOR_QUANT8_ASYMM: {
100 return executeTyped<uint8_t>(context);
101 } break;
102 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: {
103 return executeTyped<int8_t>(context);
104 } break;
105 default: {
106 LOG(ERROR) << "Unsupported data type: " << inputShape.type;
107 return false;
108 }
109 }
110 }
111
112 } // namespace topk_v2
113
114 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(TOPK_V2, topk_v2::prepare, topk_v2::execute);
115
116 } // namespace nn
117 } // namespace android
118