1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "ChannelShuffle.h"
20 
21 #include "OperationResolver.h"
22 #include "OperationsExecutionUtils.h"
23 #include "Tracing.h"
24 
25 namespace android {
26 namespace nn {
27 namespace channel_shuffle {
28 
29 template <typename T>
eval(const T * inputData,const Shape & inputShape,int32_t numGroups,int32_t axis,T * outputData)30 inline bool eval(const T* inputData, const Shape& inputShape, int32_t numGroups, int32_t axis,
31                  T* outputData) {
32     const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis);
33     const uint32_t axisSize = getSizeOfDimension(inputShape, axis);
34     const uint32_t innerSize =
35             getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape));
36     const uint32_t groupSize = axisSize / numGroups;
37     for (uint32_t outer = 0; outer < outerSize; ++outer) {
38         for (uint32_t inner = 0; inner < innerSize; ++inner) {
39             const T* inputBase = inputData + outer * axisSize * innerSize + inner;
40             T* outputBase = outputData + outer * axisSize * innerSize + inner;
41             for (uint32_t i = 0; i < groupSize; i++) {
42                 for (uint32_t j = 0; j < static_cast<uint32_t>(numGroups);
43                      j++, outputBase += innerSize) {
44                     *outputBase = inputBase[innerSize * (i + j * groupSize)];
45                 }
46             }
47         }
48     }
49     return true;
50 }
51 
prepare(IOperationExecutionContext * context)52 bool prepare(IOperationExecutionContext* context) {
53     Shape input = context->getInputShape(kInputTensor);
54     int32_t numGroups = context->getInputValue<int32_t>(kNumGroups);
55     int32_t axis = context->getInputValue<int32_t>(kInputAxis);
56     NN_RET_CHECK(handleNegativeAxis(input, &axis));
57     NN_RET_CHECK(numGroups > 0);
58     NN_RET_CHECK(getSizeOfDimension(input, axis) % numGroups == 0);
59     return context->setOutputShape(kOutputTensor, input);
60 }
61 
execute(IOperationExecutionContext * context)62 bool execute(IOperationExecutionContext* context) {
63     int32_t numGroups = context->getInputValue<int32_t>(kNumGroups);
64     int32_t axis = context->getInputValue<int32_t>(kInputAxis);
65     NN_RET_CHECK(handleNegativeAxis(context->getInputShape(kInputTensor), &axis));
66     switch (context->getInputType(kInputTensor)) {
67         case OperandType::TENSOR_FLOAT16:
68             return eval(context->getInputBuffer<_Float16>(kInputTensor),
69                         context->getInputShape(kInputTensor), numGroups, axis,
70                         context->getOutputBuffer<_Float16>(kOutputTensor));
71         case OperandType::TENSOR_FLOAT32:
72             return eval(context->getInputBuffer<float>(kInputTensor),
73                         context->getInputShape(kInputTensor), numGroups, axis,
74                         context->getOutputBuffer<float>(kOutputTensor));
75         case OperandType::TENSOR_QUANT8_ASYMM:
76             return eval(context->getInputBuffer<uint8_t>(kInputTensor),
77                         context->getInputShape(kInputTensor), numGroups, axis,
78                         context->getOutputBuffer<uint8_t>(kOutputTensor));
79         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
80             return eval(context->getInputBuffer<int8_t>(kInputTensor),
81                         context->getInputShape(kInputTensor), numGroups, axis,
82                         context->getOutputBuffer<int8_t>(kOutputTensor));
83         default:
84             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
85     }
86 }
87 
88 }  // namespace channel_shuffle
89 
90 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(CHANNEL_SHUFFLE, channel_shuffle::prepare,
91                                          channel_shuffle::execute);
92 
93 }  // namespace nn
94 }  // namespace android
95