1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // TODO(b/331815574): Decouple this test from assumed config values.
18 #include <chrono>
19 #include <cmath>
20 
21 #include <com_android_input_flags.h>
22 #include <flag_macros.h>
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include <input/Input.h>
26 #include <input/MotionPredictor.h>
27 
28 using namespace std::literals::chrono_literals;
29 
30 namespace android {
31 
32 using ::testing::IsEmpty;
33 using ::testing::SizeIs;
34 using ::testing::UnorderedElementsAre;
35 
36 constexpr int32_t DOWN = AMOTION_EVENT_ACTION_DOWN;
37 constexpr int32_t MOVE = AMOTION_EVENT_ACTION_MOVE;
38 constexpr int32_t UP = AMOTION_EVENT_ACTION_UP;
39 constexpr nsecs_t NSEC_PER_MSEC = 1'000'000;
40 
getMotionEvent(int32_t action,float x,float y,std::chrono::nanoseconds eventTime,int32_t deviceId=0)41 static MotionEvent getMotionEvent(int32_t action, float x, float y,
42                                   std::chrono::nanoseconds eventTime, int32_t deviceId = 0) {
43     MotionEvent event;
44     constexpr size_t pointerCount = 1;
45     std::vector<PointerProperties> pointerProperties;
46     std::vector<PointerCoords> pointerCoords;
47     for (size_t i = 0; i < pointerCount; i++) {
48         PointerProperties properties;
49         properties.clear();
50         properties.id = i;
51         properties.toolType = ToolType::STYLUS;
52         pointerProperties.push_back(properties);
53         PointerCoords coords;
54         coords.clear();
55         coords.setAxisValue(AMOTION_EVENT_AXIS_X, x);
56         coords.setAxisValue(AMOTION_EVENT_AXIS_Y, y);
57         pointerCoords.push_back(coords);
58     }
59 
60     ui::Transform identityTransform;
61     event.initialize(InputEvent::nextId(), deviceId, AINPUT_SOURCE_STYLUS,
62                      ui::LogicalDisplayId::DEFAULT, {0}, action, /*actionButton=*/0, /*flags=*/0,
63                      AMOTION_EVENT_EDGE_FLAG_NONE, AMETA_NONE, /*buttonState=*/0,
64                      MotionClassification::NONE, identityTransform,
65                      /*xPrecision=*/0.1,
66                      /*yPrecision=*/0.2, /*xCursorPosition=*/280, /*yCursorPosition=*/540,
67                      identityTransform, /*downTime=*/100, eventTime.count(), pointerCount,
68                      pointerProperties.data(), pointerCoords.data());
69     return event;
70 }
71 
TEST(JerkTrackerTest,JerkReadiness)72 TEST(JerkTrackerTest, JerkReadiness) {
73     JerkTracker jerkTracker(true);
74     EXPECT_FALSE(jerkTracker.jerkMagnitude());
75     jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
76     EXPECT_FALSE(jerkTracker.jerkMagnitude());
77     jerkTracker.pushSample(/*timestamp=*/1, 25, 53);
78     EXPECT_FALSE(jerkTracker.jerkMagnitude());
79     jerkTracker.pushSample(/*timestamp=*/2, 30, 60);
80     EXPECT_FALSE(jerkTracker.jerkMagnitude());
81     jerkTracker.pushSample(/*timestamp=*/3, 35, 70);
82     EXPECT_TRUE(jerkTracker.jerkMagnitude());
83     jerkTracker.reset();
84     EXPECT_FALSE(jerkTracker.jerkMagnitude());
85     jerkTracker.pushSample(/*timestamp=*/4, 30, 60);
86     EXPECT_FALSE(jerkTracker.jerkMagnitude());
87 }
88 
TEST(JerkTrackerTest,JerkCalculationNormalizedDtTrue)89 TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) {
90     JerkTracker jerkTracker(true);
91     jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
92     jerkTracker.pushSample(/*timestamp=*/1, 25, 53);
93     jerkTracker.pushSample(/*timestamp=*/2, 30, 60);
94     jerkTracker.pushSample(/*timestamp=*/3, 45, 70);
95     /**
96      * Jerk derivative table
97      * x:    20   25   30   45
98      * x':    5    5   15
99      * x'':   0   10
100      * x''': 10
101      *
102      * y:    50   53   60   70
103      * y':    3    7   10
104      * y'':   4    3
105      * y''': -1
106      */
107     EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(10, -1));
108     jerkTracker.pushSample(/*timestamp=*/4, 20, 65);
109     /**
110      * (continuing from above table)
111      * x:    45 -> 20
112      * x':   15 -> -25
113      * x'':  10 -> -40
114      * x''': -50
115      *
116      * y:    70 -> 65
117      * y':   10 -> -5
118      * y'':  3 -> -15
119      * y''': -18
120      */
121     EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-50, -18));
122 }
123 
TEST(JerkTrackerTest,JerkCalculationNormalizedDtFalse)124 TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) {
125     JerkTracker jerkTracker(false);
126     jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
127     jerkTracker.pushSample(/*timestamp=*/10, 25, 53);
128     jerkTracker.pushSample(/*timestamp=*/20, 30, 60);
129     jerkTracker.pushSample(/*timestamp=*/30, 45, 70);
130     /**
131      * Jerk derivative table
132      * x:     20   25   30   45
133      * x':    .5   .5  1.5
134      * x'':    0   .1
135      * x''': .01
136      *
137      * y:       50   53   60   70
138      * y':      .3   .7    1
139      * y'':    .04  .03
140      * y''': -.001
141      */
142     EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(.01, -.001));
143     jerkTracker.pushSample(/*timestamp=*/50, 20, 65);
144     /**
145      * (continuing from above table)
146      * x:    45 -> 20
147      * x':   1.5 -> -1.25 (delta above, divide by 20)
148      * x'':  .1 -> -.275 (delta above, divide by 10)
149      * x''': -.0375 (delta above, divide by 10)
150      *
151      * y:    70 -> 65
152      * y':   1 -> -.25 (delta above, divide by 20)
153      * y'':  .03 -> -.125 (delta above, divide by 10)
154      * y''': -.0155 (delta above, divide by 10)
155      */
156     EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-.0375, -.0155));
157 }
158 
TEST(JerkTrackerTest,JerkCalculationAfterReset)159 TEST(JerkTrackerTest, JerkCalculationAfterReset) {
160     JerkTracker jerkTracker(true);
161     jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
162     jerkTracker.pushSample(/*timestamp=*/1, 25, 53);
163     jerkTracker.pushSample(/*timestamp=*/2, 30, 60);
164     jerkTracker.pushSample(/*timestamp=*/3, 45, 70);
165     jerkTracker.pushSample(/*timestamp=*/4, 20, 65);
166     jerkTracker.reset();
167     jerkTracker.pushSample(/*timestamp=*/5, 20, 50);
168     jerkTracker.pushSample(/*timestamp=*/6, 25, 53);
169     jerkTracker.pushSample(/*timestamp=*/7, 30, 60);
170     jerkTracker.pushSample(/*timestamp=*/8, 45, 70);
171     EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(10, -1));
172 }
173 
TEST(MotionPredictorTest,IsPredictionAvailable)174 TEST(MotionPredictorTest, IsPredictionAvailable) {
175     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
176                               []() { return true /*enable prediction*/; });
177     ASSERT_TRUE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_STYLUS));
178     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
179 }
180 
TEST(MotionPredictorTest,StationaryNoiseFloor)181 TEST(MotionPredictorTest, StationaryNoiseFloor) {
182     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1,
183                               []() { return true /*enable prediction*/; });
184     predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
185     predictor.record(getMotionEvent(MOVE, 0, 1, 35ms)); // No movement.
186     std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
187     ASSERT_EQ(nullptr, predicted);
188 }
189 
TEST(MotionPredictorTest,Offset)190 TEST(MotionPredictorTest, Offset) {
191     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1,
192                               []() { return true /*enable prediction*/; });
193     predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
194     predictor.record(getMotionEvent(MOVE, 0, 5, 35ms)); // Move enough to overcome the noise floor.
195     std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
196     ASSERT_NE(nullptr, predicted);
197     ASSERT_GE(predicted->getEventTime(), 41);
198 }
199 
TEST(MotionPredictorTest,FollowsGesture)200 TEST(MotionPredictorTest, FollowsGesture) {
201     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
202                               []() { return true /*enable prediction*/; });
203     predictor.record(getMotionEvent(DOWN, 3.75, 3, 20ms));
204     predictor.record(getMotionEvent(MOVE, 4.8, 3, 30ms));
205     predictor.record(getMotionEvent(MOVE, 6.2, 3, 40ms));
206     predictor.record(getMotionEvent(MOVE, 8, 3, 50ms));
207     EXPECT_NE(nullptr, predictor.predict(90 * NSEC_PER_MSEC));
208 
209     predictor.record(getMotionEvent(UP, 10.25, 3, 60ms));
210     EXPECT_EQ(nullptr, predictor.predict(100 * NSEC_PER_MSEC));
211 }
212 
TEST(MotionPredictorTest,MultipleDevicesNotSupported)213 TEST(MotionPredictorTest, MultipleDevicesNotSupported) {
214     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
215                               []() { return true /*enable prediction*/; });
216 
217     ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 1, 3, 0ms, /*deviceId=*/0)).ok());
218     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 1, 3, 10ms, /*deviceId=*/0)).ok());
219     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 2, 5, 20ms, /*deviceId=*/0)).ok());
220     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 3, 7, 30ms, /*deviceId=*/0)).ok());
221 
222     ASSERT_FALSE(predictor.record(getMotionEvent(DOWN, 100, 300, 40ms, /*deviceId=*/1)).ok());
223     ASSERT_FALSE(predictor.record(getMotionEvent(MOVE, 100, 300, 50ms, /*deviceId=*/1)).ok());
224 }
225 
TEST(MotionPredictorTest,IndividualGesturesFromDifferentDevicesAreSupported)226 TEST(MotionPredictorTest, IndividualGesturesFromDifferentDevicesAreSupported) {
227     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
228                               []() { return true /*enable prediction*/; });
229 
230     ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 1, 3, 0ms, /*deviceId=*/0)).ok());
231     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 1, 3, 10ms, /*deviceId=*/0)).ok());
232     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 2, 5, 20ms, /*deviceId=*/0)).ok());
233     ASSERT_TRUE(predictor.record(getMotionEvent(UP, 2, 5, 30ms, /*deviceId=*/0)).ok());
234 
235     // Now, send a gesture from a different device. Since we have no active gesture, the new gesture
236     // should be processed correctly.
237     ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 100, 300, 40ms, /*deviceId=*/1)).ok());
238     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 100, 300, 50ms, /*deviceId=*/1)).ok());
239 }
240 
TEST(MotionPredictorTest,FlagDisablesPrediction)241 TEST(MotionPredictorTest, FlagDisablesPrediction) {
242     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
243                               []() { return false /*disable prediction*/; });
244     predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
245     predictor.record(getMotionEvent(MOVE, 0, 1, 35ms));
246     std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
247     ASSERT_EQ(nullptr, predicted);
248     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_STYLUS));
249     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
250 }
251 
TEST_WITH_FLAGS(MotionPredictorTest,LowJerkNoPruning,REQUIRES_FLAGS_ENABLED (ACONFIG_FLAG (com::android::input::flags,enable_prediction_pruning_via_jerk_thresholding)))252 TEST_WITH_FLAGS(
253         MotionPredictorTest, LowJerkNoPruning,
254         REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
255                                             enable_prediction_pruning_via_jerk_thresholding))) {
256     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
257                               []() { return true /*enable prediction*/; });
258 
259     // Jerk is low (0.05 normalized).
260     predictor.record(getMotionEvent(DOWN, 2, 7, 20ms));
261     predictor.record(getMotionEvent(MOVE, 2.75, 7, 30ms));
262     predictor.record(getMotionEvent(MOVE, 3.8, 7, 40ms));
263     predictor.record(getMotionEvent(MOVE, 5.2, 7, 50ms));
264     predictor.record(getMotionEvent(MOVE, 7, 7, 60ms));
265     std::unique_ptr<MotionEvent> predicted = predictor.predict(90 * NSEC_PER_MSEC);
266     EXPECT_NE(nullptr, predicted);
267     EXPECT_EQ(static_cast<size_t>(5), predicted->getHistorySize() + 1);
268 }
269 
TEST_WITH_FLAGS(MotionPredictorTest,HighJerkPredictionsPruned,REQUIRES_FLAGS_ENABLED (ACONFIG_FLAG (com::android::input::flags,enable_prediction_pruning_via_jerk_thresholding)))270 TEST_WITH_FLAGS(
271         MotionPredictorTest, HighJerkPredictionsPruned,
272         REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
273                                             enable_prediction_pruning_via_jerk_thresholding))) {
274     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
275                               []() { return true /*enable prediction*/; });
276 
277     // Jerk is incredibly high.
278     predictor.record(getMotionEvent(DOWN, 0, 5, 20ms));
279     predictor.record(getMotionEvent(MOVE, 0, 70, 30ms));
280     predictor.record(getMotionEvent(MOVE, 0, 139, 40ms));
281     predictor.record(getMotionEvent(MOVE, 0, 1421, 50ms));
282     predictor.record(getMotionEvent(MOVE, 0, 41233, 60ms));
283     std::unique_ptr<MotionEvent> predicted = predictor.predict(90 * NSEC_PER_MSEC);
284     EXPECT_EQ(nullptr, predicted);
285 }
286 
TEST_WITH_FLAGS(MotionPredictorTest,MediumJerkPredictionsSomePruned,REQUIRES_FLAGS_ENABLED (ACONFIG_FLAG (com::android::input::flags,enable_prediction_pruning_via_jerk_thresholding)))287 TEST_WITH_FLAGS(
288         MotionPredictorTest, MediumJerkPredictionsSomePruned,
289         REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
290                                             enable_prediction_pruning_via_jerk_thresholding))) {
291     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
292                               []() { return true /*enable prediction*/; });
293 
294     // Jerk is medium (1.05 normalized, which is halfway between LOW_JANK and HIGH_JANK)
295     predictor.record(getMotionEvent(DOWN, 0, 5.2, 20ms));
296     predictor.record(getMotionEvent(MOVE, 0, 11.5, 30ms));
297     predictor.record(getMotionEvent(MOVE, 0, 22, 40ms));
298     predictor.record(getMotionEvent(MOVE, 0, 37.75, 50ms));
299     predictor.record(getMotionEvent(MOVE, 0, 59.8, 60ms));
300     std::unique_ptr<MotionEvent> predicted = predictor.predict(82 * NSEC_PER_MSEC);
301     EXPECT_NE(nullptr, predicted);
302     // Halfway between LOW_JANK and HIGH_JANK means that half of the predictions
303     // will be pruned. If model prediction window is close enough to predict()
304     // call time window, then half of the model predictions (5/2 -> 2) will be
305     // ouputted.
306     EXPECT_EQ(static_cast<size_t>(3), predicted->getHistorySize() + 1);
307 }
308 
309 using AtomFields = MotionPredictorMetricsManager::AtomFields;
310 using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction;
311 
312 // Creates a mock atom reporting function that appends the reported atom to the given vector.
313 // The passed-in pointer must not be nullptr.
createMockReportAtomFunction(std::vector<AtomFields> * reportedAtomFields)314 ReportAtomFunction createMockReportAtomFunction(std::vector<AtomFields>* reportedAtomFields) {
315     return [reportedAtomFields](const AtomFields& atomFields) -> void {
316         reportedAtomFields->push_back(atomFields);
317     };
318 }
319 
TEST(MotionPredictorMetricsManagerIntegrationTest,ReportsMetrics)320 TEST(MotionPredictorMetricsManagerIntegrationTest, ReportsMetrics) {
321     std::vector<AtomFields> reportedAtomFields;
322     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
323                               []() { return true /*enable prediction*/; },
324                               createMockReportAtomFunction(&reportedAtomFields));
325 
326     ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 1, 1, 0ms, /*deviceId=*/0)).ok());
327     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 2, 2, 4ms, /*deviceId=*/0)).ok());
328     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 3, 3, 8ms, /*deviceId=*/0)).ok());
329     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 4, 4, 12ms, /*deviceId=*/0)).ok());
330     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 5, 5, 16ms, /*deviceId=*/0)).ok());
331     ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 6, 6, 20ms, /*deviceId=*/0)).ok());
332     ASSERT_TRUE(predictor.record(getMotionEvent(UP, 7, 7, 24ms, /*deviceId=*/0)).ok());
333 
334     // The number of atoms reported should equal the number of prediction time buckets, which is
335     // given by the prediction model's output length. For now, this value is always 5, and we
336     // hardcode it because it's not publicly accessible from the MotionPredictor.
337     EXPECT_EQ(5u, reportedAtomFields.size());
338 }
339 
340 } // namespace android
341