1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "FullyConnected.h"
20 
21 #include <vector>
22 
23 #include "OperationResolver.h"
24 #include "Tracing.h"
25 
26 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
27 #pragma clang diagnostic push
28 #pragma clang diagnostic ignored "-Wunused-parameter"
29 #pragma clang diagnostic ignored "-Wsign-compare"
30 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
31 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
32 #include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
33 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
34 #include <tensorflow/lite/kernels/internal/types.h>
35 #pragma clang diagnostic pop
36 
37 #include "CpuOperationUtils.h"
38 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
39 
40 namespace android {
41 namespace nn {
42 namespace fully_connected {
43 
44 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
45 namespace {
46 
47 // executionMutex is used to protect concurrent access of non-threadsafe resources
48 // like gemmlowp::GemmContext.
49 // std::mutex is safe for pthreads on Android.
50 static std::mutex executionMutex;
51 
fullyConnectedFloat32(const float * inputData,const Shape & inputShape,const float * weightsData,const Shape & weightsShape,const float * biasData,const Shape & biasShape,int32_t activation,float * outputData,const Shape & outputShape)52 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
53                            const float* weightsData, const Shape& weightsShape,
54                            const float* biasData, const Shape& biasShape, int32_t activation,
55                            float* outputData, const Shape& outputShape) {
56     NNTRACE_TRANS("fullyConnectedFloat32");
57     float output_activation_min, output_activation_max;
58     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
59 
60     // b/80425683, optimized implementation produces incorrect results when the
61     // number of input elements is the squre of batch_size.
62     uint32_t batch_size = getSizeOfDimension(outputShape, 0);
63     uint32_t input_n_elements = getNumberOfElements(inputShape);
64     if (batch_size * batch_size == input_n_elements) {
65         NNTRACE_COMP_SWITCH("reference_ops::FullyConnected");
66         tflite::reference_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
67                                               weightsData, convertShapeToDims(weightsShape),
68                                               biasData, convertShapeToDims(biasShape),
69                                               output_activation_min, output_activation_max,
70                                               outputData, convertShapeToDims(outputShape));
71     } else {
72         NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
73         tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
74                                               weightsData, convertShapeToDims(weightsShape),
75                                               biasData, convertShapeToDims(biasShape),
76                                               output_activation_min, output_activation_max,
77                                               outputData, convertShapeToDims(outputShape));
78     }
79     return true;
80 }
81 
fullyConnectedFloat16(const _Float16 * inputData,const Shape & inputShape,const _Float16 * weightsData,const Shape & weightsShape,const _Float16 * biasData,const Shape & biasShape,int32_t activation,_Float16 * outputData,const Shape & outputShape)82 bool fullyConnectedFloat16(const _Float16* inputData, const Shape& inputShape,
83                            const _Float16* weightsData, const Shape& weightsShape,
84                            const _Float16* biasData, const Shape& biasShape, int32_t activation,
85                            _Float16* outputData, const Shape& outputShape) {
86     NNTRACE_TRANS("fullyConnectedFloat16");
87     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
88     convertFloat16ToFloat32(inputData, &inputDataFloat32);
89     std::vector<float> weightsDataFloat32(getNumberOfElements(weightsShape));
90     convertFloat16ToFloat32(weightsData, &weightsDataFloat32);
91     std::vector<float> biasDataFloat32(getNumberOfElements(biasShape));
92     convertFloat16ToFloat32(biasData, &biasDataFloat32);
93 
94     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
95     fullyConnectedFloat32(inputDataFloat32.data(), inputShape, weightsDataFloat32.data(),
96                           weightsShape, biasDataFloat32.data(), biasShape, activation,
97                           outputDataFloat32.data(), outputShape);
98     convertFloat32ToFloat16(outputDataFloat32, outputData);
99 
100     return true;
101 }
102 
fullyConnectedQuant8(const uint8_t * inputData,const Shape & inputShape,const uint8_t * weightsData,const Shape & weightsShape,const int32_t * biasData,const Shape & biasShape,int32_t activation,uint8_t * outputData,const Shape & outputShape)103 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
104                           const uint8_t* weightsData, const Shape& weightsShape,
105                           const int32_t* biasData, const Shape& biasShape, int32_t activation,
106                           uint8_t* outputData, const Shape& outputShape) {
107     NNTRACE_TRANS("fullyConnectedQuant8");
108     int32_t inputOffset = -inputShape.offset;
109     int32_t weightsOffset = -weightsShape.offset;
110     int32_t outputOffset = outputShape.offset;
111 
112     double realMultiplier = 0.0;
113     int32_t outputMultiplier = 0;
114     int32_t outputShift = 0;
115     int32_t outputActivationMin = 0;
116     int32_t outputActivationMax = 0;
117 
118     NN_RET_CHECK(GetQuantizedConvolutionMultiplier(inputShape, weightsShape, biasShape, outputShape,
119                                                    &realMultiplier));
120     int exponent;
121     NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &exponent));
122     outputShift = -exponent;
123     CalculateActivationRangeUint8(activation, outputShape, &outputActivationMin,
124                                   &outputActivationMax);
125 
126     static gemmlowp::GemmContext gemmContext;
127 
128     // Prevent concurrent executions that access gemmContext.
129     std::unique_lock<std::mutex> lock(executionMutex);
130     // Alow gemmlowp automatically decide how many threads to use.
131     gemmContext.set_max_num_threads(0);
132 
133     NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
134     tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape), inputOffset,
135                                           weightsData, convertShapeToDims(weightsShape),
136                                           weightsOffset, biasData, convertShapeToDims(biasShape),
137                                           outputOffset, outputMultiplier, outputShift,
138                                           outputActivationMin, outputActivationMax, outputData,
139                                           convertShapeToDims(outputShape), &gemmContext);
140 
141     return true;
142 }
143 
fullyConnectedQuant8(const int8_t * inputData,const Shape & inputShape,const int8_t * weightsData,const Shape & weightsShape,const int32_t * biasData,const Shape & biasShape,int32_t activation,int8_t * outputData,const Shape & outputShape)144 bool fullyConnectedQuant8(const int8_t* inputData, const Shape& inputShape,
145                           const int8_t* weightsData, const Shape& weightsShape,
146                           const int32_t* biasData, const Shape& biasShape, int32_t activation,
147                           int8_t* outputData, const Shape& outputShape) {
148     NNTRACE_TRANS("fullyConnectedQuant8Signed");
149 
150     double realMultiplier = 0.0;
151     int32_t outputMultiplier = 0;
152     int32_t outputShift = 0;
153     int32_t outputActivationMin = 0;
154     int32_t outputActivationMax = 0;
155 
156     NN_RET_CHECK(GetQuantizedConvolutionMultiplier(inputShape, weightsShape, biasShape, outputShape,
157                                                    &realMultiplier));
158     NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &outputShift));
159     CalculateActivationRangeInt8(activation, outputShape, &outputActivationMin,
160                                  &outputActivationMax);
161 
162     tflite::FullyConnectedParams params;
163     params.input_offset = -inputShape.offset;
164     params.weights_offset = -weightsShape.offset;
165     params.output_offset = outputShape.offset;
166     params.output_multiplier = outputMultiplier;
167     params.output_shift = outputShift;
168     params.quantized_activation_min = outputActivationMin;
169     params.quantized_activation_max = outputActivationMax;
170 
171     NNTRACE_COMP_SWITCH("reference_integer_ops::FullyConnected");
172     tflite::reference_integer_ops::FullyConnected(
173             params, convertShapeToTflshape(inputShape), inputData,
174             convertShapeToTflshape(weightsShape), weightsData, convertShapeToTflshape(biasShape),
175             biasData, convertShapeToTflshape(outputShape), outputData);
176 
177     return true;
178 }
179 
180 }  // namespace
181 
prepare(IOperationExecutionContext * context)182 bool prepare(IOperationExecutionContext* context) {
183     Shape input = context->getInputShape(kInputTensor);
184     Shape weights = context->getInputShape(kWeightsTensor);
185     Shape bias = context->getInputShape(kBiasTensor);
186     Shape output = context->getOutputShape(kOutputTensor);
187     NN_RET_CHECK(validateShapes(input, weights, bias, &output));
188     return context->setOutputShape(kOutputTensor, output);
189 }
190 
execute(IOperationExecutionContext * context)191 bool execute(IOperationExecutionContext* context) {
192     // Bypass execution in the case of zero-sized input.
193     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
194     switch (context->getInputType(kInputTensor)) {
195         case OperandType::TENSOR_FLOAT32:
196             return fullyConnectedFloat32(context->getInputBuffer<float>(kInputTensor),
197                                          context->getInputShape(kInputTensor),
198                                          context->getInputBuffer<float>(kWeightsTensor),
199                                          context->getInputShape(kWeightsTensor),
200                                          context->getInputBuffer<float>(kBiasTensor),
201                                          context->getInputShape(kBiasTensor),
202                                          context->getInputValue<int32_t>(kActivationScalar),
203                                          context->getOutputBuffer<float>(kOutputTensor),
204                                          context->getOutputShape(kOutputTensor));
205         case OperandType::TENSOR_FLOAT16:
206             return fullyConnectedFloat16(context->getInputBuffer<_Float16>(kInputTensor),
207                                          context->getInputShape(kInputTensor),
208                                          context->getInputBuffer<_Float16>(kWeightsTensor),
209                                          context->getInputShape(kWeightsTensor),
210                                          context->getInputBuffer<_Float16>(kBiasTensor),
211                                          context->getInputShape(kBiasTensor),
212                                          context->getInputValue<int32_t>(kActivationScalar),
213                                          context->getOutputBuffer<_Float16>(kOutputTensor),
214                                          context->getOutputShape(kOutputTensor));
215         case OperandType::TENSOR_QUANT8_ASYMM:
216             return fullyConnectedQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
217                                         context->getInputShape(kInputTensor),
218                                         context->getInputBuffer<uint8_t>(kWeightsTensor),
219                                         context->getInputShape(kWeightsTensor),
220                                         context->getInputBuffer<int32_t>(kBiasTensor),
221                                         context->getInputShape(kBiasTensor),
222                                         context->getInputValue<int32_t>(kActivationScalar),
223                                         context->getOutputBuffer<uint8_t>(kOutputTensor),
224                                         context->getOutputShape(kOutputTensor));
225         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
226             return fullyConnectedQuant8(context->getInputBuffer<int8_t>(kInputTensor),
227                                         context->getInputShape(kInputTensor),
228                                         context->getInputBuffer<int8_t>(kWeightsTensor),
229                                         context->getInputShape(kWeightsTensor),
230                                         context->getInputBuffer<int32_t>(kBiasTensor),
231                                         context->getInputShape(kBiasTensor),
232                                         context->getInputValue<int32_t>(kActivationScalar),
233                                         context->getOutputBuffer<int8_t>(kOutputTensor),
234                                         context->getOutputShape(kOutputTensor));
235         default:
236             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
237     }
238 }
239 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
240 
241 }  // namespace fully_connected
242 
243 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(FULLY_CONNECTED, fully_connected::prepare,
244                                          fully_connected::execute, .allowZeroSizedInput = true);
245 
246 }  // namespace nn
247 }  // namespace android
248