1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "L2Normalization.h"
20 
21 #include <algorithm>
22 #include <vector>
23 
24 #include "OperationResolver.h"
25 #include "Tracing.h"
26 
27 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
28 #pragma clang diagnostic push
29 #pragma clang diagnostic ignored "-Wunused-parameter"
30 #pragma clang diagnostic ignored "-Wsign-compare"
31 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
32 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
33 #include <tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h>
34 #pragma clang diagnostic pop
35 
36 #include "CpuOperationUtils.h"
37 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
38 
39 namespace android {
40 namespace nn {
41 namespace l2_norm {
42 
43 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
44 namespace {
45 
l2normFloat32Impl(const float * inputData,const Shape & inputShape,int32_t axis,float * outputData,const Shape &)46 inline bool l2normFloat32Impl(const float* inputData, const Shape& inputShape, int32_t axis,
47                               float* outputData, const Shape& /*outputShape*/) {
48     NNTRACE_TRANS("l2normFloat32");
49     constexpr float kEpsilon = 1e-6f;
50     const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis);
51     const uint32_t axisSize = getSizeOfDimension(inputShape, axis);
52     const uint32_t innerSize =
53             getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape));
54     for (uint32_t outer = 0; outer < outerSize; ++outer) {
55         const float* inputBeg = inputData + outer * axisSize * innerSize;
56         const float* inputEnd = inputBeg + axisSize * innerSize;
57         float* outputBeg = outputData + outer * axisSize * innerSize;
58         for (uint32_t inner = 0; inner < innerSize; ++inner, ++inputBeg, ++inputEnd, ++outputBeg) {
59             float sum = 0.0f;
60             for (const float* p = inputBeg; p < inputEnd; p += innerSize) {
61                 float val = *p;
62                 sum += val * val;
63             }
64             float l2_norm = std::max(std::sqrt(sum), kEpsilon);
65             float* pOut = outputBeg;
66             for (const float* p = inputBeg; p < inputEnd; p += innerSize, pOut += innerSize) {
67                 *pOut = *p / l2_norm;
68             }
69         }
70     }
71     return true;
72 }
73 
l2normQuant8Impl(const uint8_t * inputData,const Shape & inputShape,int32_t axis,uint8_t * outputData,const Shape &)74 inline bool l2normQuant8Impl(const uint8_t* inputData, const Shape& inputShape, int32_t axis,
75                              uint8_t* outputData, const Shape& /*outputShape*/) {
76     NNTRACE_TRANS("l2normQuant8");
77     const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis);
78     const uint32_t axisSize = getSizeOfDimension(inputShape, axis);
79     const uint32_t innerSize =
80             getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape));
81     for (uint32_t outer = 0; outer < outerSize; ++outer) {
82         const uint8_t* inputBeg = inputData + outer * axisSize * innerSize;
83         const uint8_t* inputEnd = inputBeg + axisSize * innerSize;
84         uint8_t* outputBeg = outputData + outer * axisSize * innerSize;
85         for (uint32_t inner = 0; inner < innerSize; ++inner, ++inputBeg, ++inputEnd, ++outputBeg) {
86             int32_t sum = 0;
87             for (const uint8_t* p = inputBeg; p < inputEnd; p += innerSize) {
88                 int32_t val = static_cast<int32_t>(*p) - inputShape.offset;
89                 sum += val * val;
90             }
91             int32_t invMultiplier, invShift;
92             tflite::GetInvSqrtQuantizedMultiplierExp(sum, -1, &invMultiplier, &invShift);
93             uint8_t* pOut = outputBeg;
94             for (const uint8_t* p = inputBeg; p < inputEnd; p += innerSize, pOut += innerSize) {
95                 int32_t val = static_cast<int32_t>(*p) - inputShape.offset;
96                 int32_t scaledVal = tflite::MultiplyByQuantizedMultiplierSmallerThanOneExp(
97                                             val * 128, invMultiplier, invShift) +
98                                     128;
99                 *pOut = static_cast<uint8_t>(std::min(std::max(scaledVal, 0), 255));
100             }
101         }
102     }
103     return true;
104 }
105 
l2normQuant8SignedImpl(const int8_t * inputData,const Shape & inputShape,int32_t axis,int8_t * outputData,const Shape &)106 inline bool l2normQuant8SignedImpl(const int8_t* inputData, const Shape& inputShape, int32_t axis,
107                                    int8_t* outputData, const Shape& /*outputShape*/) {
108     NNTRACE_TRANS("l2normQuant8Signed");
109     const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis);
110     const uint32_t axisSize = getSizeOfDimension(inputShape, axis);
111     const uint32_t innerSize =
112             getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape));
113     for (uint32_t outer = 0; outer < outerSize; ++outer) {
114         const int8_t* inputBeg = inputData + outer * axisSize * innerSize;
115         const int8_t* inputEnd = inputBeg + axisSize * innerSize;
116         int8_t* outputBeg = outputData + outer * axisSize * innerSize;
117         for (uint32_t inner = 0; inner < innerSize; ++inner, ++inputBeg, ++inputEnd, ++outputBeg) {
118             int32_t sum = 0;
119             for (const int8_t* p = inputBeg; p < inputEnd; p += innerSize) {
120                 int32_t val = static_cast<int32_t>(*p) - inputShape.offset;
121                 sum += val * val;
122             }
123             int32_t invMultiplier, invShift;
124             tflite::GetInvSqrtQuantizedMultiplierExp(sum, -1, &invMultiplier, &invShift);
125             int8_t* pOut = outputBeg;
126             for (const int8_t* p = inputBeg; p < inputEnd; p += innerSize, pOut += innerSize) {
127                 int32_t val = static_cast<int32_t>(*p) - inputShape.offset;
128                 int32_t scaledVal = tflite::MultiplyByQuantizedMultiplierSmallerThanOneExp(
129                         val * 128, invMultiplier, invShift);
130                 *pOut = static_cast<int8_t>(std::min(std::max(scaledVal, -128), 127));
131             }
132         }
133     }
134     return true;
135 }
136 
l2normFloat32(const float * inputData,const Shape & inputShape,int32_t axis,float * outputData,const Shape & outputShape)137 bool l2normFloat32(const float* inputData, const Shape& inputShape, int32_t axis, float* outputData,
138                    const Shape& outputShape) {
139     int32_t ndim = getNumberOfDimensions(inputShape);
140     NN_CHECK(handleNegativeAxis(inputShape, &axis));
141     // TFLite optimized implementation only supports computation along the last axis
142     if (axis == ndim - 1) {
143         NNTRACE_COMP("optimized_ops::L2Normalization::float");
144         tflite::L2NormalizationParams param = {.input_zero_point = 0};
145         tflite::optimized_ops::L2Normalization(param, convertShapeToTflshape(inputShape), inputData,
146                                                convertShapeToTflshape(outputShape), outputData);
147         return true;
148     } else {
149         return l2normFloat32Impl(inputData, inputShape, axis, outputData, outputShape);
150     }
151 }
152 
l2normFloat16(const _Float16 * inputData,const Shape & inputShape,int32_t axis,_Float16 * outputData,const Shape & outputShape)153 bool l2normFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis,
154                    _Float16* outputData, const Shape& outputShape) {
155     NNTRACE_TRANS("l2normFloat16");
156     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
157     convertFloat16ToFloat32(inputData, &inputDataFloat32);
158     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
159 
160     l2normFloat32(inputDataFloat32.data(), inputShape, axis, outputDataFloat32.data(), outputShape);
161     convertFloat32ToFloat16(outputDataFloat32, outputData);
162 
163     return true;
164 }
165 
l2normQuant8(const uint8_t * inputData,const Shape & inputShape,int32_t axis,uint8_t * outputData,const Shape & outputShape)166 bool l2normQuant8(const uint8_t* inputData, const Shape& inputShape, int32_t axis,
167                   uint8_t* outputData, const Shape& outputShape) {
168     int32_t ndim = getNumberOfDimensions(inputShape);
169     NN_CHECK(handleNegativeAxis(inputShape, &axis));
170     // TFLite optimized implementation only supports computation along the last axis
171     if (axis == ndim - 1) {
172         NNTRACE_COMP("optimized_ops::L2Normalization::uint8");
173         tflite::L2NormalizationParams param = {.input_zero_point = inputShape.offset};
174         tflite::optimized_ops::L2Normalization(param, convertShapeToTflshape(inputShape), inputData,
175                                                convertShapeToTflshape(outputShape), outputData);
176         return true;
177     } else {
178         return l2normQuant8Impl(inputData, inputShape, axis, outputData, outputShape);
179     }
180 }
181 
l2normQuant8Signed(const int8_t * inputData,const Shape & inputShape,int32_t axis,int8_t * outputData,const Shape & outputShape)182 bool l2normQuant8Signed(const int8_t* inputData, const Shape& inputShape, int32_t axis,
183                         int8_t* outputData, const Shape& outputShape) {
184     int32_t ndim = getNumberOfDimensions(inputShape);
185     NN_CHECK(handleNegativeAxis(inputShape, &axis));
186     // TFLite implementation only supports computation along the last axis
187     if (axis == ndim - 1) {
188         NNTRACE_COMP("reference_integer_ops::L2Normalization");
189         const int32_t outerSize = getNumberOfElements(inputShape, 0, axis);
190         const int32_t axisSize = getSizeOfDimension(inputShape, axis);
191         tflite::reference_integer_ops::L2Normalization(inputShape.offset, outerSize, axisSize,
192                                                        inputData, outputData);
193         return true;
194     } else {
195         return l2normQuant8SignedImpl(inputData, inputShape, axis, outputData, outputShape);
196     }
197 }
198 
199 }  // namespace
200 
prepare(IOperationExecutionContext * context)201 bool prepare(IOperationExecutionContext* context) {
202     const Shape& input = context->getInputShape(kInputTensor);
203     int32_t numDimensions = getNumberOfDimensions(input);
204     int32_t axis = context->getNumInputs() == kNumInputs
205                            ? context->getInputValue<int32_t>(kAxisScalar)
206                            : -1;
207     NN_RET_CHECK_LE(numDimensions, 4);
208     NN_RET_CHECK_GE(axis, -numDimensions);
209     NN_RET_CHECK_LT(axis, numDimensions);
210     Shape output = context->getOutputShape(kOutputTensor);
211     output.type = input.type;
212     output.dimensions = input.dimensions;
213     if (output.type == OperandType::TENSOR_QUANT8_ASYMM) {
214         output.scale = 1.0f / 128.0f;
215         output.offset = 128;
216     } else if (output.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
217         output.scale = 1.0f / 128.0f;
218         output.offset = 0;
219     } else {
220         output.scale = 0;
221         output.offset = 0;
222     }
223     return context->setOutputShape(kOutputTensor, output);
224 }
225 
execute(IOperationExecutionContext * context)226 bool execute(IOperationExecutionContext* context) {
227     int32_t axis = context->getNumInputs() == kNumInputs
228                            ? context->getInputValue<int32_t>(kAxisScalar)
229                            : -1;
230     NN_RET_CHECK(handleNegativeAxis(context->getInputShape(kInputTensor), &axis));
231     switch (context->getInputType(kInputTensor)) {
232         case OperandType::TENSOR_FLOAT32:
233             return l2normFloat32(context->getInputBuffer<float>(kInputTensor),
234                                  context->getInputShape(kInputTensor), axis,
235                                  context->getOutputBuffer<float>(kOutputTensor),
236                                  context->getOutputShape(kOutputTensor));
237         case OperandType::TENSOR_FLOAT16:
238             return l2normFloat16(context->getInputBuffer<_Float16>(kInputTensor),
239                                  context->getInputShape(kInputTensor), axis,
240                                  context->getOutputBuffer<_Float16>(kOutputTensor),
241                                  context->getOutputShape(kOutputTensor));
242         case OperandType::TENSOR_QUANT8_ASYMM:
243             return l2normQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
244                                 context->getInputShape(kInputTensor), axis,
245                                 context->getOutputBuffer<uint8_t>(kOutputTensor),
246                                 context->getOutputShape(kOutputTensor));
247         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
248             return l2normQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
249                                       context->getInputShape(kInputTensor), axis,
250                                       context->getOutputBuffer<int8_t>(kOutputTensor),
251                                       context->getOutputShape(kOutputTensor));
252         default:
253             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
254     }
255 }
256 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
257 
258 }  // namespace l2_norm
259 
260 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(L2_NORMALIZATION, l2_norm::prepare, l2_norm::execute);
261 
262 }  // namespace nn
263 }  // namespace android
264