1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <initializer_list>
18 #include <string>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include <utils/Log.h>
23 
24 // Note: tflite headers have warnings, so we disable them here.
25 #pragma clang diagnostic ignored "-Wsign-compare"
26 #pragma clang diagnostic ignored "-Wunused-parameter"
27 
28 #include <flatbuffers/flexbuffers.h>
29 #include <tensorflow/lite/core/api/error_reporter.h>
30 #include <tensorflow/lite/delegates/gpu/delegate.h>
31 #include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h>
32 #include <tensorflow/lite/interpreter.h>
33 #include <tensorflow/lite/kernels/register.h>
34 #include <tensorflow/lite/model.h>
35 #include <tensorflow/lite/schema/schema_conversion_utils.h>
36 #include <tensorflow/lite/schema/schema_generated.h>
37 #include <tensorflow/lite/version.h>
38 
39 namespace tflite {
40 namespace {
41 
42 using flatbuffers::Offset;
43 using flatbuffers::Vector;
44 using ::testing::ElementsAre;
45 
46 // A tflite ErrorReporter that logs to logcat.
47 class LoggingErrorReporter : public tflite::ErrorReporter {
48   public:
Report(const char * format,va_list args)49     int Report(const char* format, va_list args) override {
50         return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
51     }
52 };
53 
54 // The tflite stable API is through creating a FlatBuffer model with a specific
55 // schema version.
56 //
57 // Here we create a simple model that performs a 2D convolutional filter
58 // and test that it can be created and works.
59 
60 // Adapted from external/tensorflow/tensorflow/lite/toco/tflite/import_test.cc
61 
62 class ModelTest : public ::testing::Test {
63   protected:
~ModelTest()64     ~ModelTest() { CleanUp(); }
65 
66     template <typename T>
CreateDataVector(const std::vector<T> & data)67     Offset<Vector<unsigned char>> CreateDataVector(const std::vector<T>& data) {
68         return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
69                                      sizeof(T) * data.size());
70     }
71 
BuildBuffers()72     Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers() {
73         auto input1881 = ::tflite::CreateBuffer(builder_);  // input buffer not assigned data.
74         auto filter1331 = ::tflite::CreateBuffer(
75                 builder_, CreateDataVector<float>({1.f, 2.f, 1.f, 2.f, 4.f, 2.f, 1.f, 2.f, 1.f}));
76         auto bias1 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({0.f}));
77         auto output1881 = ::tflite::CreateBuffer(builder_);  // output buffer not assigned data.
78         return builder_.CreateVector(
79                 std::vector<Offset<::tflite::Buffer>>({input1881, filter1331, bias1, output1881}));
80     }
81 
BuildTensors()82     Offset<Vector<Offset<::tflite::Tensor>>> BuildTensors() {
83         auto input1881 = ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 8, 8, 1}),
84                                                 ::tflite::TensorType_FLOAT32, 0 /* buffer */,
85                                                 builder_.CreateString("tensor_input"));
86         auto filter1331 = ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 3, 3, 1}),
87                                                  ::tflite::TensorType_FLOAT32, 1 /* buffer */,
88                                                  builder_.CreateString("tensor_filter"));
89         auto bias1 = ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1}),
90                                             ::tflite::TensorType_FLOAT32, 2 /* buffer */,
91                                             builder_.CreateString("tensor_bias"));
92         auto output1881 = ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 8, 8, 1}),
93                                                  ::tflite::TensorType_FLOAT32, 3 /* buffer */,
94                                                  builder_.CreateString("tensor_output"));
95         return builder_.CreateVector(
96                 std::vector<Offset<::tflite::Tensor>>({input1881, filter1331, bias1, output1881}));
97     }
98 
BuildOpCodes(std::initializer_list<::tflite::BuiltinOperator> op_codes)99     Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes(
100             std::initializer_list<::tflite::BuiltinOperator> op_codes) {
101         std::vector<Offset<::tflite::OperatorCode>> op_codes_vector;
102         for (auto op : op_codes) {
103             op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, (int8_t)op, 0));
104         }
105         return builder_.CreateVector(op_codes_vector);
106     }
107 
BuildConv2DOperator(std::initializer_list<int> inputs,std::initializer_list<int> outputs)108     Offset<::tflite::Operator> BuildConv2DOperator(std::initializer_list<int> inputs,
109                                                    std::initializer_list<int> outputs) {
110         auto is = builder_.CreateVector<int>(inputs);
111         if (inputs.size() == 0) is = 0;
112         auto os = builder_.CreateVector<int>(outputs);
113         if (outputs.size() == 0) os = 0;
114         auto op = ::tflite::CreateOperator(
115                 builder_, 0 /* opcode_index */, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
116                 ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_SAME, 1 /* stride_w */,
117                                               1 /* stride_h */,
118                                               ::tflite::ActivationFunctionType_NONE)
119                         .Union());
120         return op;
121     }
122 
BuildOperators(std::initializer_list<Offset<tflite::Operator>> operators)123     Offset<Vector<Offset<::tflite::Operator>>> BuildOperators(
124             std::initializer_list<Offset<tflite::Operator>> operators) {
125         std::vector<Offset<::tflite::Operator>> operators_vector;
126         for (auto op : operators) {
127             operators_vector.push_back(op);
128         }
129         return builder_.CreateVector(operators_vector);
130     }
131 
BuildSubGraphs(Offset<Vector<Offset<::tflite::Tensor>>> tensors,Offset<Vector<Offset<::tflite::Operator>>> operators,int num_sub_graphs=1)132     Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
133             Offset<Vector<Offset<::tflite::Tensor>>> tensors,
134             Offset<Vector<Offset<::tflite::Operator>>> operators, int num_sub_graphs = 1) {
135         std::vector<int32_t> inputs = {0};
136         std::vector<int32_t> outputs = {3};
137         std::vector<Offset<::tflite::SubGraph>> v;
138         for (int i = 0; i < num_sub_graphs; ++i) {
139             v.push_back(::tflite::CreateSubGraph(builder_, tensors, builder_.CreateVector(inputs),
140                                                  builder_.CreateVector(outputs), operators,
141                                                  builder_.CreateString("subgraph")));
142         }
143         return builder_.CreateVector(v);
144     }
145 
146     enum class DelegateType {
147         kCpu,
148         kTpu,
149     };
150 
BuildTestModel(DelegateType type)151     void BuildTestModel(DelegateType type) {
152         delegate_type_ = type;
153 
154         // auto errorReporter = std::make_unique<LoggingErrorReporter>();
155 
156         auto buffers = BuildBuffers();
157         auto tensors = BuildTensors();
158         auto opcodes = BuildOpCodes({::tflite::BuiltinOperator_CONV_2D});
159         auto conv2DOp = BuildConv2DOperator({0, 1, 2}, {3});
160         auto operators = BuildOperators({conv2DOp});
161         auto subgraphs = BuildSubGraphs(tensors, operators);
162         auto description = builder_.CreateString("ModelTest");
163 
164         ::tflite::FinishModelBuffer(
165                 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, subgraphs,
166                                                 description, buffers));
167 
168         input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
169 
170         ASSERT_NE(nullptr, input_model_);
171 
172         tflite::ops::builtin::BuiltinOpResolver resolver;
173         tflite::InterpreterBuilder builder(input_model_, resolver);
174 
175         ASSERT_EQ(kTfLiteOk, builder(&interpreter_));
176         ASSERT_NE(nullptr, interpreter_);
177 
178         TfLiteStatus delegate_status = kTfLiteOk;
179         switch (type) {
180             default:
181             case DelegateType::kCpu:
182                 break;
183             case DelegateType::kTpu:
184                 interpreter_->SetAllowFp16PrecisionForFp32(true);
185                 delegate_ = NnApiDelegate();  // singleton
186                 delegate_status = interpreter_->ModifyGraphWithDelegate(delegate_);
187                 EXPECT_EQ(kTfLiteOk, delegate_status);
188                 break;
189         }
190     }
191 
CleanUp()192     void CleanUp() {
193         // reset the values.
194         delegate_ = nullptr;
195         input_model_ = nullptr;
196         builder_.Clear();
197         interpreter_.reset();
198     }
199 
asString()200     std::string asString() {
201         return std::string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
202                            builder_.GetSize());
203     }
204 
205     flatbuffers::FlatBufferBuilder builder_;
206     const ::tflite::Model* input_model_ = nullptr;
207     TfLiteDelegate* delegate_ = nullptr;
208     std::unique_ptr<tflite::Interpreter> interpreter_;
209 
210     DelegateType delegate_type_ = DelegateType::kCpu;
211 };
212 
TEST_F(ModelTest,BuildConv)213 TEST_F(ModelTest, BuildConv) {
214     for (auto type : {DelegateType::kCpu, DelegateType::kTpu}) {
215         for (float inputValue : {10.f, 11.f}) {
216             BuildTestModel(type);
217 
218             interpreter_->AllocateTensors();
219             TfLiteTensor* input = interpreter_->input_tensor(0);
220             TfLiteTensor* output = interpreter_->output_tensor(0);
221 
222             input->data.f[0] = inputValue;
223             TfLiteStatus invoke_status = interpreter_->Invoke();
224             ASSERT_EQ(kTfLiteOk, invoke_status);
225 
226             // Result is the point impulse multiplied by the
227             // tap value of the 3 x 3 filter (starting from center).
228             EXPECT_EQ(inputValue * 4.f, output->data.f[0]);
229             EXPECT_EQ(inputValue * 2.f, output->data.f[1]);
230             EXPECT_EQ(0.f, output->data.f[2]);
231             EXPECT_EQ(inputValue * 2.f, output->data.f[8]);
232             EXPECT_EQ(inputValue * 1.f, output->data.f[9]);
233             EXPECT_EQ(0.f, output->data.f[10]);
234             CleanUp();
235         }
236     }
237 }
238 
239 }  // namespace
240 }  // namespace tflite
241