1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "Activation.h"
20 
21 #include <algorithm>
22 #include <limits>
23 #include <vector>
24 
25 #include "ActivationFunctor.h"
26 #include "OperationResolver.h"
27 #include "OperationsExecutionUtils.h"
28 #include "Tracing.h"
29 
30 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
31 #pragma clang diagnostic push
32 #pragma clang diagnostic ignored "-Wunused-parameter"
33 #pragma clang diagnostic ignored "-Wsign-compare"
34 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
35 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
36 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
37 #include <tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h>
38 #include <tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h>
39 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
40 #pragma clang diagnostic pop
41 
42 #include "CpuOperationUtils.h"
43 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
44 
45 namespace android {
46 namespace nn {
47 
48 namespace activation {
49 
50 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
51 namespace {
52 
53 template <typename T>
reluFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape &,float reluMin=0.f,float reluMax=std::numeric_limits<float>::max ())54 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData,
55                const Shape& /*outputShape*/, float reluMin = 0.f,
56                float reluMax = std::numeric_limits<float>::max()) {
57     NNTRACE_COMP("reluX");
58     int numElements = getNumberOfElements(inputShape);
59     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
60         *outputData = static_cast<T>(
61                 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax));
62     }
63     return true;
64 }
65 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData,
66                                const Shape& outputShape, float reluMin, float reluMax);
67 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
68                                   _Float16* outputData, const Shape& outputShape, float reluMin,
69                                   float reluMax);
70 
71 template <typename T>
relu1Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)72 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData,
73                 const Shape& outputShape) {
74     return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f);
75 }
76 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
77                                 const Shape& outputShape);
78 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
79                                    _Float16* outputData, const Shape& outputShape);
80 
81 template <typename T>
relu6Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)82 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData,
83                 const Shape& outputShape) {
84     return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f);
85 }
86 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
87                                 const Shape& outputShape);
88 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
89                                    _Float16* outputData, const Shape& outputShape);
90 
tanhFloat16(const _Float16 * inputData,const Shape & inputShape,_Float16 * outputData,const Shape &)91 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData,
92                  const Shape& /*outputShape*/) {
93     NNTRACE_COMP("tanhFloat16");
94     int numElements = getNumberOfElements(inputShape);
95     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
96         *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData)));
97     }
98     return true;
99 }
100 
tanhFloat32(const float * inputData,const Shape & inputShape,float * outputData,const Shape &)101 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData,
102                  const Shape& /*outputShape*/) {
103     NNTRACE_COMP("tanhFloat32");
104     int numElements = getNumberOfElements(inputShape);
105     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
106         *outputData = std::tanh(*inputData);
107     }
108     return true;
109 }
110 
111 template <typename T>
logisticFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape &)112 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData,
113                    const Shape& /*outputShape*/) {
114     NNTRACE_COMP("logisticFloat");
115     int numElements = getNumberOfElements(inputShape);
116     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
117         *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData))));
118     }
119     return true;
120 }
121 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape,
122                                    float* outputData, const Shape& outputShape);
123 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
124                                       _Float16* outputData, const Shape& outputShape);
125 
126 template <ActivationFn activation>
reluXQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape &)127 inline bool reluXQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
128                         const Shape& /*outputShape*/) {
129     int numElements = getNumberOfElements(inputShape);
130     int32_t output_activation_min = 0;
131     int32_t output_activation_max = 0;
132 
133     CalculateActivationRangeUint8(activation, inputShape, &output_activation_min,
134                                   &output_activation_max);
135 
136     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
137         *outputData = std::min((uint8_t)output_activation_max,
138                                std::max((uint8_t)output_activation_min, *inputData));
139     }
140     return true;
141 }
142 
reluQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)143 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
144                 const Shape& outputShape) {
145     NNTRACE_COMP("reluQuant8");
146     return reluXQuant8<kActivationRelu>(inputData, inputShape, outputData, outputShape);
147 }
148 
relu1Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)149 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
150                  const Shape& outputShape) {
151     NNTRACE_COMP("relu1Quant8");
152     return reluXQuant8<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
153 }
154 
relu6Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)155 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
156                  const Shape& outputShape) {
157     NNTRACE_COMP("relu6Quant8");
158     return reluXQuant8<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
159 }
160 
tanhQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)161 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
162                 const Shape& outputShape) {
163     NNTRACE_TRANS("tanhQuant8");
164     if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) {
165         LOG(ERROR) << "incorrect scale or offset for TANH output";
166         return false;
167     }
168 
169     [[maybe_unused]] int numElements = getNumberOfElements(inputShape);
170     static constexpr int kInputIntegerBits = 4;
171 
172     const double input_real_multiplier =
173             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
174 
175     int32_t input_multiplier = 0;
176     int32_t input_left_shift = 0;
177     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
178                                           &input_left_shift)) {
179         return false;
180     }
181     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
182 
183     NNTRACE_COMP_SWITCH("optimized_ops::Tanh");
184     tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset,
185                                 input_range_radius, input_multiplier, input_left_shift, outputData,
186                                 convertShapeToTflshape(outputShape));
187 
188     return true;
189 }
190 
logisticQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)191 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
192                     const Shape& outputShape) {
193     NNTRACE_TRANS("logisticQuant8");
194     if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
195         LOG(ERROR) << "incorrect scale / offset for output";
196         return false;
197     }
198 
199     [[maybe_unused]] int numElements = getNumberOfElements(inputShape);
200     static constexpr int kInputIntegerBits = 4;
201 
202     const double input_real_multiplier =
203             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
204 
205     int32_t input_multiplier = 0;
206     int32_t input_left_shift = 0;
207     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
208                                           &input_left_shift)) {
209         return false;
210     }
211     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
212 
213     NNTRACE_COMP_SWITCH("optimized_ops::Logistic");
214     tflite::optimized_ops::Logistic(
215             inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius,
216             input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape));
217 
218     return true;
219 }
220 
221 template <ActivationFn activation>
reluXQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape &)222 inline bool reluXQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
223                               const Shape& /*outputShape*/) {
224     int numElements = getNumberOfElements(inputShape);
225     int32_t output_activation_min = 0;
226     int32_t output_activation_max = 0;
227 
228     CalculateActivationRangeInt8(activation, inputShape, &output_activation_min,
229                                  &output_activation_max);
230 
231     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
232         *outputData = std::min((int8_t)output_activation_max,
233                                std::max((int8_t)output_activation_min, *inputData));
234     }
235     return true;
236 }
237 
reluQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)238 bool reluQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
239                       const Shape& outputShape) {
240     NNTRACE_COMP("reluQuant8");
241     return reluXQuant8Signed<kActivationRelu>(inputData, inputShape, outputData, outputShape);
242 }
243 
relu1Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)244 bool relu1Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
245                        const Shape& outputShape) {
246     NNTRACE_COMP("relu1Quant8");
247     return reluXQuant8Signed<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
248 }
249 
relu6Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)250 bool relu6Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
251                        const Shape& outputShape) {
252     NNTRACE_COMP("relu6Quant8");
253     return reluXQuant8Signed<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
254 }
255 
tanhQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)256 bool tanhQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
257                       const Shape& outputShape) {
258     NNTRACE_TRANS("tanhQuant8Signed");
259     if (outputShape.offset != 0 || outputShape.scale != 1.f / 128) {
260         LOG(ERROR) << "incorrect scale or offset for TANH output";
261         return false;
262     }
263 
264     [[maybe_unused]] int numElements = getNumberOfElements(inputShape);
265     static constexpr int kInputIntegerBits = 4;
266 
267     const double input_real_multiplier =
268             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
269 
270     int32_t input_multiplier = 0;
271     int32_t input_left_shift = 0;
272     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
273                                           &input_left_shift)) {
274         return false;
275     }
276     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
277 
278     NNTRACE_COMP_SWITCH("reference_integer_ops::Tanh");
279     tflite::reference_integer_ops::Tanh(inputShape.offset, input_range_radius, input_multiplier,
280                                         input_left_shift, convertShapeToTflshape(inputShape),
281                                         inputData, convertShapeToTflshape(outputShape), outputData);
282 
283     return true;
284 }
285 
logisticQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)286 bool logisticQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
287                           const Shape& outputShape) {
288     NNTRACE_TRANS("logisticQuant8Signed");
289     if (outputShape.offset != -128 || outputShape.scale != 1.f / 256) {
290         LOG(ERROR) << "incorrect scale / offset for output";
291         return false;
292     }
293 
294     int numElements = getNumberOfElements(inputShape);
295     static constexpr int kInputIntegerBits = 4;
296 
297     const double input_real_multiplier =
298             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
299 
300     int32_t input_multiplier = 0;
301     int32_t input_left_shift = 0;
302     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
303                                           &input_left_shift)) {
304         return false;
305     }
306     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
307 
308     NNTRACE_COMP_SWITCH("reference_integer_ops::Logistic");
309     tflite::reference_integer_ops::Logistic(inputShape.offset, input_range_radius, input_multiplier,
310                                             input_left_shift, numElements, inputData, outputData);
311 
312     return true;
313 }
314 
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,int16_t * multiplier_int16)315 void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32, int16_t* multiplier_int16) {
316     TFLITE_DCHECK_GE(multiplier_int32, 0);
317     static constexpr int32_t kRoundingOffset = 1 << 15;
318     if (multiplier_int32 >= std::numeric_limits<int32_t>::max() - kRoundingOffset) {
319         *multiplier_int16 = std::numeric_limits<int16_t>::max();
320         return;
321     }
322     const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
323     TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
324     TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
325     *multiplier_int16 = result;
326     TFLITE_DCHECK_EQ(*multiplier_int16, result);
327 }
328 
329 template <typename T>
hardSwishQuant(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)330 bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData,
331                     const Shape& outputShape) {
332     tflite::HardSwishParams params;
333     params.input_zero_point = inputShape.offset;
334     params.output_zero_point = outputShape.offset;
335     const float input_scale = inputShape.scale;
336     const float hires_input_scale = (1.0f / 128.0f) * input_scale;
337     const float reluish_scale = 3.0f / 32768.0f;
338     const float output_scale = outputShape.scale;
339 
340     const float output_multiplier = hires_input_scale / output_scale;
341 
342     int32_t output_multiplier_fixedpoint_int32;
343     NN_RET_CHECK(QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
344                                     &params.output_multiplier_exponent));
345     DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32,
346                                     &params.output_multiplier_fixedpoint_int16);
347     NN_RET_CHECK(params.output_multiplier_exponent <= 0);
348 
349     const float reluish_multiplier = hires_input_scale / reluish_scale;
350     int32_t reluish_multiplier_fixedpoint_int32;
351     NN_RET_CHECK(QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
352                                     &params.reluish_multiplier_exponent));
353     DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32,
354                                     &params.reluish_multiplier_fixedpoint_int16);
355 
356     tflite::reference_ops::HardSwish(params, convertShapeToTflshape(inputShape), inputData,
357                                      convertShapeToTflshape(outputShape), outputData);
358     return true;
359 }
360 
361 }  // namespace
362 
prepare(OperationType opType,IOperationExecutionContext * context)363 bool prepare(OperationType opType, IOperationExecutionContext* context) {
364     Shape input = context->getInputShape(kInputTensor);
365     if (opType != OperationType::HARD_SWISH) {
366         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4u);
367     }
368     Shape output = input;
369     if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
370         input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
371         bool isSigned = input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED;
372         switch (opType) {
373             case OperationType::HARD_SWISH: {
374                 auto outputShape = context->getOutputShape(kOutputTensor);
375                 output.scale = outputShape.scale;
376                 output.offset = outputShape.offset;
377             } break;
378             case OperationType::RELU:
379             case OperationType::RELU1:
380             case OperationType::RELU6:
381                 break;
382             case OperationType::LOGISTIC:
383                 output.scale = 1.f / 256;
384                 output.offset = isSigned ? -128 : 0;
385                 break;
386             case OperationType::TANH:
387                 output.scale = 1.f / 128;
388                 output.offset = isSigned ? 0 : 128;
389                 break;
390             default:
391                 NN_RET_CHECK_FAIL() << "Unsupported operation type";
392         }
393     }
394     return context->setOutputShape(kOutputTensor, output);
395 }
396 
executeRelu(IOperationExecutionContext * context)397 bool executeRelu(IOperationExecutionContext* context) {
398     // Bypass execution in the case of zero-sized input.
399     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
400     switch (context->getInputType(kInputTensor)) {
401         case OperandType::TENSOR_FLOAT16:
402             return reluFloat(context->getInputBuffer<_Float16>(kInputTensor),
403                              context->getInputShape(kInputTensor),
404                              context->getOutputBuffer<_Float16>(kOutputTensor),
405                              context->getOutputShape(kOutputTensor));
406         case OperandType::TENSOR_FLOAT32:
407             return reluFloat(context->getInputBuffer<float>(kInputTensor),
408                              context->getInputShape(kInputTensor),
409                              context->getOutputBuffer<float>(kOutputTensor),
410                              context->getOutputShape(kOutputTensor));
411         case OperandType::TENSOR_QUANT8_ASYMM:
412             return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
413                               context->getInputShape(kInputTensor),
414                               context->getOutputBuffer<uint8_t>(kOutputTensor),
415                               context->getOutputShape(kOutputTensor));
416         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
417             return reluQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
418                                     context->getInputShape(kInputTensor),
419                                     context->getOutputBuffer<int8_t>(kOutputTensor),
420                                     context->getOutputShape(kOutputTensor));
421         default:
422             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU";
423     }
424 }
425 
executeRelu1(IOperationExecutionContext * context)426 bool executeRelu1(IOperationExecutionContext* context) {
427     // Bypass execution in the case of zero-sized input.
428     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
429     switch (context->getInputType(kInputTensor)) {
430         case OperandType::TENSOR_FLOAT16:
431             return relu1Float(context->getInputBuffer<_Float16>(kInputTensor),
432                               context->getInputShape(kInputTensor),
433                               context->getOutputBuffer<_Float16>(kOutputTensor),
434                               context->getOutputShape(kOutputTensor));
435         case OperandType::TENSOR_FLOAT32:
436             return relu1Float(context->getInputBuffer<float>(kInputTensor),
437                               context->getInputShape(kInputTensor),
438                               context->getOutputBuffer<float>(kOutputTensor),
439                               context->getOutputShape(kOutputTensor));
440         case OperandType::TENSOR_QUANT8_ASYMM:
441             return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
442                                context->getInputShape(kInputTensor),
443                                context->getOutputBuffer<uint8_t>(kOutputTensor),
444                                context->getOutputShape(kOutputTensor));
445         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
446             return relu1Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
447                                      context->getInputShape(kInputTensor),
448                                      context->getOutputBuffer<int8_t>(kOutputTensor),
449                                      context->getOutputShape(kOutputTensor));
450         default:
451             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1";
452     }
453 }
454 
executeRelu6(IOperationExecutionContext * context)455 bool executeRelu6(IOperationExecutionContext* context) {
456     // Bypass execution in the case of zero-sized input.
457     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
458     switch (context->getInputType(kInputTensor)) {
459         case OperandType::TENSOR_FLOAT16:
460             return relu6Float(context->getInputBuffer<_Float16>(kInputTensor),
461                               context->getInputShape(kInputTensor),
462                               context->getOutputBuffer<_Float16>(kOutputTensor),
463                               context->getOutputShape(kOutputTensor));
464         case OperandType::TENSOR_FLOAT32:
465             return relu6Float(context->getInputBuffer<float>(kInputTensor),
466                               context->getInputShape(kInputTensor),
467                               context->getOutputBuffer<float>(kOutputTensor),
468                               context->getOutputShape(kOutputTensor));
469         case OperandType::TENSOR_QUANT8_ASYMM:
470             return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
471                                context->getInputShape(kInputTensor),
472                                context->getOutputBuffer<uint8_t>(kOutputTensor),
473                                context->getOutputShape(kOutputTensor));
474         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
475             return relu6Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
476                                      context->getInputShape(kInputTensor),
477                                      context->getOutputBuffer<int8_t>(kOutputTensor),
478                                      context->getOutputShape(kOutputTensor));
479         default:
480             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6";
481     }
482 }
483 
executeLogistic(IOperationExecutionContext * context)484 bool executeLogistic(IOperationExecutionContext* context) {
485     // Bypass execution in the case of zero-sized input.
486     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
487     switch (context->getInputType(kInputTensor)) {
488         case OperandType::TENSOR_FLOAT16:
489             return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor),
490                                  context->getInputShape(kInputTensor),
491                                  context->getOutputBuffer<_Float16>(kOutputTensor),
492                                  context->getOutputShape(kOutputTensor));
493         case OperandType::TENSOR_FLOAT32:
494             return logisticFloat(context->getInputBuffer<float>(kInputTensor),
495                                  context->getInputShape(kInputTensor),
496                                  context->getOutputBuffer<float>(kOutputTensor),
497                                  context->getOutputShape(kOutputTensor));
498         case OperandType::TENSOR_QUANT8_ASYMM:
499             return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
500                                   context->getInputShape(kInputTensor),
501                                   context->getOutputBuffer<uint8_t>(kOutputTensor),
502                                   context->getOutputShape(kOutputTensor));
503         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
504             return logisticQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
505                                         context->getInputShape(kInputTensor),
506                                         context->getOutputBuffer<int8_t>(kOutputTensor),
507                                         context->getOutputShape(kOutputTensor));
508         default:
509             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
510     }
511 }
512 
executeTanh(IOperationExecutionContext * context)513 bool executeTanh(IOperationExecutionContext* context) {
514     // Bypass execution in the case of zero-sized input.
515     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
516     switch (context->getInputType(kInputTensor)) {
517         case OperandType::TENSOR_FLOAT16:
518             return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor),
519                                context->getInputShape(kInputTensor),
520                                context->getOutputBuffer<_Float16>(kOutputTensor),
521                                context->getOutputShape(kOutputTensor));
522         case OperandType::TENSOR_FLOAT32:
523             return tanhFloat32(context->getInputBuffer<float>(kInputTensor),
524                                context->getInputShape(kInputTensor),
525                                context->getOutputBuffer<float>(kOutputTensor),
526                                context->getOutputShape(kOutputTensor));
527         case OperandType::TENSOR_QUANT8_ASYMM:
528             return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
529                               context->getInputShape(kInputTensor),
530                               context->getOutputBuffer<uint8_t>(kOutputTensor),
531                               context->getOutputShape(kOutputTensor));
532         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
533             return tanhQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
534                                     context->getInputShape(kInputTensor),
535                                     context->getOutputBuffer<int8_t>(kOutputTensor),
536                                     context->getOutputShape(kOutputTensor));
537         default:
538             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
539     }
540 }
541 
executeHardSwish(IOperationExecutionContext * context)542 bool executeHardSwish(IOperationExecutionContext* context) {
543     // Bypass execution in the case of zero-sized input.
544     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
545     switch (context->getInputType(kInputTensor)) {
546         case OperandType::TENSOR_FLOAT16: {
547             const Shape& inputShape = context->getInputShape(kInputTensor);
548             const Shape& outputShape = context->getOutputShape(kOutputTensor);
549             std::vector<float> inputFloat(getNumberOfElements(inputShape));
550             std::vector<float> outputFloat(getNumberOfElements(outputShape));
551             convertFloat16ToFloat32(context->getInputBuffer<_Float16>(kInputTensor), &inputFloat);
552             tflite::reference_ops::HardSwish(convertShapeToTflshape(inputShape), inputFloat.data(),
553                                              convertShapeToTflshape(outputShape),
554                                              outputFloat.data());
555             convertFloat32ToFloat16(outputFloat, context->getOutputBuffer<_Float16>(kOutputTensor));
556             return true;
557         }
558         case OperandType::TENSOR_FLOAT32: {
559             tflite::reference_ops::HardSwish(
560                     convertShapeToTflshape(context->getInputShape(kInputTensor)),
561                     context->getInputBuffer<float>(kInputTensor),
562                     convertShapeToTflshape(context->getOutputShape(kOutputTensor)),
563                     context->getOutputBuffer<float>(kOutputTensor));
564             return true;
565         }
566         case OperandType::TENSOR_QUANT8_ASYMM:
567             return hardSwishQuant(context->getInputBuffer<uint8_t>(kInputTensor),
568                                   context->getInputShape(kInputTensor),
569                                   context->getOutputBuffer<uint8_t>(kOutputTensor),
570                                   context->getOutputShape(kOutputTensor));
571         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
572             return hardSwishQuant(context->getInputBuffer<int8_t>(kInputTensor),
573                                   context->getInputShape(kInputTensor),
574                                   context->getOutputBuffer<int8_t>(kOutputTensor),
575                                   context->getOutputShape(kOutputTensor));
576         default:
577             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
578     }
579 }
580 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
581 
582 }  // namespace activation
583 
584 using std::placeholders::_1;
585 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(RELU,
586                                          std::bind(activation::prepare, OperationType::RELU, _1),
587                                          activation::executeRelu, .allowZeroSizedInput = true);
588 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(RELU1,
589                                          std::bind(activation::prepare, OperationType::RELU1, _1),
590                                          activation::executeRelu1, .allowZeroSizedInput = true);
591 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(RELU6,
592                                          std::bind(activation::prepare, OperationType::RELU6, _1),
593                                          activation::executeRelu6, .allowZeroSizedInput = true);
594 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(LOGISTIC,
595                                          std::bind(activation::prepare, OperationType::LOGISTIC,
596                                                    _1),
597                                          activation::executeLogistic, .allowZeroSizedInput = true);
598 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(TANH,
599                                          std::bind(activation::prepare, OperationType::TANH, _1),
600                                          activation::executeTanh, .allowZeroSizedInput = true);
601 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(HARD_SWISH,
602                                          std::bind(activation::prepare, OperationType::HARD_SWISH,
603                                                    _1),
604                                          activation::executeHardSwish, .allowZeroSizedInput = true);
605 
606 }  // namespace nn
607 }  // namespace android
608