1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "FullyConnected.h"
20
21 #include <vector>
22
23 #include "OperationResolver.h"
24 #include "Tracing.h"
25
26 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
27 #pragma clang diagnostic push
28 #pragma clang diagnostic ignored "-Wunused-parameter"
29 #pragma clang diagnostic ignored "-Wsign-compare"
30 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
31 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
32 #include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
33 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
34 #include <tensorflow/lite/kernels/internal/types.h>
35 #pragma clang diagnostic pop
36
37 #include "CpuOperationUtils.h"
38 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
39
40 namespace android {
41 namespace nn {
42 namespace fully_connected {
43
44 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
45 namespace {
46
47 // executionMutex is used to protect concurrent access of non-threadsafe resources
48 // like gemmlowp::GemmContext.
49 // std::mutex is safe for pthreads on Android.
50 static std::mutex executionMutex;
51
fullyConnectedFloat32(const float * inputData,const Shape & inputShape,const float * weightsData,const Shape & weightsShape,const float * biasData,const Shape & biasShape,int32_t activation,float * outputData,const Shape & outputShape)52 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
53 const float* weightsData, const Shape& weightsShape,
54 const float* biasData, const Shape& biasShape, int32_t activation,
55 float* outputData, const Shape& outputShape) {
56 NNTRACE_TRANS("fullyConnectedFloat32");
57 float output_activation_min, output_activation_max;
58 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
59
60 // b/80425683, optimized implementation produces incorrect results when the
61 // number of input elements is the squre of batch_size.
62 uint32_t batch_size = getSizeOfDimension(outputShape, 0);
63 uint32_t input_n_elements = getNumberOfElements(inputShape);
64 if (batch_size * batch_size == input_n_elements) {
65 NNTRACE_COMP_SWITCH("reference_ops::FullyConnected");
66 tflite::reference_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
67 weightsData, convertShapeToDims(weightsShape),
68 biasData, convertShapeToDims(biasShape),
69 output_activation_min, output_activation_max,
70 outputData, convertShapeToDims(outputShape));
71 } else {
72 NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
73 tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
74 weightsData, convertShapeToDims(weightsShape),
75 biasData, convertShapeToDims(biasShape),
76 output_activation_min, output_activation_max,
77 outputData, convertShapeToDims(outputShape));
78 }
79 return true;
80 }
81
fullyConnectedFloat16(const _Float16 * inputData,const Shape & inputShape,const _Float16 * weightsData,const Shape & weightsShape,const _Float16 * biasData,const Shape & biasShape,int32_t activation,_Float16 * outputData,const Shape & outputShape)82 bool fullyConnectedFloat16(const _Float16* inputData, const Shape& inputShape,
83 const _Float16* weightsData, const Shape& weightsShape,
84 const _Float16* biasData, const Shape& biasShape, int32_t activation,
85 _Float16* outputData, const Shape& outputShape) {
86 NNTRACE_TRANS("fullyConnectedFloat16");
87 std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
88 convertFloat16ToFloat32(inputData, &inputDataFloat32);
89 std::vector<float> weightsDataFloat32(getNumberOfElements(weightsShape));
90 convertFloat16ToFloat32(weightsData, &weightsDataFloat32);
91 std::vector<float> biasDataFloat32(getNumberOfElements(biasShape));
92 convertFloat16ToFloat32(biasData, &biasDataFloat32);
93
94 std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
95 fullyConnectedFloat32(inputDataFloat32.data(), inputShape, weightsDataFloat32.data(),
96 weightsShape, biasDataFloat32.data(), biasShape, activation,
97 outputDataFloat32.data(), outputShape);
98 convertFloat32ToFloat16(outputDataFloat32, outputData);
99
100 return true;
101 }
102
fullyConnectedQuant8(const uint8_t * inputData,const Shape & inputShape,const uint8_t * weightsData,const Shape & weightsShape,const int32_t * biasData,const Shape & biasShape,int32_t activation,uint8_t * outputData,const Shape & outputShape)103 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
104 const uint8_t* weightsData, const Shape& weightsShape,
105 const int32_t* biasData, const Shape& biasShape, int32_t activation,
106 uint8_t* outputData, const Shape& outputShape) {
107 NNTRACE_TRANS("fullyConnectedQuant8");
108 int32_t inputOffset = -inputShape.offset;
109 int32_t weightsOffset = -weightsShape.offset;
110 int32_t outputOffset = outputShape.offset;
111
112 double realMultiplier = 0.0;
113 int32_t outputMultiplier = 0;
114 int32_t outputShift = 0;
115 int32_t outputActivationMin = 0;
116 int32_t outputActivationMax = 0;
117
118 NN_RET_CHECK(GetQuantizedConvolutionMultiplier(inputShape, weightsShape, biasShape, outputShape,
119 &realMultiplier));
120 int exponent;
121 NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &exponent));
122 outputShift = -exponent;
123 CalculateActivationRangeUint8(activation, outputShape, &outputActivationMin,
124 &outputActivationMax);
125
126 static gemmlowp::GemmContext gemmContext;
127
128 // Prevent concurrent executions that access gemmContext.
129 std::unique_lock<std::mutex> lock(executionMutex);
130 // Alow gemmlowp automatically decide how many threads to use.
131 gemmContext.set_max_num_threads(0);
132
133 NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
134 tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape), inputOffset,
135 weightsData, convertShapeToDims(weightsShape),
136 weightsOffset, biasData, convertShapeToDims(biasShape),
137 outputOffset, outputMultiplier, outputShift,
138 outputActivationMin, outputActivationMax, outputData,
139 convertShapeToDims(outputShape), &gemmContext);
140
141 return true;
142 }
143
fullyConnectedQuant8(const int8_t * inputData,const Shape & inputShape,const int8_t * weightsData,const Shape & weightsShape,const int32_t * biasData,const Shape & biasShape,int32_t activation,int8_t * outputData,const Shape & outputShape)144 bool fullyConnectedQuant8(const int8_t* inputData, const Shape& inputShape,
145 const int8_t* weightsData, const Shape& weightsShape,
146 const int32_t* biasData, const Shape& biasShape, int32_t activation,
147 int8_t* outputData, const Shape& outputShape) {
148 NNTRACE_TRANS("fullyConnectedQuant8Signed");
149
150 double realMultiplier = 0.0;
151 int32_t outputMultiplier = 0;
152 int32_t outputShift = 0;
153 int32_t outputActivationMin = 0;
154 int32_t outputActivationMax = 0;
155
156 NN_RET_CHECK(GetQuantizedConvolutionMultiplier(inputShape, weightsShape, biasShape, outputShape,
157 &realMultiplier));
158 NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &outputShift));
159 CalculateActivationRangeInt8(activation, outputShape, &outputActivationMin,
160 &outputActivationMax);
161
162 tflite::FullyConnectedParams params;
163 params.input_offset = -inputShape.offset;
164 params.weights_offset = -weightsShape.offset;
165 params.output_offset = outputShape.offset;
166 params.output_multiplier = outputMultiplier;
167 params.output_shift = outputShift;
168 params.quantized_activation_min = outputActivationMin;
169 params.quantized_activation_max = outputActivationMax;
170
171 NNTRACE_COMP_SWITCH("reference_integer_ops::FullyConnected");
172 tflite::reference_integer_ops::FullyConnected(
173 params, convertShapeToTflshape(inputShape), inputData,
174 convertShapeToTflshape(weightsShape), weightsData, convertShapeToTflshape(biasShape),
175 biasData, convertShapeToTflshape(outputShape), outputData);
176
177 return true;
178 }
179
180 } // namespace
181
prepare(IOperationExecutionContext * context)182 bool prepare(IOperationExecutionContext* context) {
183 Shape input = context->getInputShape(kInputTensor);
184 Shape weights = context->getInputShape(kWeightsTensor);
185 Shape bias = context->getInputShape(kBiasTensor);
186 Shape output = context->getOutputShape(kOutputTensor);
187 NN_RET_CHECK(validateShapes(input, weights, bias, &output));
188 return context->setOutputShape(kOutputTensor, output);
189 }
190
execute(IOperationExecutionContext * context)191 bool execute(IOperationExecutionContext* context) {
192 // Bypass execution in the case of zero-sized input.
193 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
194 switch (context->getInputType(kInputTensor)) {
195 case OperandType::TENSOR_FLOAT32:
196 return fullyConnectedFloat32(context->getInputBuffer<float>(kInputTensor),
197 context->getInputShape(kInputTensor),
198 context->getInputBuffer<float>(kWeightsTensor),
199 context->getInputShape(kWeightsTensor),
200 context->getInputBuffer<float>(kBiasTensor),
201 context->getInputShape(kBiasTensor),
202 context->getInputValue<int32_t>(kActivationScalar),
203 context->getOutputBuffer<float>(kOutputTensor),
204 context->getOutputShape(kOutputTensor));
205 case OperandType::TENSOR_FLOAT16:
206 return fullyConnectedFloat16(context->getInputBuffer<_Float16>(kInputTensor),
207 context->getInputShape(kInputTensor),
208 context->getInputBuffer<_Float16>(kWeightsTensor),
209 context->getInputShape(kWeightsTensor),
210 context->getInputBuffer<_Float16>(kBiasTensor),
211 context->getInputShape(kBiasTensor),
212 context->getInputValue<int32_t>(kActivationScalar),
213 context->getOutputBuffer<_Float16>(kOutputTensor),
214 context->getOutputShape(kOutputTensor));
215 case OperandType::TENSOR_QUANT8_ASYMM:
216 return fullyConnectedQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
217 context->getInputShape(kInputTensor),
218 context->getInputBuffer<uint8_t>(kWeightsTensor),
219 context->getInputShape(kWeightsTensor),
220 context->getInputBuffer<int32_t>(kBiasTensor),
221 context->getInputShape(kBiasTensor),
222 context->getInputValue<int32_t>(kActivationScalar),
223 context->getOutputBuffer<uint8_t>(kOutputTensor),
224 context->getOutputShape(kOutputTensor));
225 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
226 return fullyConnectedQuant8(context->getInputBuffer<int8_t>(kInputTensor),
227 context->getInputShape(kInputTensor),
228 context->getInputBuffer<int8_t>(kWeightsTensor),
229 context->getInputShape(kWeightsTensor),
230 context->getInputBuffer<int32_t>(kBiasTensor),
231 context->getInputShape(kBiasTensor),
232 context->getInputValue<int32_t>(kActivationScalar),
233 context->getOutputBuffer<int8_t>(kOutputTensor),
234 context->getOutputShape(kOutputTensor));
235 default:
236 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
237 }
238 }
239 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
240
241 } // namespace fully_connected
242
243 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(FULLY_CONNECTED, fully_connected::prepare,
244 fully_connected::execute, .allowZeroSizedInput = true);
245
246 } // namespace nn
247 } // namespace android
248