1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains the implementation of the operations.
18 
19 #define LOG_TAG "Operations"
20 
21 #include "Broadcast.h"
22 
23 #include <algorithm>
24 #include <functional>
25 #include <vector>
26 
27 #include "IndexedShapeWrapper.h"
28 #include "OperationResolver.h"
29 #include "Tracing.h"
30 #include "nnapi/Types.h"
31 #include "nnapi/Validation.h"
32 
33 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
34 #pragma clang diagnostic push
35 #pragma clang diagnostic ignored "-Wunused-parameter"
36 #pragma clang diagnostic ignored "-Wsign-compare"
37 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
38 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/add.h>
39 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h>
40 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
41 #include <tensorflow/lite/kernels/internal/reference/integer_ops/add.h>
42 #include <tensorflow/lite/kernels/internal/reference/integer_ops/mul.h>
43 #include <tensorflow/lite/kernels/internal/types.h>
44 #pragma clang diagnostic pop
45 
46 #include "CpuOperationUtils.h"
47 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
48 
49 namespace android {
50 namespace nn {
51 
52 namespace broadcast {
53 
54 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
55 namespace {
56 
57 #define ANDROID_NN_MACRO_DISPATCH(macro)                                \
58     switch (activation) {                                               \
59         case static_cast<int32_t>(FusedActivationFunc::NONE):           \
60             macro(kNone);                                               \
61             break;                                                      \
62         case static_cast<int32_t>(FusedActivationFunc::RELU):           \
63             macro(kRelu);                                               \
64             break;                                                      \
65         case static_cast<int32_t>(FusedActivationFunc::RELU1):          \
66             macro(kRelu1);                                              \
67             break;                                                      \
68         case static_cast<int32_t>(FusedActivationFunc::RELU6):          \
69             macro(kRelu6);                                              \
70             break;                                                      \
71         default:                                                        \
72             LOG(ERROR) << "Unsupported fused activation function type"; \
73             return false;                                               \
74     }
75 
76 using binaryFunctionFloat32 = std::function<bool(
77         const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
78         int32_t activation, float* out, const Shape& shapeOut)>;
79 
binaryOperationFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut,binaryFunctionFloat32 operationFloat32)80 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
81                             const Shape& shape2, int32_t activation, _Float16* out,
82                             const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
83     std::vector<float> in1_float32(getNumberOfElements(shape1));
84     convertFloat16ToFloat32(in1, &in1_float32);
85     std::vector<float> in2_float32(getNumberOfElements(shape2));
86     convertFloat16ToFloat32(in2, &in2_float32);
87     std::vector<float> out_float32(getNumberOfElements(shapeOut));
88 
89     operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
90                      out_float32.data(), shapeOut);
91     convertFloat32ToFloat16(out_float32, out);
92 
93     return true;
94 }
95 
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)96 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
97                 int32_t activation, float* out, const Shape& shapeOut) {
98     NNTRACE_TRANS("addFloat32");
99     bool needBroadcast = !SameShape(shape1, shape2);
100     if (needBroadcast) {
101         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
102 #define ANDROID_NN_BROADCAST_ADD(activation)                                              \
103     tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
104             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
105             convertShapeToDims(shapeOut))
106 
107         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
108 #undef ANDROID_NN_BROADCAST_ADD
109     } else {
110         NNTRACE_COMP_SWITCH("optimized_ops::Add");
111 #define ANDROID_NN_ADD(activation)                                                 \
112     tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(   \
113             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
114             convertShapeToDims(shapeOut))
115 
116         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
117 #undef ANDROID_NN_ADD
118     }
119 
120     return true;
121 }
122 
addFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)123 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
124                 int32_t activation, _Float16* out, const Shape& shapeOut) {
125     NNTRACE_TRANS("addFloat16");
126     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
127 }
128 
129 template <typename T>
addQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)130 bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
131                int32_t activation, T* out, const Shape& shapeOut) {
132     NNTRACE_TRANS("addQuant8");
133     const bool needBroadcast = !SameShape(shape1, shape2);
134 
135     const int32_t input1_offset = -shape1.offset;
136     const int32_t input2_offset = -shape2.offset;
137     const int32_t output_offset = shapeOut.offset;
138     const int left_shift = 20;
139     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
140     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
141     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
142     const double real_output_multiplier =
143             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
144 
145     int32_t input1_multiplier;
146     int32_t input1_shift;
147     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
148                                                      &input1_shift));
149     int32_t input2_multiplier;
150     int32_t input2_shift;
151     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
152                                                      &input2_shift));
153     int32_t output_multiplier;
154     int32_t output_shift;
155     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
156                                                      &output_shift));
157 
158     int32_t output_activation_min;
159     int32_t output_activation_max;
160     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
161     if constexpr (isSignedOp) {
162         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
163                                      &output_activation_max);
164     } else {
165         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
166                                       &output_activation_max);
167     }
168 
169     tflite::ArithmeticParams op_params;
170     op_params.left_shift = left_shift;
171     op_params.input1_offset = input1_offset;
172     op_params.input1_multiplier = input1_multiplier;
173     op_params.input1_shift = input1_shift;
174     op_params.input2_offset = input2_offset;
175     op_params.input2_multiplier = input2_multiplier;
176     op_params.input2_shift = input2_shift;
177     op_params.output_offset = output_offset;
178     op_params.output_multiplier = output_multiplier;
179     op_params.output_shift = output_shift;
180     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
181 
182     if (needBroadcast) {
183         if constexpr (isSignedOp) {
184             NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
185             tflite::reference_integer_ops::BroadcastAdd4DSlow(
186                     op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2),
187                     in2, convertShapeToTflshape(shapeOut), out);
188         } else {
189             NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
190             tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
191                                                       in1, convertShapeToTflshape(shape2), in2,
192                                                       convertShapeToTflshape(shapeOut), out);
193         }
194     } else {
195         if constexpr (isSignedOp) {
196             NNTRACE_COMP_SWITCH("optimized_integer_ops::Add");
197             tflite::optimized_integer_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
198                                                convertShapeToTflshape(shape2), in2,
199                                                convertShapeToTflshape(shapeOut), out);
200         } else {
201             NNTRACE_COMP_SWITCH("optimized_ops::Add");
202             tflite::optimized_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
203                                        convertShapeToTflshape(shape2), in2,
204                                        convertShapeToTflshape(shapeOut), out);
205         }
206     }
207 
208     return true;
209 }
210 
executeInt32(const int32_t * aData,const Shape & aShape,const int32_t * bData,const Shape & bShape,int32_t activation,int32_t * outputData,const Shape & outputShape,int32_t func (int32_t,int32_t))211 bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
212                   const Shape& bShape, int32_t activation, int32_t* outputData,
213                   const Shape& outputShape, int32_t func(int32_t, int32_t)) {
214     NN_RET_CHECK_EQ(static_cast<FusedActivationFunc>(activation), FusedActivationFunc::NONE);
215     IndexedShapeWrapper aShapeIndexed(aShape);
216     IndexedShapeWrapper bShapeIndexed(bShape);
217     IndexedShapeWrapper outputShapeIndexed(outputShape);
218     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
219     bool lastIndex = false;
220     do {
221         uint32_t outputFlatIndex;
222         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
223         uint32_t aFlatIndex;
224         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
225         uint32_t bFlatIndex;
226         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
227 
228         outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
229 
230         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
231     } while (!lastIndex);
232     return true;
233 }
234 
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)235 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
236                 int32_t activation, float* out, const Shape& shapeOut) {
237     NNTRACE_TRANS("mulFloat32");
238     bool needBroadcast = !SameShape(shape1, shape2);
239 
240     if (needBroadcast) {
241         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
242 #define ANDROID_NN_BROADCAST_MUL(activation)                                              \
243     tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
244             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
245             convertShapeToDims(shapeOut))
246 
247         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
248 #undef ANDROID_NN_BROADCAST_MUL
249     } else {
250         float output_activation_min, output_activation_max;
251         CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
252 
253         NNTRACE_COMP_SWITCH("optimized_ops::Mul");
254         tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
255                                    output_activation_min, output_activation_max, out,
256                                    convertShapeToDims(shapeOut));
257     }
258 
259     return true;
260 }
261 
mulFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)262 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
263                 int32_t activation, _Float16* out, const Shape& shapeOut) {
264     NNTRACE_TRANS("mulFloat16");
265     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
266 }
267 
268 template <typename T>
mulQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)269 bool mulQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
270                int32_t activation, T* out, const Shape& shapeOut) {
271     NNTRACE_TRANS("mulQuant8");
272     const int32_t input1_offset = -shape1.offset;
273     const int32_t input2_offset = -shape2.offset;
274     const int32_t output_offset = shapeOut.offset;
275     const double input_product_scale = shape1.scale * shape2.scale;
276     const double real_multiplier = input_product_scale / shapeOut.scale;
277     int32 output_multiplier;
278     int output_shift;
279     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
280                                                      &output_shift));
281 
282     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
283     int32_t output_activation_min;
284     int32_t output_activation_max;
285     if constexpr (isSignedOp) {
286         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
287                                      &output_activation_max);
288     } else {
289         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
290                                       &output_activation_max);
291     }
292 
293     tflite::ArithmeticParams op_params;
294     op_params.input1_offset = input1_offset;
295     op_params.input2_offset = input2_offset;
296     op_params.output_offset = output_offset;
297     op_params.output_multiplier = output_multiplier;
298     op_params.output_shift = output_shift;
299     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
300 
301     if constexpr (isSignedOp) {
302         NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastMul4DSlow");
303         tflite::reference_integer_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1),
304                                                           in1, convertShapeToTflshape(shape2), in2,
305                                                           convertShapeToTflshape(shapeOut), out);
306     } else {
307         NNTRACE_COMP_SWITCH("reference_ops::BroadcastMul4DSlow");
308         tflite::reference_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1,
309                                                   convertShapeToTflshape(shape2), in2,
310                                                   convertShapeToTflshape(shapeOut), out);
311     }
312 
313     return true;
314 }
315 
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)316 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
317                 int32_t activation, float* out, const Shape& shapeOut) {
318     NNTRACE_TRANS("subFloat32");
319     NNTRACE_COMP_SWITCH("optimized_ops::Sub");
320     tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
321                                out, convertShapeToDims(shapeOut));
322 
323     // TFLite does not apply activation to broadcast sub.
324     float output_activation_min, output_activation_max;
325     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
326     uint32_t numOutputElements = getNumberOfElements(shapeOut);
327     for (uint32_t i = 0; i < numOutputElements; i++) {
328         out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
329     }
330     return true;
331 }
332 
subFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)333 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
334                 int32_t activation, _Float16* out, const Shape& shapeOut) {
335     NNTRACE_TRANS("subFloat16");
336     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
337 }
338 
339 template <typename T>
subQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)340 bool subQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
341                int32_t activation, T* out, const Shape& shapeOut) {
342     NNTRACE_TRANS("subQuant8");
343 
344     const int32_t input1_offset = -shape1.offset;
345     const int32_t input2_offset = -shape2.offset;
346     const int32_t output_offset = shapeOut.offset;
347     const int left_shift = 20;
348     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
349     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
350     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
351     const double real_output_multiplier =
352             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
353 
354     int32_t input1_multiplier;
355     int32_t input1_shift;
356     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
357                                                      &input1_shift));
358     int32_t input2_multiplier;
359     int32_t input2_shift;
360     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
361                                                      &input2_shift));
362     // Negate multiplier of the second input, so that we can use Add kernels.
363     input2_multiplier *= -1;
364 
365     int32_t output_multiplier;
366     int32_t output_shift;
367     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
368                                                      &output_shift));
369 
370     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
371     int32_t output_activation_min;
372     int32_t output_activation_max;
373     if constexpr (isSignedOp) {
374         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
375                                      &output_activation_max);
376     } else {
377         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
378                                       &output_activation_max);
379     }
380 
381     tflite::ArithmeticParams op_params;
382     op_params.left_shift = left_shift;
383     op_params.input1_offset = input1_offset;
384     op_params.input1_multiplier = input1_multiplier;
385     op_params.input1_shift = input1_shift;
386     op_params.input2_offset = input2_offset;
387     op_params.input2_multiplier = input2_multiplier;
388     op_params.input2_shift = input2_shift;
389     op_params.output_offset = output_offset;
390     op_params.output_multiplier = output_multiplier;
391     op_params.output_shift = output_shift;
392     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
393 
394     // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
395     // because tflite::optimized_ops::Add fails to pass some of the
396     // sub_quantized_different_scales tests.
397     if constexpr (isSignedOp) {
398         NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
399         tflite::reference_integer_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
400                                                           in1, convertShapeToTflshape(shape2), in2,
401                                                           convertShapeToTflshape(shapeOut), out);
402     } else {
403         NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
404         tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1,
405                                                   convertShapeToTflshape(shape2), in2,
406                                                   convertShapeToTflshape(shapeOut), out);
407     }
408 
409     return true;
410 }
411 
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)412 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
413                 int32_t activation, float* out, const Shape& shapeOut) {
414     NNTRACE_TRANS("divFloat32");
415     float output_activation_min, output_activation_max;
416     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
417 
418     bool needBroadcast = !SameShape(shape1, shape2);
419     if (needBroadcast) {
420         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
421         tflite::optimized_ops::BroadcastDiv(
422                 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
423                 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
424     } else {
425         NNTRACE_COMP_SWITCH("optimized_ops::Div");
426         tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
427                                    output_activation_min, output_activation_max, out,
428                                    convertShapeToDims(shapeOut));
429     }
430     return true;
431 }
432 
divFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)433 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
434                 int32_t activation, _Float16* out, const Shape& shapeOut) {
435     NNTRACE_TRANS("divFloat16");
436     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
437 }
438 
439 }  // namespace
440 
prepare(IOperationExecutionContext * context)441 bool prepare(IOperationExecutionContext* context) {
442     Shape input1 = context->getInputShape(kInputTensor1);
443     Shape input2 = context->getInputShape(kInputTensor2);
444     Shape output = context->getOutputShape(kOutputTensor);
445     NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4u);
446     NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4u);
447     NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
448     return context->setOutputShape(kOutputTensor, output);
449 }
450 
executeAdd(IOperationExecutionContext * context)451 bool executeAdd(IOperationExecutionContext* context) {
452     // Bypass execution in the case of zero-sized input.
453     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
454     switch (context->getInputType(kInputTensor1)) {
455         case OperandType::TENSOR_FLOAT16:
456             return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
457                               context->getInputShape(kInputTensor1),
458                               context->getInputBuffer<_Float16>(kInputTensor2),
459                               context->getInputShape(kInputTensor2),
460                               context->getInputValue<int32_t>(kActivationScalar),
461                               context->getOutputBuffer<_Float16>(kOutputTensor),
462                               context->getOutputShape(kOutputTensor));
463         case OperandType::TENSOR_FLOAT32:
464             return addFloat32(context->getInputBuffer<float>(kInputTensor1),
465                               context->getInputShape(kInputTensor1),
466                               context->getInputBuffer<float>(kInputTensor2),
467                               context->getInputShape(kInputTensor2),
468                               context->getInputValue<int32_t>(kActivationScalar),
469                               context->getOutputBuffer<float>(kOutputTensor),
470                               context->getOutputShape(kOutputTensor));
471         case OperandType::TENSOR_QUANT8_ASYMM:
472             return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
473                              context->getInputShape(kInputTensor1),
474                              context->getInputBuffer<uint8_t>(kInputTensor2),
475                              context->getInputShape(kInputTensor2),
476                              context->getInputValue<int32_t>(kActivationScalar),
477                              context->getOutputBuffer<uint8_t>(kOutputTensor),
478                              context->getOutputShape(kOutputTensor));
479         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
480             return addQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
481                              context->getInputShape(kInputTensor1),
482                              context->getInputBuffer<int8_t>(kInputTensor2),
483                              context->getInputShape(kInputTensor2),
484                              context->getInputValue<int32_t>(kActivationScalar),
485                              context->getOutputBuffer<int8_t>(kOutputTensor),
486                              context->getOutputShape(kOutputTensor));
487         case OperandType::TENSOR_INT32:
488             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
489                                 context->getInputShape(kInputTensor1),
490                                 context->getInputBuffer<int32_t>(kInputTensor2),
491                                 context->getInputShape(kInputTensor2),
492                                 context->getInputValue<int32_t>(kActivationScalar),
493                                 context->getOutputBuffer<int32_t>(kOutputTensor),
494                                 context->getOutputShape(kOutputTensor),
495                                 [](int32_t a, int32_t b) { return a + b; });
496         default:
497             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
498     }
499 }
500 
executeMul(IOperationExecutionContext * context)501 bool executeMul(IOperationExecutionContext* context) {
502     // Bypass execution in the case of zero-sized input.
503     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
504     switch (context->getInputType(kInputTensor1)) {
505         case OperandType::TENSOR_FLOAT16:
506             return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
507                               context->getInputShape(kInputTensor1),
508                               context->getInputBuffer<_Float16>(kInputTensor2),
509                               context->getInputShape(kInputTensor2),
510                               context->getInputValue<int32_t>(kActivationScalar),
511                               context->getOutputBuffer<_Float16>(kOutputTensor),
512                               context->getOutputShape(kOutputTensor));
513         case OperandType::TENSOR_FLOAT32:
514             return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
515                               context->getInputShape(kInputTensor1),
516                               context->getInputBuffer<float>(kInputTensor2),
517                               context->getInputShape(kInputTensor2),
518                               context->getInputValue<int32_t>(kActivationScalar),
519                               context->getOutputBuffer<float>(kOutputTensor),
520                               context->getOutputShape(kOutputTensor));
521         case OperandType::TENSOR_QUANT8_ASYMM:
522             return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
523                              context->getInputShape(kInputTensor1),
524                              context->getInputBuffer<uint8_t>(kInputTensor2),
525                              context->getInputShape(kInputTensor2),
526                              context->getInputValue<int32_t>(kActivationScalar),
527                              context->getOutputBuffer<uint8_t>(kOutputTensor),
528                              context->getOutputShape(kOutputTensor));
529         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
530             return mulQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
531                              context->getInputShape(kInputTensor1),
532                              context->getInputBuffer<int8_t>(kInputTensor2),
533                              context->getInputShape(kInputTensor2),
534                              context->getInputValue<int32_t>(kActivationScalar),
535                              context->getOutputBuffer<int8_t>(kOutputTensor),
536                              context->getOutputShape(kOutputTensor));
537         case OperandType::TENSOR_INT32:
538             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
539                                 context->getInputShape(kInputTensor1),
540                                 context->getInputBuffer<int32_t>(kInputTensor2),
541                                 context->getInputShape(kInputTensor2),
542                                 context->getInputValue<int32_t>(kActivationScalar),
543                                 context->getOutputBuffer<int32_t>(kOutputTensor),
544                                 context->getOutputShape(kOutputTensor),
545                                 [](int32_t a, int32_t b) { return a * b; });
546         default:
547             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
548     }
549 }
550 
executeSub(IOperationExecutionContext * context)551 bool executeSub(IOperationExecutionContext* context) {
552     // Bypass execution in the case of zero-sized input.
553     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
554     switch (context->getInputType(kInputTensor1)) {
555         case OperandType::TENSOR_FLOAT16:
556             return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
557                               context->getInputShape(kInputTensor1),
558                               context->getInputBuffer<_Float16>(kInputTensor2),
559                               context->getInputShape(kInputTensor2),
560                               context->getInputValue<int32_t>(kActivationScalar),
561                               context->getOutputBuffer<_Float16>(kOutputTensor),
562                               context->getOutputShape(kOutputTensor));
563         case OperandType::TENSOR_FLOAT32:
564             return subFloat32(context->getInputBuffer<float>(kInputTensor1),
565                               context->getInputShape(kInputTensor1),
566                               context->getInputBuffer<float>(kInputTensor2),
567                               context->getInputShape(kInputTensor2),
568                               context->getInputValue<int32_t>(kActivationScalar),
569                               context->getOutputBuffer<float>(kOutputTensor),
570                               context->getOutputShape(kOutputTensor));
571         case OperandType::TENSOR_QUANT8_ASYMM:
572             return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
573                              context->getInputShape(kInputTensor1),
574                              context->getInputBuffer<uint8_t>(kInputTensor2),
575                              context->getInputShape(kInputTensor2),
576                              context->getInputValue<int32_t>(kActivationScalar),
577                              context->getOutputBuffer<uint8_t>(kOutputTensor),
578                              context->getOutputShape(kOutputTensor));
579         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
580             return subQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
581                              context->getInputShape(kInputTensor1),
582                              context->getInputBuffer<int8_t>(kInputTensor2),
583                              context->getInputShape(kInputTensor2),
584                              context->getInputValue<int32_t>(kActivationScalar),
585                              context->getOutputBuffer<int8_t>(kOutputTensor),
586                              context->getOutputShape(kOutputTensor));
587         case OperandType::TENSOR_INT32:
588             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
589                                 context->getInputShape(kInputTensor1),
590                                 context->getInputBuffer<int32_t>(kInputTensor2),
591                                 context->getInputShape(kInputTensor2),
592                                 context->getInputValue<int32_t>(kActivationScalar),
593                                 context->getOutputBuffer<int32_t>(kOutputTensor),
594                                 context->getOutputShape(kOutputTensor),
595                                 [](int32_t a, int32_t b) { return a - b; });
596         default:
597             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
598     }
599 }
600 
executeDiv(IOperationExecutionContext * context)601 bool executeDiv(IOperationExecutionContext* context) {
602     // Bypass execution in the case of zero-sized input.
603     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
604     switch (context->getInputType(kInputTensor1)) {
605         case OperandType::TENSOR_FLOAT16:
606             return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
607                               context->getInputShape(kInputTensor1),
608                               context->getInputBuffer<_Float16>(kInputTensor2),
609                               context->getInputShape(kInputTensor2),
610                               context->getInputValue<int32_t>(kActivationScalar),
611                               context->getOutputBuffer<_Float16>(kOutputTensor),
612                               context->getOutputShape(kOutputTensor));
613         case OperandType::TENSOR_FLOAT32:
614             return divFloat32(context->getInputBuffer<float>(kInputTensor1),
615                               context->getInputShape(kInputTensor1),
616                               context->getInputBuffer<float>(kInputTensor2),
617                               context->getInputShape(kInputTensor2),
618                               context->getInputValue<int32_t>(kActivationScalar),
619                               context->getOutputBuffer<float>(kOutputTensor),
620                               context->getOutputShape(kOutputTensor));
621         case OperandType::TENSOR_INT32:
622             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
623                                 context->getInputShape(kInputTensor1),
624                                 context->getInputBuffer<int32_t>(kInputTensor2),
625                                 context->getInputShape(kInputTensor2),
626                                 context->getInputValue<int32_t>(kActivationScalar),
627                                 context->getOutputBuffer<int32_t>(kOutputTensor),
628                                 context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) {
629                                     // In NNAPI, DIV by zero is undefined, but should not crash.
630                                     if (b == 0) return 0;
631                                     int32_t result = a / b;
632                                     if (a % b != 0 && ((a < 0) != (b < 0))) {
633                                         // Implement "floor division".
634                                         --result;
635                                     }
636                                     return result;
637                                 });
638         default:
639             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
640     }
641 }
642 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
643 
644 }  // namespace broadcast
645 
646 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(ADD, broadcast::prepare, broadcast::executeAdd,
647                                          .allowZeroSizedInput = true);
648 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(MUL, broadcast::prepare, broadcast::executeMul,
649                                          .allowZeroSizedInput = true);
650 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(SUB, broadcast::prepare, broadcast::executeSub,
651                                          .allowZeroSizedInput = true);
652 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(DIV, broadcast::prepare, broadcast::executeDiv,
653                                          .allowZeroSizedInput = true);
654 
655 }  // namespace nn
656 }  // namespace android
657