1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "TransposeConv2D.h"
20 
21 #include <algorithm>
22 #include <cfloat>
23 #include <cmath>
24 #include <memory>
25 #include <vector>
26 
27 #include "OperationResolver.h"
28 #include "Tracing.h"
29 
30 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
31 #pragma clang diagnostic push
32 #pragma clang diagnostic ignored "-Wunused-parameter"
33 #include <tensorflow/lite/kernels/internal/common.h>
34 #pragma clang diagnostic pop
35 
36 #include "CpuOperationUtils.h"
37 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
38 
39 namespace android {
40 namespace nn {
41 namespace transpose_conv_2d {
42 
43 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
44 namespace {
45 
46 // If possible we will use this static buffer for the tensor.
47 constexpr size_t kStaticBufferSize = 1605632;
48 [[maybe_unused]] char static_scratch_buffer[kStaticBufferSize];
49 
50 // executionMutex is used to protect concurrent access of the static_scratch_buffer.
51 // std::mutex is safe for pthreads on Android.
52 std::mutex executionMutex;
53 
54 struct TransposeConv2dParam {
55     int32_t paddingLeft, paddingRight;
56     int32_t paddingTop, paddingBottom;
57     int32_t strideWidth, strideHeight;
58     int32_t activation;
59     bool useNchw = false;
60 
initializeandroid::nn::transpose_conv_2d::__anon564e98b60111::TransposeConv2dParam61     bool initialize(const IOperationExecutionContext* context) {
62         uint32_t inCount = context->getNumInputs();
63         int32_t paddingImplicit = 0;
64         if (inCount == 9) {
65             paddingImplicit = context->getInputValue<int32_t>(4);
66             strideWidth = context->getInputValue<int32_t>(5);
67             strideHeight = context->getInputValue<int32_t>(6);
68             activation = context->getInputValue<int32_t>(7);
69             useNchw = context->getInputValue<bool>(8);
70             Shape filterShape = context->getInputShape(kFilterTensor);
71             int32_t filterWidth = getSizeOfDimension(filterShape, 2);
72             int32_t filterHeight = getSizeOfDimension(filterShape, 1);
73             NN_RET_CHECK_EQ(getNumberOfDimensions(context->getInputShape(3)), 1u);
74             NN_RET_CHECK_EQ(getSizeOfDimension(context->getInputShape(3), 0), 4u);
75             const int32_t* outputShapeData = context->getInputBuffer<int32_t>(3);
76             int32_t outputWidth = useNchw ? outputShapeData[3] : outputShapeData[2];
77             int32_t outputHeight = useNchw ? outputShapeData[2] : outputShapeData[1];
78             calculateExplicitPaddingTransposeConv(outputWidth, strideWidth, filterWidth,
79                                                   paddingImplicit, &paddingLeft, &paddingRight);
80             calculateExplicitPaddingTransposeConv(outputHeight, strideHeight, filterHeight,
81                                                   paddingImplicit, &paddingTop, &paddingBottom);
82         } else if (inCount == 11) {
83             paddingLeft = context->getInputValue<int32_t>(3);
84             paddingRight = context->getInputValue<int32_t>(4);
85             paddingTop = context->getInputValue<int32_t>(5);
86             paddingBottom = context->getInputValue<int32_t>(6);
87             strideWidth = context->getInputValue<int32_t>(7);
88             strideHeight = context->getInputValue<int32_t>(8);
89             activation = context->getInputValue<int32_t>(9);
90             useNchw = context->getInputValue<bool>(10);
91         } else {
92             NN_RET_CHECK_FAIL() << "Unsupported input spec for operation " << kOperationName;
93         }
94         // paddingRight and paddingBottom in transpose conv may be less than 0 to resolve the
95         // ambiguous output shape issue in the case of stride > 1.
96         NN_RET_CHECK_GE(paddingLeft, 0);
97         NN_RET_CHECK_GE(paddingTop, 0);
98         NN_RET_CHECK_GT(strideWidth, 0);
99         NN_RET_CHECK_GT(strideHeight, 0);
100         NN_RET_CHECK_GE(activation, 0);
101         return true;
102     }
103 };
104 
105 #define ANDROID_NN_TRANSPOSE_CONV_PARAMETERS                                    \
106     uint32_t numBatches = getSizeOfDimension(inputShape, 0);                    \
107     uint32_t inputHeight = getSizeOfDimension(inputShape, 1);                   \
108     uint32_t inputWidth = getSizeOfDimension(inputShape, 2);                    \
109     uint32_t inputDepth = getSizeOfDimension(inputShape, 3);                    \
110     uint32_t filterHeight = getSizeOfDimension(filterShape, 1);                 \
111     uint32_t filterWidth = getSizeOfDimension(filterShape, 2);                  \
112     uint32_t outputHeight = getSizeOfDimension(outputShape, 1);                 \
113     uint32_t outputWidth = getSizeOfDimension(outputShape, 2);                  \
114     uint32_t outputDepth = getSizeOfDimension(outputShape, 3);                  \
115     int32_t paddingLeft = param.paddingLeft;                                    \
116     [[maybe_unused]] int32_t paddingRight = param.paddingRight;                 \
117     int32_t paddingTop = param.paddingTop;                                      \
118     [[maybe_unused]] int32_t paddingBottom = param.paddingBottom;               \
119     int32_t strideWidth = param.strideWidth, strideHeight = param.strideHeight; \
120     int32_t activation = param.activation;
121 
transposeConvNhwc(const float * inputData,const Shape & inputShape,const float * filterData,const Shape & filterShape,const float * biasData,const Shape &,const TransposeConv2dParam & param,float * outputData,const Shape & outputShape)122 bool transposeConvNhwc(const float* inputData, const Shape& inputShape, const float* filterData,
123                        const Shape& filterShape, const float* biasData, const Shape& /*biasShape*/,
124                        const TransposeConv2dParam& param, float* outputData,
125                        const Shape& outputShape) {
126     NNTRACE_TRANS("transposeConvFloat32");
127     ANDROID_NN_TRANSPOSE_CONV_PARAMETERS
128 
129     float outputActivationMin = 0.0f, outputActivationMax = 0.0f;
130     CalculateActivationRangeFloat(activation, &outputActivationMin, &outputActivationMax);
131 
132     memset(outputData, 0, getNumberOfElements(outputShape) * sizeof(float));
133 
134     const float* inputBase = inputData;
135     float* outputBase = outputData;
136     for (uint32_t b = 0; b < numBatches; b++) {
137         for (uint32_t h = 0; h < inputHeight; h++) {
138             for (uint32_t w = 0; w < inputWidth; w++) {
139                 int32_t wOutputOrigin = static_cast<int32_t>(w) * strideWidth - paddingLeft;
140                 int32_t hOutputOrigin = static_cast<int32_t>(h) * strideHeight - paddingTop;
141 
142                 const float* filterBase = filterData;
143                 for (uint32_t k = 0; k < outputDepth; k++) {
144                     for (uint32_t i = 0; i < filterHeight; i++) {
145                         for (uint32_t j = 0; j < filterWidth; j++, filterBase += inputDepth) {
146                             int32_t hOutput = hOutputOrigin + static_cast<int32_t>(i);
147                             int32_t wOutput = wOutputOrigin + static_cast<int32_t>(j);
148                             if (hOutput >= 0 && hOutput < static_cast<int32_t>(outputHeight) &&
149                                 wOutput >= 0 && wOutput < static_cast<int32_t>(outputWidth)) {
150                                 for (uint32_t d = 0; d < inputDepth; d++) {
151                                     uint32_t outputIndex = hOutput * outputWidth * outputDepth +
152                                                            wOutput * outputDepth + k;
153                                     outputBase[outputIndex] += inputBase[d] * filterBase[d];
154                                 }
155                             }
156                         }
157                     }
158                 }
159 
160                 inputBase += inputDepth;
161             }
162         }
163         outputBase += outputHeight * outputWidth * outputDepth;
164     }
165 
166     const uint32_t outerSize = numBatches * outputHeight * outputWidth;
167     float* outPtr = outputData;
168     for (uint32_t i = 0; i < outerSize; i++) {
169         for (uint32_t d = 0; d < outputDepth; d++, outPtr++) {
170             *outPtr += biasData[d];
171             *outPtr = std::max(std::min(*outPtr, outputActivationMax), outputActivationMin);
172         }
173     }
174 
175     return true;
176 }
177 
178 template <typename T>
transposeConvNhwc(const T * inputData,const Shape & inputShape,const T * filterData,const Shape & filterShape,const int32_t * biasData,const Shape & biasShape,const TransposeConv2dParam & param,T * outputData,const Shape & outputShape)179 bool transposeConvNhwc(const T* inputData, const Shape& inputShape, const T* filterData,
180                        const Shape& filterShape, const int32_t* biasData, const Shape& biasShape,
181                        const TransposeConv2dParam& param, T* outputData, const Shape& outputShape) {
182     NNTRACE_TRANS("transposeConvQuant8");
183     ANDROID_NN_TRANSPOSE_CONV_PARAMETERS
184 
185     int32_t* tempBuffer = nullptr;
186     std::unique_ptr<int32_t[]> bufferGuard;
187     uint32_t tempBufferByteSize = getNumberOfElements(outputShape) * sizeof(int32_t);
188     if (tempBufferByteSize <= kStaticBufferSize) {
189         tempBuffer = reinterpret_cast<int32_t*>(static_scratch_buffer);
190     } else {
191         tempBuffer = new (std::nothrow) int32_t[tempBufferByteSize / sizeof(int32_t)];
192         if (tempBuffer == nullptr) {
193             LOG(ERROR) << "ConvTranspose size is too large, not enough memory";
194             return false;
195         }
196         bufferGuard.reset(tempBuffer);
197     }
198 
199     int32_t inputOffset = -inputShape.offset;
200     int32_t filterOffset = -filterShape.offset;
201     int32_t outputOffset = outputShape.offset;
202 
203     double realMultiplier = 0.0;
204     int32_t outputMultiplier = 0;
205     int32_t outputShift = 0;
206     NN_RET_CHECK(GetQuantizedConvolutionMultiplier(inputShape, filterShape, biasShape, outputShape,
207                                                    &realMultiplier));
208     int exponent;
209     NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &exponent));
210     outputShift = -exponent;
211 
212     int32_t outputActivationMin = 0, outputActivationMax = 0;
213     CalculateActivationRange<T>(activation, outputShape, &outputActivationMin,
214                                 &outputActivationMax);
215 
216     // Prevent concurrent executions that may access the scratch buffer
217     std::unique_lock<std::mutex> lock(executionMutex);
218     memset(tempBuffer, 0, tempBufferByteSize);
219 
220     const T* inputPtr = inputData;
221     int32_t* outputBase = tempBuffer;
222     for (uint32_t b = 0; b < numBatches; b++) {
223         for (uint32_t h = 0; h < inputHeight; h++) {
224             for (uint32_t w = 0; w < inputWidth; w++) {
225                 for (uint32_t d = 0; d < inputDepth; d++) {
226                     int32_t wOutputOrigin = static_cast<int32_t>(w) * strideWidth - paddingLeft;
227                     int32_t hOutputOrigin = static_cast<int32_t>(h) * strideHeight - paddingTop;
228 
229                     for (uint32_t i = 0; i < filterHeight; i++) {
230                         for (uint32_t j = 0; j < filterWidth; j++) {
231                             for (uint32_t k = 0; k < outputDepth; k++) {
232                                 int32_t hOutput = hOutputOrigin + static_cast<int32_t>(i);
233                                 int32_t wOutput = wOutputOrigin + static_cast<int32_t>(j);
234                                 if (hOutput >= 0 && hOutput < static_cast<int32_t>(outputHeight) &&
235                                     wOutput >= 0 && wOutput < static_cast<int32_t>(outputWidth)) {
236                                     uint32_t filterIndex =
237                                             k * filterHeight * filterWidth * inputDepth +
238                                             i * filterWidth * inputDepth + j * inputDepth + d;
239                                     uint32_t outputIndex = hOutput * outputWidth * outputDepth +
240                                                            wOutput * outputDepth + k;
241                                     outputBase[outputIndex] +=
242                                             (static_cast<int32_t>(*inputPtr) + inputOffset) *
243                                             (static_cast<int32_t>(filterData[filterIndex]) +
244                                              filterOffset);
245                                 }
246                             }
247                         }
248                     }
249 
250                     inputPtr++;
251                 }
252             }
253         }
254         outputBase += outputHeight * outputWidth * outputDepth;
255     }
256 
257     const uint32_t outerSize = numBatches * outputHeight * outputWidth;
258     int32_t* bufferPtr = tempBuffer;
259     T* outPtr = outputData;
260     for (uint32_t i = 0; i < outerSize; i++) {
261         for (uint32_t d = 0; d < outputDepth; d++, bufferPtr++, outPtr++) {
262             int32_t outVal = *bufferPtr + biasData[d];
263             outVal = tflite::MultiplyByQuantizedMultiplier(outVal, outputMultiplier, -outputShift);
264             outVal += outputOffset;
265             outVal = std::max(std::min(outVal, outputActivationMax), outputActivationMin);
266             *outPtr = static_cast<T>(outVal);
267         }
268     }
269 
270     return true;
271 }
272 
transposeConvNhwc(const _Float16 * inputData,const Shape & inputShape,const _Float16 * filterData,const Shape & filterShape,const _Float16 * biasData,const Shape & biasShape,const TransposeConv2dParam & param,_Float16 * outputData,const Shape & outputShape)273 bool transposeConvNhwc(const _Float16* inputData, const Shape& inputShape,
274                        const _Float16* filterData, const Shape& filterShape,
275                        const _Float16* biasData, const Shape& biasShape,
276                        const TransposeConv2dParam& param, _Float16* outputData,
277                        const Shape& outputShape) {
278     NNTRACE_TRANS("transposeConvFloat16");
279     std::vector<float> inputData_float32(getNumberOfElements(inputShape));
280     std::vector<float> filterData_float32(getNumberOfElements(filterShape));
281     std::vector<float> biasData_float32(getNumberOfElements(biasShape));
282     std::vector<float> outputData_float32(getNumberOfElements(outputShape));
283 
284     convertFloat16ToFloat32(inputData, &inputData_float32);
285     convertFloat16ToFloat32(filterData, &filterData_float32);
286     convertFloat16ToFloat32(biasData, &biasData_float32);
287 
288     transposeConvNhwc(inputData_float32.data(), inputShape, filterData_float32.data(), filterShape,
289                       biasData_float32.data(), biasShape, param, outputData_float32.data(),
290                       outputShape);
291     convertFloat32ToFloat16(outputData_float32, outputData);
292 
293     return true;
294 }
295 
296 template <typename T_Input, typename T_Filter, typename T_Bias>
transposeConv(const T_Input * inputData,const Shape & inputShape,const T_Filter * filterData,const Shape & filterShape,const T_Bias * biasData,const Shape & biasShape,const TransposeConv2dParam & param,T_Input * outputData,const Shape & outputShape)297 bool transposeConv(const T_Input* inputData, const Shape& inputShape, const T_Filter* filterData,
298                    const Shape& filterShape, const T_Bias* biasData, const Shape& biasShape,
299                    const TransposeConv2dParam& param, T_Input* outputData,
300                    const Shape& outputShape) {
301     InputWithLayout<T_Input> input(param.useNchw);
302     OutputWithLayout<T_Input> output(param.useNchw);
303     NN_RET_CHECK(input.initialize(inputData, inputShape));
304     NN_RET_CHECK(output.initialize(outputData, outputShape));
305     NN_RET_CHECK(transposeConvNhwc(input.getNhwcBuffer(), input.getNhwcShape(), filterData,
306                                    filterShape, biasData, biasShape, param, output.getNhwcBuffer(),
307                                    output.getNhwcShape()));
308     NN_RET_CHECK(output.commit());
309     return true;
310 }
311 
312 template <typename T>
transposeConvQuant8PerChannelNhwc(const T * inputData,const Shape & inputShape,const int8_t * filterData,const Shape & filterShape,const float * filterScales,const int32_t * biasData,const Shape & biasShape,const TransposeConv2dParam & param,T * outputData,const Shape & outputShape)313 bool transposeConvQuant8PerChannelNhwc(const T* inputData, const Shape& inputShape,
314                                        const int8_t* filterData, const Shape& filterShape,
315                                        const float* filterScales, const int32_t* biasData,
316                                        const Shape& biasShape, const TransposeConv2dParam& param,
317                                        T* outputData, const Shape& outputShape) {
318     NNTRACE_TRANS("transposeConvQuant8PerChannel");
319     ANDROID_NN_TRANSPOSE_CONV_PARAMETERS
320 
321     int32_t* tempBuffer = nullptr;
322     std::unique_ptr<int32_t[]> bufferGuard;
323     uint32_t tempBufferByteSize = getNumberOfElements(outputShape) * sizeof(int32_t);
324     if (tempBufferByteSize <= kStaticBufferSize) {
325         tempBuffer = reinterpret_cast<int32_t*>(static_scratch_buffer);
326     } else {
327         tempBuffer = new (std::nothrow) int32_t[tempBufferByteSize / sizeof(int32_t)];
328         if (tempBuffer == nullptr) {
329             LOG(ERROR) << "ConvTranspose size is too large, not enough memory";
330             return false;
331         }
332         bufferGuard.reset(tempBuffer);
333     }
334 
335     int32_t inputOffset = -inputShape.offset;
336     int32_t outputOffset = outputShape.offset;
337 
338     std::vector<double> realMultiplier(outputDepth, 0.0);
339     std::vector<int32_t> outputMultiplier(outputDepth, 0);
340     std::vector<int32_t> outputShift(outputDepth, 0);
341     for (uint32_t i = 0; i < outputDepth; ++i) {
342         Shape filterChannelShape = filterShape;
343         filterChannelShape.scale = filterScales[i];
344         Shape biasChannelShape = biasShape;
345         biasChannelShape.scale = filterScales[i] * inputShape.scale;
346 
347         NN_RET_CHECK(GetQuantizedConvolutionMultiplier(
348                 inputShape, filterChannelShape, biasChannelShape, outputShape, &realMultiplier[i]));
349         int exponent;
350         NN_RET_CHECK(QuantizeMultiplier(realMultiplier[i], &outputMultiplier[i], &exponent));
351         outputShift[i] = -exponent;
352     }
353 
354     int32_t outputActivationMin = 0, outputActivationMax = 0;
355     CalculateActivationRange<T>(activation, outputShape, &outputActivationMin,
356                                 &outputActivationMax);
357 
358     // Prevent concurrent executions that may access the scratch buffer
359     std::unique_lock<std::mutex> lock(executionMutex);
360     memset(tempBuffer, 0, tempBufferByteSize);
361 
362     const T* inputPtr = inputData;
363     int32_t* outputBase = tempBuffer;
364     for (uint32_t b = 0; b < numBatches; b++) {
365         for (uint32_t h = 0; h < inputHeight; h++) {
366             for (uint32_t w = 0; w < inputWidth; w++) {
367                 for (uint32_t d = 0; d < inputDepth; d++) {
368                     int32_t wOutputOrigin = static_cast<int32_t>(w) * strideWidth - paddingLeft;
369                     int32_t hOutputOrigin = static_cast<int32_t>(h) * strideHeight - paddingTop;
370 
371                     for (uint32_t i = 0; i < filterHeight; i++) {
372                         for (uint32_t j = 0; j < filterWidth; j++) {
373                             for (uint32_t k = 0; k < outputDepth; k++) {
374                                 int32_t hOutput = hOutputOrigin + static_cast<int32_t>(i);
375                                 int32_t wOutput = wOutputOrigin + static_cast<int32_t>(j);
376                                 if (hOutput >= 0 && hOutput < static_cast<int32_t>(outputHeight) &&
377                                     wOutput >= 0 && wOutput < static_cast<int32_t>(outputWidth)) {
378                                     uint32_t filterIndex =
379                                             k * filterHeight * filterWidth * inputDepth +
380                                             i * filterWidth * inputDepth + j * inputDepth + d;
381                                     uint32_t outputIndex = hOutput * outputWidth * outputDepth +
382                                                            wOutput * outputDepth + k;
383                                     outputBase[outputIndex] +=
384                                             (static_cast<int32_t>(*inputPtr) + inputOffset) *
385                                             static_cast<int32_t>(filterData[filterIndex]);
386                                 }
387                             }
388                         }
389                     }
390 
391                     inputPtr++;
392                 }
393             }
394         }
395         outputBase += outputHeight * outputWidth * outputDepth;
396     }
397 
398     const uint32_t outerSize = numBatches * outputHeight * outputWidth;
399     int32_t* bufferPtr = tempBuffer;
400     T* outPtr = outputData;
401     for (uint32_t i = 0; i < outerSize; i++) {
402         for (uint32_t d = 0; d < outputDepth; d++, bufferPtr++, outPtr++) {
403             int32_t outVal = *bufferPtr + biasData[d];
404             outVal = tflite::MultiplyByQuantizedMultiplier(outVal, outputMultiplier[d],
405                                                            -outputShift[d]);
406             outVal += outputOffset;
407             outVal = std::max(std::min(outVal, outputActivationMax), outputActivationMin);
408             *outPtr = static_cast<T>(outVal);
409         }
410     }
411 
412     return true;
413 }
414 
415 template <typename T>
transposeConvQuant8PerChannel(const T * inputData,const Shape & inputShape,const int8_t * filterData,const Shape & filterShape,const float * filterScales,const int32_t * biasData,const Shape & biasShape,const TransposeConv2dParam & param,T * outputData,const Shape & outputShape)416 bool transposeConvQuant8PerChannel(const T* inputData, const Shape& inputShape,
417                                    const int8_t* filterData, const Shape& filterShape,
418                                    const float* filterScales, const int32_t* biasData,
419                                    const Shape& biasShape, const TransposeConv2dParam& param,
420                                    T* outputData, const Shape& outputShape) {
421     InputWithLayout<T> input(param.useNchw);
422     OutputWithLayout<T> output(param.useNchw);
423     NN_RET_CHECK(input.initialize(inputData, inputShape));
424     NN_RET_CHECK(output.initialize(outputData, outputShape));
425     NN_RET_CHECK(transposeConvQuant8PerChannelNhwc(
426             input.getNhwcBuffer(), input.getNhwcShape(), filterData, filterShape, filterScales,
427             biasData, biasShape, param, output.getNhwcBuffer(), output.getNhwcShape()));
428     NN_RET_CHECK(output.commit());
429     return true;
430 }
431 
432 #undef ANDROID_NN_TRANSPOSE_CONV_PARAMETERS
433 
434 }  // namespace
435 
prepare(IOperationExecutionContext * context)436 bool prepare(IOperationExecutionContext* context) {
437     Shape input = context->getInputShape(kInputTensor);
438     Shape filter = context->getInputShape(kFilterTensor);
439     Shape bias = context->getInputShape(kBiasTensor);
440 
441     if (filter.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
442         NN_RET_CHECK(input.type == OperandType::TENSOR_QUANT8_ASYMM ||
443                      input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
444     } else {
445         NN_RET_CHECK(input.type == filter.type);
446     }
447     if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
448         input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
449         NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32);
450     } else {
451         NN_RET_CHECK(input.type == bias.type);
452     }
453     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4u);
454     NN_RET_CHECK_EQ(getNumberOfDimensions(filter), 4u);
455     NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1u);
456 
457     TransposeConv2dParam param;
458     NN_RET_CHECK(param.initialize(context));
459 
460     uint32_t batches = getSizeOfDimension(input, 0);
461     uint32_t height = getSizeOfDimension(input, param.useNchw ? 2 : 1);
462     uint32_t width = getSizeOfDimension(input, param.useNchw ? 3 : 2);
463     uint32_t channels_in = getSizeOfDimension(input, param.useNchw ? 1 : 3);
464     uint32_t channels_out = getSizeOfDimension(filter, 0);
465     uint32_t filterHeight = getSizeOfDimension(filter, 1);
466     uint32_t filterWidth = getSizeOfDimension(filter, 2);
467     // Only batches can be zero.
468     NN_RET_CHECK_EQ(channels_in, getSizeOfDimension(filter, 3));
469     NN_RET_CHECK_EQ(channels_out, getSizeOfDimension(bias, 0));
470     NN_RET_CHECK_GT(height, 0u);
471     NN_RET_CHECK_GT(width, 0u);
472     NN_RET_CHECK_GT(channels_in, 0u);
473     NN_RET_CHECK_GT(channels_out, 0u);
474     NN_RET_CHECK_GT(filterWidth, 0u);
475     NN_RET_CHECK_GT(filterHeight, 0u);
476 
477     uint32_t outWidth = computeOutSizeTransposeConv(width, filterWidth, param.strideWidth,
478                                                     param.paddingLeft, param.paddingRight);
479     uint32_t outHeight = computeOutSizeTransposeConv(height, filterHeight, param.strideHeight,
480                                                      param.paddingTop, param.paddingBottom);
481     NN_RET_CHECK_GT(outWidth, 0u);
482     NN_RET_CHECK_GT(outHeight, 0u);
483 
484     Shape output = context->getOutputShape(kOutputTensor);
485     output.type = input.type;
486     if (param.useNchw) {
487         output.dimensions = {batches, channels_out, outHeight, outWidth};
488     } else {
489         output.dimensions = {batches, outHeight, outWidth, channels_out};
490     }
491     return context->setOutputShape(kOutputTensor, output);
492 }
493 
execute(IOperationExecutionContext * context)494 bool execute(IOperationExecutionContext* context) {
495     // Bypass execution in the case of zero-sized input.
496     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
497     TransposeConv2dParam param;
498     NN_RET_CHECK(param.initialize(context));
499     switch (context->getInputType(kInputTensor)) {
500         case OperandType::TENSOR_FLOAT32:
501             return transposeConv(context->getInputBuffer<float>(kInputTensor),
502                                  context->getInputShape(kInputTensor),
503                                  context->getInputBuffer<float>(kFilterTensor),
504                                  context->getInputShape(kFilterTensor),
505                                  context->getInputBuffer<float>(kBiasTensor),
506                                  context->getInputShape(kBiasTensor), param,
507                                  context->getOutputBuffer<float>(kOutputTensor),
508                                  context->getOutputShape(kOutputTensor));
509         case OperandType::TENSOR_FLOAT16:
510             return transposeConv(context->getInputBuffer<_Float16>(kInputTensor),
511                                  context->getInputShape(kInputTensor),
512                                  context->getInputBuffer<_Float16>(kFilterTensor),
513                                  context->getInputShape(kFilterTensor),
514                                  context->getInputBuffer<_Float16>(kBiasTensor),
515                                  context->getInputShape(kBiasTensor), param,
516                                  context->getOutputBuffer<_Float16>(kOutputTensor),
517                                  context->getOutputShape(kOutputTensor));
518         case OperandType::TENSOR_QUANT8_ASYMM:
519             if (context->getInputType(kFilterTensor) ==
520                 OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
521                 return transposeConvQuant8PerChannel(
522                         context->getInputBuffer<uint8_t>(kInputTensor),
523                         context->getInputShape(kInputTensor),
524                         context->getInputBuffer<int8_t>(kFilterTensor),
525                         context->getInputShape(kFilterTensor),
526                         std::get<Operand::SymmPerChannelQuantParams>(
527                                 context->getInputExtraParams(kFilterTensor))
528                                 .scales.data(),
529                         context->getInputBuffer<int32_t>(kBiasTensor),
530                         context->getInputShape(kBiasTensor), param,
531                         context->getOutputBuffer<uint8_t>(kOutputTensor),
532                         context->getOutputShape(kOutputTensor));
533             } else if (context->getInputType(kFilterTensor) == OperandType::TENSOR_QUANT8_ASYMM) {
534                 return transposeConv(context->getInputBuffer<uint8_t>(kInputTensor),
535                                      context->getInputShape(kInputTensor),
536                                      context->getInputBuffer<uint8_t>(kFilterTensor),
537                                      context->getInputShape(kFilterTensor),
538                                      context->getInputBuffer<int32_t>(kBiasTensor),
539                                      context->getInputShape(kBiasTensor), param,
540                                      context->getOutputBuffer<uint8_t>(kOutputTensor),
541                                      context->getOutputShape(kOutputTensor));
542             } else {
543                 NN_RET_CHECK_FAIL() << "Unsupported filter type for operation " << kOperationName;
544             }
545         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
546             if (context->getInputType(kFilterTensor) ==
547                 OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
548                 return transposeConvQuant8PerChannel(
549                         context->getInputBuffer<int8_t>(kInputTensor),
550                         context->getInputShape(kInputTensor),
551                         context->getInputBuffer<int8_t>(kFilterTensor),
552                         context->getInputShape(kFilterTensor),
553                         std::get<Operand::SymmPerChannelQuantParams>(
554                                 context->getInputExtraParams(kFilterTensor))
555                                 .scales.data(),
556                         context->getInputBuffer<int32_t>(kBiasTensor),
557                         context->getInputShape(kBiasTensor), param,
558                         context->getOutputBuffer<int8_t>(kOutputTensor),
559                         context->getOutputShape(kOutputTensor));
560             } else if (context->getInputType(kFilterTensor) ==
561                        OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
562                 return transposeConv(context->getInputBuffer<int8_t>(kInputTensor),
563                                      context->getInputShape(kInputTensor),
564                                      context->getInputBuffer<int8_t>(kFilterTensor),
565                                      context->getInputShape(kFilterTensor),
566                                      context->getInputBuffer<int32_t>(kBiasTensor),
567                                      context->getInputShape(kBiasTensor), param,
568                                      context->getOutputBuffer<int8_t>(kOutputTensor),
569                                      context->getOutputShape(kOutputTensor));
570             } else {
571                 NN_RET_CHECK_FAIL() << "Unsupported filter type for operation " << kOperationName;
572             }
573         default:
574             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
575     }
576 }
577 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
578 
579 }  // namespace transpose_conv_2d
580 
581 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(TRANSPOSE_CONV_2D, transpose_conv_2d::prepare,
582                                          transpose_conv_2d::execute, .allowZeroSizedInput = true);
583 
584 }  // namespace nn
585 }  // namespace android
586