1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains the implementation of the operations.
18 
19 #define LOG_TAG "Operations"
20 
21 #include "StridedSlice.h"
22 
23 #include <vector>
24 
25 #include "OperationResolver.h"
26 #include "Operations.h"
27 #include "Tracing.h"
28 
29 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
30 #pragma clang diagnostic push
31 #pragma clang diagnostic ignored "-Wunused-parameter"
32 #pragma clang diagnostic ignored "-Wsign-compare"
33 #include <tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h>
34 #pragma clang diagnostic pop
35 
36 #include "CpuOperationUtils.h"
37 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
38 
39 namespace android {
40 namespace nn {
41 namespace strided_slice {
42 
43 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
44 namespace {
45 
46 template <typename T>
compute(const T * inputData,const Shape & inputShape,const int32_t * beginData,const int32_t * endData,const int32_t * stridesData,int32_t beginMask,int32_t endMask,int32_t shrinkAxisMask,T * outputData,const Shape & outputShape)47 bool compute(const T* inputData, const Shape& inputShape, const int32_t* beginData,
48              const int32_t* endData, const int32_t* stridesData, int32_t beginMask, int32_t endMask,
49              int32_t shrinkAxisMask, T* outputData, const Shape& outputShape) {
50     NNTRACE_TRANS("stridedSlice");
51     // This Op only supports 1-4D cases and since we use the reference 4D
52     // implementation, the 1-3D tensors are mapped to 4D.
53     const int kMaxDim = 4;
54 
55     std::vector<int> starts;
56     std::vector<int> stops;
57     std::vector<int> strides;
58 
59     int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape));
60     for (int32_t idx = numInputDims - 1; idx >= 0; --idx) {
61         starts.emplace_back(beginData[idx]);
62         stops.emplace_back(endData[idx]);
63         strides.emplace_back(stridesData[idx]);
64     }
65 
66     for (int i = numInputDims; i < kMaxDim; i++) {
67         starts.emplace_back(0);
68         stops.emplace_back(1);
69         strides.emplace_back(1);
70     }
71 
72     beginMask = ReverseMaskBits(beginMask, numInputDims);
73     endMask = ReverseMaskBits(endMask, numInputDims);
74     shrinkAxisMask = ReverseMaskBits(shrinkAxisMask, numInputDims);
75 
76     tflite::reference_ops::StridedSlice(inputData, convertShapeToDims(inputShape), beginMask,
77                                         endMask, shrinkAxisMask, starts, stops, strides, outputData,
78                                         convertShapeToDims(outputShape));
79 
80     return true;
81 }
82 
83 template <typename T>
executeTyped(IOperationExecutionContext * context)84 bool executeTyped(IOperationExecutionContext* context) {
85     return compute<T>(
86             context->getInputBuffer<T>(kInputTensor), context->getInputShape(kInputTensor),
87             context->getInputBuffer<int32_t>(kBeginTensor),
88             context->getInputBuffer<int32_t>(kEndTensor),
89             context->getInputBuffer<int32_t>(kStridesTensor),
90             context->getInputValue<int32_t>(kBeginMask), context->getInputValue<int32_t>(kEndMask),
91             context->getInputValue<int32_t>(kShrinkAxisMask),
92             context->getOutputBuffer<T>(kOutputTensor), context->getOutputShape(kOutputTensor));
93 }
94 
95 }  // namespace
96 
prepare(IOperationExecutionContext * context)97 bool prepare(IOperationExecutionContext* context) {
98     // StridedSlice op only supports 1D-4D input arrays.
99     const Shape& inputShape = context->getInputShape(kInputTensor);
100     uint32_t numInputDims = getNumberOfDimensions(inputShape);
101     NN_OPS_CHECK(numInputDims <= 4);
102 
103     const Shape& beginShape = context->getInputShape(kBeginTensor);
104     const Shape& endShape = context->getInputShape(kEndTensor);
105     const Shape& stridesShape = context->getInputShape(kStridesTensor);
106 
107     NN_OPS_CHECK(getNumberOfDimensions(beginShape) == 1);
108     NN_OPS_CHECK(getNumberOfDimensions(endShape) == 1);
109     NN_OPS_CHECK(getNumberOfDimensions(stridesShape) == 1);
110 
111     NN_OPS_CHECK(getSizeOfDimension(beginShape, 0) == numInputDims);
112     NN_OPS_CHECK(getSizeOfDimension(endShape, 0) == numInputDims);
113     NN_OPS_CHECK(getSizeOfDimension(stridesShape, 0) == numInputDims);
114 
115     NN_OPS_CHECK(beginShape.type == OperandType::TENSOR_INT32);
116     NN_OPS_CHECK(endShape.type == OperandType::TENSOR_INT32);
117     NN_OPS_CHECK(stridesShape.type == OperandType::TENSOR_INT32);
118 
119     const int32_t* beginData = context->getInputBuffer<int32_t>(kBeginTensor);
120     const int32_t* endData = context->getInputBuffer<int32_t>(kEndTensor);
121     const int32_t* stridesData = context->getInputBuffer<int32_t>(kStridesTensor);
122 
123     const int32_t beginMask = context->getInputValue<int32_t>(kBeginMask);
124     const int32_t endMask = context->getInputValue<int32_t>(kEndMask);
125     const int32_t shrinkAxisMask = context->getInputValue<int32_t>(kShrinkAxisMask);
126 
127     // Determine size of output tensor and map indices
128     std::vector<uint32_t> outDims;
129     for (int32_t idx = 0; idx < static_cast<int32_t>(numInputDims); idx++) {
130         int32_t dim = static_cast<int32_t>(getSizeOfDimension(inputShape, idx));
131         int32_t stride = stridesData[idx];
132         // stride value has to be non-zero
133         NN_OPS_CHECK(stride != 0);
134         bool positiveStride = stride > 0;
135 
136         int32_t begin = beginMask & (1 << idx) ? positiveStride ? 0 : dim - 1
137                                                : ClampedIndex(beginData[idx], dim, positiveStride);
138         int32_t end = endMask & (1 << idx) ? positiveStride ? dim : -1
139                                            : ClampedIndex(endData[idx], dim, positiveStride);
140 
141         // This is valid for both positive and negative strides
142         int32_t outDim = ceil((end - begin) / static_cast<float>(stride));
143         outDim = outDim < 0 ? 0 : static_cast<uint32_t>(outDim);
144         if (!(shrinkAxisMask & (1 << idx))) {
145             outDims.push_back(outDim);
146         } else {
147             // Only positive stride is allowed on non-range indexing (i.e. shrinkMask is set).
148             NN_RET_CHECK_GT(stride, 0) << "index = " << idx;
149             NN_RET_CHECK_EQ(outDim, 1) << "index = " << idx;
150         }
151     }
152 
153     // Handle the case when all dimensions are removed
154     if (outDims.empty()) {
155         outDims.push_back(1);
156     }
157 
158     Shape outputShape = context->getOutputShape(kOutputTensor);
159     NN_RET_CHECK(SetShape(inputShape, &outputShape));
160     outputShape.dimensions = outDims;
161     return context->setOutputShape(kOutputTensor, outputShape);
162 }
163 
execute(IOperationExecutionContext * context)164 bool execute(IOperationExecutionContext* context) {
165     switch (context->getInputType(kInputTensor)) {
166         case OperandType::TENSOR_FLOAT16:
167             return executeTyped<_Float16>(context);
168         case OperandType::TENSOR_FLOAT32:
169             return executeTyped<float>(context);
170         case OperandType::TENSOR_QUANT8_ASYMM:
171             return executeTyped<uint8_t>(context);
172         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
173             return executeTyped<int8_t>(context);
174         default:
175             NN_RET_CHECK_FAIL() << "Unsupported tensor type for STRIDED_SLICE op.";
176     }
177 }
178 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
179 
180 }  // namespace strided_slice
181 
182 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(STRIDED_SLICE, strided_slice::prepare,
183                                          strided_slice::execute);
184 
185 }  // namespace nn
186 }  // namespace android
187