1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "TransposeConv2D.h"
20
21 #include <algorithm>
22 #include <cfloat>
23 #include <cmath>
24 #include <memory>
25 #include <vector>
26
27 #include "OperationResolver.h"
28 #include "Tracing.h"
29
30 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
31 #pragma clang diagnostic push
32 #pragma clang diagnostic ignored "-Wunused-parameter"
33 #include <tensorflow/lite/kernels/internal/common.h>
34 #pragma clang diagnostic pop
35
36 #include "CpuOperationUtils.h"
37 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
38
39 namespace android {
40 namespace nn {
41 namespace transpose_conv_2d {
42
43 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
44 namespace {
45
46 // If possible we will use this static buffer for the tensor.
47 constexpr size_t kStaticBufferSize = 1605632;
48 [[maybe_unused]] char static_scratch_buffer[kStaticBufferSize];
49
50 // executionMutex is used to protect concurrent access of the static_scratch_buffer.
51 // std::mutex is safe for pthreads on Android.
52 std::mutex executionMutex;
53
54 struct TransposeConv2dParam {
55 int32_t paddingLeft, paddingRight;
56 int32_t paddingTop, paddingBottom;
57 int32_t strideWidth, strideHeight;
58 int32_t activation;
59 bool useNchw = false;
60
initializeandroid::nn::transpose_conv_2d::__anon564e98b60111::TransposeConv2dParam61 bool initialize(const IOperationExecutionContext* context) {
62 uint32_t inCount = context->getNumInputs();
63 int32_t paddingImplicit = 0;
64 if (inCount == 9) {
65 paddingImplicit = context->getInputValue<int32_t>(4);
66 strideWidth = context->getInputValue<int32_t>(5);
67 strideHeight = context->getInputValue<int32_t>(6);
68 activation = context->getInputValue<int32_t>(7);
69 useNchw = context->getInputValue<bool>(8);
70 Shape filterShape = context->getInputShape(kFilterTensor);
71 int32_t filterWidth = getSizeOfDimension(filterShape, 2);
72 int32_t filterHeight = getSizeOfDimension(filterShape, 1);
73 NN_RET_CHECK_EQ(getNumberOfDimensions(context->getInputShape(3)), 1u);
74 NN_RET_CHECK_EQ(getSizeOfDimension(context->getInputShape(3), 0), 4u);
75 const int32_t* outputShapeData = context->getInputBuffer<int32_t>(3);
76 int32_t outputWidth = useNchw ? outputShapeData[3] : outputShapeData[2];
77 int32_t outputHeight = useNchw ? outputShapeData[2] : outputShapeData[1];
78 calculateExplicitPaddingTransposeConv(outputWidth, strideWidth, filterWidth,
79 paddingImplicit, &paddingLeft, &paddingRight);
80 calculateExplicitPaddingTransposeConv(outputHeight, strideHeight, filterHeight,
81 paddingImplicit, &paddingTop, &paddingBottom);
82 } else if (inCount == 11) {
83 paddingLeft = context->getInputValue<int32_t>(3);
84 paddingRight = context->getInputValue<int32_t>(4);
85 paddingTop = context->getInputValue<int32_t>(5);
86 paddingBottom = context->getInputValue<int32_t>(6);
87 strideWidth = context->getInputValue<int32_t>(7);
88 strideHeight = context->getInputValue<int32_t>(8);
89 activation = context->getInputValue<int32_t>(9);
90 useNchw = context->getInputValue<bool>(10);
91 } else {
92 NN_RET_CHECK_FAIL() << "Unsupported input spec for operation " << kOperationName;
93 }
94 // paddingRight and paddingBottom in transpose conv may be less than 0 to resolve the
95 // ambiguous output shape issue in the case of stride > 1.
96 NN_RET_CHECK_GE(paddingLeft, 0);
97 NN_RET_CHECK_GE(paddingTop, 0);
98 NN_RET_CHECK_GT(strideWidth, 0);
99 NN_RET_CHECK_GT(strideHeight, 0);
100 NN_RET_CHECK_GE(activation, 0);
101 return true;
102 }
103 };
104
105 #define ANDROID_NN_TRANSPOSE_CONV_PARAMETERS \
106 uint32_t numBatches = getSizeOfDimension(inputShape, 0); \
107 uint32_t inputHeight = getSizeOfDimension(inputShape, 1); \
108 uint32_t inputWidth = getSizeOfDimension(inputShape, 2); \
109 uint32_t inputDepth = getSizeOfDimension(inputShape, 3); \
110 uint32_t filterHeight = getSizeOfDimension(filterShape, 1); \
111 uint32_t filterWidth = getSizeOfDimension(filterShape, 2); \
112 uint32_t outputHeight = getSizeOfDimension(outputShape, 1); \
113 uint32_t outputWidth = getSizeOfDimension(outputShape, 2); \
114 uint32_t outputDepth = getSizeOfDimension(outputShape, 3); \
115 int32_t paddingLeft = param.paddingLeft; \
116 [[maybe_unused]] int32_t paddingRight = param.paddingRight; \
117 int32_t paddingTop = param.paddingTop; \
118 [[maybe_unused]] int32_t paddingBottom = param.paddingBottom; \
119 int32_t strideWidth = param.strideWidth, strideHeight = param.strideHeight; \
120 int32_t activation = param.activation;
121
transposeConvNhwc(const float * inputData,const Shape & inputShape,const float * filterData,const Shape & filterShape,const float * biasData,const Shape &,const TransposeConv2dParam & param,float * outputData,const Shape & outputShape)122 bool transposeConvNhwc(const float* inputData, const Shape& inputShape, const float* filterData,
123 const Shape& filterShape, const float* biasData, const Shape& /*biasShape*/,
124 const TransposeConv2dParam& param, float* outputData,
125 const Shape& outputShape) {
126 NNTRACE_TRANS("transposeConvFloat32");
127 ANDROID_NN_TRANSPOSE_CONV_PARAMETERS
128
129 float outputActivationMin = 0.0f, outputActivationMax = 0.0f;
130 CalculateActivationRangeFloat(activation, &outputActivationMin, &outputActivationMax);
131
132 memset(outputData, 0, getNumberOfElements(outputShape) * sizeof(float));
133
134 const float* inputBase = inputData;
135 float* outputBase = outputData;
136 for (uint32_t b = 0; b < numBatches; b++) {
137 for (uint32_t h = 0; h < inputHeight; h++) {
138 for (uint32_t w = 0; w < inputWidth; w++) {
139 int32_t wOutputOrigin = static_cast<int32_t>(w) * strideWidth - paddingLeft;
140 int32_t hOutputOrigin = static_cast<int32_t>(h) * strideHeight - paddingTop;
141
142 const float* filterBase = filterData;
143 for (uint32_t k = 0; k < outputDepth; k++) {
144 for (uint32_t i = 0; i < filterHeight; i++) {
145 for (uint32_t j = 0; j < filterWidth; j++, filterBase += inputDepth) {
146 int32_t hOutput = hOutputOrigin + static_cast<int32_t>(i);
147 int32_t wOutput = wOutputOrigin + static_cast<int32_t>(j);
148 if (hOutput >= 0 && hOutput < static_cast<int32_t>(outputHeight) &&
149 wOutput >= 0 && wOutput < static_cast<int32_t>(outputWidth)) {
150 for (uint32_t d = 0; d < inputDepth; d++) {
151 uint32_t outputIndex = hOutput * outputWidth * outputDepth +
152 wOutput * outputDepth + k;
153 outputBase[outputIndex] += inputBase[d] * filterBase[d];
154 }
155 }
156 }
157 }
158 }
159
160 inputBase += inputDepth;
161 }
162 }
163 outputBase += outputHeight * outputWidth * outputDepth;
164 }
165
166 const uint32_t outerSize = numBatches * outputHeight * outputWidth;
167 float* outPtr = outputData;
168 for (uint32_t i = 0; i < outerSize; i++) {
169 for (uint32_t d = 0; d < outputDepth; d++, outPtr++) {
170 *outPtr += biasData[d];
171 *outPtr = std::max(std::min(*outPtr, outputActivationMax), outputActivationMin);
172 }
173 }
174
175 return true;
176 }
177
178 template <typename T>
transposeConvNhwc(const T * inputData,const Shape & inputShape,const T * filterData,const Shape & filterShape,const int32_t * biasData,const Shape & biasShape,const TransposeConv2dParam & param,T * outputData,const Shape & outputShape)179 bool transposeConvNhwc(const T* inputData, const Shape& inputShape, const T* filterData,
180 const Shape& filterShape, const int32_t* biasData, const Shape& biasShape,
181 const TransposeConv2dParam& param, T* outputData, const Shape& outputShape) {
182 NNTRACE_TRANS("transposeConvQuant8");
183 ANDROID_NN_TRANSPOSE_CONV_PARAMETERS
184
185 int32_t* tempBuffer = nullptr;
186 std::unique_ptr<int32_t[]> bufferGuard;
187 uint32_t tempBufferByteSize = getNumberOfElements(outputShape) * sizeof(int32_t);
188 if (tempBufferByteSize <= kStaticBufferSize) {
189 tempBuffer = reinterpret_cast<int32_t*>(static_scratch_buffer);
190 } else {
191 tempBuffer = new (std::nothrow) int32_t[tempBufferByteSize / sizeof(int32_t)];
192 if (tempBuffer == nullptr) {
193 LOG(ERROR) << "ConvTranspose size is too large, not enough memory";
194 return false;
195 }
196 bufferGuard.reset(tempBuffer);
197 }
198
199 int32_t inputOffset = -inputShape.offset;
200 int32_t filterOffset = -filterShape.offset;
201 int32_t outputOffset = outputShape.offset;
202
203 double realMultiplier = 0.0;
204 int32_t outputMultiplier = 0;
205 int32_t outputShift = 0;
206 NN_RET_CHECK(GetQuantizedConvolutionMultiplier(inputShape, filterShape, biasShape, outputShape,
207 &realMultiplier));
208 int exponent;
209 NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &exponent));
210 outputShift = -exponent;
211
212 int32_t outputActivationMin = 0, outputActivationMax = 0;
213 CalculateActivationRange<T>(activation, outputShape, &outputActivationMin,
214 &outputActivationMax);
215
216 // Prevent concurrent executions that may access the scratch buffer
217 std::unique_lock<std::mutex> lock(executionMutex);
218 memset(tempBuffer, 0, tempBufferByteSize);
219
220 const T* inputPtr = inputData;
221 int32_t* outputBase = tempBuffer;
222 for (uint32_t b = 0; b < numBatches; b++) {
223 for (uint32_t h = 0; h < inputHeight; h++) {
224 for (uint32_t w = 0; w < inputWidth; w++) {
225 for (uint32_t d = 0; d < inputDepth; d++) {
226 int32_t wOutputOrigin = static_cast<int32_t>(w) * strideWidth - paddingLeft;
227 int32_t hOutputOrigin = static_cast<int32_t>(h) * strideHeight - paddingTop;
228
229 for (uint32_t i = 0; i < filterHeight; i++) {
230 for (uint32_t j = 0; j < filterWidth; j++) {
231 for (uint32_t k = 0; k < outputDepth; k++) {
232 int32_t hOutput = hOutputOrigin + static_cast<int32_t>(i);
233 int32_t wOutput = wOutputOrigin + static_cast<int32_t>(j);
234 if (hOutput >= 0 && hOutput < static_cast<int32_t>(outputHeight) &&
235 wOutput >= 0 && wOutput < static_cast<int32_t>(outputWidth)) {
236 uint32_t filterIndex =
237 k * filterHeight * filterWidth * inputDepth +
238 i * filterWidth * inputDepth + j * inputDepth + d;
239 uint32_t outputIndex = hOutput * outputWidth * outputDepth +
240 wOutput * outputDepth + k;
241 outputBase[outputIndex] +=
242 (static_cast<int32_t>(*inputPtr) + inputOffset) *
243 (static_cast<int32_t>(filterData[filterIndex]) +
244 filterOffset);
245 }
246 }
247 }
248 }
249
250 inputPtr++;
251 }
252 }
253 }
254 outputBase += outputHeight * outputWidth * outputDepth;
255 }
256
257 const uint32_t outerSize = numBatches * outputHeight * outputWidth;
258 int32_t* bufferPtr = tempBuffer;
259 T* outPtr = outputData;
260 for (uint32_t i = 0; i < outerSize; i++) {
261 for (uint32_t d = 0; d < outputDepth; d++, bufferPtr++, outPtr++) {
262 int32_t outVal = *bufferPtr + biasData[d];
263 outVal = tflite::MultiplyByQuantizedMultiplier(outVal, outputMultiplier, -outputShift);
264 outVal += outputOffset;
265 outVal = std::max(std::min(outVal, outputActivationMax), outputActivationMin);
266 *outPtr = static_cast<T>(outVal);
267 }
268 }
269
270 return true;
271 }
272
transposeConvNhwc(const _Float16 * inputData,const Shape & inputShape,const _Float16 * filterData,const Shape & filterShape,const _Float16 * biasData,const Shape & biasShape,const TransposeConv2dParam & param,_Float16 * outputData,const Shape & outputShape)273 bool transposeConvNhwc(const _Float16* inputData, const Shape& inputShape,
274 const _Float16* filterData, const Shape& filterShape,
275 const _Float16* biasData, const Shape& biasShape,
276 const TransposeConv2dParam& param, _Float16* outputData,
277 const Shape& outputShape) {
278 NNTRACE_TRANS("transposeConvFloat16");
279 std::vector<float> inputData_float32(getNumberOfElements(inputShape));
280 std::vector<float> filterData_float32(getNumberOfElements(filterShape));
281 std::vector<float> biasData_float32(getNumberOfElements(biasShape));
282 std::vector<float> outputData_float32(getNumberOfElements(outputShape));
283
284 convertFloat16ToFloat32(inputData, &inputData_float32);
285 convertFloat16ToFloat32(filterData, &filterData_float32);
286 convertFloat16ToFloat32(biasData, &biasData_float32);
287
288 transposeConvNhwc(inputData_float32.data(), inputShape, filterData_float32.data(), filterShape,
289 biasData_float32.data(), biasShape, param, outputData_float32.data(),
290 outputShape);
291 convertFloat32ToFloat16(outputData_float32, outputData);
292
293 return true;
294 }
295
296 template <typename T_Input, typename T_Filter, typename T_Bias>
transposeConv(const T_Input * inputData,const Shape & inputShape,const T_Filter * filterData,const Shape & filterShape,const T_Bias * biasData,const Shape & biasShape,const TransposeConv2dParam & param,T_Input * outputData,const Shape & outputShape)297 bool transposeConv(const T_Input* inputData, const Shape& inputShape, const T_Filter* filterData,
298 const Shape& filterShape, const T_Bias* biasData, const Shape& biasShape,
299 const TransposeConv2dParam& param, T_Input* outputData,
300 const Shape& outputShape) {
301 InputWithLayout<T_Input> input(param.useNchw);
302 OutputWithLayout<T_Input> output(param.useNchw);
303 NN_RET_CHECK(input.initialize(inputData, inputShape));
304 NN_RET_CHECK(output.initialize(outputData, outputShape));
305 NN_RET_CHECK(transposeConvNhwc(input.getNhwcBuffer(), input.getNhwcShape(), filterData,
306 filterShape, biasData, biasShape, param, output.getNhwcBuffer(),
307 output.getNhwcShape()));
308 NN_RET_CHECK(output.commit());
309 return true;
310 }
311
312 template <typename T>
transposeConvQuant8PerChannelNhwc(const T * inputData,const Shape & inputShape,const int8_t * filterData,const Shape & filterShape,const float * filterScales,const int32_t * biasData,const Shape & biasShape,const TransposeConv2dParam & param,T * outputData,const Shape & outputShape)313 bool transposeConvQuant8PerChannelNhwc(const T* inputData, const Shape& inputShape,
314 const int8_t* filterData, const Shape& filterShape,
315 const float* filterScales, const int32_t* biasData,
316 const Shape& biasShape, const TransposeConv2dParam& param,
317 T* outputData, const Shape& outputShape) {
318 NNTRACE_TRANS("transposeConvQuant8PerChannel");
319 ANDROID_NN_TRANSPOSE_CONV_PARAMETERS
320
321 int32_t* tempBuffer = nullptr;
322 std::unique_ptr<int32_t[]> bufferGuard;
323 uint32_t tempBufferByteSize = getNumberOfElements(outputShape) * sizeof(int32_t);
324 if (tempBufferByteSize <= kStaticBufferSize) {
325 tempBuffer = reinterpret_cast<int32_t*>(static_scratch_buffer);
326 } else {
327 tempBuffer = new (std::nothrow) int32_t[tempBufferByteSize / sizeof(int32_t)];
328 if (tempBuffer == nullptr) {
329 LOG(ERROR) << "ConvTranspose size is too large, not enough memory";
330 return false;
331 }
332 bufferGuard.reset(tempBuffer);
333 }
334
335 int32_t inputOffset = -inputShape.offset;
336 int32_t outputOffset = outputShape.offset;
337
338 std::vector<double> realMultiplier(outputDepth, 0.0);
339 std::vector<int32_t> outputMultiplier(outputDepth, 0);
340 std::vector<int32_t> outputShift(outputDepth, 0);
341 for (uint32_t i = 0; i < outputDepth; ++i) {
342 Shape filterChannelShape = filterShape;
343 filterChannelShape.scale = filterScales[i];
344 Shape biasChannelShape = biasShape;
345 biasChannelShape.scale = filterScales[i] * inputShape.scale;
346
347 NN_RET_CHECK(GetQuantizedConvolutionMultiplier(
348 inputShape, filterChannelShape, biasChannelShape, outputShape, &realMultiplier[i]));
349 int exponent;
350 NN_RET_CHECK(QuantizeMultiplier(realMultiplier[i], &outputMultiplier[i], &exponent));
351 outputShift[i] = -exponent;
352 }
353
354 int32_t outputActivationMin = 0, outputActivationMax = 0;
355 CalculateActivationRange<T>(activation, outputShape, &outputActivationMin,
356 &outputActivationMax);
357
358 // Prevent concurrent executions that may access the scratch buffer
359 std::unique_lock<std::mutex> lock(executionMutex);
360 memset(tempBuffer, 0, tempBufferByteSize);
361
362 const T* inputPtr = inputData;
363 int32_t* outputBase = tempBuffer;
364 for (uint32_t b = 0; b < numBatches; b++) {
365 for (uint32_t h = 0; h < inputHeight; h++) {
366 for (uint32_t w = 0; w < inputWidth; w++) {
367 for (uint32_t d = 0; d < inputDepth; d++) {
368 int32_t wOutputOrigin = static_cast<int32_t>(w) * strideWidth - paddingLeft;
369 int32_t hOutputOrigin = static_cast<int32_t>(h) * strideHeight - paddingTop;
370
371 for (uint32_t i = 0; i < filterHeight; i++) {
372 for (uint32_t j = 0; j < filterWidth; j++) {
373 for (uint32_t k = 0; k < outputDepth; k++) {
374 int32_t hOutput = hOutputOrigin + static_cast<int32_t>(i);
375 int32_t wOutput = wOutputOrigin + static_cast<int32_t>(j);
376 if (hOutput >= 0 && hOutput < static_cast<int32_t>(outputHeight) &&
377 wOutput >= 0 && wOutput < static_cast<int32_t>(outputWidth)) {
378 uint32_t filterIndex =
379 k * filterHeight * filterWidth * inputDepth +
380 i * filterWidth * inputDepth + j * inputDepth + d;
381 uint32_t outputIndex = hOutput * outputWidth * outputDepth +
382 wOutput * outputDepth + k;
383 outputBase[outputIndex] +=
384 (static_cast<int32_t>(*inputPtr) + inputOffset) *
385 static_cast<int32_t>(filterData[filterIndex]);
386 }
387 }
388 }
389 }
390
391 inputPtr++;
392 }
393 }
394 }
395 outputBase += outputHeight * outputWidth * outputDepth;
396 }
397
398 const uint32_t outerSize = numBatches * outputHeight * outputWidth;
399 int32_t* bufferPtr = tempBuffer;
400 T* outPtr = outputData;
401 for (uint32_t i = 0; i < outerSize; i++) {
402 for (uint32_t d = 0; d < outputDepth; d++, bufferPtr++, outPtr++) {
403 int32_t outVal = *bufferPtr + biasData[d];
404 outVal = tflite::MultiplyByQuantizedMultiplier(outVal, outputMultiplier[d],
405 -outputShift[d]);
406 outVal += outputOffset;
407 outVal = std::max(std::min(outVal, outputActivationMax), outputActivationMin);
408 *outPtr = static_cast<T>(outVal);
409 }
410 }
411
412 return true;
413 }
414
415 template <typename T>
transposeConvQuant8PerChannel(const T * inputData,const Shape & inputShape,const int8_t * filterData,const Shape & filterShape,const float * filterScales,const int32_t * biasData,const Shape & biasShape,const TransposeConv2dParam & param,T * outputData,const Shape & outputShape)416 bool transposeConvQuant8PerChannel(const T* inputData, const Shape& inputShape,
417 const int8_t* filterData, const Shape& filterShape,
418 const float* filterScales, const int32_t* biasData,
419 const Shape& biasShape, const TransposeConv2dParam& param,
420 T* outputData, const Shape& outputShape) {
421 InputWithLayout<T> input(param.useNchw);
422 OutputWithLayout<T> output(param.useNchw);
423 NN_RET_CHECK(input.initialize(inputData, inputShape));
424 NN_RET_CHECK(output.initialize(outputData, outputShape));
425 NN_RET_CHECK(transposeConvQuant8PerChannelNhwc(
426 input.getNhwcBuffer(), input.getNhwcShape(), filterData, filterShape, filterScales,
427 biasData, biasShape, param, output.getNhwcBuffer(), output.getNhwcShape()));
428 NN_RET_CHECK(output.commit());
429 return true;
430 }
431
432 #undef ANDROID_NN_TRANSPOSE_CONV_PARAMETERS
433
434 } // namespace
435
prepare(IOperationExecutionContext * context)436 bool prepare(IOperationExecutionContext* context) {
437 Shape input = context->getInputShape(kInputTensor);
438 Shape filter = context->getInputShape(kFilterTensor);
439 Shape bias = context->getInputShape(kBiasTensor);
440
441 if (filter.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
442 NN_RET_CHECK(input.type == OperandType::TENSOR_QUANT8_ASYMM ||
443 input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
444 } else {
445 NN_RET_CHECK(input.type == filter.type);
446 }
447 if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
448 input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
449 NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32);
450 } else {
451 NN_RET_CHECK(input.type == bias.type);
452 }
453 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4u);
454 NN_RET_CHECK_EQ(getNumberOfDimensions(filter), 4u);
455 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1u);
456
457 TransposeConv2dParam param;
458 NN_RET_CHECK(param.initialize(context));
459
460 uint32_t batches = getSizeOfDimension(input, 0);
461 uint32_t height = getSizeOfDimension(input, param.useNchw ? 2 : 1);
462 uint32_t width = getSizeOfDimension(input, param.useNchw ? 3 : 2);
463 uint32_t channels_in = getSizeOfDimension(input, param.useNchw ? 1 : 3);
464 uint32_t channels_out = getSizeOfDimension(filter, 0);
465 uint32_t filterHeight = getSizeOfDimension(filter, 1);
466 uint32_t filterWidth = getSizeOfDimension(filter, 2);
467 // Only batches can be zero.
468 NN_RET_CHECK_EQ(channels_in, getSizeOfDimension(filter, 3));
469 NN_RET_CHECK_EQ(channels_out, getSizeOfDimension(bias, 0));
470 NN_RET_CHECK_GT(height, 0u);
471 NN_RET_CHECK_GT(width, 0u);
472 NN_RET_CHECK_GT(channels_in, 0u);
473 NN_RET_CHECK_GT(channels_out, 0u);
474 NN_RET_CHECK_GT(filterWidth, 0u);
475 NN_RET_CHECK_GT(filterHeight, 0u);
476
477 uint32_t outWidth = computeOutSizeTransposeConv(width, filterWidth, param.strideWidth,
478 param.paddingLeft, param.paddingRight);
479 uint32_t outHeight = computeOutSizeTransposeConv(height, filterHeight, param.strideHeight,
480 param.paddingTop, param.paddingBottom);
481 NN_RET_CHECK_GT(outWidth, 0u);
482 NN_RET_CHECK_GT(outHeight, 0u);
483
484 Shape output = context->getOutputShape(kOutputTensor);
485 output.type = input.type;
486 if (param.useNchw) {
487 output.dimensions = {batches, channels_out, outHeight, outWidth};
488 } else {
489 output.dimensions = {batches, outHeight, outWidth, channels_out};
490 }
491 return context->setOutputShape(kOutputTensor, output);
492 }
493
execute(IOperationExecutionContext * context)494 bool execute(IOperationExecutionContext* context) {
495 // Bypass execution in the case of zero-sized input.
496 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
497 TransposeConv2dParam param;
498 NN_RET_CHECK(param.initialize(context));
499 switch (context->getInputType(kInputTensor)) {
500 case OperandType::TENSOR_FLOAT32:
501 return transposeConv(context->getInputBuffer<float>(kInputTensor),
502 context->getInputShape(kInputTensor),
503 context->getInputBuffer<float>(kFilterTensor),
504 context->getInputShape(kFilterTensor),
505 context->getInputBuffer<float>(kBiasTensor),
506 context->getInputShape(kBiasTensor), param,
507 context->getOutputBuffer<float>(kOutputTensor),
508 context->getOutputShape(kOutputTensor));
509 case OperandType::TENSOR_FLOAT16:
510 return transposeConv(context->getInputBuffer<_Float16>(kInputTensor),
511 context->getInputShape(kInputTensor),
512 context->getInputBuffer<_Float16>(kFilterTensor),
513 context->getInputShape(kFilterTensor),
514 context->getInputBuffer<_Float16>(kBiasTensor),
515 context->getInputShape(kBiasTensor), param,
516 context->getOutputBuffer<_Float16>(kOutputTensor),
517 context->getOutputShape(kOutputTensor));
518 case OperandType::TENSOR_QUANT8_ASYMM:
519 if (context->getInputType(kFilterTensor) ==
520 OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
521 return transposeConvQuant8PerChannel(
522 context->getInputBuffer<uint8_t>(kInputTensor),
523 context->getInputShape(kInputTensor),
524 context->getInputBuffer<int8_t>(kFilterTensor),
525 context->getInputShape(kFilterTensor),
526 std::get<Operand::SymmPerChannelQuantParams>(
527 context->getInputExtraParams(kFilterTensor))
528 .scales.data(),
529 context->getInputBuffer<int32_t>(kBiasTensor),
530 context->getInputShape(kBiasTensor), param,
531 context->getOutputBuffer<uint8_t>(kOutputTensor),
532 context->getOutputShape(kOutputTensor));
533 } else if (context->getInputType(kFilterTensor) == OperandType::TENSOR_QUANT8_ASYMM) {
534 return transposeConv(context->getInputBuffer<uint8_t>(kInputTensor),
535 context->getInputShape(kInputTensor),
536 context->getInputBuffer<uint8_t>(kFilterTensor),
537 context->getInputShape(kFilterTensor),
538 context->getInputBuffer<int32_t>(kBiasTensor),
539 context->getInputShape(kBiasTensor), param,
540 context->getOutputBuffer<uint8_t>(kOutputTensor),
541 context->getOutputShape(kOutputTensor));
542 } else {
543 NN_RET_CHECK_FAIL() << "Unsupported filter type for operation " << kOperationName;
544 }
545 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
546 if (context->getInputType(kFilterTensor) ==
547 OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
548 return transposeConvQuant8PerChannel(
549 context->getInputBuffer<int8_t>(kInputTensor),
550 context->getInputShape(kInputTensor),
551 context->getInputBuffer<int8_t>(kFilterTensor),
552 context->getInputShape(kFilterTensor),
553 std::get<Operand::SymmPerChannelQuantParams>(
554 context->getInputExtraParams(kFilterTensor))
555 .scales.data(),
556 context->getInputBuffer<int32_t>(kBiasTensor),
557 context->getInputShape(kBiasTensor), param,
558 context->getOutputBuffer<int8_t>(kOutputTensor),
559 context->getOutputShape(kOutputTensor));
560 } else if (context->getInputType(kFilterTensor) ==
561 OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
562 return transposeConv(context->getInputBuffer<int8_t>(kInputTensor),
563 context->getInputShape(kInputTensor),
564 context->getInputBuffer<int8_t>(kFilterTensor),
565 context->getInputShape(kFilterTensor),
566 context->getInputBuffer<int32_t>(kBiasTensor),
567 context->getInputShape(kBiasTensor), param,
568 context->getOutputBuffer<int8_t>(kOutputTensor),
569 context->getOutputShape(kOutputTensor));
570 } else {
571 NN_RET_CHECK_FAIL() << "Unsupported filter type for operation " << kOperationName;
572 }
573 default:
574 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
575 }
576 }
577 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
578
579 } // namespace transpose_conv_2d
580
581 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(TRANSPOSE_CONV_2D, transpose_conv_2d::prepare,
582 transpose_conv_2d::execute, .allowZeroSizedInput = true);
583
584 } // namespace nn
585 } // namespace android
586