1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <initializer_list>
18 #include <string>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include <utils/Log.h>
23
24 // Note: tflite headers have warnings, so we disable them here.
25 #pragma clang diagnostic ignored "-Wsign-compare"
26 #pragma clang diagnostic ignored "-Wunused-parameter"
27
28 #include <flatbuffers/flexbuffers.h>
29 #include <tensorflow/lite/core/api/error_reporter.h>
30 #include <tensorflow/lite/delegates/gpu/delegate.h>
31 #include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h>
32 #include <tensorflow/lite/interpreter.h>
33 #include <tensorflow/lite/kernels/register.h>
34 #include <tensorflow/lite/model.h>
35 #include <tensorflow/lite/schema/schema_conversion_utils.h>
36 #include <tensorflow/lite/schema/schema_generated.h>
37 #include <tensorflow/lite/version.h>
38
39 namespace tflite {
40 namespace {
41
42 using flatbuffers::Offset;
43 using flatbuffers::Vector;
44 using ::testing::ElementsAre;
45
46 // A tflite ErrorReporter that logs to logcat.
47 class LoggingErrorReporter : public tflite::ErrorReporter {
48 public:
Report(const char * format,va_list args)49 int Report(const char* format, va_list args) override {
50 return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
51 }
52 };
53
54 // The tflite stable API is through creating a FlatBuffer model with a specific
55 // schema version.
56 //
57 // Here we create a simple model that performs a 2D convolutional filter
58 // and test that it can be created and works.
59
60 // Adapted from external/tensorflow/tensorflow/lite/toco/tflite/import_test.cc
61
62 class ModelTest : public ::testing::Test {
63 protected:
~ModelTest()64 ~ModelTest() { CleanUp(); }
65
66 template <typename T>
CreateDataVector(const std::vector<T> & data)67 Offset<Vector<unsigned char>> CreateDataVector(const std::vector<T>& data) {
68 return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
69 sizeof(T) * data.size());
70 }
71
BuildBuffers()72 Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers() {
73 auto input1881 = ::tflite::CreateBuffer(builder_); // input buffer not assigned data.
74 auto filter1331 = ::tflite::CreateBuffer(
75 builder_, CreateDataVector<float>({1.f, 2.f, 1.f, 2.f, 4.f, 2.f, 1.f, 2.f, 1.f}));
76 auto bias1 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({0.f}));
77 auto output1881 = ::tflite::CreateBuffer(builder_); // output buffer not assigned data.
78 return builder_.CreateVector(
79 std::vector<Offset<::tflite::Buffer>>({input1881, filter1331, bias1, output1881}));
80 }
81
BuildTensors()82 Offset<Vector<Offset<::tflite::Tensor>>> BuildTensors() {
83 auto input1881 = ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 8, 8, 1}),
84 ::tflite::TensorType_FLOAT32, 0 /* buffer */,
85 builder_.CreateString("tensor_input"));
86 auto filter1331 = ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 3, 3, 1}),
87 ::tflite::TensorType_FLOAT32, 1 /* buffer */,
88 builder_.CreateString("tensor_filter"));
89 auto bias1 = ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1}),
90 ::tflite::TensorType_FLOAT32, 2 /* buffer */,
91 builder_.CreateString("tensor_bias"));
92 auto output1881 = ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 8, 8, 1}),
93 ::tflite::TensorType_FLOAT32, 3 /* buffer */,
94 builder_.CreateString("tensor_output"));
95 return builder_.CreateVector(
96 std::vector<Offset<::tflite::Tensor>>({input1881, filter1331, bias1, output1881}));
97 }
98
BuildOpCodes(std::initializer_list<::tflite::BuiltinOperator> op_codes)99 Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes(
100 std::initializer_list<::tflite::BuiltinOperator> op_codes) {
101 std::vector<Offset<::tflite::OperatorCode>> op_codes_vector;
102 for (auto op : op_codes) {
103 op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, (int8_t)op, 0));
104 }
105 return builder_.CreateVector(op_codes_vector);
106 }
107
BuildConv2DOperator(std::initializer_list<int> inputs,std::initializer_list<int> outputs)108 Offset<::tflite::Operator> BuildConv2DOperator(std::initializer_list<int> inputs,
109 std::initializer_list<int> outputs) {
110 auto is = builder_.CreateVector<int>(inputs);
111 if (inputs.size() == 0) is = 0;
112 auto os = builder_.CreateVector<int>(outputs);
113 if (outputs.size() == 0) os = 0;
114 auto op = ::tflite::CreateOperator(
115 builder_, 0 /* opcode_index */, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
116 ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_SAME, 1 /* stride_w */,
117 1 /* stride_h */,
118 ::tflite::ActivationFunctionType_NONE)
119 .Union());
120 return op;
121 }
122
BuildOperators(std::initializer_list<Offset<tflite::Operator>> operators)123 Offset<Vector<Offset<::tflite::Operator>>> BuildOperators(
124 std::initializer_list<Offset<tflite::Operator>> operators) {
125 std::vector<Offset<::tflite::Operator>> operators_vector;
126 for (auto op : operators) {
127 operators_vector.push_back(op);
128 }
129 return builder_.CreateVector(operators_vector);
130 }
131
BuildSubGraphs(Offset<Vector<Offset<::tflite::Tensor>>> tensors,Offset<Vector<Offset<::tflite::Operator>>> operators,int num_sub_graphs=1)132 Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
133 Offset<Vector<Offset<::tflite::Tensor>>> tensors,
134 Offset<Vector<Offset<::tflite::Operator>>> operators, int num_sub_graphs = 1) {
135 std::vector<int32_t> inputs = {0};
136 std::vector<int32_t> outputs = {3};
137 std::vector<Offset<::tflite::SubGraph>> v;
138 for (int i = 0; i < num_sub_graphs; ++i) {
139 v.push_back(::tflite::CreateSubGraph(builder_, tensors, builder_.CreateVector(inputs),
140 builder_.CreateVector(outputs), operators,
141 builder_.CreateString("subgraph")));
142 }
143 return builder_.CreateVector(v);
144 }
145
146 enum class DelegateType {
147 kCpu,
148 kTpu,
149 };
150
BuildTestModel(DelegateType type)151 void BuildTestModel(DelegateType type) {
152 delegate_type_ = type;
153
154 // auto errorReporter = std::make_unique<LoggingErrorReporter>();
155
156 auto buffers = BuildBuffers();
157 auto tensors = BuildTensors();
158 auto opcodes = BuildOpCodes({::tflite::BuiltinOperator_CONV_2D});
159 auto conv2DOp = BuildConv2DOperator({0, 1, 2}, {3});
160 auto operators = BuildOperators({conv2DOp});
161 auto subgraphs = BuildSubGraphs(tensors, operators);
162 auto description = builder_.CreateString("ModelTest");
163
164 ::tflite::FinishModelBuffer(
165 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, subgraphs,
166 description, buffers));
167
168 input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
169
170 ASSERT_NE(nullptr, input_model_);
171
172 tflite::ops::builtin::BuiltinOpResolver resolver;
173 tflite::InterpreterBuilder builder(input_model_, resolver);
174
175 ASSERT_EQ(kTfLiteOk, builder(&interpreter_));
176 ASSERT_NE(nullptr, interpreter_);
177
178 TfLiteStatus delegate_status = kTfLiteOk;
179 switch (type) {
180 default:
181 case DelegateType::kCpu:
182 break;
183 case DelegateType::kTpu:
184 interpreter_->SetAllowFp16PrecisionForFp32(true);
185 delegate_ = NnApiDelegate(); // singleton
186 delegate_status = interpreter_->ModifyGraphWithDelegate(delegate_);
187 EXPECT_EQ(kTfLiteOk, delegate_status);
188 break;
189 }
190 }
191
CleanUp()192 void CleanUp() {
193 // reset the values.
194 delegate_ = nullptr;
195 input_model_ = nullptr;
196 builder_.Clear();
197 interpreter_.reset();
198 }
199
asString()200 std::string asString() {
201 return std::string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
202 builder_.GetSize());
203 }
204
205 flatbuffers::FlatBufferBuilder builder_;
206 const ::tflite::Model* input_model_ = nullptr;
207 TfLiteDelegate* delegate_ = nullptr;
208 std::unique_ptr<tflite::Interpreter> interpreter_;
209
210 DelegateType delegate_type_ = DelegateType::kCpu;
211 };
212
TEST_F(ModelTest,BuildConv)213 TEST_F(ModelTest, BuildConv) {
214 for (auto type : {DelegateType::kCpu, DelegateType::kTpu}) {
215 for (float inputValue : {10.f, 11.f}) {
216 BuildTestModel(type);
217
218 interpreter_->AllocateTensors();
219 TfLiteTensor* input = interpreter_->input_tensor(0);
220 TfLiteTensor* output = interpreter_->output_tensor(0);
221
222 input->data.f[0] = inputValue;
223 TfLiteStatus invoke_status = interpreter_->Invoke();
224 ASSERT_EQ(kTfLiteOk, invoke_status);
225
226 // Result is the point impulse multiplied by the
227 // tap value of the 3 x 3 filter (starting from center).
228 EXPECT_EQ(inputValue * 4.f, output->data.f[0]);
229 EXPECT_EQ(inputValue * 2.f, output->data.f[1]);
230 EXPECT_EQ(0.f, output->data.f[2]);
231 EXPECT_EQ(inputValue * 2.f, output->data.f[8]);
232 EXPECT_EQ(inputValue * 1.f, output->data.f[9]);
233 EXPECT_EQ(0.f, output->data.f[10]);
234 CleanUp();
235 }
236 }
237 }
238
239 } // namespace
240 } // namespace tflite
241