1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Contains the implementation of the operations.
18
19 #define LOG_TAG "Operations"
20
21 #include "Broadcast.h"
22
23 #include <algorithm>
24 #include <functional>
25 #include <vector>
26
27 #include "IndexedShapeWrapper.h"
28 #include "OperationResolver.h"
29 #include "Tracing.h"
30 #include "nnapi/Types.h"
31 #include "nnapi/Validation.h"
32
33 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
34 #pragma clang diagnostic push
35 #pragma clang diagnostic ignored "-Wunused-parameter"
36 #pragma clang diagnostic ignored "-Wsign-compare"
37 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
38 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/add.h>
39 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h>
40 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
41 #include <tensorflow/lite/kernels/internal/reference/integer_ops/add.h>
42 #include <tensorflow/lite/kernels/internal/reference/integer_ops/mul.h>
43 #include <tensorflow/lite/kernels/internal/types.h>
44 #pragma clang diagnostic pop
45
46 #include "CpuOperationUtils.h"
47 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
48
49 namespace android {
50 namespace nn {
51
52 namespace broadcast {
53
54 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
55 namespace {
56
57 #define ANDROID_NN_MACRO_DISPATCH(macro) \
58 switch (activation) { \
59 case static_cast<int32_t>(FusedActivationFunc::NONE): \
60 macro(kNone); \
61 break; \
62 case static_cast<int32_t>(FusedActivationFunc::RELU): \
63 macro(kRelu); \
64 break; \
65 case static_cast<int32_t>(FusedActivationFunc::RELU1): \
66 macro(kRelu1); \
67 break; \
68 case static_cast<int32_t>(FusedActivationFunc::RELU6): \
69 macro(kRelu6); \
70 break; \
71 default: \
72 LOG(ERROR) << "Unsupported fused activation function type"; \
73 return false; \
74 }
75
76 using binaryFunctionFloat32 = std::function<bool(
77 const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
78 int32_t activation, float* out, const Shape& shapeOut)>;
79
binaryOperationFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut,binaryFunctionFloat32 operationFloat32)80 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
81 const Shape& shape2, int32_t activation, _Float16* out,
82 const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
83 std::vector<float> in1_float32(getNumberOfElements(shape1));
84 convertFloat16ToFloat32(in1, &in1_float32);
85 std::vector<float> in2_float32(getNumberOfElements(shape2));
86 convertFloat16ToFloat32(in2, &in2_float32);
87 std::vector<float> out_float32(getNumberOfElements(shapeOut));
88
89 operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
90 out_float32.data(), shapeOut);
91 convertFloat32ToFloat16(out_float32, out);
92
93 return true;
94 }
95
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)96 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
97 int32_t activation, float* out, const Shape& shapeOut) {
98 NNTRACE_TRANS("addFloat32");
99 bool needBroadcast = !SameShape(shape1, shape2);
100 if (needBroadcast) {
101 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
102 #define ANDROID_NN_BROADCAST_ADD(activation) \
103 tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
104 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
105 convertShapeToDims(shapeOut))
106
107 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
108 #undef ANDROID_NN_BROADCAST_ADD
109 } else {
110 NNTRACE_COMP_SWITCH("optimized_ops::Add");
111 #define ANDROID_NN_ADD(activation) \
112 tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \
113 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
114 convertShapeToDims(shapeOut))
115
116 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
117 #undef ANDROID_NN_ADD
118 }
119
120 return true;
121 }
122
addFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)123 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
124 int32_t activation, _Float16* out, const Shape& shapeOut) {
125 NNTRACE_TRANS("addFloat16");
126 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
127 }
128
129 template <typename T>
addQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)130 bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
131 int32_t activation, T* out, const Shape& shapeOut) {
132 NNTRACE_TRANS("addQuant8");
133 const bool needBroadcast = !SameShape(shape1, shape2);
134
135 const int32_t input1_offset = -shape1.offset;
136 const int32_t input2_offset = -shape2.offset;
137 const int32_t output_offset = shapeOut.offset;
138 const int left_shift = 20;
139 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
140 const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
141 const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
142 const double real_output_multiplier =
143 twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
144
145 int32_t input1_multiplier;
146 int32_t input1_shift;
147 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
148 &input1_shift));
149 int32_t input2_multiplier;
150 int32_t input2_shift;
151 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
152 &input2_shift));
153 int32_t output_multiplier;
154 int32_t output_shift;
155 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
156 &output_shift));
157
158 int32_t output_activation_min;
159 int32_t output_activation_max;
160 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
161 if constexpr (isSignedOp) {
162 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
163 &output_activation_max);
164 } else {
165 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
166 &output_activation_max);
167 }
168
169 tflite::ArithmeticParams op_params;
170 op_params.left_shift = left_shift;
171 op_params.input1_offset = input1_offset;
172 op_params.input1_multiplier = input1_multiplier;
173 op_params.input1_shift = input1_shift;
174 op_params.input2_offset = input2_offset;
175 op_params.input2_multiplier = input2_multiplier;
176 op_params.input2_shift = input2_shift;
177 op_params.output_offset = output_offset;
178 op_params.output_multiplier = output_multiplier;
179 op_params.output_shift = output_shift;
180 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
181
182 if (needBroadcast) {
183 if constexpr (isSignedOp) {
184 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
185 tflite::reference_integer_ops::BroadcastAdd4DSlow(
186 op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2),
187 in2, convertShapeToTflshape(shapeOut), out);
188 } else {
189 NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
190 tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
191 in1, convertShapeToTflshape(shape2), in2,
192 convertShapeToTflshape(shapeOut), out);
193 }
194 } else {
195 if constexpr (isSignedOp) {
196 NNTRACE_COMP_SWITCH("optimized_integer_ops::Add");
197 tflite::optimized_integer_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
198 convertShapeToTflshape(shape2), in2,
199 convertShapeToTflshape(shapeOut), out);
200 } else {
201 NNTRACE_COMP_SWITCH("optimized_ops::Add");
202 tflite::optimized_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
203 convertShapeToTflshape(shape2), in2,
204 convertShapeToTflshape(shapeOut), out);
205 }
206 }
207
208 return true;
209 }
210
executeInt32(const int32_t * aData,const Shape & aShape,const int32_t * bData,const Shape & bShape,int32_t activation,int32_t * outputData,const Shape & outputShape,int32_t func (int32_t,int32_t))211 bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
212 const Shape& bShape, int32_t activation, int32_t* outputData,
213 const Shape& outputShape, int32_t func(int32_t, int32_t)) {
214 NN_RET_CHECK_EQ(static_cast<FusedActivationFunc>(activation), FusedActivationFunc::NONE);
215 IndexedShapeWrapper aShapeIndexed(aShape);
216 IndexedShapeWrapper bShapeIndexed(bShape);
217 IndexedShapeWrapper outputShapeIndexed(outputShape);
218 std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
219 bool lastIndex = false;
220 do {
221 uint32_t outputFlatIndex;
222 NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
223 uint32_t aFlatIndex;
224 NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
225 uint32_t bFlatIndex;
226 NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
227
228 outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
229
230 NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
231 } while (!lastIndex);
232 return true;
233 }
234
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)235 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
236 int32_t activation, float* out, const Shape& shapeOut) {
237 NNTRACE_TRANS("mulFloat32");
238 bool needBroadcast = !SameShape(shape1, shape2);
239
240 if (needBroadcast) {
241 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
242 #define ANDROID_NN_BROADCAST_MUL(activation) \
243 tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
244 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
245 convertShapeToDims(shapeOut))
246
247 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
248 #undef ANDROID_NN_BROADCAST_MUL
249 } else {
250 float output_activation_min, output_activation_max;
251 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
252
253 NNTRACE_COMP_SWITCH("optimized_ops::Mul");
254 tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
255 output_activation_min, output_activation_max, out,
256 convertShapeToDims(shapeOut));
257 }
258
259 return true;
260 }
261
mulFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)262 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
263 int32_t activation, _Float16* out, const Shape& shapeOut) {
264 NNTRACE_TRANS("mulFloat16");
265 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
266 }
267
268 template <typename T>
mulQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)269 bool mulQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
270 int32_t activation, T* out, const Shape& shapeOut) {
271 NNTRACE_TRANS("mulQuant8");
272 const int32_t input1_offset = -shape1.offset;
273 const int32_t input2_offset = -shape2.offset;
274 const int32_t output_offset = shapeOut.offset;
275 const double input_product_scale = shape1.scale * shape2.scale;
276 const double real_multiplier = input_product_scale / shapeOut.scale;
277 int32 output_multiplier;
278 int output_shift;
279 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
280 &output_shift));
281
282 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
283 int32_t output_activation_min;
284 int32_t output_activation_max;
285 if constexpr (isSignedOp) {
286 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
287 &output_activation_max);
288 } else {
289 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
290 &output_activation_max);
291 }
292
293 tflite::ArithmeticParams op_params;
294 op_params.input1_offset = input1_offset;
295 op_params.input2_offset = input2_offset;
296 op_params.output_offset = output_offset;
297 op_params.output_multiplier = output_multiplier;
298 op_params.output_shift = output_shift;
299 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
300
301 if constexpr (isSignedOp) {
302 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastMul4DSlow");
303 tflite::reference_integer_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1),
304 in1, convertShapeToTflshape(shape2), in2,
305 convertShapeToTflshape(shapeOut), out);
306 } else {
307 NNTRACE_COMP_SWITCH("reference_ops::BroadcastMul4DSlow");
308 tflite::reference_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1,
309 convertShapeToTflshape(shape2), in2,
310 convertShapeToTflshape(shapeOut), out);
311 }
312
313 return true;
314 }
315
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)316 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
317 int32_t activation, float* out, const Shape& shapeOut) {
318 NNTRACE_TRANS("subFloat32");
319 NNTRACE_COMP_SWITCH("optimized_ops::Sub");
320 tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
321 out, convertShapeToDims(shapeOut));
322
323 // TFLite does not apply activation to broadcast sub.
324 float output_activation_min, output_activation_max;
325 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
326 uint32_t numOutputElements = getNumberOfElements(shapeOut);
327 for (uint32_t i = 0; i < numOutputElements; i++) {
328 out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
329 }
330 return true;
331 }
332
subFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)333 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
334 int32_t activation, _Float16* out, const Shape& shapeOut) {
335 NNTRACE_TRANS("subFloat16");
336 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
337 }
338
339 template <typename T>
subQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)340 bool subQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
341 int32_t activation, T* out, const Shape& shapeOut) {
342 NNTRACE_TRANS("subQuant8");
343
344 const int32_t input1_offset = -shape1.offset;
345 const int32_t input2_offset = -shape2.offset;
346 const int32_t output_offset = shapeOut.offset;
347 const int left_shift = 20;
348 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
349 const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
350 const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
351 const double real_output_multiplier =
352 twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
353
354 int32_t input1_multiplier;
355 int32_t input1_shift;
356 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
357 &input1_shift));
358 int32_t input2_multiplier;
359 int32_t input2_shift;
360 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
361 &input2_shift));
362 // Negate multiplier of the second input, so that we can use Add kernels.
363 input2_multiplier *= -1;
364
365 int32_t output_multiplier;
366 int32_t output_shift;
367 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
368 &output_shift));
369
370 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
371 int32_t output_activation_min;
372 int32_t output_activation_max;
373 if constexpr (isSignedOp) {
374 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
375 &output_activation_max);
376 } else {
377 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
378 &output_activation_max);
379 }
380
381 tflite::ArithmeticParams op_params;
382 op_params.left_shift = left_shift;
383 op_params.input1_offset = input1_offset;
384 op_params.input1_multiplier = input1_multiplier;
385 op_params.input1_shift = input1_shift;
386 op_params.input2_offset = input2_offset;
387 op_params.input2_multiplier = input2_multiplier;
388 op_params.input2_shift = input2_shift;
389 op_params.output_offset = output_offset;
390 op_params.output_multiplier = output_multiplier;
391 op_params.output_shift = output_shift;
392 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
393
394 // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
395 // because tflite::optimized_ops::Add fails to pass some of the
396 // sub_quantized_different_scales tests.
397 if constexpr (isSignedOp) {
398 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
399 tflite::reference_integer_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
400 in1, convertShapeToTflshape(shape2), in2,
401 convertShapeToTflshape(shapeOut), out);
402 } else {
403 NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
404 tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1,
405 convertShapeToTflshape(shape2), in2,
406 convertShapeToTflshape(shapeOut), out);
407 }
408
409 return true;
410 }
411
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)412 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
413 int32_t activation, float* out, const Shape& shapeOut) {
414 NNTRACE_TRANS("divFloat32");
415 float output_activation_min, output_activation_max;
416 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
417
418 bool needBroadcast = !SameShape(shape1, shape2);
419 if (needBroadcast) {
420 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
421 tflite::optimized_ops::BroadcastDiv(
422 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
423 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
424 } else {
425 NNTRACE_COMP_SWITCH("optimized_ops::Div");
426 tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
427 output_activation_min, output_activation_max, out,
428 convertShapeToDims(shapeOut));
429 }
430 return true;
431 }
432
divFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)433 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
434 int32_t activation, _Float16* out, const Shape& shapeOut) {
435 NNTRACE_TRANS("divFloat16");
436 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
437 }
438
439 } // namespace
440
prepare(IOperationExecutionContext * context)441 bool prepare(IOperationExecutionContext* context) {
442 Shape input1 = context->getInputShape(kInputTensor1);
443 Shape input2 = context->getInputShape(kInputTensor2);
444 Shape output = context->getOutputShape(kOutputTensor);
445 NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4u);
446 NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4u);
447 NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
448 return context->setOutputShape(kOutputTensor, output);
449 }
450
executeAdd(IOperationExecutionContext * context)451 bool executeAdd(IOperationExecutionContext* context) {
452 // Bypass execution in the case of zero-sized input.
453 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
454 switch (context->getInputType(kInputTensor1)) {
455 case OperandType::TENSOR_FLOAT16:
456 return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
457 context->getInputShape(kInputTensor1),
458 context->getInputBuffer<_Float16>(kInputTensor2),
459 context->getInputShape(kInputTensor2),
460 context->getInputValue<int32_t>(kActivationScalar),
461 context->getOutputBuffer<_Float16>(kOutputTensor),
462 context->getOutputShape(kOutputTensor));
463 case OperandType::TENSOR_FLOAT32:
464 return addFloat32(context->getInputBuffer<float>(kInputTensor1),
465 context->getInputShape(kInputTensor1),
466 context->getInputBuffer<float>(kInputTensor2),
467 context->getInputShape(kInputTensor2),
468 context->getInputValue<int32_t>(kActivationScalar),
469 context->getOutputBuffer<float>(kOutputTensor),
470 context->getOutputShape(kOutputTensor));
471 case OperandType::TENSOR_QUANT8_ASYMM:
472 return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
473 context->getInputShape(kInputTensor1),
474 context->getInputBuffer<uint8_t>(kInputTensor2),
475 context->getInputShape(kInputTensor2),
476 context->getInputValue<int32_t>(kActivationScalar),
477 context->getOutputBuffer<uint8_t>(kOutputTensor),
478 context->getOutputShape(kOutputTensor));
479 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
480 return addQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
481 context->getInputShape(kInputTensor1),
482 context->getInputBuffer<int8_t>(kInputTensor2),
483 context->getInputShape(kInputTensor2),
484 context->getInputValue<int32_t>(kActivationScalar),
485 context->getOutputBuffer<int8_t>(kOutputTensor),
486 context->getOutputShape(kOutputTensor));
487 case OperandType::TENSOR_INT32:
488 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
489 context->getInputShape(kInputTensor1),
490 context->getInputBuffer<int32_t>(kInputTensor2),
491 context->getInputShape(kInputTensor2),
492 context->getInputValue<int32_t>(kActivationScalar),
493 context->getOutputBuffer<int32_t>(kOutputTensor),
494 context->getOutputShape(kOutputTensor),
495 [](int32_t a, int32_t b) { return a + b; });
496 default:
497 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
498 }
499 }
500
executeMul(IOperationExecutionContext * context)501 bool executeMul(IOperationExecutionContext* context) {
502 // Bypass execution in the case of zero-sized input.
503 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
504 switch (context->getInputType(kInputTensor1)) {
505 case OperandType::TENSOR_FLOAT16:
506 return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
507 context->getInputShape(kInputTensor1),
508 context->getInputBuffer<_Float16>(kInputTensor2),
509 context->getInputShape(kInputTensor2),
510 context->getInputValue<int32_t>(kActivationScalar),
511 context->getOutputBuffer<_Float16>(kOutputTensor),
512 context->getOutputShape(kOutputTensor));
513 case OperandType::TENSOR_FLOAT32:
514 return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
515 context->getInputShape(kInputTensor1),
516 context->getInputBuffer<float>(kInputTensor2),
517 context->getInputShape(kInputTensor2),
518 context->getInputValue<int32_t>(kActivationScalar),
519 context->getOutputBuffer<float>(kOutputTensor),
520 context->getOutputShape(kOutputTensor));
521 case OperandType::TENSOR_QUANT8_ASYMM:
522 return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
523 context->getInputShape(kInputTensor1),
524 context->getInputBuffer<uint8_t>(kInputTensor2),
525 context->getInputShape(kInputTensor2),
526 context->getInputValue<int32_t>(kActivationScalar),
527 context->getOutputBuffer<uint8_t>(kOutputTensor),
528 context->getOutputShape(kOutputTensor));
529 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
530 return mulQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
531 context->getInputShape(kInputTensor1),
532 context->getInputBuffer<int8_t>(kInputTensor2),
533 context->getInputShape(kInputTensor2),
534 context->getInputValue<int32_t>(kActivationScalar),
535 context->getOutputBuffer<int8_t>(kOutputTensor),
536 context->getOutputShape(kOutputTensor));
537 case OperandType::TENSOR_INT32:
538 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
539 context->getInputShape(kInputTensor1),
540 context->getInputBuffer<int32_t>(kInputTensor2),
541 context->getInputShape(kInputTensor2),
542 context->getInputValue<int32_t>(kActivationScalar),
543 context->getOutputBuffer<int32_t>(kOutputTensor),
544 context->getOutputShape(kOutputTensor),
545 [](int32_t a, int32_t b) { return a * b; });
546 default:
547 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
548 }
549 }
550
executeSub(IOperationExecutionContext * context)551 bool executeSub(IOperationExecutionContext* context) {
552 // Bypass execution in the case of zero-sized input.
553 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
554 switch (context->getInputType(kInputTensor1)) {
555 case OperandType::TENSOR_FLOAT16:
556 return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
557 context->getInputShape(kInputTensor1),
558 context->getInputBuffer<_Float16>(kInputTensor2),
559 context->getInputShape(kInputTensor2),
560 context->getInputValue<int32_t>(kActivationScalar),
561 context->getOutputBuffer<_Float16>(kOutputTensor),
562 context->getOutputShape(kOutputTensor));
563 case OperandType::TENSOR_FLOAT32:
564 return subFloat32(context->getInputBuffer<float>(kInputTensor1),
565 context->getInputShape(kInputTensor1),
566 context->getInputBuffer<float>(kInputTensor2),
567 context->getInputShape(kInputTensor2),
568 context->getInputValue<int32_t>(kActivationScalar),
569 context->getOutputBuffer<float>(kOutputTensor),
570 context->getOutputShape(kOutputTensor));
571 case OperandType::TENSOR_QUANT8_ASYMM:
572 return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
573 context->getInputShape(kInputTensor1),
574 context->getInputBuffer<uint8_t>(kInputTensor2),
575 context->getInputShape(kInputTensor2),
576 context->getInputValue<int32_t>(kActivationScalar),
577 context->getOutputBuffer<uint8_t>(kOutputTensor),
578 context->getOutputShape(kOutputTensor));
579 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
580 return subQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
581 context->getInputShape(kInputTensor1),
582 context->getInputBuffer<int8_t>(kInputTensor2),
583 context->getInputShape(kInputTensor2),
584 context->getInputValue<int32_t>(kActivationScalar),
585 context->getOutputBuffer<int8_t>(kOutputTensor),
586 context->getOutputShape(kOutputTensor));
587 case OperandType::TENSOR_INT32:
588 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
589 context->getInputShape(kInputTensor1),
590 context->getInputBuffer<int32_t>(kInputTensor2),
591 context->getInputShape(kInputTensor2),
592 context->getInputValue<int32_t>(kActivationScalar),
593 context->getOutputBuffer<int32_t>(kOutputTensor),
594 context->getOutputShape(kOutputTensor),
595 [](int32_t a, int32_t b) { return a - b; });
596 default:
597 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
598 }
599 }
600
executeDiv(IOperationExecutionContext * context)601 bool executeDiv(IOperationExecutionContext* context) {
602 // Bypass execution in the case of zero-sized input.
603 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
604 switch (context->getInputType(kInputTensor1)) {
605 case OperandType::TENSOR_FLOAT16:
606 return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
607 context->getInputShape(kInputTensor1),
608 context->getInputBuffer<_Float16>(kInputTensor2),
609 context->getInputShape(kInputTensor2),
610 context->getInputValue<int32_t>(kActivationScalar),
611 context->getOutputBuffer<_Float16>(kOutputTensor),
612 context->getOutputShape(kOutputTensor));
613 case OperandType::TENSOR_FLOAT32:
614 return divFloat32(context->getInputBuffer<float>(kInputTensor1),
615 context->getInputShape(kInputTensor1),
616 context->getInputBuffer<float>(kInputTensor2),
617 context->getInputShape(kInputTensor2),
618 context->getInputValue<int32_t>(kActivationScalar),
619 context->getOutputBuffer<float>(kOutputTensor),
620 context->getOutputShape(kOutputTensor));
621 case OperandType::TENSOR_INT32:
622 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
623 context->getInputShape(kInputTensor1),
624 context->getInputBuffer<int32_t>(kInputTensor2),
625 context->getInputShape(kInputTensor2),
626 context->getInputValue<int32_t>(kActivationScalar),
627 context->getOutputBuffer<int32_t>(kOutputTensor),
628 context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) {
629 // In NNAPI, DIV by zero is undefined, but should not crash.
630 if (b == 0) return 0;
631 int32_t result = a / b;
632 if (a % b != 0 && ((a < 0) != (b < 0))) {
633 // Implement "floor division".
634 --result;
635 }
636 return result;
637 });
638 default:
639 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
640 }
641 }
642 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
643
644 } // namespace broadcast
645
646 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(ADD, broadcast::prepare, broadcast::executeAdd,
647 .allowZeroSizedInput = true);
648 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(MUL, broadcast::prepare, broadcast::executeMul,
649 .allowZeroSizedInput = true);
650 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(SUB, broadcast::prepare, broadcast::executeSub,
651 .allowZeroSizedInput = true);
652 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(DIV, broadcast::prepare, broadcast::executeDiv,
653 .allowZeroSizedInput = true);
654
655 } // namespace nn
656 } // namespace android
657