1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "Slice.h"
20 
21 #include <vector>
22 
23 #include "IndexedShapeWrapper.h"
24 #include "OperationResolver.h"
25 
26 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
27 #include "CpuOperationUtils.h"
28 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
29 
30 namespace android {
31 namespace nn {
32 namespace slice {
33 
34 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
35 namespace {
36 
37 template <typename T>
addVectors(const std::vector<T> & a,const std::vector<T> & b,std::vector<T> * res)38 void addVectors(const std::vector<T>& a, const std::vector<T>& b, std::vector<T>* res) {
39     for (size_t i = 0; i < res->size(); ++i) {
40         res->at(i) = a[i] + b[i];
41     }
42 }
43 
44 template <typename T>
evalGeneric(const T * inputData,const Shape & inputShape,const int32_t * beginData,const Shape & beginShape,const int32_t *,const Shape &,T * outputData,const Shape & outputShape)45 bool evalGeneric(const T* inputData, const Shape& inputShape, const int32_t* beginData,
46                  const Shape& beginShape, const int32_t* /*sizeData*/, const Shape& /*sizeShape*/,
47                  T* outputData, const Shape& outputShape) {
48     [[maybe_unused]] const int outputSize = getNumberOfElements(outputShape);
49     const IndexedShapeWrapper indexedOutput = IndexedShapeWrapper(outputShape);
50     const IndexedShapeWrapper indexedInput = IndexedShapeWrapper(inputShape);
51     std::vector<uint32_t> outputIndex(getNumberOfDimensions(outputShape), 0);
52     std::vector<uint32_t> beginIndex(getSizeOfDimension(beginShape, 0));
53     std::vector<uint32_t> inputIndex(getNumberOfDimensions(inputShape));
54 
55     for (size_t i = 0; i < beginIndex.size(); ++i) {
56         beginIndex[i] = static_cast<uint32_t>(beginData[i]);
57     }
58 
59     bool lastIndex = false;
60     uint32_t outputOffset;
61     uint32_t inputOffset;
62 
63     do {
64         addVectors(outputIndex, beginIndex, &inputIndex);
65 
66         NN_RET_CHECK(indexedOutput.indexToFlatIndex(outputIndex, &outputOffset));
67         NN_RET_CHECK(indexedInput.indexToFlatIndex(inputIndex, &inputOffset));
68 
69         outputData[outputOffset] = inputData[inputOffset];
70         NN_RET_CHECK(indexedOutput.nextIndexInplace(&outputIndex, &lastIndex));
71     } while (!lastIndex);
72     return true;
73 }
74 
75 }  // namespace
76 
prepare(IOperationExecutionContext * context)77 bool prepare(IOperationExecutionContext* context) {
78     const Shape& inputShape = context->getInputShape(kInputTensor);
79     const uint32_t n_dims = getNumberOfDimensions(inputShape);
80     NN_RET_CHECK(n_dims > 0);
81 
82     const Shape& beginShape = context->getInputShape(kBeginTensor);
83     NN_RET_CHECK_EQ(getNumberOfDimensions(beginShape), 1u);
84     NN_RET_CHECK_EQ(getSizeOfDimension(beginShape, 0), n_dims);
85 
86     const Shape& sizeShape = context->getInputShape(kSizeTensor);
87     NN_RET_CHECK_EQ(getNumberOfDimensions(sizeShape), 1u);
88     NN_RET_CHECK_EQ(getSizeOfDimension(sizeShape, 0), n_dims);
89 
90     const int32_t* beginData = context->getInputBuffer<int32_t>(kBeginTensor);
91     const int32_t* sizeData = context->getInputBuffer<int32_t>(kSizeTensor);
92 
93     Shape outputShape = context->getOutputShape(kOutputTensor);
94     outputShape.dimensions.resize(n_dims);
95     for (uint32_t i = 0; i < n_dims; ++i) {
96         const int32_t sliceBegin = beginData[i];
97         int32_t sliceSize = sizeData[i];
98         if (sliceSize == -1) {
99             sliceSize = getSizeOfDimension(inputShape, i) - sliceBegin;
100         }
101         NN_RET_CHECK_LE(static_cast<uint32_t>(beginData[i]), getSizeOfDimension(inputShape, i));
102         NN_RET_CHECK_GE(sliceSize, 0);
103         NN_RET_CHECK_LE(static_cast<uint32_t>(sliceBegin + sliceSize),
104                         getSizeOfDimension(inputShape, i));
105         outputShape.dimensions[i] = sliceSize;
106     }
107     return context->setOutputShape(kOutputTensor, outputShape);
108 }
109 
execute(IOperationExecutionContext * context)110 bool execute(IOperationExecutionContext* context) {
111     // Bypass execution in the case of zero-sized input.
112     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
113     switch (context->getInputType(kInputTensor)) {
114         case OperandType::TENSOR_FLOAT16:
115             return evalGeneric(context->getInputBuffer<_Float16>(kInputTensor),
116                                context->getInputShape(kInputTensor),
117                                context->getInputBuffer<int32_t>(kBeginTensor),
118                                context->getInputShape(kBeginTensor),
119                                context->getInputBuffer<int32_t>(kSizeTensor),
120                                context->getInputShape(kSizeTensor),
121                                context->getOutputBuffer<_Float16>(kOutputTensor),
122                                context->getOutputShape(kOutputTensor));
123         case OperandType::TENSOR_FLOAT32:
124             return evalGeneric(context->getInputBuffer<float>(kInputTensor),
125                                context->getInputShape(kInputTensor),
126                                context->getInputBuffer<int32_t>(kBeginTensor),
127                                context->getInputShape(kBeginTensor),
128                                context->getInputBuffer<int32_t>(kSizeTensor),
129                                context->getInputShape(kSizeTensor),
130                                context->getOutputBuffer<float>(kOutputTensor),
131                                context->getOutputShape(kOutputTensor));
132         case OperandType::TENSOR_INT32:
133             return evalGeneric(context->getInputBuffer<int32_t>(kInputTensor),
134                                context->getInputShape(kInputTensor),
135                                context->getInputBuffer<int32_t>(kBeginTensor),
136                                context->getInputShape(kBeginTensor),
137                                context->getInputBuffer<int32_t>(kSizeTensor),
138                                context->getInputShape(kSizeTensor),
139                                context->getOutputBuffer<int32_t>(kOutputTensor),
140                                context->getOutputShape(kOutputTensor));
141         case OperandType::TENSOR_QUANT8_ASYMM:
142             return evalGeneric(context->getInputBuffer<uint8_t>(kInputTensor),
143                                context->getInputShape(kInputTensor),
144                                context->getInputBuffer<int32_t>(kBeginTensor),
145                                context->getInputShape(kBeginTensor),
146                                context->getInputBuffer<int32_t>(kSizeTensor),
147                                context->getInputShape(kSizeTensor),
148                                context->getOutputBuffer<uint8_t>(kOutputTensor),
149                                context->getOutputShape(kOutputTensor));
150         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
151             return evalGeneric(context->getInputBuffer<int8_t>(kInputTensor),
152                                context->getInputShape(kInputTensor),
153                                context->getInputBuffer<int32_t>(kBeginTensor),
154                                context->getInputShape(kBeginTensor),
155                                context->getInputBuffer<int32_t>(kSizeTensor),
156                                context->getInputShape(kSizeTensor),
157                                context->getOutputBuffer<int8_t>(kOutputTensor),
158                                context->getOutputShape(kOutputTensor));
159         default:
160             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
161     }
162 }
163 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
164 
165 }  // namespace slice
166 
167 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(SLICE, slice::prepare, slice::execute,
168                                          .allowZeroSizedInput = true);
169 
170 }  // namespace nn
171 }  // namespace android
172