1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "GenerateProposals.h"
20 
21 #include <algorithm>
22 #include <cfloat>
23 #include <cmath>
24 #include <functional>
25 #include <numeric>
26 #include <utility>
27 #include <vector>
28 
29 #include "OperationResolver.h"
30 #include "OperationsExecutionUtils.h"
31 #include "Tracing.h"
32 
33 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
34 #include "CpuOperationUtils.h"
35 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
36 
37 namespace android {
38 namespace nn {
39 namespace bbox_ops {
40 
41 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
42 namespace {
43 
44 struct BoxEncodingCorner {
45     float x1, y1, x2, y2;
46 };
47 struct BoxEncodingCenter {
48     float w, h, x, y;
49 };
toBoxEncodingCorner(const BoxEncodingCenter & ctr)50 BoxEncodingCorner toBoxEncodingCorner(const BoxEncodingCenter& ctr) {
51     return {.x1 = ctr.x - ctr.w / 2,
52             .y1 = ctr.y - ctr.h / 2,
53             .x2 = ctr.x + ctr.w / 2,
54             .y2 = ctr.y + ctr.h / 2};
55 }
toBoxEncodingCenter(const BoxEncodingCorner & cnr)56 BoxEncodingCenter toBoxEncodingCenter(const BoxEncodingCorner& cnr) {
57     return {.w = cnr.x2 - cnr.x1,
58             .h = cnr.y2 - cnr.y1,
59             .x = (cnr.x1 + cnr.x2) / 2,
60             .y = (cnr.y1 + cnr.y2) / 2};
61 }
62 
bboxTransformFloat32(const float * roiData,const Shape & roiShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape &,const float * imageInfoData,const Shape & imageInfoDataShape,float * outputData,const Shape &)63 inline bool bboxTransformFloat32(const float* roiData, const Shape& roiShape,
64                                  const float* bboxDeltasData, const Shape& bboxDeltasShape,
65                                  const int32_t* batchesData, const Shape& /*batchesShape*/,
66                                  const float* imageInfoData, const Shape& imageInfoDataShape,
67                                  float* outputData, const Shape& /*outputShape*/) {
68     const uint32_t roiLength = 4;
69     const uint32_t imageLength = 2;
70 
71     uint32_t numClasses = getSizeOfDimension(bboxDeltasShape, 1) / roiLength;
72     uint32_t numBatches = getSizeOfDimension(imageInfoDataShape, 0);
73 
74     const float* roiDataEnd = roiData + getNumberOfElements(roiShape);
75     const float* deltas = bboxDeltasData;
76     float* outPtr = outputData;
77     uint32_t roiIndex = 0;
78     for (const float* roiBase = roiData; roiBase < roiDataEnd; roiBase += roiLength, roiIndex++) {
79         uint32_t batchIndex = batchesData[roiIndex];
80         // Check for malformed data
81         // 1. Invalid batch id
82         // 2. Invalid region: x2 < x1 || y2 < y1
83         NN_RET_CHECK_GE(batchIndex, 0u);
84         NN_RET_CHECK_LT(batchIndex, numBatches);
85         NN_RET_CHECK_LE(roiBase[0], roiBase[2]);
86         NN_RET_CHECK_LE(roiBase[1], roiBase[3]);
87 
88         const float* imageInfoBase = imageInfoData + batchIndex * imageLength;
89         float imageHeight = imageInfoBase[0];
90         float imageWidth = imageInfoBase[1];
91         auto roiBefore = toBoxEncodingCenter(
92                 {.x1 = roiBase[0], .y1 = roiBase[1], .x2 = roiBase[2], .y2 = roiBase[3]});
93         for (uint32_t i = 0; i < numClasses; i++) {
94             auto roiAfter = toBoxEncodingCorner({.w = std::exp(deltas[2]) * roiBefore.w,
95                                                  .h = std::exp(deltas[3]) * roiBefore.h,
96                                                  .x = roiBefore.x + deltas[0] * roiBefore.w,
97                                                  .y = roiBefore.y + deltas[1] * roiBefore.h});
98             BoxEncodingCorner cliped = {.x1 = std::min(std::max(roiAfter.x1, 0.0f), imageWidth),
99                                         .y1 = std::min(std::max(roiAfter.y1, 0.0f), imageHeight),
100                                         .x2 = std::min(std::max(roiAfter.x2, 0.0f), imageWidth),
101                                         .y2 = std::min(std::max(roiAfter.y2, 0.0f), imageHeight)};
102             outPtr[0] = cliped.x1;
103             outPtr[1] = cliped.y1;
104             outPtr[2] = cliped.x2;
105             outPtr[3] = cliped.y2;
106             deltas += roiLength;
107             outPtr += roiLength;
108         }
109     }
110     return true;
111 }
112 
bboxTransformFloat16(const _Float16 * roiData,const Shape & roiShape,const _Float16 * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const _Float16 * imageInfoData,const Shape & imageInfoDataShape,_Float16 * outputData,const Shape & outputShape)113 inline bool bboxTransformFloat16(const _Float16* roiData, const Shape& roiShape,
114                                  const _Float16* bboxDeltasData, const Shape& bboxDeltasShape,
115                                  const int32_t* batchesData, const Shape& batchesShape,
116                                  const _Float16* imageInfoData, const Shape& imageInfoDataShape,
117                                  _Float16* outputData, const Shape& outputShape) {
118     std::vector<float> roi_float32(getNumberOfElements(roiShape));
119     convertFloat16ToFloat32(roiData, &roi_float32);
120     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
121     convertFloat16ToFloat32(bboxDeltasData, &delta_float32);
122     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoDataShape));
123     convertFloat16ToFloat32(imageInfoData, &imageInfo_float32);
124     std::vector<float> output_float32(getNumberOfElements(outputShape));
125     NN_RET_CHECK(bboxTransformFloat32(roi_float32.data(), roiShape, delta_float32.data(),
126                                       bboxDeltasShape, batchesData, batchesShape,
127                                       imageInfo_float32.data(), imageInfoDataShape,
128                                       output_float32.data(), outputShape));
129     convertFloat32ToFloat16(output_float32, outputData);
130     return true;
131 }
132 
bboxTransformQuant(const uint16_t * roiData,const Shape & roiShape,const uint8_t * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const uint16_t * imageInfoData,const Shape & imageInfoDataShape,uint16_t * outputData,const Shape & outputShape)133 inline bool bboxTransformQuant(const uint16_t* roiData, const Shape& roiShape,
134                                const uint8_t* bboxDeltasData, const Shape& bboxDeltasShape,
135                                const int32_t* batchesData, const Shape& batchesShape,
136                                const uint16_t* imageInfoData, const Shape& imageInfoDataShape,
137                                uint16_t* outputData, const Shape& outputShape) {
138     std::vector<float> roi_float32(getNumberOfElements(roiShape));
139     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
140     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
141     convertQuantToFloat32(bboxDeltasData, bboxDeltasShape.scale, bboxDeltasShape.offset,
142                           &delta_float32);
143     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoDataShape));
144     convertQuantToFloat32(imageInfoData, imageInfoDataShape.scale, imageInfoDataShape.offset,
145                           &imageInfo_float32);
146     std::vector<float> output_float32(getNumberOfElements(outputShape));
147     NN_RET_CHECK(bboxTransformFloat32(roi_float32.data(), roiShape, delta_float32.data(),
148                                       bboxDeltasShape, batchesData, batchesShape,
149                                       imageInfo_float32.data(), imageInfoDataShape,
150                                       output_float32.data(), outputShape));
151     convertFloat32ToQuant(output_float32, outputShape.scale, outputShape.offset, outputData);
152     return true;
153 }
154 
bboxTransformQuant(const uint16_t * roiData,const Shape & roiShape,const int8_t * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const uint16_t * imageInfoData,const Shape & imageInfoDataShape,uint16_t * outputData,const Shape & outputShape)155 inline bool bboxTransformQuant(const uint16_t* roiData, const Shape& roiShape,
156                                const int8_t* bboxDeltasData, const Shape& bboxDeltasShape,
157                                const int32_t* batchesData, const Shape& batchesShape,
158                                const uint16_t* imageInfoData, const Shape& imageInfoDataShape,
159                                uint16_t* outputData, const Shape& outputShape) {
160     std::vector<float> roi_float32(getNumberOfElements(roiShape));
161     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
162     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
163     convertQuantToFloat32<int8_t>(bboxDeltasData, bboxDeltasShape.scale, bboxDeltasShape.offset,
164                                   &delta_float32);
165     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoDataShape));
166     convertQuantToFloat32(imageInfoData, imageInfoDataShape.scale, imageInfoDataShape.offset,
167                           &imageInfo_float32);
168     std::vector<float> output_float32(getNumberOfElements(outputShape));
169     NN_RET_CHECK(bboxTransformFloat32(roi_float32.data(), roiShape, delta_float32.data(),
170                                       bboxDeltasShape, batchesData, batchesShape,
171                                       imageInfo_float32.data(), imageInfoDataShape,
172                                       output_float32.data(), outputShape));
173     convertFloat32ToQuant(output_float32, outputShape.scale, outputShape.offset, outputData);
174     return true;
175 }
176 
177 // Taking two indices of bounding boxes, return the intersection-of-union.
getIoUAxisAligned(const float * roi1,const float * roi2)178 float getIoUAxisAligned(const float* roi1, const float* roi2) {
179     const float area1 = (roi1[2] - roi1[0]) * (roi1[3] - roi1[1]);
180     const float area2 = (roi2[2] - roi2[0]) * (roi2[3] - roi2[1]);
181     const float x1 = std::max(roi1[0], roi2[0]);
182     const float x2 = std::min(roi1[2], roi2[2]);
183     const float y1 = std::max(roi1[1], roi2[1]);
184     const float y2 = std::min(roi1[3], roi2[3]);
185     const float w = std::max(x2 - x1, 0.0f);
186     const float h = std::max(y2 - y1, 0.0f);
187     const float areaIntersect = w * h;
188     const float areaUnion = area1 + area2 - areaIntersect;
189     return areaIntersect / areaUnion;
190 }
191 
192 }  // namespace
193 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
194 
195 namespace axis_aligned_bbox_transform {
196 
197 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)198 bool prepare(IOperationExecutionContext* context) {
199     Shape roiShape = context->getInputShape(kRoiTensor);
200     Shape bboxDeltasShape = context->getInputShape(kDeltaTensor);
201     Shape batchesShape = context->getInputShape(kBatchesTensor);
202     Shape imageInfoShape = context->getInputShape(kImageInfoTensor);
203     Shape outputShape = context->getOutputShape(kOutputTensor);
204 
205     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2u);
206     NN_RET_CHECK_EQ(getNumberOfDimensions(bboxDeltasShape), 2u);
207     NN_RET_CHECK_EQ(getNumberOfDimensions(batchesShape), 1u);
208     NN_RET_CHECK_EQ(getNumberOfDimensions(imageInfoShape), 2u);
209 
210     // Only numRois can be zero.
211     const uint32_t kRoiDim = 4;
212     uint32_t numRois = getSizeOfDimension(roiShape, 0);
213     uint32_t numClasses = getSizeOfDimension(bboxDeltasShape, 1) / kRoiDim;
214     uint32_t numBatches = getSizeOfDimension(imageInfoShape, 0);
215     NN_RET_CHECK_GT(numClasses, 0u);
216     NN_RET_CHECK_GT(numBatches, 0u);
217     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), kRoiDim);
218     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 0), numRois);
219     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 1), kRoiDim * numClasses);
220     NN_RET_CHECK_EQ(getSizeOfDimension(batchesShape, 0), numRois);
221     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoShape, 1), 2u);
222 
223     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
224         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
225         NN_RET_CHECK_EQ(roiShape.offset, 0);
226         NN_RET_CHECK_EQ(imageInfoShape.scale, 0.125f);
227         NN_RET_CHECK_EQ(imageInfoShape.offset, 0);
228     }
229 
230     outputShape.type = roiShape.type;
231     outputShape.dimensions = {numRois, numClasses * kRoiDim};
232     outputShape.scale = 0.f;
233     outputShape.offset = 0;
234     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
235         outputShape.scale = 0.125f;
236     }
237     NN_RET_CHECK(context->setOutputShape(kOutputTensor, outputShape));
238     return true;
239 }
240 
execute(IOperationExecutionContext * context)241 bool execute(IOperationExecutionContext* context) {
242     NNTRACE_TRANS("axisAlignedBBoxTransform");
243     // Bypass execution in the case of zero-sized input.
244     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
245     switch (context->getInputType(kRoiTensor)) {
246         case OperandType::TENSOR_FLOAT16: {
247             return bboxTransformFloat16(context->getInputBuffer<_Float16>(kRoiTensor),
248                                         context->getInputShape(kRoiTensor),
249                                         context->getInputBuffer<_Float16>(kDeltaTensor),
250                                         context->getInputShape(kDeltaTensor),
251                                         context->getInputBuffer<int32_t>(kBatchesTensor),
252                                         context->getInputShape(kBatchesTensor),
253                                         context->getInputBuffer<_Float16>(kImageInfoTensor),
254                                         context->getInputShape(kImageInfoTensor),
255                                         context->getOutputBuffer<_Float16>(kOutputTensor),
256                                         context->getOutputShape(kOutputTensor));
257         }
258         case OperandType::TENSOR_FLOAT32: {
259             return bboxTransformFloat32(context->getInputBuffer<float>(kRoiTensor),
260                                         context->getInputShape(kRoiTensor),
261                                         context->getInputBuffer<float>(kDeltaTensor),
262                                         context->getInputShape(kDeltaTensor),
263                                         context->getInputBuffer<int32_t>(kBatchesTensor),
264                                         context->getInputShape(kBatchesTensor),
265                                         context->getInputBuffer<float>(kImageInfoTensor),
266                                         context->getInputShape(kImageInfoTensor),
267                                         context->getOutputBuffer<float>(kOutputTensor),
268                                         context->getOutputShape(kOutputTensor));
269         }
270         case OperandType::TENSOR_QUANT16_ASYMM: {
271             if (context->getInputType(kDeltaTensor) == OperandType::TENSOR_QUANT8_ASYMM) {
272                 return bboxTransformQuant(context->getInputBuffer<uint16_t>(kRoiTensor),
273                                           context->getInputShape(kRoiTensor),
274                                           context->getInputBuffer<uint8_t>(kDeltaTensor),
275                                           context->getInputShape(kDeltaTensor),
276                                           context->getInputBuffer<int32_t>(kBatchesTensor),
277                                           context->getInputShape(kBatchesTensor),
278                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
279                                           context->getInputShape(kImageInfoTensor),
280                                           context->getOutputBuffer<uint16_t>(kOutputTensor),
281                                           context->getOutputShape(kOutputTensor));
282             } else {
283                 return bboxTransformQuant(context->getInputBuffer<uint16_t>(kRoiTensor),
284                                           context->getInputShape(kRoiTensor),
285                                           context->getInputBuffer<int8_t>(kDeltaTensor),
286                                           context->getInputShape(kDeltaTensor),
287                                           context->getInputBuffer<int32_t>(kBatchesTensor),
288                                           context->getInputShape(kBatchesTensor),
289                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
290                                           context->getInputShape(kImageInfoTensor),
291                                           context->getOutputBuffer<uint16_t>(kOutputTensor),
292                                           context->getOutputShape(kOutputTensor));
293             }
294         }
295         default:
296             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
297     }
298 }
299 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
300 
301 }  // namespace axis_aligned_bbox_transform
302 
303 namespace box_with_nms_limit {
304 
305 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
306 namespace {
307 
308 // TODO(xusongw): Reduce code duplication with hard/soft nms path.
309 
310 // Inplace hard NMS within range [select, select + selectLength).
hardNmsSingleClass(const float * scoresData,float iouThreshold,int32_t maxNumDetections,std::function<const float * (uint32_t)> getRoiBase,uint32_t * select,uint32_t selectLength)311 uint32_t* hardNmsSingleClass(const float* scoresData, float iouThreshold, int32_t maxNumDetections,
312                              std::function<const float*(uint32_t)> getRoiBase, uint32_t* select,
313                              uint32_t selectLength) {
314     uint32_t *selectStart = select, *selectEnd = select + selectLength, numDetections = 0;
315     if (maxNumDetections < 0) {
316         maxNumDetections = selectLength;
317     }
318     while (selectStart < selectEnd && numDetections < static_cast<uint32_t>(maxNumDetections)) {
319         // find max score and swap to the front
320         auto& maxScore = *std::max_element(selectStart, selectEnd,
321                                            [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
322                                                return scoresData[lhs] < scoresData[rhs];
323                                            });
324         std::swap(maxScore, *selectStart);
325 
326         // Calculate IoU of the rest, swap to the end (disgard) if needed.
327         for (uint32_t* i = selectStart + 1; i < selectEnd; i++) {
328             float iou = getIoUAxisAligned(getRoiBase(*i), getRoiBase(*selectStart));
329             if (iou >= iouThreshold) {
330                 std::swap(*i--, *(--selectEnd));
331             }
332         }
333         selectStart++;
334         numDetections++;
335     }
336     return selectStart;
337 }
338 
hardNmsMultiClass(const float * scoresData,uint32_t numClasses,uint32_t numRois,float scoreThreshold,float iouThreshold,int32_t maxNumDetections,int32_t maxNumDetectionsPerClass,std::function<const float * (uint32_t)> getRoiBase,std::vector<uint32_t> * select)339 void hardNmsMultiClass(const float* scoresData, uint32_t numClasses, uint32_t numRois,
340                        float scoreThreshold, float iouThreshold, int32_t maxNumDetections,
341                        int32_t maxNumDetectionsPerClass,
342                        std::function<const float*(uint32_t)> getRoiBase,
343                        std::vector<uint32_t>* select) {
344     // Exclude class 0 (background)
345     for (uint32_t c = 1; c < numClasses; c++) {
346         uint32_t size = select->size();
347         for (uint32_t b = 0; b < numRois; b++) {
348             const uint32_t index = b * numClasses + c;
349             const float score = scoresData[index];
350             if (score > scoreThreshold) {
351                 select->push_back(index);
352             }
353         }
354         uint32_t* selectStart = select->data() + size;
355         uint32_t selectLength = select->size() - size;
356         uint32_t* selectEnd = hardNmsSingleClass(scoresData, iouThreshold, maxNumDetectionsPerClass,
357                                                  getRoiBase, selectStart, selectLength);
358         select->resize(selectEnd - select->data());
359     }
360 
361     // Take top maxNumDetections.
362     std::sort(select->begin(), select->end(),
363               [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
364                   return scoresData[lhs] > scoresData[rhs];
365               });
366     if (maxNumDetections < 0 || select->size() <= static_cast<size_t>(maxNumDetections)) {
367         return;
368     }
369     select->resize(maxNumDetections);
370 }
371 
372 // Inplace soft NMS within range [select, select + selectLength).
373 using SoftNmsKernel = std::function<float(float)>;
softNmsSingleClass(float * scoresData,float scoreThreshold,int32_t maxNumDetections,std::function<const float * (uint32_t)> getRoiBase,SoftNmsKernel kernel,uint32_t * select,uint32_t selectLength)374 uint32_t* softNmsSingleClass(float* scoresData, float scoreThreshold, int32_t maxNumDetections,
375                              std::function<const float*(uint32_t)> getRoiBase, SoftNmsKernel kernel,
376                              uint32_t* select, uint32_t selectLength) {
377     uint32_t *selectStart = select, *selectEnd = select + selectLength, numDetections = 0;
378     if (maxNumDetections < 0) {
379         maxNumDetections = selectLength;
380     }
381     while (selectStart < selectEnd && numDetections < static_cast<uint32_t>(maxNumDetections)) {
382         // find max score and swap to the front
383         auto& maxScore = *std::max_element(selectStart, selectEnd,
384                                            [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
385                                                return scoresData[lhs] < scoresData[rhs];
386                                            });
387         std::swap(maxScore, *selectStart);
388 
389         // Calculate IoU of the rest, swap to the end (disgard) if needed.
390         for (uint32_t* i = selectStart + 1; i < selectEnd; i++) {
391             float iou = getIoUAxisAligned(getRoiBase(*i), getRoiBase(*selectStart));
392             scoresData[*i] *= kernel(iou);
393             if (scoresData[*i] < scoreThreshold) {
394                 std::swap(*i--, *(--selectEnd));
395             }
396         }
397         selectStart++;
398         numDetections++;
399     }
400     return selectStart;
401 }
402 
softNmsMultiClass(float * scoresData,uint32_t numClasses,uint32_t numRois,float scoreThreshold,float nmsScoreThreshold,int32_t maxNumDetections,int32_t maxNumDetectionsPerClass,std::function<const float * (uint32_t)> getRoiBase,SoftNmsKernel kernel,std::vector<uint32_t> * select)403 void softNmsMultiClass(float* scoresData, uint32_t numClasses, uint32_t numRois,
404                        float scoreThreshold, float nmsScoreThreshold, int32_t maxNumDetections,
405                        int32_t maxNumDetectionsPerClass,
406                        std::function<const float*(uint32_t)> getRoiBase, SoftNmsKernel kernel,
407                        std::vector<uint32_t>* select) {
408     // Exclude class 0 (background)
409     for (uint32_t c = 1; c < numClasses; c++) {
410         uint32_t size = select->size();
411         for (uint32_t b = 0; b < numRois; b++) {
412             const uint32_t index = b * numClasses + c;
413             const float score = scoresData[index];
414             if (score > scoreThreshold) {
415                 select->push_back(index);
416             }
417         }
418         uint32_t* selectStart = select->data() + size;
419         uint32_t selectLength = select->size() - size;
420         uint32_t* selectEnd =
421                 softNmsSingleClass(scoresData, nmsScoreThreshold, maxNumDetectionsPerClass,
422                                    getRoiBase, kernel, selectStart, selectLength);
423         select->resize(selectEnd - select->data());
424     }
425 
426     // Take top maxNumDetections.
427     std::sort(select->begin(), select->end(),
428               [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
429                   return scoresData[lhs] > scoresData[rhs];
430               });
431     if (maxNumDetections < 0 || select->size() <= static_cast<size_t>(maxNumDetections)) {
432         return;
433     }
434     select->resize(maxNumDetections);
435 }
436 
boxWithNmsLimitFloat32Compute(float * scoresData,const Shape & scoresShape,const float * roiData,const Shape &,const int32_t * batchesData,const Shape &,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,std::vector<uint32_t> * batchSplitIn,std::vector<uint32_t> * batchSplitOut,std::vector<uint32_t> * selected)437 bool boxWithNmsLimitFloat32Compute(float* scoresData, const Shape& scoresShape,
438                                    const float* roiData, const Shape& /*roiShape*/,
439                                    const int32_t* batchesData, const Shape& /*batchesShape*/,
440                                    float scoreThreshold, int32_t maxNumDetections,
441                                    int32_t softNmsKernel, float iouThreshold, float sigma,
442                                    float nmsScoreThreshold, std::vector<uint32_t>* batchSplitIn,
443                                    std::vector<uint32_t>* batchSplitOut,
444                                    std::vector<uint32_t>* selected) {
445     SoftNmsKernel kernel = nullptr;
446     if (softNmsKernel == 0) {
447         kernel = [&iouThreshold](float iou) { return iou < iouThreshold ? 1.0f : 0.0f; };
448     } else if (softNmsKernel == 1) {
449         kernel = [&iouThreshold](float iou) { return iou < iouThreshold ? 1.0f : 1.0f - iou; };
450     } else if (softNmsKernel == 2) {
451         kernel = [&sigma](float iou) { return std::exp(-1.0f * iou * iou / sigma); };
452     } else {
453         NN_RET_CHECK_FAIL() << "Unsupported soft NMS kernel " << softNmsKernel;
454     }
455 
456     const uint32_t kRoiDim = 4;
457     uint32_t numRois = getSizeOfDimension(scoresShape, 0);
458     uint32_t numClasses = getSizeOfDimension(scoresShape, 1);
459 
460     // We assume boxes of the same batch are grouped together.
461     std::vector<uint32_t> batch;
462     int32_t ind = -1;
463     for (uint32_t i = 0; i < numRois; i++) {
464         if (batchesData[i] == ind) {
465             (batchSplitIn->back())++;
466         } else {
467             ind = batchesData[i];
468             batchSplitIn->push_back(1);
469         }
470     }
471 
472     float* scoresBase = scoresData;
473     const float* roiBase = roiData;
474     selected->clear();
475     for (uint32_t b = 0; b < batchSplitIn->size(); b++) {
476         for (uint32_t i = 0; i < batchSplitIn->at(b); i++) {
477             const float* roi = roiBase + i * kRoiDim;
478             // Check for malformed data: invalid region: x2 < x1 || y2 < y1
479             NN_RET_CHECK_LE(roi[0], roi[2]);
480             NN_RET_CHECK_LE(roi[1], roi[3]);
481         }
482         std::vector<uint32_t> result;
483         softNmsMultiClass(
484                 scoresBase, numClasses, batchSplitIn->at(b), scoreThreshold, nmsScoreThreshold,
485                 maxNumDetections, maxNumDetections,
486                 [&roiBase](uint32_t ind) { return roiBase + ind * kRoiDim; }, kernel, &result);
487         // Sort again by class.
488         std::sort(result.begin(), result.end(),
489                   [&scoresBase, numClasses](const uint32_t& lhs, const uint32_t& rhs) {
490                       uint32_t lhsClass = lhs % numClasses, rhsClass = rhs % numClasses;
491                       return lhsClass == rhsClass ? scoresBase[lhs] > scoresBase[rhs]
492                                                   : lhsClass < rhsClass;
493                   });
494         selected->insert(selected->end(), result.begin(), result.end());
495         batchSplitOut->push_back(result.size());
496         scoresBase += batchSplitIn->at(b) * numClasses;
497         roiBase += batchSplitIn->at(b) * numClasses * kRoiDim;
498     }
499     return true;
500 }
501 
502 template <typename T>
castTo(float val,const Shape &)503 T castTo(float val, const Shape&) {
504     return val;
505 }
506 template <>
castTo(float val,const Shape & shape)507 uint8_t castTo(float val, const Shape& shape) {
508     return saturateCast<uint8_t>(std::round(val / shape.scale + shape.offset));
509 }
510 
511 template <>
castTo(float val,const Shape & shape)512 int8_t castTo(float val, const Shape& shape) {
513     return saturateCast<int8_t>(std::round(val / shape.scale + shape.offset));
514 }
515 
516 template <typename T_Score, typename T_Roi>
boxWithNmsLimitWriteOutput(const std::vector<uint32_t> & selected,const std::vector<uint32_t> & batchSplitIn,const std::vector<uint32_t> & batchSplitOut,const std::vector<float> & scores,IOperationExecutionContext * context)517 bool boxWithNmsLimitWriteOutput(const std::vector<uint32_t>& selected,
518                                 const std::vector<uint32_t>& batchSplitIn,
519                                 const std::vector<uint32_t>& batchSplitOut,
520                                 const std::vector<float>& scores,
521                                 IOperationExecutionContext* context) {
522     const uint32_t kRoiDim = 4;
523     Shape scoresShape = context->getInputShape(kScoreTensor);
524     uint32_t numClasses = getSizeOfDimension(scoresShape, 1);
525 
526     // Set output dimensions.
527     uint32_t numOutRois = selected.size();
528     if (numOutRois == 0) return true;
529     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
530     scoresOutShape.dimensions = {numOutRois};
531     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
532 
533     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
534     roiOutShape.dimensions = {numOutRois, 4};
535     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
536 
537     Shape classesOutShape = context->getOutputShape(kOutputClassTensor);
538     classesOutShape.dimensions = {numOutRois};
539     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, classesOutShape));
540 
541     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
542     batchesOutShape.dimensions = {numOutRois};
543     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
544 
545     // Write outputs.
546     const float* scoresBase = scores.data();
547     const T_Roi* roiBase = context->getInputBuffer<T_Roi>(kRoiTensor);
548     const int32_t* batchesInPtr = context->getInputBuffer<int32_t>(kBatchesTensor);
549     T_Score* scoresOutPtr = context->getOutputBuffer<T_Score>(kOutputScoreTensor);
550     T_Roi* roiOutPtr = context->getOutputBuffer<T_Roi>(kOutputRoiTensor);
551     int32_t* classesOutPtr = context->getOutputBuffer<int32_t>(kOutputClassTensor);
552     int32_t* batchesOutPtr = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
553     uint32_t i = 0;
554     for (uint32_t b = 0; b < batchSplitOut.size(); b++) {
555         for (uint32_t j = 0; j < batchSplitOut[b]; j++) {
556             uint32_t index = selected[i++];
557             *scoresOutPtr++ = castTo<T_Score>(scoresBase[index], scoresOutShape);
558             memcpy(roiOutPtr, roiBase + index * kRoiDim, kRoiDim * sizeof(T_Roi));
559             roiOutPtr += kRoiDim;
560             *classesOutPtr++ = index % numClasses;
561             *batchesOutPtr++ = *batchesInPtr;
562         }
563         scoresBase += batchSplitIn[b] * numClasses;
564         roiBase += batchSplitIn[b] * numClasses * kRoiDim;
565         batchesInPtr += batchSplitIn[b];
566     }
567     return true;
568 }
569 
boxWithNmsLimitFloat32(const float * scoresData,const Shape & scoresShape,const float * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,float *,Shape,float *,Shape,int32_t *,Shape,int32_t *,const Shape &,IOperationExecutionContext * context)570 bool boxWithNmsLimitFloat32(const float* scoresData, const Shape& scoresShape, const float* roiData,
571                             const Shape& roiShape, const int32_t* batchesData,
572                             const Shape& batchesShape, float scoreThreshold,
573                             int32_t maxNumDetections, int32_t softNmsKernel, float iouThreshold,
574                             float sigma, float nmsScoreThreshold, float* /*scoresOutData*/,
575                             Shape /*scoresOutShape*/, float* /*roiOutData*/, Shape /*roiOutShape*/,
576                             int32_t* /*classesOutData*/, Shape /*classesOutShape*/,
577                             int32_t* /*batchesOutData*/, const Shape& /*batchSplitOutShape*/,
578                             IOperationExecutionContext* context) {
579     NNTRACE_TRANS("boxWithNmsLimit");
580     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
581     for (uint32_t i = 0; i < scores_float32.size(); i++) {
582         scores_float32[i] = scoresData[i];
583     }
584     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
585     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
586             scores_float32.data(), scoresShape, roiData, roiShape, batchesData, batchesShape,
587             scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma, nmsScoreThreshold,
588             &batchSplitIn, &batchSplitOut, &selected));
589     return boxWithNmsLimitWriteOutput<float, float>(selected, batchSplitIn, batchSplitOut,
590                                                     scores_float32, context);
591 }
592 
boxWithNmsLimitFloat16(const _Float16 * scoresData,const Shape & scoresShape,const _Float16 * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,_Float16 scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,_Float16 iouThreshold,_Float16 sigma,_Float16 nmsScoreThreshold,_Float16 *,const Shape &,_Float16 *,const Shape &,int32_t *,const Shape &,int32_t *,const Shape &,IOperationExecutionContext * context)593 bool boxWithNmsLimitFloat16(const _Float16* scoresData, const Shape& scoresShape,
594                             const _Float16* roiData, const Shape& roiShape,
595                             const int32_t* batchesData, const Shape& batchesShape,
596                             _Float16 scoreThreshold, int32_t maxNumDetections,
597                             int32_t softNmsKernel, _Float16 iouThreshold, _Float16 sigma,
598                             _Float16 nmsScoreThreshold, _Float16* /*scoresOutData*/,
599                             const Shape& /*scoresOutShape*/, _Float16* /*roiOutData*/,
600                             const Shape& /*roiOutShape*/, int32_t* /*classesOutData*/,
601                             const Shape& /*classesOutShape*/, int32_t* /*batchesOutData*/,
602                             const Shape& /*batchSplitOutShape*/,
603                             IOperationExecutionContext* context) {
604     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
605     convertFloat16ToFloat32(scoresData, &scores_float32);
606     std::vector<float> roi_float32(getNumberOfElements(roiShape));
607     convertFloat16ToFloat32(roiData, &roi_float32);
608     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
609     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
610             scores_float32.data(), scoresShape, roi_float32.data(), roiShape, batchesData,
611             batchesShape, scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma,
612             nmsScoreThreshold, &batchSplitIn, &batchSplitOut, &selected));
613     return boxWithNmsLimitWriteOutput<_Float16, _Float16>(selected, batchSplitIn, batchSplitOut,
614                                                           scores_float32, context);
615 }
616 
boxWithNmsLimitQuant(const uint8_t * scoresData,const Shape & scoresShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,uint8_t *,const Shape &,uint16_t *,const Shape &,int32_t *,const Shape &,int32_t *,const Shape &,IOperationExecutionContext * context)617 bool boxWithNmsLimitQuant(const uint8_t* scoresData, const Shape& scoresShape,
618                           const uint16_t* roiData, const Shape& roiShape,
619                           const int32_t* batchesData, const Shape& batchesShape,
620                           float scoreThreshold, int32_t maxNumDetections, int32_t softNmsKernel,
621                           float iouThreshold, float sigma, float nmsScoreThreshold,
622                           uint8_t* /*scoresOutData*/, const Shape& /*scoresOutShape*/,
623                           uint16_t* /*roiOutData*/, const Shape& /*roiOutShape*/,
624                           int32_t* /*classesOutData*/, const Shape& /*classesOutShape*/,
625                           int32_t* /*batchesOutData*/, const Shape& /*batchSplitOutShape*/,
626                           IOperationExecutionContext* context) {
627     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
628     convertQuantToFloat32(scoresData, scoresShape.scale, scoresShape.offset, &scores_float32);
629     std::vector<float> roi_float32(getNumberOfElements(roiShape));
630     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
631     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
632     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
633             scores_float32.data(), scoresShape, roi_float32.data(), roiShape, batchesData,
634             batchesShape, scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma,
635             nmsScoreThreshold, &batchSplitIn, &batchSplitOut, &selected));
636     return boxWithNmsLimitWriteOutput<uint8_t, uint16_t>(selected, batchSplitIn, batchSplitOut,
637                                                          scores_float32, context);
638 }
639 
boxWithNmsLimitQuant(const int8_t * scoresData,const Shape & scoresShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,int8_t *,const Shape &,uint16_t *,const Shape &,int32_t *,const Shape &,int32_t *,const Shape &,IOperationExecutionContext * context)640 bool boxWithNmsLimitQuant(const int8_t* scoresData, const Shape& scoresShape,
641                           const uint16_t* roiData, const Shape& roiShape,
642                           const int32_t* batchesData, const Shape& batchesShape,
643                           float scoreThreshold, int32_t maxNumDetections, int32_t softNmsKernel,
644                           float iouThreshold, float sigma, float nmsScoreThreshold,
645                           int8_t* /*scoresOutData*/, const Shape& /*scoresOutShape*/,
646                           uint16_t* /*roiOutData*/, const Shape& /*roiOutShape*/,
647                           int32_t* /*classesOutData*/, const Shape& /*classesOutShape*/,
648                           int32_t* /*batchesOutData*/, const Shape& /*batchSplitOutShape*/,
649                           IOperationExecutionContext* context) {
650     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
651     convertQuantToFloat32<int8_t>(scoresData, scoresShape.scale, scoresShape.offset,
652                                   &scores_float32);
653     std::vector<float> roi_float32(getNumberOfElements(roiShape));
654     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
655     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
656     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
657             scores_float32.data(), scoresShape, roi_float32.data(), roiShape, batchesData,
658             batchesShape, scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma,
659             nmsScoreThreshold, &batchSplitIn, &batchSplitOut, &selected));
660     return boxWithNmsLimitWriteOutput<int8_t, uint16_t>(selected, batchSplitIn, batchSplitOut,
661                                                         scores_float32, context);
662 }
663 
664 }  // namespace
665 
prepare(IOperationExecutionContext * context)666 bool prepare(IOperationExecutionContext* context) {
667     Shape scoreShape = context->getInputShape(kScoreTensor);
668     Shape roiShape = context->getInputShape(kRoiTensor);
669     Shape batchesShape = context->getInputShape(kBatchesTensor);
670     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
671     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
672     Shape outputClassShape = context->getOutputShape(kOutputClassTensor);
673     Shape outputBatchSplitShape = context->getOutputShape(kOutputBatchesTensor);
674 
675     NN_RET_CHECK(getNumberOfDimensions(scoreShape) == 2);
676     NN_RET_CHECK(getNumberOfDimensions(roiShape) == 2);
677     NN_RET_CHECK(getNumberOfDimensions(batchesShape) == 1);
678 
679     // Only numRois can be zero.
680     const uint32_t kRoiDim = 4;
681     uint32_t numRois = getSizeOfDimension(scoreShape, 0);
682     uint32_t numClasses = getSizeOfDimension(scoreShape, 1);
683     NN_RET_CHECK(getSizeOfDimension(roiShape, 0) == numRois);
684     NN_RET_CHECK(getSizeOfDimension(roiShape, 1) == kRoiDim * numClasses);
685     NN_RET_CHECK(getSizeOfDimension(batchesShape, 0) == numRois);
686     NN_RET_CHECK_GT(numClasses, 1u);
687 
688     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM ||
689         scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
690         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
691         NN_RET_CHECK_EQ(roiShape.offset, 0);
692     }
693 
694     outputScoreShape.type = scoreShape.type;
695     outputScoreShape.dimensions = {0};
696     outputScoreShape.scale = scoreShape.scale;
697     outputScoreShape.offset = scoreShape.offset;
698     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
699 
700     outputRoiShape.type = roiShape.type;
701     outputRoiShape.dimensions = {0, 4};
702     outputRoiShape.scale = 0.f;
703     outputRoiShape.offset = 0;
704     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM ||
705         scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
706         outputRoiShape.scale = 0.125f;
707     }
708     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
709 
710     outputClassShape.type = OperandType::TENSOR_INT32;
711     outputClassShape.dimensions = {0};
712     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, outputClassShape));
713 
714     outputBatchSplitShape.type = batchesShape.type;
715     outputBatchSplitShape.dimensions = {0};
716     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, outputBatchSplitShape));
717     return true;
718 }
719 
execute(IOperationExecutionContext * context)720 bool execute(IOperationExecutionContext* context) {
721     NNTRACE_TRANS("boxWithNMSLimit");
722     // Bypass execution in the case of zero numRois.
723     if (getSizeOfDimension(context->getInputShape(kScoreTensor), 0) == 0) return true;
724     switch (context->getInputType(kScoreTensor)) {
725         case OperandType::TENSOR_FLOAT16: {
726             return boxWithNmsLimitFloat16(
727                     context->getInputBuffer<_Float16>(kScoreTensor),
728                     context->getInputShape(kScoreTensor),
729                     context->getInputBuffer<_Float16>(kRoiTensor),
730                     context->getInputShape(kRoiTensor),
731                     context->getInputBuffer<int32_t>(kBatchesTensor),
732                     context->getInputShape(kBatchesTensor),
733                     context->getInputValue<_Float16>(kScoreThresholdScalar),
734                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
735                     context->getInputValue<int32_t>(kNmsKernelScalar),
736                     context->getInputValue<_Float16>(kIoUThresholdScalar),
737                     context->getInputValue<_Float16>(kSigmaScalar),
738                     context->getInputValue<_Float16>(kNmsScoreThresholdScalar),
739                     context->getOutputBuffer<_Float16>(kOutputScoreTensor),
740                     context->getOutputShape(kOutputScoreTensor),
741                     context->getOutputBuffer<_Float16>(kOutputRoiTensor),
742                     context->getOutputShape(kOutputRoiTensor),
743                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
744                     context->getOutputShape(kOutputClassTensor),
745                     context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
746                     context->getOutputShape(kOutputBatchesTensor), context);
747         }
748         case OperandType::TENSOR_FLOAT32: {
749             return boxWithNmsLimitFloat32(context->getInputBuffer<float>(kScoreTensor),
750                                           context->getInputShape(kScoreTensor),
751                                           context->getInputBuffer<float>(kRoiTensor),
752                                           context->getInputShape(kRoiTensor),
753                                           context->getInputBuffer<int32_t>(kBatchesTensor),
754                                           context->getInputShape(kBatchesTensor),
755                                           context->getInputValue<float>(kScoreThresholdScalar),
756                                           context->getInputValue<int32_t>(kMaxNumDetectionScalar),
757                                           context->getInputValue<int32_t>(kNmsKernelScalar),
758                                           context->getInputValue<float>(kIoUThresholdScalar),
759                                           context->getInputValue<float>(kSigmaScalar),
760                                           context->getInputValue<float>(kNmsScoreThresholdScalar),
761                                           context->getOutputBuffer<float>(kOutputScoreTensor),
762                                           context->getOutputShape(kOutputScoreTensor),
763                                           context->getOutputBuffer<float>(kOutputRoiTensor),
764                                           context->getOutputShape(kOutputRoiTensor),
765                                           context->getOutputBuffer<int32_t>(kOutputClassTensor),
766                                           context->getOutputShape(kOutputClassTensor),
767                                           context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
768                                           context->getOutputShape(kOutputBatchesTensor), context);
769         }
770         case OperandType::TENSOR_QUANT8_ASYMM: {
771             return boxWithNmsLimitQuant(context->getInputBuffer<uint8_t>(kScoreTensor),
772                                         context->getInputShape(kScoreTensor),
773                                         context->getInputBuffer<uint16_t>(kRoiTensor),
774                                         context->getInputShape(kRoiTensor),
775                                         context->getInputBuffer<int32_t>(kBatchesTensor),
776                                         context->getInputShape(kBatchesTensor),
777                                         context->getInputValue<float>(kScoreThresholdScalar),
778                                         context->getInputValue<int32_t>(kMaxNumDetectionScalar),
779                                         context->getInputValue<int32_t>(kNmsKernelScalar),
780                                         context->getInputValue<float>(kIoUThresholdScalar),
781                                         context->getInputValue<float>(kSigmaScalar),
782                                         context->getInputValue<float>(kNmsScoreThresholdScalar),
783                                         context->getOutputBuffer<uint8_t>(kOutputScoreTensor),
784                                         context->getOutputShape(kOutputScoreTensor),
785                                         context->getOutputBuffer<uint16_t>(kOutputRoiTensor),
786                                         context->getOutputShape(kOutputRoiTensor),
787                                         context->getOutputBuffer<int32_t>(kOutputClassTensor),
788                                         context->getOutputShape(kOutputClassTensor),
789                                         context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
790                                         context->getOutputShape(kOutputBatchesTensor), context);
791         }
792         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: {
793             return boxWithNmsLimitQuant(context->getInputBuffer<int8_t>(kScoreTensor),
794                                         context->getInputShape(kScoreTensor),
795                                         context->getInputBuffer<uint16_t>(kRoiTensor),
796                                         context->getInputShape(kRoiTensor),
797                                         context->getInputBuffer<int32_t>(kBatchesTensor),
798                                         context->getInputShape(kBatchesTensor),
799                                         context->getInputValue<float>(kScoreThresholdScalar),
800                                         context->getInputValue<int32_t>(kMaxNumDetectionScalar),
801                                         context->getInputValue<int32_t>(kNmsKernelScalar),
802                                         context->getInputValue<float>(kIoUThresholdScalar),
803                                         context->getInputValue<float>(kSigmaScalar),
804                                         context->getInputValue<float>(kNmsScoreThresholdScalar),
805                                         context->getOutputBuffer<int8_t>(kOutputScoreTensor),
806                                         context->getOutputShape(kOutputScoreTensor),
807                                         context->getOutputBuffer<uint16_t>(kOutputRoiTensor),
808                                         context->getOutputShape(kOutputRoiTensor),
809                                         context->getOutputBuffer<int32_t>(kOutputClassTensor),
810                                         context->getOutputShape(kOutputClassTensor),
811                                         context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
812                                         context->getOutputShape(kOutputBatchesTensor), context);
813         }
814         default:
815             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
816     }
817 }
818 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
819 
820 }  // namespace box_with_nms_limit
821 
822 namespace generate_proposals {
823 
824 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
825 namespace {
826 
filterBoxes(const float * roiBase,const float * imageInfoBase,float minSize,std::vector<uint32_t> * select)827 void filterBoxes(const float* roiBase, const float* imageInfoBase, float minSize,
828                  std::vector<uint32_t>* select) {
829     const uint32_t kRoiDim = 4;
830     uint32_t i = 0;
831     for (uint32_t j = 0; j < select->size(); j++) {
832         const float* roiInfo = roiBase + (*select)[j] * kRoiDim;
833         float roiWidth, roiHeight, xRoiCenter, yRoiCenter;
834         roiWidth = roiInfo[2] - roiInfo[0];
835         roiHeight = roiInfo[3] - roiInfo[1];
836         xRoiCenter = roiInfo[0] + roiWidth / 2.0f;
837         yRoiCenter = roiInfo[1] + roiHeight / 2.0f;
838         if (roiWidth > minSize && roiHeight > minSize && xRoiCenter < imageInfoBase[1] &&
839             yRoiCenter < imageInfoBase[0]) {
840             (*select)[i++] = (*select)[j];
841         }
842     }
843     select->resize(i);
844 }
845 
generateProposalsNhwcFloat32Compute(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,std::vector<float> * scoresOutData,std::vector<float> * roiOutData,std::vector<int32_t> * batchesOutData)846 bool generateProposalsNhwcFloat32Compute(const float* scoresData, const Shape& scoresShape,
847                                          const float* bboxDeltasData, const Shape& bboxDeltasShape,
848                                          const float* anchorsData, const Shape& anchorsShape,
849                                          const float* imageInfoData, const Shape& imageInfoShape,
850                                          float heightStride, float widthStride, int32_t preNmsTopN,
851                                          int32_t postNmsTopN, float iouThreshold, float minSize,
852                                          std::vector<float>* scoresOutData,
853                                          std::vector<float>* roiOutData,
854                                          std::vector<int32_t>* batchesOutData) {
855     const uint32_t kRoiDim = 4;
856     uint32_t numBatches = getSizeOfDimension(scoresShape, 0);
857     uint32_t height = getSizeOfDimension(scoresShape, 1);
858     uint32_t width = getSizeOfDimension(scoresShape, 2);
859     uint32_t numAnchors = getSizeOfDimension(scoresShape, 3);
860     uint32_t imageInfoLength = getSizeOfDimension(imageInfoShape, 1);
861 
862     uint32_t batchSize = height * width * numAnchors;
863     uint32_t roiBufferSize = batchSize * kRoiDim;
864     std::vector<float> roiBuffer(roiBufferSize);
865     std::vector<float> roiTransformedBuffer(roiBufferSize);
866     scoresOutData->clear();
867     roiOutData->clear();
868     batchesOutData->clear();
869 
870     // Compute the roi region for each anchor.
871     float* roiBase = roiBuffer.data();
872     for (uint32_t h = 0; h < height; h++) {
873         float hShift = h * heightStride;
874         for (uint32_t w = 0; w < width; w++) {
875             const float* anchorsBase = anchorsData;
876             float wShift = w * widthStride;
877             for (uint32_t a = 0; a < numAnchors; a++, roiBase += kRoiDim, anchorsBase += kRoiDim) {
878                 roiBase[0] = anchorsBase[0] + wShift;
879                 roiBase[1] = anchorsBase[1] + hShift;
880                 roiBase[2] = anchorsBase[2] + wShift;
881                 roiBase[3] = anchorsBase[3] + hShift;
882             }
883         }
884     }
885 
886     const float* scoresBase = scoresData;
887     const float* bboxDeltasBase = bboxDeltasData;
888     const float* imageInfoBase = imageInfoData;
889     // Need to fake some data to satisfy bboxTransform.
890     Shape tempRoiShape = anchorsShape;
891     tempRoiShape.dimensions = {batchSize, kRoiDim};
892     Shape tempBBoxDeltasShape = bboxDeltasShape;
893     tempBBoxDeltasShape.dimensions = {batchSize, kRoiDim};
894     std::vector<int32_t> tempBatchSplitData(batchSize, 0);
895     Shape tempbatchSplitShape = {.dimensions = {batchSize}};
896     Shape tempImageInfoShape = imageInfoShape;
897     tempImageInfoShape.dimensions = {1, imageInfoLength};
898 
899     for (uint32_t b = 0; b < numBatches; b++) {
900         // Apply bboxDeltas to anchor locations.
901         float tempImageInfo[] = {imageInfoBase[0], imageInfoBase[1]};
902         if (!bboxTransformFloat32(roiBuffer.data(), tempRoiShape, bboxDeltasBase,
903                                   tempBBoxDeltasShape, tempBatchSplitData.data(),
904                                   tempbatchSplitShape, tempImageInfo, tempImageInfoShape,
905                                   roiTransformedBuffer.data(), tempRoiShape)) {
906             LOG(ERROR) << "BBoxTransform step failed in GENERATE_PROPOSALS op.";
907             return false;
908         }
909 
910         // Find the top preNmsTopN scores.
911         std::vector<uint32_t> select(batchSize);
912         std::iota(select.begin(), select.end(), 0);
913         if (preNmsTopN > 0 && static_cast<size_t>(preNmsTopN) < select.size()) {
914             std::sort(select.begin(), select.end(),
915                       [&scoresBase](const uint32_t lhs, const uint32_t rhs) {
916                           return scoresBase[lhs] > scoresBase[rhs];
917                       });
918             select.resize(preNmsTopN);
919         }
920 
921         // Filter boxes, disgard regions with height or width < minSize.
922         filterBoxes(roiTransformedBuffer.data(), imageInfoBase, minSize, &select);
923 
924         // Apply hard NMS.
925         uint32_t* selectEnd = box_with_nms_limit::hardNmsSingleClass(
926                 scoresBase, iouThreshold, postNmsTopN,
927                 [&roiTransformedBuffer](uint32_t ind) {
928                     return roiTransformedBuffer.data() + ind * kRoiDim;
929                 },
930                 select.data(), select.size());
931         uint32_t selectSize = selectEnd - select.data();
932         select.resize(selectSize);
933 
934         // Write output.
935         for (auto i : select) {
936             roiOutData->insert(roiOutData->end(), roiTransformedBuffer.begin() + i * kRoiDim,
937                                roiTransformedBuffer.begin() + (i + 1) * kRoiDim);
938             scoresOutData->push_back(scoresBase[i]);
939             batchesOutData->push_back(b);
940         }
941         scoresBase += batchSize;
942         bboxDeltasBase += roiBufferSize;
943         imageInfoBase += imageInfoLength;
944     }
945     return true;
946 }
947 
generateProposalsFloat32Compute(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,std::vector<float> * scoresOutData,std::vector<float> * roiOutData,std::vector<int32_t> * batchesOutData)948 bool generateProposalsFloat32Compute(const float* scoresData, const Shape& scoresShape,
949                                      const float* bboxDeltasData, const Shape& bboxDeltasShape,
950                                      const float* anchorsData, const Shape& anchorsShape,
951                                      const float* imageInfoData, const Shape& imageInfoShape,
952                                      float heightStride, float widthStride, int32_t preNmsTopN,
953                                      int32_t postNmsTopN, float iouThreshold, float minSize,
954                                      bool useNchw, std::vector<float>* scoresOutData,
955                                      std::vector<float>* roiOutData,
956                                      std::vector<int32_t>* batchesOutData) {
957     InputWithLayout<float> score_nhwc(useNchw), delta_nhwc(useNchw);
958     NN_RET_CHECK(score_nhwc.initialize(scoresData, scoresShape));
959     NN_RET_CHECK(delta_nhwc.initialize(bboxDeltasData, bboxDeltasShape));
960     return generateProposalsNhwcFloat32Compute(
961             score_nhwc.getNhwcBuffer(), score_nhwc.getNhwcShape(), delta_nhwc.getNhwcBuffer(),
962             delta_nhwc.getNhwcShape(), anchorsData, anchorsShape, imageInfoData, imageInfoShape,
963             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize,
964             scoresOutData, roiOutData, batchesOutData);
965 }
966 
generateProposalsFloat32(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)967 bool generateProposalsFloat32(const float* scoresData, const Shape& scoresShape,
968                               const float* bboxDeltasData, const Shape& bboxDeltasShape,
969                               const float* anchorsData, const Shape& anchorsShape,
970                               const float* imageInfoData, const Shape& imageInfoShape,
971                               float heightStride, float widthStride, int32_t preNmsTopN,
972                               int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
973                               IOperationExecutionContext* context) {
974     std::vector<float> scoresOut_float32, roiOut_float32;
975     std::vector<int32_t> batchesOut;
976     NN_RET_CHECK(generateProposalsFloat32Compute(
977             scoresData, scoresShape, bboxDeltasData, bboxDeltasShape, anchorsData, anchorsShape,
978             imageInfoData, imageInfoShape, heightStride, widthStride, preNmsTopN, postNmsTopN,
979             iouThreshold, minSize, useNchw, &scoresOut_float32, &roiOut_float32, &batchesOut));
980 
981     // Set output dimensions.
982     uint32_t numOutRois = scoresOut_float32.size();
983     if (numOutRois == 0) return true;
984     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
985     scoresOutShape.dimensions = {numOutRois};
986     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
987     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
988     roiOutShape.dimensions = {numOutRois, 4};
989     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
990     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
991     batchesOutShape.dimensions = {numOutRois};
992     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
993 
994     // Write outputs.
995     float* scoresOutData = context->getOutputBuffer<float>(kOutputScoreTensor);
996     for (uint32_t i = 0; i < scoresOut_float32.size(); i++) {
997         scoresOutData[i] = scoresOut_float32[i];
998     }
999     float* roiOutData = context->getOutputBuffer<float>(kOutputRoiTensor);
1000     for (uint32_t i = 0; i < roiOut_float32.size(); i++) {
1001         roiOutData[i] = roiOut_float32[i];
1002     }
1003     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1004     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1005         batchesOutData[i] = batchesOut[i];
1006     }
1007     return true;
1008 }
1009 
generateProposalsFloat16(const _Float16 * scoresData,const Shape & scoresShape,const _Float16 * bboxDeltasData,const Shape & bboxDeltasShape,const _Float16 * anchorsData,const Shape & anchorsShape,const _Float16 * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)1010 bool generateProposalsFloat16(const _Float16* scoresData, const Shape& scoresShape,
1011                               const _Float16* bboxDeltasData, const Shape& bboxDeltasShape,
1012                               const _Float16* anchorsData, const Shape& anchorsShape,
1013                               const _Float16* imageInfoData, const Shape& imageInfoShape,
1014                               float heightStride, float widthStride, int32_t preNmsTopN,
1015                               int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
1016                               IOperationExecutionContext* context) {
1017     std::vector<float> score_float32(getNumberOfElements(scoresShape));
1018     convertFloat16ToFloat32(scoresData, &score_float32);
1019     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
1020     convertFloat16ToFloat32(bboxDeltasData, &delta_float32);
1021     std::vector<float> anchors_float32(getNumberOfElements(anchorsShape));
1022     convertFloat16ToFloat32(anchorsData, &anchors_float32);
1023     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoShape));
1024     convertFloat16ToFloat32(imageInfoData, &imageInfo_float32);
1025     std::vector<float> scoresOut_float32, roiOut_float32;
1026     std::vector<int32_t> batchesOut;
1027     NN_RET_CHECK(generateProposalsFloat32Compute(
1028             score_float32.data(), scoresShape, delta_float32.data(), bboxDeltasShape,
1029             anchors_float32.data(), anchorsShape, imageInfo_float32.data(), imageInfoShape,
1030             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize, useNchw,
1031             &scoresOut_float32, &roiOut_float32, &batchesOut));
1032 
1033     // Set output dimensions.
1034     uint32_t numOutRois = scoresOut_float32.size();
1035     if (numOutRois == 0) return true;
1036     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
1037     scoresOutShape.dimensions = {numOutRois};
1038     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
1039     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
1040     roiOutShape.dimensions = {numOutRois, 4};
1041     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
1042     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
1043     batchesOutShape.dimensions = {numOutRois};
1044     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
1045 
1046     // Write outputs.
1047     _Float16* scoresOutData = context->getOutputBuffer<_Float16>(kOutputScoreTensor);
1048     convertFloat32ToFloat16(scoresOut_float32, scoresOutData);
1049     _Float16* roiOutData = context->getOutputBuffer<_Float16>(kOutputRoiTensor);
1050     convertFloat32ToFloat16(roiOut_float32, roiOutData);
1051     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1052     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1053         batchesOutData[i] = batchesOut[i];
1054     }
1055     return true;
1056 }
1057 
1058 template <typename T_8QInput>
generateProposalsQuant(const T_8QInput * scoresData,const Shape & scoresShape,const T_8QInput * bboxDeltasData,const Shape & bboxDeltasShape,const int16_t * anchorsData,const Shape & anchorsShape,const uint16_t * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)1059 bool generateProposalsQuant(const T_8QInput* scoresData, const Shape& scoresShape,
1060                             const T_8QInput* bboxDeltasData, const Shape& bboxDeltasShape,
1061                             const int16_t* anchorsData, const Shape& anchorsShape,
1062                             const uint16_t* imageInfoData, const Shape& imageInfoShape,
1063                             float heightStride, float widthStride, int32_t preNmsTopN,
1064                             int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
1065                             IOperationExecutionContext* context) {
1066     std::vector<float> score_float32(getNumberOfElements(scoresShape));
1067     convertQuantToFloat32<T_8QInput>(scoresData, scoresShape.scale, scoresShape.offset,
1068                                      &score_float32);
1069     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
1070     convertQuantToFloat32<T_8QInput>(bboxDeltasData, bboxDeltasShape.scale, bboxDeltasShape.offset,
1071                                      &delta_float32);
1072     std::vector<float> anchors_float32(getNumberOfElements(anchorsShape));
1073     convertQuantToFloat32(anchorsData, anchorsShape.scale, anchorsShape.offset, &anchors_float32);
1074     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoShape));
1075     convertQuantToFloat32(imageInfoData, imageInfoShape.scale, imageInfoShape.offset,
1076                           &imageInfo_float32);
1077     std::vector<float> scoresOut_float32, roiOut_float32;
1078     std::vector<int32_t> batchesOut;
1079     NN_RET_CHECK(generateProposalsFloat32Compute(
1080             score_float32.data(), scoresShape, delta_float32.data(), bboxDeltasShape,
1081             anchors_float32.data(), anchorsShape, imageInfo_float32.data(), imageInfoShape,
1082             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize, useNchw,
1083             &scoresOut_float32, &roiOut_float32, &batchesOut));
1084 
1085     // Set output dimensions.
1086     uint32_t numOutRois = scoresOut_float32.size();
1087     if (numOutRois == 0) return true;
1088     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
1089     scoresOutShape.dimensions = {numOutRois};
1090     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
1091     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
1092     roiOutShape.dimensions = {numOutRois, 4};
1093     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
1094     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
1095     batchesOutShape.dimensions = {numOutRois};
1096     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
1097 
1098     // Write outputs.
1099     T_8QInput* scoresOutData = context->getOutputBuffer<T_8QInput>(kOutputScoreTensor);
1100     convertFloat32ToQuant<T_8QInput>(scoresOut_float32, scoresOutShape.scale, scoresOutShape.offset,
1101                                      scoresOutData);
1102     uint16_t* roiOutData = context->getOutputBuffer<uint16_t>(kOutputRoiTensor);
1103     convertFloat32ToQuant(roiOut_float32, roiOutShape.scale, roiOutShape.offset, roiOutData);
1104     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1105     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1106         batchesOutData[i] = batchesOut[i];
1107     }
1108     return true;
1109 }
1110 
1111 }  // namespace
1112 
prepare(IOperationExecutionContext * context)1113 bool prepare(IOperationExecutionContext* context) {
1114     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
1115     Shape scoreShape = context->getInputShape(kScoreTensor);
1116     Shape bboxDeltasShape = context->getInputShape(kDeltaTensor);
1117     Shape anchorsShape = context->getInputShape(kAnchorTensor);
1118     Shape imageInfoDataShape = context->getInputShape(kImageInfoTensor);
1119     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
1120     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
1121     Shape outputBatchSplitShape = context->getOutputShape(kOutputBatchesTensor);
1122 
1123     NN_RET_CHECK_EQ(getNumberOfDimensions(scoreShape), 4u);
1124     NN_RET_CHECK_EQ(getNumberOfDimensions(bboxDeltasShape), 4u);
1125     NN_RET_CHECK_EQ(getNumberOfDimensions(anchorsShape), 2u);
1126     NN_RET_CHECK_EQ(getNumberOfDimensions(imageInfoDataShape), 2u);
1127 
1128     const uint32_t kRoiDim = 4;
1129     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1130     uint32_t height = getSizeOfDimension(scoreShape, useNchw ? 2 : 1);
1131     uint32_t width = getSizeOfDimension(scoreShape, useNchw ? 3 : 2);
1132     uint32_t numAnchors = getSizeOfDimension(scoreShape, useNchw ? 1 : 3);
1133 
1134     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 0), numBatches);
1135     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 2 : 1), height);
1136     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 3 : 2), width);
1137     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 1 : 3), numAnchors * kRoiDim);
1138     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoDataShape, 0), numBatches);
1139     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoDataShape, 1), 2u);
1140     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 0), numAnchors);
1141     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 1), kRoiDim);
1142 
1143     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
1144         NN_RET_CHECK_EQ(anchorsShape.scale, 0.125f);
1145         NN_RET_CHECK_EQ(imageInfoDataShape.scale, 0.125f);
1146         NN_RET_CHECK_EQ(imageInfoDataShape.offset, 0);
1147     }
1148 
1149     outputScoreShape.type = scoreShape.type;
1150     outputScoreShape.dimensions = {0};
1151     outputScoreShape.scale = scoreShape.scale;
1152     outputScoreShape.offset = scoreShape.offset;
1153     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
1154 
1155     outputRoiShape.dimensions = {0, 4};
1156     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
1157         outputRoiShape.scale = 0.125f;
1158         outputRoiShape.offset = 0;
1159     }
1160     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
1161 
1162     outputBatchSplitShape.dimensions = {0};
1163     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, outputBatchSplitShape));
1164     return true;
1165 }
1166 
execute(IOperationExecutionContext * context)1167 bool execute(IOperationExecutionContext* context) {
1168     NNTRACE_TRANS("generateProposals");
1169     switch (context->getInputType(kScoreTensor)) {
1170         case OperandType::TENSOR_FLOAT16: {
1171             return generateProposalsFloat16(context->getInputBuffer<_Float16>(kScoreTensor),
1172                                             context->getInputShape(kScoreTensor),
1173                                             context->getInputBuffer<_Float16>(kDeltaTensor),
1174                                             context->getInputShape(kDeltaTensor),
1175                                             context->getInputBuffer<_Float16>(kAnchorTensor),
1176                                             context->getInputShape(kAnchorTensor),
1177                                             context->getInputBuffer<_Float16>(kImageInfoTensor),
1178                                             context->getInputShape(kImageInfoTensor),
1179                                             context->getInputValue<_Float16>(kHeightStrideSalar),
1180                                             context->getInputValue<_Float16>(kWidthStrideScalar),
1181                                             context->getInputValue<int32_t>(kPreNmsMaxScalar),
1182                                             context->getInputValue<int32_t>(kPostNmsMaxScalar),
1183                                             context->getInputValue<_Float16>(kIoUThresholdScalar),
1184                                             context->getInputValue<_Float16>(kMinSizeScalar),
1185                                             context->getInputValue<bool>(kLayoutScalar), context);
1186         }
1187         case OperandType::TENSOR_FLOAT32: {
1188             return generateProposalsFloat32(context->getInputBuffer<float>(kScoreTensor),
1189                                             context->getInputShape(kScoreTensor),
1190                                             context->getInputBuffer<float>(kDeltaTensor),
1191                                             context->getInputShape(kDeltaTensor),
1192                                             context->getInputBuffer<float>(kAnchorTensor),
1193                                             context->getInputShape(kAnchorTensor),
1194                                             context->getInputBuffer<float>(kImageInfoTensor),
1195                                             context->getInputShape(kImageInfoTensor),
1196                                             context->getInputValue<float>(kHeightStrideSalar),
1197                                             context->getInputValue<float>(kWidthStrideScalar),
1198                                             context->getInputValue<int32_t>(kPreNmsMaxScalar),
1199                                             context->getInputValue<int32_t>(kPostNmsMaxScalar),
1200                                             context->getInputValue<float>(kIoUThresholdScalar),
1201                                             context->getInputValue<float>(kMinSizeScalar),
1202                                             context->getInputValue<bool>(kLayoutScalar), context);
1203         }
1204         case OperandType::TENSOR_QUANT8_ASYMM: {
1205             return generateProposalsQuant(context->getInputBuffer<uint8_t>(kScoreTensor),
1206                                           context->getInputShape(kScoreTensor),
1207                                           context->getInputBuffer<uint8_t>(kDeltaTensor),
1208                                           context->getInputShape(kDeltaTensor),
1209                                           context->getInputBuffer<int16_t>(kAnchorTensor),
1210                                           context->getInputShape(kAnchorTensor),
1211                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
1212                                           context->getInputShape(kImageInfoTensor),
1213                                           context->getInputValue<float>(kHeightStrideSalar),
1214                                           context->getInputValue<float>(kWidthStrideScalar),
1215                                           context->getInputValue<int32_t>(kPreNmsMaxScalar),
1216                                           context->getInputValue<int32_t>(kPostNmsMaxScalar),
1217                                           context->getInputValue<float>(kIoUThresholdScalar),
1218                                           context->getInputValue<float>(kMinSizeScalar),
1219                                           context->getInputValue<bool>(kLayoutScalar), context);
1220         }
1221         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: {
1222             return generateProposalsQuant(context->getInputBuffer<int8_t>(kScoreTensor),
1223                                           context->getInputShape(kScoreTensor),
1224                                           context->getInputBuffer<int8_t>(kDeltaTensor),
1225                                           context->getInputShape(kDeltaTensor),
1226                                           context->getInputBuffer<int16_t>(kAnchorTensor),
1227                                           context->getInputShape(kAnchorTensor),
1228                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
1229                                           context->getInputShape(kImageInfoTensor),
1230                                           context->getInputValue<float>(kHeightStrideSalar),
1231                                           context->getInputValue<float>(kWidthStrideScalar),
1232                                           context->getInputValue<int32_t>(kPreNmsMaxScalar),
1233                                           context->getInputValue<int32_t>(kPostNmsMaxScalar),
1234                                           context->getInputValue<float>(kIoUThresholdScalar),
1235                                           context->getInputValue<float>(kMinSizeScalar),
1236                                           context->getInputValue<bool>(kLayoutScalar), context);
1237         }
1238         default:
1239             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1240     }
1241 }
1242 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
1243 
1244 }  // namespace generate_proposals
1245 
1246 namespace detection_postprocess {
1247 
1248 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
1249 namespace {
1250 
detectionPostprocessFloat32(const float * scoreData,const Shape & scoreShape,const float * deltaData,const Shape & deltaShape,const float * anchorData,const Shape &,float scaleY,float scaleX,float scaleH,float scaleW,bool useRegularNms,int32_t maxNumDetections,int32_t maxClassesPerDetection,int32_t maxNumDetectionsPerClass,float iouThreshold,float scoreThreshold,bool isBGInLabel,float * scoreOutData,const Shape & scoreOutShape,float * roiOutData,const Shape & roiOutShape,int32_t * classOutData,const Shape & classOutShape,int32_t * detectionOutData,const Shape & detectionOutShape)1251 bool detectionPostprocessFloat32(const float* scoreData, const Shape& scoreShape,
1252                                  const float* deltaData, const Shape& deltaShape,
1253                                  const float* anchorData, const Shape& /*anchorShape*/,
1254                                  float scaleY, float scaleX, float scaleH, float scaleW,
1255                                  bool useRegularNms, int32_t maxNumDetections,
1256                                  int32_t maxClassesPerDetection, int32_t maxNumDetectionsPerClass,
1257                                  float iouThreshold, float scoreThreshold, bool isBGInLabel,
1258                                  float* scoreOutData, const Shape& scoreOutShape, float* roiOutData,
1259                                  const Shape& roiOutShape, int32_t* classOutData,
1260                                  const Shape& classOutShape, int32_t* detectionOutData,
1261                                  const Shape& detectionOutShape) {
1262     const uint32_t kRoiDim = 4;
1263     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1264     uint32_t numAnchors = getSizeOfDimension(scoreShape, 1);
1265     uint32_t numClasses = getSizeOfDimension(scoreShape, 2);
1266     uint32_t lengthBoxEncoding = getSizeOfDimension(deltaShape, 2);
1267     uint32_t numOutDetection = getSizeOfDimension(scoreOutShape, 1);
1268 
1269     memset(scoreOutData, 0, getNumberOfElements(scoreOutShape) * sizeof(float));
1270     memset(roiOutData, 0, getNumberOfElements(roiOutShape) * sizeof(float));
1271     memset(classOutData, 0, getNumberOfElements(classOutShape) * sizeof(int32_t));
1272     memset(detectionOutData, 0, getNumberOfElements(detectionOutShape) * sizeof(int32_t));
1273 
1274     const float* scoreBase = scoreData;
1275     const float* deltaBase = deltaData;
1276     float* scoreOutBase = scoreOutData;
1277     float* roiOutBase = roiOutData;
1278     int32_t* classOutBase = classOutData;
1279     std::vector<float> roiBuffer(numAnchors * kRoiDim);
1280     std::vector<float> scoreBuffer(numAnchors);
1281     for (uint32_t b = 0; b < numBatches; b++) {
1282         const float* anchorBase = anchorData;
1283         for (uint32_t a = 0; a < numAnchors; a++) {
1284             float yCtr = anchorBase[0] + anchorBase[2] * deltaBase[0] / scaleY;
1285             float xCtr = anchorBase[1] + anchorBase[3] * deltaBase[1] / scaleX;
1286             float hHalf = anchorBase[2] * std::exp(deltaBase[2] / scaleH) * 0.5f;
1287             float wHalf = anchorBase[3] * std::exp(deltaBase[3] / scaleW) * 0.5f;
1288             roiBuffer[a * kRoiDim] = yCtr - hHalf;
1289             roiBuffer[a * kRoiDim + 1] = xCtr - wHalf;
1290             roiBuffer[a * kRoiDim + 2] = yCtr + hHalf;
1291             roiBuffer[a * kRoiDim + 3] = xCtr + wHalf;
1292             anchorBase += kRoiDim;
1293             deltaBase += lengthBoxEncoding;
1294         }
1295 
1296         if (useRegularNms) {
1297             std::vector<uint32_t> select;
1298             box_with_nms_limit::hardNmsMultiClass(
1299                     scoreBase, numClasses, numAnchors, scoreThreshold, iouThreshold,
1300                     maxNumDetections, maxNumDetectionsPerClass,
1301                     [&roiBuffer, numClasses](uint32_t ind) {
1302                         return roiBuffer.data() + (ind / numClasses) * kRoiDim;
1303                     },
1304                     &select);
1305             for (uint32_t i = 0; i < select.size(); i++) {
1306                 uint32_t ind = select[i];
1307                 scoreOutBase[i] = scoreBase[ind];
1308                 memcpy(roiOutBase + i * kRoiDim, &roiBuffer[(ind / numClasses) * kRoiDim],
1309                        kRoiDim * sizeof(float));
1310                 classOutBase[i] = (ind % numClasses) - (isBGInLabel ? 0 : 1);
1311             }
1312             *detectionOutData++ = select.size();
1313         } else {
1314             uint32_t numOutClasses = std::min<uint32_t>(numClasses - 1, maxClassesPerDetection);
1315             std::vector<float> maxScores(numAnchors);
1316             for (uint32_t a = 0; a < numAnchors; a++) {
1317                 maxScores[a] = *std::max_element(scoreBase + a * numClasses + 1,
1318                                                  scoreBase + (a + 1) * numClasses);
1319             }
1320             std::vector<uint32_t> select;
1321             for (uint32_t a = 0; a < numAnchors; a++) {
1322                 if (maxScores[a] > scoreThreshold) {
1323                     select.push_back(a);
1324                 }
1325             }
1326             uint32_t* selectEnd = box_with_nms_limit::hardNmsSingleClass(
1327                     maxScores.data(), iouThreshold, maxNumDetections,
1328                     [&roiBuffer](uint32_t ind) { return roiBuffer.data() + ind * kRoiDim; },
1329                     select.data(), select.size());
1330             select.resize(selectEnd - select.data());
1331             float* scoreOutPtr = scoreOutBase;
1332             float* roiOutPtr = roiOutBase;
1333             int32_t* classOutPtr = classOutBase;
1334             for (auto i : select) {
1335                 const float* score = scoreBase + i * numClasses;
1336                 std::vector<uint32_t> scoreInds(numClasses - 1);
1337                 std::iota(scoreInds.begin(), scoreInds.end(), 1);
1338                 std::sort(scoreInds.begin(), scoreInds.end(),
1339                           [&score](const uint32_t lhs, const uint32_t rhs) {
1340                               return score[lhs] > score[rhs];
1341                           });
1342                 for (uint32_t c = 0; c < numOutClasses; c++) {
1343                     *scoreOutPtr++ = score[scoreInds[c]];
1344                     memcpy(roiOutPtr, &roiBuffer[i * kRoiDim], kRoiDim * sizeof(float));
1345                     roiOutPtr += kRoiDim;
1346                     *classOutPtr++ = scoreInds[c] - (isBGInLabel ? 0 : 1);
1347                 }
1348             }
1349             *detectionOutData++ = select.size() * numOutClasses;
1350         }
1351         scoreBase += numAnchors * numClasses;
1352         scoreOutBase += numOutDetection;
1353         roiOutBase += numOutDetection * kRoiDim;
1354         classOutBase += numOutDetection;
1355     }
1356     return true;
1357 }
1358 
detectionPostprocessFloat16(const _Float16 * scoreData,const Shape & scoreShape,const _Float16 * deltaData,const Shape & deltaShape,const _Float16 * anchorData,const Shape & anchorShape,float scaleY,float scaleX,float scaleH,float scaleW,bool useRegularNms,int32_t maxNumDetections,int32_t maxClassesPerDetection,int32_t maxNumDetectionsPerClass,float iouThreshold,float scoreThreshold,bool isBGInLabel,_Float16 * scoreOutData,const Shape & scoreOutShape,_Float16 * roiOutData,const Shape & roiOutShape,int32_t * classOutData,const Shape & classOutShape,int32_t * detectionOutData,const Shape & detectionOutShape)1359 bool detectionPostprocessFloat16(
1360         const _Float16* scoreData, const Shape& scoreShape, const _Float16* deltaData,
1361         const Shape& deltaShape, const _Float16* anchorData, const Shape& anchorShape, float scaleY,
1362         float scaleX, float scaleH, float scaleW, bool useRegularNms, int32_t maxNumDetections,
1363         int32_t maxClassesPerDetection, int32_t maxNumDetectionsPerClass, float iouThreshold,
1364         float scoreThreshold, bool isBGInLabel, _Float16* scoreOutData, const Shape& scoreOutShape,
1365         _Float16* roiOutData, const Shape& roiOutShape, int32_t* classOutData,
1366         const Shape& classOutShape, int32_t* detectionOutData, const Shape& detectionOutShape) {
1367     std::vector<float> scores_float32(getNumberOfElements(scoreShape));
1368     convertFloat16ToFloat32(scoreData, &scores_float32);
1369     std::vector<float> delta_float32(getNumberOfElements(deltaShape));
1370     convertFloat16ToFloat32(deltaData, &delta_float32);
1371     std::vector<float> anchor_float32(getNumberOfElements(anchorShape));
1372     convertFloat16ToFloat32(anchorData, &anchor_float32);
1373     std::vector<float> outputScore_float32(getNumberOfElements(scoreOutShape));
1374     std::vector<float> outputRoi_float32(getNumberOfElements(roiOutShape));
1375     NN_RET_CHECK(detectionPostprocessFloat32(
1376             scores_float32.data(), scoreShape, delta_float32.data(), deltaShape,
1377             anchor_float32.data(), anchorShape, scaleY, scaleX, scaleH, scaleW, useRegularNms,
1378             maxNumDetections, maxClassesPerDetection, maxNumDetectionsPerClass, iouThreshold,
1379             scoreThreshold, isBGInLabel, outputScore_float32.data(), scoreOutShape,
1380             outputRoi_float32.data(), roiOutShape, classOutData, classOutShape, detectionOutData,
1381             detectionOutShape));
1382     convertFloat32ToFloat16(outputScore_float32, scoreOutData);
1383     convertFloat32ToFloat16(outputRoi_float32, roiOutData);
1384     return true;
1385 }
1386 
1387 }  // namespace
1388 
prepare(IOperationExecutionContext * context)1389 bool prepare(IOperationExecutionContext* context) {
1390     Shape scoreShape = context->getInputShape(kScoreTensor);
1391     Shape deltasShape = context->getInputShape(kDeltaTensor);
1392     Shape anchorsShape = context->getInputShape(kAnchorTensor);
1393     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
1394     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
1395     Shape outputClassShape = context->getOutputShape(kOutputClassTensor);
1396     Shape outputDetectionShape = context->getOutputShape(kOutputDetectionTensor);
1397 
1398     NN_RET_CHECK_EQ(getNumberOfDimensions(scoreShape), 3u);
1399     NN_RET_CHECK_EQ(getNumberOfDimensions(deltasShape), 3u);
1400     NN_RET_CHECK_EQ(getNumberOfDimensions(anchorsShape), 2u);
1401 
1402     const uint32_t kRoiDim = 4;
1403     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1404     uint32_t numAnchors = getSizeOfDimension(scoreShape, 1);
1405     uint32_t numClasses = getSizeOfDimension(scoreShape, 2);
1406     uint32_t lengthBoxEncoding = getSizeOfDimension(deltasShape, 2);
1407     uint32_t maxNumDetections = context->getInputValue<int32_t>(kMaxNumDetectionScalar);
1408     uint32_t maxClassesPerDetection =
1409             context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar);
1410     uint32_t numOutDetections = maxNumDetections;
1411 
1412     NN_RET_CHECK_EQ(getSizeOfDimension(deltasShape, 0), numBatches);
1413     NN_RET_CHECK_EQ(getSizeOfDimension(deltasShape, 1), numAnchors);
1414     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 0), numAnchors);
1415     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 1), kRoiDim);
1416 
1417     if (scoreShape.type == OperandType::TENSOR_FLOAT32) {
1418         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleYScalar), 0);
1419         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleXScalar), 0);
1420         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleHScalar), 0);
1421         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleWScalar), 0);
1422         NN_RET_CHECK_GE(context->getInputValue<float>(kScoreThresholdScalar), 0);
1423         NN_RET_CHECK_GE(context->getInputValue<float>(kIoUThresholdScalar), 0);
1424     } else if (scoreShape.type == OperandType::TENSOR_FLOAT16) {
1425         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleYScalar) > 0);
1426         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleXScalar) > 0);
1427         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleHScalar) > 0);
1428         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleWScalar) > 0);
1429         NN_RET_CHECK(context->getInputValue<_Float16>(kScoreThresholdScalar) >= 0);
1430         NN_RET_CHECK(context->getInputValue<_Float16>(kIoUThresholdScalar) >= 0);
1431     } else {
1432         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1433     }
1434     NN_RET_CHECK_GT(numClasses, 1u);
1435     NN_RET_CHECK_GE(lengthBoxEncoding, 4u);
1436     NN_RET_CHECK_GT(maxNumDetections, 0u);
1437     if (context->getInputValue<bool>(kUseRegularNmsScalar)) {
1438         NN_RET_CHECK_GT(context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar), 0);
1439     } else {
1440         NN_RET_CHECK_GT(maxClassesPerDetection, 0u);
1441         numOutDetections *= maxClassesPerDetection;
1442     }
1443 
1444     outputScoreShape.type = scoreShape.type;
1445     outputScoreShape.dimensions = {numBatches, numOutDetections};
1446     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
1447 
1448     outputRoiShape.type = anchorsShape.type;
1449     outputRoiShape.dimensions = {numBatches, numOutDetections, 4};
1450     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
1451 
1452     outputClassShape.type = OperandType::TENSOR_INT32;
1453     outputClassShape.dimensions = {numBatches, numOutDetections};
1454     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, outputClassShape));
1455 
1456     outputDetectionShape.type = OperandType::TENSOR_INT32;
1457     outputDetectionShape.dimensions = {numBatches};
1458     NN_RET_CHECK(context->setOutputShape(kOutputDetectionTensor, outputDetectionShape));
1459     return true;
1460 }
1461 
execute(IOperationExecutionContext * context)1462 bool execute(IOperationExecutionContext* context) {
1463     NNTRACE_TRANS("detectionPostProcess");
1464     switch (context->getInputType(kScoreTensor)) {
1465         case OperandType::TENSOR_FLOAT16: {
1466             return detectionPostprocessFloat16(
1467                     context->getInputBuffer<_Float16>(kScoreTensor),
1468                     context->getInputShape(kScoreTensor),
1469                     context->getInputBuffer<_Float16>(kDeltaTensor),
1470                     context->getInputShape(kDeltaTensor),
1471                     context->getInputBuffer<_Float16>(kAnchorTensor),
1472                     context->getInputShape(kAnchorTensor),
1473                     context->getInputValue<_Float16>(kScaleYScalar),
1474                     context->getInputValue<_Float16>(kScaleXScalar),
1475                     context->getInputValue<_Float16>(kScaleHScalar),
1476                     context->getInputValue<_Float16>(kScaleWScalar),
1477                     context->getInputValue<bool>(kUseRegularNmsScalar),
1478                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
1479                     context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar),
1480                     context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar),
1481                     context->getInputValue<_Float16>(kIoUThresholdScalar),
1482                     context->getInputValue<_Float16>(kScoreThresholdScalar),
1483                     context->getInputValue<bool>(kIsBGInLabelScalar),
1484                     context->getOutputBuffer<_Float16>(kOutputScoreTensor),
1485                     context->getOutputShape(kOutputScoreTensor),
1486                     context->getOutputBuffer<_Float16>(kOutputRoiTensor),
1487                     context->getOutputShape(kOutputRoiTensor),
1488                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
1489                     context->getOutputShape(kOutputClassTensor),
1490                     context->getOutputBuffer<int32_t>(kOutputDetectionTensor),
1491                     context->getOutputShape(kOutputDetectionTensor));
1492         }
1493         case OperandType::TENSOR_FLOAT32: {
1494             return detectionPostprocessFloat32(
1495                     context->getInputBuffer<float>(kScoreTensor),
1496                     context->getInputShape(kScoreTensor),
1497                     context->getInputBuffer<float>(kDeltaTensor),
1498                     context->getInputShape(kDeltaTensor),
1499                     context->getInputBuffer<float>(kAnchorTensor),
1500                     context->getInputShape(kAnchorTensor),
1501                     context->getInputValue<float>(kScaleYScalar),
1502                     context->getInputValue<float>(kScaleXScalar),
1503                     context->getInputValue<float>(kScaleHScalar),
1504                     context->getInputValue<float>(kScaleWScalar),
1505                     context->getInputValue<bool>(kUseRegularNmsScalar),
1506                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
1507                     context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar),
1508                     context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar),
1509                     context->getInputValue<float>(kIoUThresholdScalar),
1510                     context->getInputValue<float>(kScoreThresholdScalar),
1511                     context->getInputValue<bool>(kIsBGInLabelScalar),
1512                     context->getOutputBuffer<float>(kOutputScoreTensor),
1513                     context->getOutputShape(kOutputScoreTensor),
1514                     context->getOutputBuffer<float>(kOutputRoiTensor),
1515                     context->getOutputShape(kOutputRoiTensor),
1516                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
1517                     context->getOutputShape(kOutputClassTensor),
1518                     context->getOutputBuffer<int32_t>(kOutputDetectionTensor),
1519                     context->getOutputShape(kOutputDetectionTensor));
1520         }
1521         default:
1522             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1523     }
1524 }
1525 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
1526 
1527 }  // namespace detection_postprocess
1528 
1529 }  // namespace bbox_ops
1530 
1531 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(AXIS_ALIGNED_BBOX_TRANSFORM,
1532                                          bbox_ops::axis_aligned_bbox_transform::prepare,
1533                                          bbox_ops::axis_aligned_bbox_transform::execute,
1534                                          .allowZeroSizedInput = true);
1535 
1536 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(BOX_WITH_NMS_LIMIT, bbox_ops::box_with_nms_limit::prepare,
1537                                          bbox_ops::box_with_nms_limit::execute,
1538                                          .allowZeroSizedInput = true);
1539 
1540 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(GENERATE_PROPOSALS, bbox_ops::generate_proposals::prepare,
1541                                          bbox_ops::generate_proposals::execute);
1542 
1543 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(DETECTION_POSTPROCESSING,
1544                                          bbox_ops::detection_postprocess::prepare,
1545                                          bbox_ops::detection_postprocess::execute);
1546 }  // namespace nn
1547 }  // namespace android
1548