1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "Activation.h"
20
21 #include <algorithm>
22 #include <limits>
23 #include <vector>
24
25 #include "ActivationFunctor.h"
26 #include "OperationResolver.h"
27 #include "OperationsExecutionUtils.h"
28 #include "Tracing.h"
29
30 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
31 #pragma clang diagnostic push
32 #pragma clang diagnostic ignored "-Wunused-parameter"
33 #pragma clang diagnostic ignored "-Wsign-compare"
34 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
35 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
36 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
37 #include <tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h>
38 #include <tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h>
39 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
40 #pragma clang diagnostic pop
41
42 #include "CpuOperationUtils.h"
43 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
44
45 namespace android {
46 namespace nn {
47
48 namespace activation {
49
50 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
51 namespace {
52
53 template <typename T>
reluFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape &,float reluMin=0.f,float reluMax=std::numeric_limits<float>::max ())54 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData,
55 const Shape& /*outputShape*/, float reluMin = 0.f,
56 float reluMax = std::numeric_limits<float>::max()) {
57 NNTRACE_COMP("reluX");
58 int numElements = getNumberOfElements(inputShape);
59 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
60 *outputData = static_cast<T>(
61 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax));
62 }
63 return true;
64 }
65 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData,
66 const Shape& outputShape, float reluMin, float reluMax);
67 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
68 _Float16* outputData, const Shape& outputShape, float reluMin,
69 float reluMax);
70
71 template <typename T>
relu1Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)72 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData,
73 const Shape& outputShape) {
74 return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f);
75 }
76 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
77 const Shape& outputShape);
78 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
79 _Float16* outputData, const Shape& outputShape);
80
81 template <typename T>
relu6Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)82 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData,
83 const Shape& outputShape) {
84 return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f);
85 }
86 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
87 const Shape& outputShape);
88 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
89 _Float16* outputData, const Shape& outputShape);
90
tanhFloat16(const _Float16 * inputData,const Shape & inputShape,_Float16 * outputData,const Shape &)91 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData,
92 const Shape& /*outputShape*/) {
93 NNTRACE_COMP("tanhFloat16");
94 int numElements = getNumberOfElements(inputShape);
95 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
96 *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData)));
97 }
98 return true;
99 }
100
tanhFloat32(const float * inputData,const Shape & inputShape,float * outputData,const Shape &)101 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData,
102 const Shape& /*outputShape*/) {
103 NNTRACE_COMP("tanhFloat32");
104 int numElements = getNumberOfElements(inputShape);
105 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
106 *outputData = std::tanh(*inputData);
107 }
108 return true;
109 }
110
111 template <typename T>
logisticFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape &)112 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData,
113 const Shape& /*outputShape*/) {
114 NNTRACE_COMP("logisticFloat");
115 int numElements = getNumberOfElements(inputShape);
116 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
117 *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData))));
118 }
119 return true;
120 }
121 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape,
122 float* outputData, const Shape& outputShape);
123 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
124 _Float16* outputData, const Shape& outputShape);
125
126 template <ActivationFn activation>
reluXQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape &)127 inline bool reluXQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
128 const Shape& /*outputShape*/) {
129 int numElements = getNumberOfElements(inputShape);
130 int32_t output_activation_min = 0;
131 int32_t output_activation_max = 0;
132
133 CalculateActivationRangeUint8(activation, inputShape, &output_activation_min,
134 &output_activation_max);
135
136 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
137 *outputData = std::min((uint8_t)output_activation_max,
138 std::max((uint8_t)output_activation_min, *inputData));
139 }
140 return true;
141 }
142
reluQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)143 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
144 const Shape& outputShape) {
145 NNTRACE_COMP("reluQuant8");
146 return reluXQuant8<kActivationRelu>(inputData, inputShape, outputData, outputShape);
147 }
148
relu1Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)149 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
150 const Shape& outputShape) {
151 NNTRACE_COMP("relu1Quant8");
152 return reluXQuant8<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
153 }
154
relu6Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)155 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
156 const Shape& outputShape) {
157 NNTRACE_COMP("relu6Quant8");
158 return reluXQuant8<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
159 }
160
tanhQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)161 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
162 const Shape& outputShape) {
163 NNTRACE_TRANS("tanhQuant8");
164 if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) {
165 LOG(ERROR) << "incorrect scale or offset for TANH output";
166 return false;
167 }
168
169 [[maybe_unused]] int numElements = getNumberOfElements(inputShape);
170 static constexpr int kInputIntegerBits = 4;
171
172 const double input_real_multiplier =
173 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
174
175 int32_t input_multiplier = 0;
176 int32_t input_left_shift = 0;
177 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
178 &input_left_shift)) {
179 return false;
180 }
181 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
182
183 NNTRACE_COMP_SWITCH("optimized_ops::Tanh");
184 tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset,
185 input_range_radius, input_multiplier, input_left_shift, outputData,
186 convertShapeToTflshape(outputShape));
187
188 return true;
189 }
190
logisticQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)191 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
192 const Shape& outputShape) {
193 NNTRACE_TRANS("logisticQuant8");
194 if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
195 LOG(ERROR) << "incorrect scale / offset for output";
196 return false;
197 }
198
199 [[maybe_unused]] int numElements = getNumberOfElements(inputShape);
200 static constexpr int kInputIntegerBits = 4;
201
202 const double input_real_multiplier =
203 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
204
205 int32_t input_multiplier = 0;
206 int32_t input_left_shift = 0;
207 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
208 &input_left_shift)) {
209 return false;
210 }
211 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
212
213 NNTRACE_COMP_SWITCH("optimized_ops::Logistic");
214 tflite::optimized_ops::Logistic(
215 inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius,
216 input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape));
217
218 return true;
219 }
220
221 template <ActivationFn activation>
reluXQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape &)222 inline bool reluXQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
223 const Shape& /*outputShape*/) {
224 int numElements = getNumberOfElements(inputShape);
225 int32_t output_activation_min = 0;
226 int32_t output_activation_max = 0;
227
228 CalculateActivationRangeInt8(activation, inputShape, &output_activation_min,
229 &output_activation_max);
230
231 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
232 *outputData = std::min((int8_t)output_activation_max,
233 std::max((int8_t)output_activation_min, *inputData));
234 }
235 return true;
236 }
237
reluQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)238 bool reluQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
239 const Shape& outputShape) {
240 NNTRACE_COMP("reluQuant8");
241 return reluXQuant8Signed<kActivationRelu>(inputData, inputShape, outputData, outputShape);
242 }
243
relu1Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)244 bool relu1Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
245 const Shape& outputShape) {
246 NNTRACE_COMP("relu1Quant8");
247 return reluXQuant8Signed<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
248 }
249
relu6Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)250 bool relu6Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
251 const Shape& outputShape) {
252 NNTRACE_COMP("relu6Quant8");
253 return reluXQuant8Signed<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
254 }
255
tanhQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)256 bool tanhQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
257 const Shape& outputShape) {
258 NNTRACE_TRANS("tanhQuant8Signed");
259 if (outputShape.offset != 0 || outputShape.scale != 1.f / 128) {
260 LOG(ERROR) << "incorrect scale or offset for TANH output";
261 return false;
262 }
263
264 [[maybe_unused]] int numElements = getNumberOfElements(inputShape);
265 static constexpr int kInputIntegerBits = 4;
266
267 const double input_real_multiplier =
268 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
269
270 int32_t input_multiplier = 0;
271 int32_t input_left_shift = 0;
272 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
273 &input_left_shift)) {
274 return false;
275 }
276 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
277
278 NNTRACE_COMP_SWITCH("reference_integer_ops::Tanh");
279 tflite::reference_integer_ops::Tanh(inputShape.offset, input_range_radius, input_multiplier,
280 input_left_shift, convertShapeToTflshape(inputShape),
281 inputData, convertShapeToTflshape(outputShape), outputData);
282
283 return true;
284 }
285
logisticQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)286 bool logisticQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
287 const Shape& outputShape) {
288 NNTRACE_TRANS("logisticQuant8Signed");
289 if (outputShape.offset != -128 || outputShape.scale != 1.f / 256) {
290 LOG(ERROR) << "incorrect scale / offset for output";
291 return false;
292 }
293
294 int numElements = getNumberOfElements(inputShape);
295 static constexpr int kInputIntegerBits = 4;
296
297 const double input_real_multiplier =
298 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
299
300 int32_t input_multiplier = 0;
301 int32_t input_left_shift = 0;
302 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
303 &input_left_shift)) {
304 return false;
305 }
306 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
307
308 NNTRACE_COMP_SWITCH("reference_integer_ops::Logistic");
309 tflite::reference_integer_ops::Logistic(inputShape.offset, input_range_radius, input_multiplier,
310 input_left_shift, numElements, inputData, outputData);
311
312 return true;
313 }
314
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,int16_t * multiplier_int16)315 void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32, int16_t* multiplier_int16) {
316 TFLITE_DCHECK_GE(multiplier_int32, 0);
317 static constexpr int32_t kRoundingOffset = 1 << 15;
318 if (multiplier_int32 >= std::numeric_limits<int32_t>::max() - kRoundingOffset) {
319 *multiplier_int16 = std::numeric_limits<int16_t>::max();
320 return;
321 }
322 const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
323 TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
324 TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
325 *multiplier_int16 = result;
326 TFLITE_DCHECK_EQ(*multiplier_int16, result);
327 }
328
329 template <typename T>
hardSwishQuant(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)330 bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData,
331 const Shape& outputShape) {
332 tflite::HardSwishParams params;
333 params.input_zero_point = inputShape.offset;
334 params.output_zero_point = outputShape.offset;
335 const float input_scale = inputShape.scale;
336 const float hires_input_scale = (1.0f / 128.0f) * input_scale;
337 const float reluish_scale = 3.0f / 32768.0f;
338 const float output_scale = outputShape.scale;
339
340 const float output_multiplier = hires_input_scale / output_scale;
341
342 int32_t output_multiplier_fixedpoint_int32;
343 NN_RET_CHECK(QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
344 ¶ms.output_multiplier_exponent));
345 DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32,
346 ¶ms.output_multiplier_fixedpoint_int16);
347 NN_RET_CHECK(params.output_multiplier_exponent <= 0);
348
349 const float reluish_multiplier = hires_input_scale / reluish_scale;
350 int32_t reluish_multiplier_fixedpoint_int32;
351 NN_RET_CHECK(QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
352 ¶ms.reluish_multiplier_exponent));
353 DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32,
354 ¶ms.reluish_multiplier_fixedpoint_int16);
355
356 tflite::reference_ops::HardSwish(params, convertShapeToTflshape(inputShape), inputData,
357 convertShapeToTflshape(outputShape), outputData);
358 return true;
359 }
360
361 } // namespace
362
prepare(OperationType opType,IOperationExecutionContext * context)363 bool prepare(OperationType opType, IOperationExecutionContext* context) {
364 Shape input = context->getInputShape(kInputTensor);
365 if (opType != OperationType::HARD_SWISH) {
366 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4u);
367 }
368 Shape output = input;
369 if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
370 input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
371 bool isSigned = input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED;
372 switch (opType) {
373 case OperationType::HARD_SWISH: {
374 auto outputShape = context->getOutputShape(kOutputTensor);
375 output.scale = outputShape.scale;
376 output.offset = outputShape.offset;
377 } break;
378 case OperationType::RELU:
379 case OperationType::RELU1:
380 case OperationType::RELU6:
381 break;
382 case OperationType::LOGISTIC:
383 output.scale = 1.f / 256;
384 output.offset = isSigned ? -128 : 0;
385 break;
386 case OperationType::TANH:
387 output.scale = 1.f / 128;
388 output.offset = isSigned ? 0 : 128;
389 break;
390 default:
391 NN_RET_CHECK_FAIL() << "Unsupported operation type";
392 }
393 }
394 return context->setOutputShape(kOutputTensor, output);
395 }
396
executeRelu(IOperationExecutionContext * context)397 bool executeRelu(IOperationExecutionContext* context) {
398 // Bypass execution in the case of zero-sized input.
399 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
400 switch (context->getInputType(kInputTensor)) {
401 case OperandType::TENSOR_FLOAT16:
402 return reluFloat(context->getInputBuffer<_Float16>(kInputTensor),
403 context->getInputShape(kInputTensor),
404 context->getOutputBuffer<_Float16>(kOutputTensor),
405 context->getOutputShape(kOutputTensor));
406 case OperandType::TENSOR_FLOAT32:
407 return reluFloat(context->getInputBuffer<float>(kInputTensor),
408 context->getInputShape(kInputTensor),
409 context->getOutputBuffer<float>(kOutputTensor),
410 context->getOutputShape(kOutputTensor));
411 case OperandType::TENSOR_QUANT8_ASYMM:
412 return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
413 context->getInputShape(kInputTensor),
414 context->getOutputBuffer<uint8_t>(kOutputTensor),
415 context->getOutputShape(kOutputTensor));
416 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
417 return reluQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
418 context->getInputShape(kInputTensor),
419 context->getOutputBuffer<int8_t>(kOutputTensor),
420 context->getOutputShape(kOutputTensor));
421 default:
422 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU";
423 }
424 }
425
executeRelu1(IOperationExecutionContext * context)426 bool executeRelu1(IOperationExecutionContext* context) {
427 // Bypass execution in the case of zero-sized input.
428 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
429 switch (context->getInputType(kInputTensor)) {
430 case OperandType::TENSOR_FLOAT16:
431 return relu1Float(context->getInputBuffer<_Float16>(kInputTensor),
432 context->getInputShape(kInputTensor),
433 context->getOutputBuffer<_Float16>(kOutputTensor),
434 context->getOutputShape(kOutputTensor));
435 case OperandType::TENSOR_FLOAT32:
436 return relu1Float(context->getInputBuffer<float>(kInputTensor),
437 context->getInputShape(kInputTensor),
438 context->getOutputBuffer<float>(kOutputTensor),
439 context->getOutputShape(kOutputTensor));
440 case OperandType::TENSOR_QUANT8_ASYMM:
441 return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
442 context->getInputShape(kInputTensor),
443 context->getOutputBuffer<uint8_t>(kOutputTensor),
444 context->getOutputShape(kOutputTensor));
445 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
446 return relu1Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
447 context->getInputShape(kInputTensor),
448 context->getOutputBuffer<int8_t>(kOutputTensor),
449 context->getOutputShape(kOutputTensor));
450 default:
451 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1";
452 }
453 }
454
executeRelu6(IOperationExecutionContext * context)455 bool executeRelu6(IOperationExecutionContext* context) {
456 // Bypass execution in the case of zero-sized input.
457 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
458 switch (context->getInputType(kInputTensor)) {
459 case OperandType::TENSOR_FLOAT16:
460 return relu6Float(context->getInputBuffer<_Float16>(kInputTensor),
461 context->getInputShape(kInputTensor),
462 context->getOutputBuffer<_Float16>(kOutputTensor),
463 context->getOutputShape(kOutputTensor));
464 case OperandType::TENSOR_FLOAT32:
465 return relu6Float(context->getInputBuffer<float>(kInputTensor),
466 context->getInputShape(kInputTensor),
467 context->getOutputBuffer<float>(kOutputTensor),
468 context->getOutputShape(kOutputTensor));
469 case OperandType::TENSOR_QUANT8_ASYMM:
470 return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
471 context->getInputShape(kInputTensor),
472 context->getOutputBuffer<uint8_t>(kOutputTensor),
473 context->getOutputShape(kOutputTensor));
474 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
475 return relu6Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
476 context->getInputShape(kInputTensor),
477 context->getOutputBuffer<int8_t>(kOutputTensor),
478 context->getOutputShape(kOutputTensor));
479 default:
480 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6";
481 }
482 }
483
executeLogistic(IOperationExecutionContext * context)484 bool executeLogistic(IOperationExecutionContext* context) {
485 // Bypass execution in the case of zero-sized input.
486 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
487 switch (context->getInputType(kInputTensor)) {
488 case OperandType::TENSOR_FLOAT16:
489 return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor),
490 context->getInputShape(kInputTensor),
491 context->getOutputBuffer<_Float16>(kOutputTensor),
492 context->getOutputShape(kOutputTensor));
493 case OperandType::TENSOR_FLOAT32:
494 return logisticFloat(context->getInputBuffer<float>(kInputTensor),
495 context->getInputShape(kInputTensor),
496 context->getOutputBuffer<float>(kOutputTensor),
497 context->getOutputShape(kOutputTensor));
498 case OperandType::TENSOR_QUANT8_ASYMM:
499 return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
500 context->getInputShape(kInputTensor),
501 context->getOutputBuffer<uint8_t>(kOutputTensor),
502 context->getOutputShape(kOutputTensor));
503 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
504 return logisticQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
505 context->getInputShape(kInputTensor),
506 context->getOutputBuffer<int8_t>(kOutputTensor),
507 context->getOutputShape(kOutputTensor));
508 default:
509 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
510 }
511 }
512
executeTanh(IOperationExecutionContext * context)513 bool executeTanh(IOperationExecutionContext* context) {
514 // Bypass execution in the case of zero-sized input.
515 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
516 switch (context->getInputType(kInputTensor)) {
517 case OperandType::TENSOR_FLOAT16:
518 return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor),
519 context->getInputShape(kInputTensor),
520 context->getOutputBuffer<_Float16>(kOutputTensor),
521 context->getOutputShape(kOutputTensor));
522 case OperandType::TENSOR_FLOAT32:
523 return tanhFloat32(context->getInputBuffer<float>(kInputTensor),
524 context->getInputShape(kInputTensor),
525 context->getOutputBuffer<float>(kOutputTensor),
526 context->getOutputShape(kOutputTensor));
527 case OperandType::TENSOR_QUANT8_ASYMM:
528 return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
529 context->getInputShape(kInputTensor),
530 context->getOutputBuffer<uint8_t>(kOutputTensor),
531 context->getOutputShape(kOutputTensor));
532 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
533 return tanhQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
534 context->getInputShape(kInputTensor),
535 context->getOutputBuffer<int8_t>(kOutputTensor),
536 context->getOutputShape(kOutputTensor));
537 default:
538 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
539 }
540 }
541
executeHardSwish(IOperationExecutionContext * context)542 bool executeHardSwish(IOperationExecutionContext* context) {
543 // Bypass execution in the case of zero-sized input.
544 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
545 switch (context->getInputType(kInputTensor)) {
546 case OperandType::TENSOR_FLOAT16: {
547 const Shape& inputShape = context->getInputShape(kInputTensor);
548 const Shape& outputShape = context->getOutputShape(kOutputTensor);
549 std::vector<float> inputFloat(getNumberOfElements(inputShape));
550 std::vector<float> outputFloat(getNumberOfElements(outputShape));
551 convertFloat16ToFloat32(context->getInputBuffer<_Float16>(kInputTensor), &inputFloat);
552 tflite::reference_ops::HardSwish(convertShapeToTflshape(inputShape), inputFloat.data(),
553 convertShapeToTflshape(outputShape),
554 outputFloat.data());
555 convertFloat32ToFloat16(outputFloat, context->getOutputBuffer<_Float16>(kOutputTensor));
556 return true;
557 }
558 case OperandType::TENSOR_FLOAT32: {
559 tflite::reference_ops::HardSwish(
560 convertShapeToTflshape(context->getInputShape(kInputTensor)),
561 context->getInputBuffer<float>(kInputTensor),
562 convertShapeToTflshape(context->getOutputShape(kOutputTensor)),
563 context->getOutputBuffer<float>(kOutputTensor));
564 return true;
565 }
566 case OperandType::TENSOR_QUANT8_ASYMM:
567 return hardSwishQuant(context->getInputBuffer<uint8_t>(kInputTensor),
568 context->getInputShape(kInputTensor),
569 context->getOutputBuffer<uint8_t>(kOutputTensor),
570 context->getOutputShape(kOutputTensor));
571 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
572 return hardSwishQuant(context->getInputBuffer<int8_t>(kInputTensor),
573 context->getInputShape(kInputTensor),
574 context->getOutputBuffer<int8_t>(kOutputTensor),
575 context->getOutputShape(kOutputTensor));
576 default:
577 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
578 }
579 }
580 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
581
582 } // namespace activation
583
584 using std::placeholders::_1;
585 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(RELU,
586 std::bind(activation::prepare, OperationType::RELU, _1),
587 activation::executeRelu, .allowZeroSizedInput = true);
588 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(RELU1,
589 std::bind(activation::prepare, OperationType::RELU1, _1),
590 activation::executeRelu1, .allowZeroSizedInput = true);
591 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(RELU6,
592 std::bind(activation::prepare, OperationType::RELU6, _1),
593 activation::executeRelu6, .allowZeroSizedInput = true);
594 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(LOGISTIC,
595 std::bind(activation::prepare, OperationType::LOGISTIC,
596 _1),
597 activation::executeLogistic, .allowZeroSizedInput = true);
598 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(TANH,
599 std::bind(activation::prepare, OperationType::TANH, _1),
600 activation::executeTanh, .allowZeroSizedInput = true);
601 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(HARD_SWISH,
602 std::bind(activation::prepare, OperationType::HARD_SWISH,
603 _1),
604 activation::executeHardSwish, .allowZeroSizedInput = true);
605
606 } // namespace nn
607 } // namespace android
608