1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
20 #pragma clang diagnostic push
21 #pragma clang diagnostic ignored "-Wunused-parameter"
22 #pragma clang diagnostic ignored "-Wsign-compare"
23 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
24 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
25 #include <tensorflow/lite/kernels/internal/runtime_shape.h>
26 #pragma clang diagnostic pop
27 
28 #include <limits>
29 #include <memory>
30 #include <vector>
31 
32 #include "CpuOperationUtils.h"
33 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
34 
35 #include "BatchMatmul.h"
36 #include "OperationResolver.h"
37 #include "OperationsExecutionUtils.h"
38 #include "Tracing.h"
39 
40 namespace android {
41 namespace nn {
42 namespace batch_matmul_op {
43 
44 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
45 namespace {
46 
47 // Checks if two matrices can be multiplied.
canMatrixMul(uint32_t LHSRow,uint32_t LHSCol,uint32_t RHSRow,uint32_t RHSCol,bool adjX,bool adjY)48 bool canMatrixMul(uint32_t LHSRow, uint32_t LHSCol, uint32_t RHSRow, uint32_t RHSCol, bool adjX,
49                   bool adjY) {
50     if (LHSRow == 0 || LHSCol == 0 || RHSRow == 0 || RHSCol == 0) {
51         return false;
52     }
53     if (adjX) {
54         LHSCol = LHSRow;
55     }
56     if (adjY) {
57         RHSRow = RHSCol;
58     }
59     return LHSCol == RHSRow;
60 }
61 
62 // Computes the dimensions of output tensor.
computeOutputDimensions(const Shape & LHSTensorShape,const Shape & RHSTensorShape,bool adjX,bool adjY)63 std::vector<uint32_t> computeOutputDimensions(const Shape& LHSTensorShape,
64                                               const Shape& RHSTensorShape, bool adjX, bool adjY) {
65     uint32_t numDims = getNumberOfDimensions(LHSTensorShape);
66     auto outputTensorDimensions = LHSTensorShape.dimensions;
67     outputTensorDimensions[numDims - 2] =
68             adjX ? LHSTensorShape.dimensions[numDims - 1] : LHSTensorShape.dimensions[numDims - 2];
69     outputTensorDimensions[numDims - 1] =
70             adjY ? RHSTensorShape.dimensions[numDims - 2] : RHSTensorShape.dimensions[numDims - 1];
71     return outputTensorDimensions;
72 }
73 
74 // Swaps row and column dimensions for a shape.
swapRowColumnDims(const Shape & shape)75 Shape swapRowColumnDims(const Shape& shape) {
76     Shape swappedShape = shape;
77     uint32_t numDims = getNumberOfDimensions(shape);
78     swappedShape.dimensions[numDims - 2] = shape.dimensions[numDims - 1];
79     swappedShape.dimensions[numDims - 1] = shape.dimensions[numDims - 2];
80     return swappedShape;
81 }
82 
83 // Transposes a matrix.
84 template <typename T>
transposeRowsColumns(const T * inputData,const Shape & inputShape,T * outputData)85 void transposeRowsColumns(const T* inputData, const Shape& inputShape, T* outputData) {
86     Shape transposedShape = swapRowColumnDims(inputShape);
87     tflite::TransposeParams params;
88     int rank = getNumberOfDimensions(inputShape);
89     params.perm_count = rank;
90     for (int i = 0; i < rank - 2; ++i) {
91         params.perm[i] = i;
92     }
93     params.perm[rank - 2] = rank - 1;
94     params.perm[rank - 1] = rank - 2;
95     tflite::reference_ops::Transpose(params, convertShapeToTflshape(inputShape), inputData,
96                                      convertShapeToTflshape(transposedShape), outputData);
97 }
98 
99 // Creates a temporary space in heap.
100 // Note that it is caller's responsibility to free the memory.
101 template <typename T>
getTempData(uint32_t numElems)102 std::unique_ptr<T[]> getTempData(uint32_t numElems) {
103     return std::unique_ptr<T[]>(new (std::nothrow) T[numElems]);
104 }
105 
106 // Performs batch matmul.
107 // LHS <..., A, B>  X  RHS<..., B, C>
108 // We assume that LHS and RHS are both row oriented (adjacent values in memory
109 // are in the same row) and will output in the same memory layout. However,
110 // TFLite's fast GEMM libraries assume RCC layout (LHS row oriented,
111 // RHS column oriented, output column oriented). Therefore, we perform
112 // RHS <..., C, B> X LHS <..., B, A>
113 // where output is a C X A column-oriented, which is equivalent to
114 // A X C row-oriented.
115 template <typename T>
batchMatMulGeneric(const T * inputLHSData,const Shape & inputLHSShape,const T * inputRHSData,const Shape & inputRHSShape,const bool adjX,const bool adjY,T * outputData,const Shape & outputShape)116 bool batchMatMulGeneric(const T* inputLHSData, const Shape& inputLHSShape, const T* inputRHSData,
117                         const Shape& inputRHSShape, const bool adjX, const bool adjY, T* outputData,
118                         const Shape& outputShape) {
119     NNTRACE_TRANS("batchMatMulGeneric");
120     // Only performs transpose without conjugation for adjoint since complex number is not
121     // supported.
122     NNTRACE_COMP_SWITCH("reference_ops::Transpose");
123     const T* realInputLHSData = inputLHSData;
124     const T* realInputRHSData = inputRHSData;
125     auto tempInputLHSData = getTempData<T>(getNumberOfElements(inputLHSShape));
126     auto tempInputRHSData = getTempData<T>(getNumberOfElements(inputRHSShape));
127     // For LHS, it's passed as RHS and column-oriented.
128     // If adjX is false, needs to swap shape but no need to do data transpose.
129     // If adjX is true, no need to swap shape but needs to do data transpose.
130     // For RHS, it's passed as LHS and row-oriented.
131     // If adjY is false, needs to swap shape also needs to do data transpose.
132     // If adjY is true, no need to swap shape also no need to do data transpose.
133     if (adjX) {
134         transposeRowsColumns(inputLHSData, inputLHSShape, tempInputLHSData.get());
135         realInputLHSData = tempInputLHSData.get();
136     }
137     if (!adjY) {
138         transposeRowsColumns(inputRHSData, inputRHSShape, tempInputRHSData.get());
139         realInputRHSData = tempInputRHSData.get();
140     }
141     Shape realInputLHSShape = adjX ? inputLHSShape : swapRowColumnDims(inputLHSShape);
142     Shape realInputRHSShape = adjY ? inputRHSShape : swapRowColumnDims(inputRHSShape);
143     NNTRACE_COMP_SWITCH("reference_ops::BatchMatMul");
144     tflite::reference_ops::BatchMatMul(convertShapeToTflshape(realInputRHSShape), realInputRHSData,
145                                        convertShapeToTflshape(realInputLHSShape), realInputLHSData,
146                                        convertShapeToTflshape(outputShape), outputData);
147     return true;
148 }
149 
150 // Performs batch matmul for quantized types.
151 template <typename T>
batchMatMulQuantized(const T * inputLHSData,const Shape & inputLHSShape,const T * inputRHSData,const Shape & inputRHSShape,const bool adjX,const bool adjY,T * outputData,const Shape & outputShape)152 bool batchMatMulQuantized(const T* inputLHSData, const Shape& inputLHSShape, const T* inputRHSData,
153                           const Shape& inputRHSShape, const bool adjX, const bool adjY,
154                           T* outputData, const Shape& outputShape) {
155     NNTRACE_TRANS("batchMatMulQuantized");
156     NNTRACE_COMP_SWITCH("reference_ops::Transpose");
157     const T* realInputLHSData = inputLHSData;
158     const T* realInputRHSData = inputRHSData;
159     auto tempInputLHSData = getTempData<T>(getNumberOfElements(inputLHSShape));
160     auto tempInputRHSData = getTempData<T>(getNumberOfElements(inputRHSShape));
161     if (adjX) {
162         transposeRowsColumns(inputLHSData, inputLHSShape, tempInputLHSData.get());
163         realInputLHSData = tempInputLHSData.get();
164     }
165     if (!adjY) {
166         transposeRowsColumns(inputRHSData, inputRHSShape, tempInputRHSData.get());
167         realInputRHSData = tempInputRHSData.get();
168     }
169     Shape realInputLHSShape = adjX ? inputLHSShape : swapRowColumnDims(inputLHSShape);
170     Shape realInputRHSShape = adjY ? inputRHSShape : swapRowColumnDims(inputRHSShape);
171 
172     NNTRACE_COMP_SWITCH("reference_ops::BatchMatMul");
173 
174     double realMultiplier = 0.0;
175     int32_t outputMultiplier = 0;
176     int32_t outputShift = 0;
177     NN_RET_CHECK(GetQuantizedConvolutionMultiplier(realInputLHSShape, realInputRHSShape,
178                                                    outputShape, &realMultiplier));
179     NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &outputShift));
180     tflite::FullyConnectedParams params;
181     params.input_offset = -realInputLHSShape.offset;
182     params.weights_offset = -realInputRHSShape.offset;
183     params.output_offset = outputShape.offset;
184     params.output_multiplier = outputMultiplier;
185     params.output_shift = outputShift;
186     // BatchMatMul has no fused activation functions. Therefore, sets
187     // output activation min and max to min and max of int8_t.
188     params.quantized_activation_min = std::numeric_limits<int8_t>::min();
189     params.quantized_activation_max = std::numeric_limits<int8_t>::max();
190     params.lhs_cacheable = false;
191     params.rhs_cacheable = false;
192 
193     tflite::reference_ops::BatchMatMul<T, int32_t>(
194             params, convertShapeToTflshape(realInputRHSShape), realInputRHSData,
195             convertShapeToTflshape(realInputLHSShape), realInputLHSData,
196             convertShapeToTflshape(outputShape), outputData);
197     return true;
198 }
199 
200 }  // namespace
201 
prepare(IOperationExecutionContext * context)202 bool prepare(IOperationExecutionContext* context) {
203     Shape inputLHSTensorShape = context->getInputShape(kInputLHSTensor);
204     Shape inputRHSTensorShape = context->getInputShape(kInputRHSTensor);
205     // Checks two input tensors have same number of dimensions.
206     NN_RET_CHECK_EQ(getNumberOfDimensions(inputLHSTensorShape),
207                     getNumberOfDimensions(inputRHSTensorShape))
208             << "Input tensor ranks do not match with each other.";
209     NN_RET_CHECK_GE(getNumberOfDimensions(inputLHSTensorShape), 2u)
210             << "Input tensor rank should be at least 2.";
211     NN_RET_CHECK_LE(getNumberOfDimensions(inputLHSTensorShape), 4u)
212             << "Input tensor rank should be at most 4.";
213     uint32_t numDims = getNumberOfDimensions(inputLHSTensorShape);
214     const bool adjX = context->getInputValue<bool>(kInputLHSAdj);
215     const bool adjY = context->getInputValue<bool>(kInputRHSAdj);
216     // Checks dimensions work for matrix multiplication.
217     NN_RET_CHECK(canMatrixMul(getSizeOfDimension(inputLHSTensorShape, numDims - 2),
218                               getSizeOfDimension(inputLHSTensorShape, numDims - 1),
219                               getSizeOfDimension(inputRHSTensorShape, numDims - 2),
220                               getSizeOfDimension(inputRHSTensorShape, numDims - 1), adjX, adjY))
221             << "Input tensors are not able to perform matrix multiplication.";
222 
223     Shape outputTensorShape = context->getOutputShape(kOutputTensor);
224     outputTensorShape.dimensions =
225             computeOutputDimensions(inputLHSTensorShape, inputRHSTensorShape, adjX, adjY);
226     return context->setOutputShape(kOutputTensor, outputTensorShape);
227 }
228 
execute(IOperationExecutionContext * context)229 bool execute(IOperationExecutionContext* context) {
230     switch (context->getInputType(kInputLHSTensor)) {
231         case OperandType::TENSOR_FLOAT32:
232             return batchMatMulGeneric(context->getInputBuffer<float>(kInputLHSTensor),
233                                       context->getInputShape(kInputLHSTensor),
234                                       context->getInputBuffer<float>(kInputRHSTensor),
235                                       context->getInputShape(kInputRHSTensor),
236                                       context->getInputValue<bool>(kInputLHSAdj),
237                                       context->getInputValue<bool>(kInputRHSAdj),
238                                       context->getOutputBuffer<float>(kOutputTensor),
239                                       context->getOutputShape(kOutputTensor));
240         case OperandType::TENSOR_FLOAT16:
241             return batchMatMulGeneric(context->getInputBuffer<_Float16>(kInputLHSTensor),
242                                       context->getInputShape(kInputLHSTensor),
243                                       context->getInputBuffer<_Float16>(kInputRHSTensor),
244                                       context->getInputShape(kInputRHSTensor),
245                                       context->getInputValue<bool>(kInputLHSAdj),
246                                       context->getInputValue<bool>(kInputRHSAdj),
247                                       context->getOutputBuffer<_Float16>(kOutputTensor),
248                                       context->getOutputShape(kOutputTensor));
249         case OperandType::TENSOR_INT32:
250             return batchMatMulGeneric(context->getInputBuffer<int32_t>(kInputLHSTensor),
251                                       context->getInputShape(kInputLHSTensor),
252                                       context->getInputBuffer<int32_t>(kInputRHSTensor),
253                                       context->getInputShape(kInputRHSTensor),
254                                       context->getInputValue<bool>(kInputLHSAdj),
255                                       context->getInputValue<bool>(kInputRHSAdj),
256                                       context->getOutputBuffer<int32_t>(kOutputTensor),
257                                       context->getOutputShape(kOutputTensor));
258         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
259             return batchMatMulQuantized(context->getInputBuffer<int8_t>(kInputLHSTensor),
260                                         context->getInputShape(kInputLHSTensor),
261                                         context->getInputBuffer<int8_t>(kInputRHSTensor),
262                                         context->getInputShape(kInputRHSTensor),
263                                         context->getInputValue<bool>(kInputLHSAdj),
264                                         context->getInputValue<bool>(kInputRHSAdj),
265                                         context->getOutputBuffer<int8_t>(kOutputTensor),
266                                         context->getOutputShape(kOutputTensor));
267         default:
268             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
269     }
270     return true;
271 }
272 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
273 
274 }  // namespace batch_matmul_op
275 
276 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(BATCH_MATMUL, batch_matmul_op::prepare,
277                                          batch_matmul_op::execute);
278 
279 }  // namespace nn
280 }  // namespace android
281