1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
20 #pragma clang diagnostic push
21 #pragma clang diagnostic ignored "-Wunused-parameter"
22 #pragma clang diagnostic ignored "-Wsign-compare"
23 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
24 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
25 #include <tensorflow/lite/kernels/internal/runtime_shape.h>
26 #pragma clang diagnostic pop
27
28 #include <limits>
29 #include <memory>
30 #include <vector>
31
32 #include "CpuOperationUtils.h"
33 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
34
35 #include "BatchMatmul.h"
36 #include "OperationResolver.h"
37 #include "OperationsExecutionUtils.h"
38 #include "Tracing.h"
39
40 namespace android {
41 namespace nn {
42 namespace batch_matmul_op {
43
44 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
45 namespace {
46
47 // Checks if two matrices can be multiplied.
canMatrixMul(uint32_t LHSRow,uint32_t LHSCol,uint32_t RHSRow,uint32_t RHSCol,bool adjX,bool adjY)48 bool canMatrixMul(uint32_t LHSRow, uint32_t LHSCol, uint32_t RHSRow, uint32_t RHSCol, bool adjX,
49 bool adjY) {
50 if (LHSRow == 0 || LHSCol == 0 || RHSRow == 0 || RHSCol == 0) {
51 return false;
52 }
53 if (adjX) {
54 LHSCol = LHSRow;
55 }
56 if (adjY) {
57 RHSRow = RHSCol;
58 }
59 return LHSCol == RHSRow;
60 }
61
62 // Computes the dimensions of output tensor.
computeOutputDimensions(const Shape & LHSTensorShape,const Shape & RHSTensorShape,bool adjX,bool adjY)63 std::vector<uint32_t> computeOutputDimensions(const Shape& LHSTensorShape,
64 const Shape& RHSTensorShape, bool adjX, bool adjY) {
65 uint32_t numDims = getNumberOfDimensions(LHSTensorShape);
66 auto outputTensorDimensions = LHSTensorShape.dimensions;
67 outputTensorDimensions[numDims - 2] =
68 adjX ? LHSTensorShape.dimensions[numDims - 1] : LHSTensorShape.dimensions[numDims - 2];
69 outputTensorDimensions[numDims - 1] =
70 adjY ? RHSTensorShape.dimensions[numDims - 2] : RHSTensorShape.dimensions[numDims - 1];
71 return outputTensorDimensions;
72 }
73
74 // Swaps row and column dimensions for a shape.
swapRowColumnDims(const Shape & shape)75 Shape swapRowColumnDims(const Shape& shape) {
76 Shape swappedShape = shape;
77 uint32_t numDims = getNumberOfDimensions(shape);
78 swappedShape.dimensions[numDims - 2] = shape.dimensions[numDims - 1];
79 swappedShape.dimensions[numDims - 1] = shape.dimensions[numDims - 2];
80 return swappedShape;
81 }
82
83 // Transposes a matrix.
84 template <typename T>
transposeRowsColumns(const T * inputData,const Shape & inputShape,T * outputData)85 void transposeRowsColumns(const T* inputData, const Shape& inputShape, T* outputData) {
86 Shape transposedShape = swapRowColumnDims(inputShape);
87 tflite::TransposeParams params;
88 int rank = getNumberOfDimensions(inputShape);
89 params.perm_count = rank;
90 for (int i = 0; i < rank - 2; ++i) {
91 params.perm[i] = i;
92 }
93 params.perm[rank - 2] = rank - 1;
94 params.perm[rank - 1] = rank - 2;
95 tflite::reference_ops::Transpose(params, convertShapeToTflshape(inputShape), inputData,
96 convertShapeToTflshape(transposedShape), outputData);
97 }
98
99 // Creates a temporary space in heap.
100 // Note that it is caller's responsibility to free the memory.
101 template <typename T>
getTempData(uint32_t numElems)102 std::unique_ptr<T[]> getTempData(uint32_t numElems) {
103 return std::unique_ptr<T[]>(new (std::nothrow) T[numElems]);
104 }
105
106 // Performs batch matmul.
107 // LHS <..., A, B> X RHS<..., B, C>
108 // We assume that LHS and RHS are both row oriented (adjacent values in memory
109 // are in the same row) and will output in the same memory layout. However,
110 // TFLite's fast GEMM libraries assume RCC layout (LHS row oriented,
111 // RHS column oriented, output column oriented). Therefore, we perform
112 // RHS <..., C, B> X LHS <..., B, A>
113 // where output is a C X A column-oriented, which is equivalent to
114 // A X C row-oriented.
115 template <typename T>
batchMatMulGeneric(const T * inputLHSData,const Shape & inputLHSShape,const T * inputRHSData,const Shape & inputRHSShape,const bool adjX,const bool adjY,T * outputData,const Shape & outputShape)116 bool batchMatMulGeneric(const T* inputLHSData, const Shape& inputLHSShape, const T* inputRHSData,
117 const Shape& inputRHSShape, const bool adjX, const bool adjY, T* outputData,
118 const Shape& outputShape) {
119 NNTRACE_TRANS("batchMatMulGeneric");
120 // Only performs transpose without conjugation for adjoint since complex number is not
121 // supported.
122 NNTRACE_COMP_SWITCH("reference_ops::Transpose");
123 const T* realInputLHSData = inputLHSData;
124 const T* realInputRHSData = inputRHSData;
125 auto tempInputLHSData = getTempData<T>(getNumberOfElements(inputLHSShape));
126 auto tempInputRHSData = getTempData<T>(getNumberOfElements(inputRHSShape));
127 // For LHS, it's passed as RHS and column-oriented.
128 // If adjX is false, needs to swap shape but no need to do data transpose.
129 // If adjX is true, no need to swap shape but needs to do data transpose.
130 // For RHS, it's passed as LHS and row-oriented.
131 // If adjY is false, needs to swap shape also needs to do data transpose.
132 // If adjY is true, no need to swap shape also no need to do data transpose.
133 if (adjX) {
134 transposeRowsColumns(inputLHSData, inputLHSShape, tempInputLHSData.get());
135 realInputLHSData = tempInputLHSData.get();
136 }
137 if (!adjY) {
138 transposeRowsColumns(inputRHSData, inputRHSShape, tempInputRHSData.get());
139 realInputRHSData = tempInputRHSData.get();
140 }
141 Shape realInputLHSShape = adjX ? inputLHSShape : swapRowColumnDims(inputLHSShape);
142 Shape realInputRHSShape = adjY ? inputRHSShape : swapRowColumnDims(inputRHSShape);
143 NNTRACE_COMP_SWITCH("reference_ops::BatchMatMul");
144 tflite::reference_ops::BatchMatMul(convertShapeToTflshape(realInputRHSShape), realInputRHSData,
145 convertShapeToTflshape(realInputLHSShape), realInputLHSData,
146 convertShapeToTflshape(outputShape), outputData);
147 return true;
148 }
149
150 // Performs batch matmul for quantized types.
151 template <typename T>
batchMatMulQuantized(const T * inputLHSData,const Shape & inputLHSShape,const T * inputRHSData,const Shape & inputRHSShape,const bool adjX,const bool adjY,T * outputData,const Shape & outputShape)152 bool batchMatMulQuantized(const T* inputLHSData, const Shape& inputLHSShape, const T* inputRHSData,
153 const Shape& inputRHSShape, const bool adjX, const bool adjY,
154 T* outputData, const Shape& outputShape) {
155 NNTRACE_TRANS("batchMatMulQuantized");
156 NNTRACE_COMP_SWITCH("reference_ops::Transpose");
157 const T* realInputLHSData = inputLHSData;
158 const T* realInputRHSData = inputRHSData;
159 auto tempInputLHSData = getTempData<T>(getNumberOfElements(inputLHSShape));
160 auto tempInputRHSData = getTempData<T>(getNumberOfElements(inputRHSShape));
161 if (adjX) {
162 transposeRowsColumns(inputLHSData, inputLHSShape, tempInputLHSData.get());
163 realInputLHSData = tempInputLHSData.get();
164 }
165 if (!adjY) {
166 transposeRowsColumns(inputRHSData, inputRHSShape, tempInputRHSData.get());
167 realInputRHSData = tempInputRHSData.get();
168 }
169 Shape realInputLHSShape = adjX ? inputLHSShape : swapRowColumnDims(inputLHSShape);
170 Shape realInputRHSShape = adjY ? inputRHSShape : swapRowColumnDims(inputRHSShape);
171
172 NNTRACE_COMP_SWITCH("reference_ops::BatchMatMul");
173
174 double realMultiplier = 0.0;
175 int32_t outputMultiplier = 0;
176 int32_t outputShift = 0;
177 NN_RET_CHECK(GetQuantizedConvolutionMultiplier(realInputLHSShape, realInputRHSShape,
178 outputShape, &realMultiplier));
179 NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &outputShift));
180 tflite::FullyConnectedParams params;
181 params.input_offset = -realInputLHSShape.offset;
182 params.weights_offset = -realInputRHSShape.offset;
183 params.output_offset = outputShape.offset;
184 params.output_multiplier = outputMultiplier;
185 params.output_shift = outputShift;
186 // BatchMatMul has no fused activation functions. Therefore, sets
187 // output activation min and max to min and max of int8_t.
188 params.quantized_activation_min = std::numeric_limits<int8_t>::min();
189 params.quantized_activation_max = std::numeric_limits<int8_t>::max();
190 params.lhs_cacheable = false;
191 params.rhs_cacheable = false;
192
193 tflite::reference_ops::BatchMatMul<T, int32_t>(
194 params, convertShapeToTflshape(realInputRHSShape), realInputRHSData,
195 convertShapeToTflshape(realInputLHSShape), realInputLHSData,
196 convertShapeToTflshape(outputShape), outputData);
197 return true;
198 }
199
200 } // namespace
201
prepare(IOperationExecutionContext * context)202 bool prepare(IOperationExecutionContext* context) {
203 Shape inputLHSTensorShape = context->getInputShape(kInputLHSTensor);
204 Shape inputRHSTensorShape = context->getInputShape(kInputRHSTensor);
205 // Checks two input tensors have same number of dimensions.
206 NN_RET_CHECK_EQ(getNumberOfDimensions(inputLHSTensorShape),
207 getNumberOfDimensions(inputRHSTensorShape))
208 << "Input tensor ranks do not match with each other.";
209 NN_RET_CHECK_GE(getNumberOfDimensions(inputLHSTensorShape), 2u)
210 << "Input tensor rank should be at least 2.";
211 NN_RET_CHECK_LE(getNumberOfDimensions(inputLHSTensorShape), 4u)
212 << "Input tensor rank should be at most 4.";
213 uint32_t numDims = getNumberOfDimensions(inputLHSTensorShape);
214 const bool adjX = context->getInputValue<bool>(kInputLHSAdj);
215 const bool adjY = context->getInputValue<bool>(kInputRHSAdj);
216 // Checks dimensions work for matrix multiplication.
217 NN_RET_CHECK(canMatrixMul(getSizeOfDimension(inputLHSTensorShape, numDims - 2),
218 getSizeOfDimension(inputLHSTensorShape, numDims - 1),
219 getSizeOfDimension(inputRHSTensorShape, numDims - 2),
220 getSizeOfDimension(inputRHSTensorShape, numDims - 1), adjX, adjY))
221 << "Input tensors are not able to perform matrix multiplication.";
222
223 Shape outputTensorShape = context->getOutputShape(kOutputTensor);
224 outputTensorShape.dimensions =
225 computeOutputDimensions(inputLHSTensorShape, inputRHSTensorShape, adjX, adjY);
226 return context->setOutputShape(kOutputTensor, outputTensorShape);
227 }
228
execute(IOperationExecutionContext * context)229 bool execute(IOperationExecutionContext* context) {
230 switch (context->getInputType(kInputLHSTensor)) {
231 case OperandType::TENSOR_FLOAT32:
232 return batchMatMulGeneric(context->getInputBuffer<float>(kInputLHSTensor),
233 context->getInputShape(kInputLHSTensor),
234 context->getInputBuffer<float>(kInputRHSTensor),
235 context->getInputShape(kInputRHSTensor),
236 context->getInputValue<bool>(kInputLHSAdj),
237 context->getInputValue<bool>(kInputRHSAdj),
238 context->getOutputBuffer<float>(kOutputTensor),
239 context->getOutputShape(kOutputTensor));
240 case OperandType::TENSOR_FLOAT16:
241 return batchMatMulGeneric(context->getInputBuffer<_Float16>(kInputLHSTensor),
242 context->getInputShape(kInputLHSTensor),
243 context->getInputBuffer<_Float16>(kInputRHSTensor),
244 context->getInputShape(kInputRHSTensor),
245 context->getInputValue<bool>(kInputLHSAdj),
246 context->getInputValue<bool>(kInputRHSAdj),
247 context->getOutputBuffer<_Float16>(kOutputTensor),
248 context->getOutputShape(kOutputTensor));
249 case OperandType::TENSOR_INT32:
250 return batchMatMulGeneric(context->getInputBuffer<int32_t>(kInputLHSTensor),
251 context->getInputShape(kInputLHSTensor),
252 context->getInputBuffer<int32_t>(kInputRHSTensor),
253 context->getInputShape(kInputRHSTensor),
254 context->getInputValue<bool>(kInputLHSAdj),
255 context->getInputValue<bool>(kInputRHSAdj),
256 context->getOutputBuffer<int32_t>(kOutputTensor),
257 context->getOutputShape(kOutputTensor));
258 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
259 return batchMatMulQuantized(context->getInputBuffer<int8_t>(kInputLHSTensor),
260 context->getInputShape(kInputLHSTensor),
261 context->getInputBuffer<int8_t>(kInputRHSTensor),
262 context->getInputShape(kInputRHSTensor),
263 context->getInputValue<bool>(kInputLHSAdj),
264 context->getInputValue<bool>(kInputRHSAdj),
265 context->getOutputBuffer<int8_t>(kOutputTensor),
266 context->getOutputShape(kOutputTensor));
267 default:
268 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
269 }
270 return true;
271 }
272 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
273
274 } // namespace batch_matmul_op
275
276 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(BATCH_MATMUL, batch_matmul_op::prepare,
277 batch_matmul_op::execute);
278
279 } // namespace nn
280 } // namespace android
281