1 /*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <cmath>
19 #include <fstream>
20 #include <ios>
21 #include <iterator>
22 #include <string>
23
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include <input/TfLiteMotionPredictor.h>
27
28 namespace android {
29 namespace {
30
31 using ::testing::Each;
32 using ::testing::ElementsAre;
33 using ::testing::FloatNear;
34
TEST(TfLiteMotionPredictorTest,BuffersReadiness)35 TEST(TfLiteMotionPredictorTest, BuffersReadiness) {
36 TfLiteMotionPredictorBuffers buffers(/*inputLength=*/5);
37 ASSERT_FALSE(buffers.isReady());
38
39 buffers.pushSample(/*timestamp=*/0, {.position = {.x = 100, .y = 100}});
40 ASSERT_FALSE(buffers.isReady());
41
42 buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 100}});
43 ASSERT_FALSE(buffers.isReady());
44
45 // Two samples with distinct positions are required.
46 buffers.pushSample(/*timestamp=*/2, {.position = {.x = 100, .y = 110}});
47 ASSERT_TRUE(buffers.isReady());
48
49 buffers.reset();
50 ASSERT_FALSE(buffers.isReady());
51 }
52
TEST(TfLiteMotionPredictorTest,BuffersRecentData)53 TEST(TfLiteMotionPredictorTest, BuffersRecentData) {
54 TfLiteMotionPredictorBuffers buffers(/*inputLength=*/5);
55
56 buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 200}});
57 ASSERT_EQ(buffers.lastTimestamp(), 1);
58
59 buffers.pushSample(/*timestamp=*/2, {.position = {.x = 150, .y = 250}});
60 ASSERT_EQ(buffers.lastTimestamp(), 2);
61 ASSERT_TRUE(buffers.isReady());
62 ASSERT_EQ(buffers.axisFrom().position.x, 100);
63 ASSERT_EQ(buffers.axisFrom().position.y, 200);
64 ASSERT_EQ(buffers.axisTo().position.x, 150);
65 ASSERT_EQ(buffers.axisTo().position.y, 250);
66
67 // Position doesn't change, so neither do the axes.
68 buffers.pushSample(/*timestamp=*/3, {.position = {.x = 150, .y = 250}});
69 ASSERT_EQ(buffers.lastTimestamp(), 3);
70 ASSERT_TRUE(buffers.isReady());
71 ASSERT_EQ(buffers.axisFrom().position.x, 100);
72 ASSERT_EQ(buffers.axisFrom().position.y, 200);
73 ASSERT_EQ(buffers.axisTo().position.x, 150);
74 ASSERT_EQ(buffers.axisTo().position.y, 250);
75
76 buffers.pushSample(/*timestamp=*/4, {.position = {.x = 180, .y = 280}});
77 ASSERT_EQ(buffers.lastTimestamp(), 4);
78 ASSERT_TRUE(buffers.isReady());
79 ASSERT_EQ(buffers.axisFrom().position.x, 150);
80 ASSERT_EQ(buffers.axisFrom().position.y, 250);
81 ASSERT_EQ(buffers.axisTo().position.x, 180);
82 ASSERT_EQ(buffers.axisTo().position.y, 280);
83 }
84
TEST(TfLiteMotionPredictorTest,BuffersCopyTo)85 TEST(TfLiteMotionPredictorTest, BuffersCopyTo) {
86 std::unique_ptr<TfLiteMotionPredictorModel> model = TfLiteMotionPredictorModel::create();
87 TfLiteMotionPredictorBuffers buffers(model->inputLength());
88
89 buffers.pushSample(/*timestamp=*/1,
90 {.position = {.x = 10, .y = 10},
91 .pressure = 0,
92 .orientation = 0,
93 .tilt = 0.2});
94 buffers.pushSample(/*timestamp=*/2,
95 {.position = {.x = 10, .y = 50},
96 .pressure = 0.4,
97 .orientation = M_PI / 4,
98 .tilt = 0.3});
99 buffers.pushSample(/*timestamp=*/3,
100 {.position = {.x = 30, .y = 50},
101 .pressure = 0.5,
102 .orientation = -M_PI / 4,
103 .tilt = 0.4});
104 buffers.pushSample(/*timestamp=*/3,
105 {.position = {.x = 30, .y = 60},
106 .pressure = 0,
107 .orientation = 0,
108 .tilt = 0.5});
109 buffers.copyTo(*model);
110
111 const int zeroPadding = model->inputLength() - 3;
112 ASSERT_GE(zeroPadding, 0);
113
114 EXPECT_THAT(model->inputR().subspan(0, zeroPadding), Each(0));
115 EXPECT_THAT(model->inputPhi().subspan(0, zeroPadding), Each(0));
116 EXPECT_THAT(model->inputPressure().subspan(0, zeroPadding), Each(0));
117 EXPECT_THAT(model->inputTilt().subspan(0, zeroPadding), Each(0));
118 EXPECT_THAT(model->inputOrientation().subspan(0, zeroPadding), Each(0));
119
120 EXPECT_THAT(model->inputR().subspan(zeroPadding), ElementsAre(40, 20, 10));
121 EXPECT_THAT(model->inputPhi().subspan(zeroPadding), ElementsAre(0, -M_PI / 2, M_PI / 2));
122 EXPECT_THAT(model->inputPressure().subspan(zeroPadding), ElementsAre(0.4, 0.5, 0));
123 EXPECT_THAT(model->inputTilt().subspan(zeroPadding), ElementsAre(0.3, 0.4, 0.5));
124 EXPECT_THAT(model->inputOrientation().subspan(zeroPadding),
125 ElementsAre(FloatNear(-M_PI / 4, 1e-5), FloatNear(M_PI / 4, 1e-5),
126 FloatNear(M_PI / 2, 1e-5)));
127 }
128
TEST(TfLiteMotionPredictorTest,ModelInputOutputLength)129 TEST(TfLiteMotionPredictorTest, ModelInputOutputLength) {
130 std::unique_ptr<TfLiteMotionPredictorModel> model = TfLiteMotionPredictorModel::create();
131 ASSERT_GT(model->inputLength(), 0u);
132
133 const size_t inputLength = model->inputLength();
134 ASSERT_EQ(inputLength, static_cast<size_t>(model->inputR().size()));
135 ASSERT_EQ(inputLength, static_cast<size_t>(model->inputPhi().size()));
136 ASSERT_EQ(inputLength, static_cast<size_t>(model->inputPressure().size()));
137 ASSERT_EQ(inputLength, static_cast<size_t>(model->inputOrientation().size()));
138 ASSERT_EQ(inputLength, static_cast<size_t>(model->inputTilt().size()));
139
140 ASSERT_TRUE(model->invoke());
141
142 const size_t outputLength = model->outputLength();
143 ASSERT_EQ(outputLength, static_cast<size_t>(model->outputR().size()));
144 ASSERT_EQ(outputLength, static_cast<size_t>(model->outputPhi().size()));
145 ASSERT_EQ(outputLength, static_cast<size_t>(model->outputPressure().size()));
146 }
147
TEST(TfLiteMotionPredictorTest,ModelOutput)148 TEST(TfLiteMotionPredictorTest, ModelOutput) {
149 std::unique_ptr<TfLiteMotionPredictorModel> model = TfLiteMotionPredictorModel::create();
150 TfLiteMotionPredictorBuffers buffers(model->inputLength());
151
152 buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 200}, .pressure = 0.2});
153 buffers.pushSample(/*timestamp=*/2, {.position = {.x = 150, .y = 250}, .pressure = 0.4});
154 buffers.pushSample(/*timestamp=*/3, {.position = {.x = 180, .y = 280}, .pressure = 0.6});
155 buffers.copyTo(*model);
156
157 ASSERT_TRUE(model->invoke());
158
159 // The actual model output is implementation-defined, but it should at least be non-zero and
160 // non-NaN.
161 const auto is_valid = [](float value) { return !isnan(value) && value != 0; };
162 ASSERT_TRUE(std::all_of(model->outputR().begin(), model->outputR().end(), is_valid));
163 ASSERT_TRUE(std::all_of(model->outputPhi().begin(), model->outputPhi().end(), is_valid));
164 ASSERT_TRUE(
165 std::all_of(model->outputPressure().begin(), model->outputPressure().end(), is_valid));
166 }
167
168 } // namespace
169 } // namespace android
170