1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "UnwantedInteractionBlocker"
18 #include "UnwantedInteractionBlocker.h"
19 
20 #include <android-base/stringprintf.h>
21 #include <com_android_input_flags.h>
22 #include <ftl/enum.h>
23 #include <input/PrintTools.h>
24 #include <inttypes.h>
25 #include <linux/input-event-codes.h>
26 #include <linux/input.h>
27 #include <server_configurable_flags/get_flags.h>
28 
29 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h"
30 #include "ui/events/ozone/evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_model.h"
31 
32 namespace input_flags = com::android::input::flags;
33 
34 using android::base::StringPrintf;
35 
36 /**
37  * This type is declared here to ensure consistency between the instantiated type (used in the
38  * constructor via std::make_unique) and the cast-to type (used in PalmRejector::dump() with
39  * static_cast). Due to the lack of rtti support, dynamic_cast is not available, so this can't be
40  * checked at runtime to avoid undefined behaviour.
41  */
42 using PalmFilterImplementation = ::ui::NeuralStylusPalmDetectionFilter;
43 
44 namespace android {
45 
46 /**
47  * Log detailed debug messages about each inbound motion event notification to the blocker.
48  * Enable this via "adb shell setprop log.tag.UnwantedInteractionBlockerInboundMotion DEBUG"
49  * (requires restart)
50  */
51 const bool DEBUG_INBOUND_MOTION =
52         __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG "InboundMotion", ANDROID_LOG_INFO);
53 
54 /**
55  * Log detailed debug messages about each outbound motion event processed by the blocker.
56  * Enable this via "adb shell setprop log.tag.UnwantedInteractionBlockerOutboundMotion DEBUG"
57  * (requires restart)
58  */
59 const bool DEBUG_OUTBOUND_MOTION =
60         __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG "OutboundMotion", ANDROID_LOG_INFO);
61 
62 /**
63  * Log the data sent to the model and received back from the model.
64  * Enable this via "adb shell setprop log.tag.UnwantedInteractionBlockerModel DEBUG"
65  * (requires restart)
66  */
67 const bool DEBUG_MODEL =
68         __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG "Model", ANDROID_LOG_INFO);
69 
70 /**
71  * When multi-device input is enabled, we shouldn't use PreferStylusOverTouchBlocker at all.
72  * However, multi-device input has the following default behaviour: hovering stylus rejects touch.
73  * Therefore, if we want to disable that behaviour (and go back to a place where stylus down
74  * blocks touch, but hovering stylus doesn't interact with touch), we should just disable the entire
75  * multi-device input feature.
76  */
77 const bool ENABLE_MULTI_DEVICE_INPUT = input_flags::enable_multi_device_input() &&
78         !input_flags::disable_reject_touch_on_stylus_hover();
79 
80 // Category (=namespace) name for the input settings that are applied at boot time
81 static const char* INPUT_NATIVE_BOOT = "input_native_boot";
82 /**
83  * Feature flag name. This flag determines whether palm rejection is enabled. To enable, specify
84  * 'true' (not case sensitive) or '1'. To disable, specify any other value.
85  */
86 static const char* PALM_REJECTION_ENABLED = "palm_rejection_enabled";
87 
toLower(std::string s)88 static std::string toLower(std::string s) {
89     std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
90     return s;
91 }
92 
isFromTouchscreen(int32_t source)93 static bool isFromTouchscreen(int32_t source) {
94     return isFromSource(source, AINPUT_SOURCE_TOUCHSCREEN);
95 }
96 
toChromeTimestamp(nsecs_t eventTime)97 static ::base::TimeTicks toChromeTimestamp(nsecs_t eventTime) {
98     return ::base::TimeTicks::UnixEpoch() + ::base::TimeDelta::FromNanosecondsD(eventTime);
99 }
100 
101 /**
102  * Return true if palm rejection is enabled via the server configurable flags. Return false
103  * otherwise.
104  */
isPalmRejectionEnabled()105 static bool isPalmRejectionEnabled() {
106     std::string value = toLower(
107             server_configurable_flags::GetServerConfigurableFlag(INPUT_NATIVE_BOOT,
108                                                                  PALM_REJECTION_ENABLED, "0"));
109     if (value == "1") {
110         return true;
111     }
112     return false;
113 }
114 
getLinuxToolCode(ToolType toolType)115 static int getLinuxToolCode(ToolType toolType) {
116     switch (toolType) {
117         case ToolType::STYLUS:
118             return BTN_TOOL_PEN;
119         case ToolType::ERASER:
120             return BTN_TOOL_RUBBER;
121         case ToolType::FINGER:
122             return BTN_TOOL_FINGER;
123         case ToolType::UNKNOWN:
124         case ToolType::MOUSE:
125         case ToolType::PALM:
126             break;
127     }
128     ALOGW("Got tool type %s, converting to BTN_TOOL_FINGER", ftl::enum_string(toolType).c_str());
129     return BTN_TOOL_FINGER;
130 }
131 
getActionUpForPointerId(const NotifyMotionArgs & args,int32_t pointerId)132 static int32_t getActionUpForPointerId(const NotifyMotionArgs& args, int32_t pointerId) {
133     for (size_t i = 0; i < args.getPointerCount(); i++) {
134         if (pointerId == args.pointerProperties[i].id) {
135             return AMOTION_EVENT_ACTION_POINTER_UP |
136                     (i << AMOTION_EVENT_ACTION_POINTER_INDEX_SHIFT);
137         }
138     }
139     LOG_ALWAYS_FATAL("Can't find pointerId %" PRId32 " in %s", pointerId, args.dump().c_str());
140 }
141 
142 /**
143  * Find the action for individual pointer at the given pointer index.
144  * This is always equal to MotionEvent::getActionMasked, except for
145  * POINTER_UP or POINTER_DOWN events. For example, in a POINTER_UP event, the action for
146  * the active pointer is ACTION_POINTER_UP, while the action for the other pointers is ACTION_MOVE.
147  */
resolveActionForPointer(uint8_t pointerIndex,int32_t action)148 static int32_t resolveActionForPointer(uint8_t pointerIndex, int32_t action) {
149     const int32_t actionMasked = MotionEvent::getActionMasked(action);
150     if (actionMasked != AMOTION_EVENT_ACTION_POINTER_DOWN &&
151         actionMasked != AMOTION_EVENT_ACTION_POINTER_UP) {
152         return actionMasked;
153     }
154     // This is a POINTER_DOWN or POINTER_UP event
155     const uint8_t actionIndex = MotionEvent::getActionIndex(action);
156     if (pointerIndex == actionIndex) {
157         return actionMasked;
158     }
159     // When POINTER_DOWN or POINTER_UP happens, it's actually a MOVE for all of the other
160     // pointers
161     return AMOTION_EVENT_ACTION_MOVE;
162 }
163 
removePointerIds(const NotifyMotionArgs & args,const std::set<int32_t> & pointerIds)164 NotifyMotionArgs removePointerIds(const NotifyMotionArgs& args,
165                                   const std::set<int32_t>& pointerIds) {
166     const uint8_t actionIndex = MotionEvent::getActionIndex(args.action);
167     const int32_t actionMasked = MotionEvent::getActionMasked(args.action);
168     const bool isPointerUpOrDownAction = actionMasked == AMOTION_EVENT_ACTION_POINTER_DOWN ||
169             actionMasked == AMOTION_EVENT_ACTION_POINTER_UP;
170 
171     NotifyMotionArgs newArgs{args};
172     newArgs.pointerProperties.clear();
173     newArgs.pointerCoords.clear();
174     int32_t newActionIndex = 0;
175     for (uint32_t i = 0; i < args.getPointerCount(); i++) {
176         const int32_t pointerId = args.pointerProperties[i].id;
177         if (pointerIds.find(pointerId) != pointerIds.end()) {
178             // skip this pointer
179             if (isPointerUpOrDownAction && i == actionIndex) {
180                 // The active pointer is being removed, so the action is no longer valid.
181                 // Set the action to 'UNKNOWN' here. The caller is responsible for updating this
182                 // action later to a proper value.
183                 newArgs.action = ACTION_UNKNOWN;
184             }
185             continue;
186         }
187         newArgs.pointerProperties.push_back(args.pointerProperties[i]);
188         newArgs.pointerCoords.push_back(args.pointerCoords[i]);
189         if (i == actionIndex) {
190             newActionIndex = newArgs.getPointerCount() - 1;
191         }
192     }
193     // Update POINTER_DOWN or POINTER_UP actions
194     if (isPointerUpOrDownAction && newArgs.action != ACTION_UNKNOWN) {
195         newArgs.action =
196                 actionMasked | (newActionIndex << AMOTION_EVENT_ACTION_POINTER_INDEX_SHIFT);
197         // Convert POINTER_DOWN and POINTER_UP to DOWN and UP if there's only 1 pointer remaining
198         if (newArgs.getPointerCount() == 1) {
199             if (actionMasked == AMOTION_EVENT_ACTION_POINTER_DOWN) {
200                 newArgs.action = AMOTION_EVENT_ACTION_DOWN;
201             } else if (actionMasked == AMOTION_EVENT_ACTION_POINTER_UP) {
202                 newArgs.action = AMOTION_EVENT_ACTION_UP;
203             }
204         }
205     }
206     return newArgs;
207 }
208 
209 /**
210  * Remove stylus pointers from the provided NotifyMotionArgs.
211  *
212  * Return NotifyMotionArgs where the stylus pointers have been removed.
213  * If this results in removal of the active pointer, then return nullopt.
214  */
removeStylusPointerIds(const NotifyMotionArgs & args)215 static std::optional<NotifyMotionArgs> removeStylusPointerIds(const NotifyMotionArgs& args) {
216     std::set<int32_t> stylusPointerIds;
217     for (uint32_t i = 0; i < args.getPointerCount(); i++) {
218         if (isStylusToolType(args.pointerProperties[i].toolType)) {
219             stylusPointerIds.insert(args.pointerProperties[i].id);
220         }
221     }
222     NotifyMotionArgs withoutStylusPointers = removePointerIds(args, stylusPointerIds);
223     if (withoutStylusPointers.getPointerCount() == 0 ||
224         withoutStylusPointers.action == ACTION_UNKNOWN) {
225         return std::nullopt;
226     }
227     return withoutStylusPointers;
228 }
229 
createPalmFilterDeviceInfo(const InputDeviceInfo & deviceInfo)230 std::optional<AndroidPalmFilterDeviceInfo> createPalmFilterDeviceInfo(
231         const InputDeviceInfo& deviceInfo) {
232     if (!isFromTouchscreen(deviceInfo.getSources())) {
233         return std::nullopt;
234     }
235     AndroidPalmFilterDeviceInfo out;
236     const InputDeviceInfo::MotionRange* axisX =
237             deviceInfo.getMotionRange(AMOTION_EVENT_AXIS_X, AINPUT_SOURCE_TOUCHSCREEN);
238     if (axisX != nullptr) {
239         out.max_x = axisX->max;
240         out.x_res = axisX->resolution;
241     } else {
242         ALOGW("Palm rejection is disabled for %s because AXIS_X is not supported",
243               deviceInfo.getDisplayName().c_str());
244         return std::nullopt;
245     }
246     const InputDeviceInfo::MotionRange* axisY =
247             deviceInfo.getMotionRange(AMOTION_EVENT_AXIS_Y, AINPUT_SOURCE_TOUCHSCREEN);
248     if (axisY != nullptr) {
249         out.max_y = axisY->max;
250         out.y_res = axisY->resolution;
251     } else {
252         ALOGW("Palm rejection is disabled for %s because AXIS_Y is not supported",
253               deviceInfo.getDisplayName().c_str());
254         return std::nullopt;
255     }
256     const InputDeviceInfo::MotionRange* axisMajor =
257             deviceInfo.getMotionRange(AMOTION_EVENT_AXIS_TOUCH_MAJOR, AINPUT_SOURCE_TOUCHSCREEN);
258     if (axisMajor != nullptr) {
259         out.major_radius_res = axisMajor->resolution;
260         out.touch_major_res = axisMajor->resolution;
261     } else {
262         return std::nullopt;
263     }
264     const InputDeviceInfo::MotionRange* axisMinor =
265             deviceInfo.getMotionRange(AMOTION_EVENT_AXIS_TOUCH_MINOR, AINPUT_SOURCE_TOUCHSCREEN);
266     if (axisMinor != nullptr) {
267         out.minor_radius_res = axisMinor->resolution;
268         out.touch_minor_res = axisMinor->resolution;
269         out.minor_radius_supported = true;
270     } else {
271         out.minor_radius_supported = false;
272     }
273 
274     return out;
275 }
276 
277 /**
278  * Synthesize CANCEL events for any new pointers that should be canceled, while removing pointers
279  * that have already been canceled.
280  * The flow of the function is as follows:
281  * 1. Remove all already canceled pointers
282  * 2. Cancel all newly suppressed pointers
283  * 3. Decide what to do with the current event : keep it, or drop it
284  * The pointers can never be "unsuppressed": once a pointer is canceled, it will never become valid.
285  */
cancelSuppressedPointers(const NotifyMotionArgs & args,const std::set<int32_t> & oldSuppressedPointerIds,const std::set<int32_t> & newSuppressedPointerIds)286 std::vector<NotifyMotionArgs> cancelSuppressedPointers(
287         const NotifyMotionArgs& args, const std::set<int32_t>& oldSuppressedPointerIds,
288         const std::set<int32_t>& newSuppressedPointerIds) {
289     LOG_ALWAYS_FATAL_IF(args.getPointerCount() == 0, "0 pointers in %s", args.dump().c_str());
290 
291     // First, let's remove the old suppressed pointers. They've already been canceled previously.
292     NotifyMotionArgs oldArgs = removePointerIds(args, oldSuppressedPointerIds);
293 
294     // Cancel any newly suppressed pointers.
295     std::vector<NotifyMotionArgs> out;
296     const int32_t activePointerId =
297             args.pointerProperties[MotionEvent::getActionIndex(args.action)].id;
298     const int32_t actionMasked = MotionEvent::getActionMasked(args.action);
299     // We will iteratively remove pointers from 'removedArgs'.
300     NotifyMotionArgs removedArgs{oldArgs};
301     for (uint32_t i = 0; i < oldArgs.getPointerCount(); i++) {
302         const int32_t pointerId = oldArgs.pointerProperties[i].id;
303         if (newSuppressedPointerIds.find(pointerId) == newSuppressedPointerIds.end()) {
304             // This is a pointer that should not be canceled. Move on.
305             continue;
306         }
307         if (pointerId == activePointerId && actionMasked == AMOTION_EVENT_ACTION_POINTER_DOWN) {
308             // Remove this pointer, but don't cancel it. We'll just not send the POINTER_DOWN event
309             removedArgs = removePointerIds(removedArgs, {pointerId});
310             continue;
311         }
312 
313         if (removedArgs.getPointerCount() == 1) {
314             // We are about to remove the last pointer, which means there will be no more gesture
315             // remaining. This is identical to canceling all pointers, so just send a single CANCEL
316             // event, without any of the preceding POINTER_UP with FLAG_CANCELED events.
317             oldArgs.flags |= AMOTION_EVENT_FLAG_CANCELED;
318             oldArgs.action = AMOTION_EVENT_ACTION_CANCEL;
319             return {oldArgs};
320         }
321         // Cancel the current pointer
322         out.push_back(removedArgs);
323         out.back().flags |= AMOTION_EVENT_FLAG_CANCELED;
324         out.back().action = getActionUpForPointerId(out.back(), pointerId);
325 
326         // Remove the newly canceled pointer from the args
327         removedArgs = removePointerIds(removedArgs, {pointerId});
328     }
329 
330     // Now 'removedArgs' contains only pointers that are valid.
331     if (removedArgs.getPointerCount() <= 0 || removedArgs.action == ACTION_UNKNOWN) {
332         return out;
333     }
334     out.push_back(removedArgs);
335     return out;
336 }
337 
UnwantedInteractionBlocker(InputListenerInterface & listener)338 UnwantedInteractionBlocker::UnwantedInteractionBlocker(InputListenerInterface& listener)
339       : UnwantedInteractionBlocker(listener, isPalmRejectionEnabled()){};
340 
UnwantedInteractionBlocker(InputListenerInterface & listener,bool enablePalmRejection)341 UnwantedInteractionBlocker::UnwantedInteractionBlocker(InputListenerInterface& listener,
342                                                        bool enablePalmRejection)
343       : mQueuedListener(listener), mEnablePalmRejection(enablePalmRejection) {}
344 
notifyConfigurationChanged(const NotifyConfigurationChangedArgs & args)345 void UnwantedInteractionBlocker::notifyConfigurationChanged(
346         const NotifyConfigurationChangedArgs& args) {
347     mQueuedListener.notifyConfigurationChanged(args);
348     mQueuedListener.flush();
349 }
350 
notifyKey(const NotifyKeyArgs & args)351 void UnwantedInteractionBlocker::notifyKey(const NotifyKeyArgs& args) {
352     mQueuedListener.notifyKey(args);
353     mQueuedListener.flush();
354 }
355 
notifyMotion(const NotifyMotionArgs & args)356 void UnwantedInteractionBlocker::notifyMotion(const NotifyMotionArgs& args) {
357     ALOGD_IF(DEBUG_INBOUND_MOTION, "%s: %s", __func__, args.dump().c_str());
358     { // acquire lock
359         std::scoped_lock lock(mLock);
360         if (ENABLE_MULTI_DEVICE_INPUT) {
361             notifyMotionLocked(args);
362         } else {
363             const std::vector<NotifyMotionArgs> processedArgs =
364                     mPreferStylusOverTouchBlocker.processMotion(args);
365             for (const NotifyMotionArgs& loopArgs : processedArgs) {
366                 notifyMotionLocked(loopArgs);
367             }
368         }
369     } // release lock
370 
371     // Call out to the next stage without holding the lock
372     mQueuedListener.flush();
373 }
374 
enqueueOutboundMotionLocked(const NotifyMotionArgs & args)375 void UnwantedInteractionBlocker::enqueueOutboundMotionLocked(const NotifyMotionArgs& args) {
376     ALOGD_IF(DEBUG_OUTBOUND_MOTION, "%s: %s", __func__, args.dump().c_str());
377     mQueuedListener.notifyMotion(args);
378 }
379 
notifyMotionLocked(const NotifyMotionArgs & args)380 void UnwantedInteractionBlocker::notifyMotionLocked(const NotifyMotionArgs& args) {
381     auto it = mPalmRejectors.find(args.deviceId);
382     const bool sendToPalmRejector = it != mPalmRejectors.end() && isFromTouchscreen(args.source);
383     if (!sendToPalmRejector) {
384         enqueueOutboundMotionLocked(args);
385         return;
386     }
387 
388     std::vector<NotifyMotionArgs> processedArgs = it->second.processMotion(args);
389     for (const NotifyMotionArgs& loopArgs : processedArgs) {
390         enqueueOutboundMotionLocked(loopArgs);
391     }
392 }
393 
notifySwitch(const NotifySwitchArgs & args)394 void UnwantedInteractionBlocker::notifySwitch(const NotifySwitchArgs& args) {
395     mQueuedListener.notifySwitch(args);
396     mQueuedListener.flush();
397 }
398 
notifySensor(const NotifySensorArgs & args)399 void UnwantedInteractionBlocker::notifySensor(const NotifySensorArgs& args) {
400     mQueuedListener.notifySensor(args);
401     mQueuedListener.flush();
402 }
403 
notifyVibratorState(const NotifyVibratorStateArgs & args)404 void UnwantedInteractionBlocker::notifyVibratorState(const NotifyVibratorStateArgs& args) {
405     mQueuedListener.notifyVibratorState(args);
406     mQueuedListener.flush();
407 }
notifyDeviceReset(const NotifyDeviceResetArgs & args)408 void UnwantedInteractionBlocker::notifyDeviceReset(const NotifyDeviceResetArgs& args) {
409     { // acquire lock
410         std::scoped_lock lock(mLock);
411         auto it = mPalmRejectors.find(args.deviceId);
412         if (it != mPalmRejectors.end()) {
413             AndroidPalmFilterDeviceInfo info = it->second.getPalmFilterDeviceInfo();
414             // Re-create the object instead of resetting it
415             mPalmRejectors.erase(it);
416             mPalmRejectors.emplace(args.deviceId, info);
417         }
418         mQueuedListener.notifyDeviceReset(args);
419         mPreferStylusOverTouchBlocker.notifyDeviceReset(args);
420     } // release lock
421     // Send events to the next stage without holding the lock
422     mQueuedListener.flush();
423 }
424 
notifyPointerCaptureChanged(const NotifyPointerCaptureChangedArgs & args)425 void UnwantedInteractionBlocker::notifyPointerCaptureChanged(
426         const NotifyPointerCaptureChangedArgs& args) {
427     mQueuedListener.notifyPointerCaptureChanged(args);
428     mQueuedListener.flush();
429 }
430 
notifyInputDevicesChanged(const NotifyInputDevicesChangedArgs & args)431 void UnwantedInteractionBlocker::notifyInputDevicesChanged(
432         const NotifyInputDevicesChangedArgs& args) {
433     onInputDevicesChanged(args.inputDeviceInfos);
434     mQueuedListener.notify(args);
435     mQueuedListener.flush();
436 }
437 
onInputDevicesChanged(const std::vector<InputDeviceInfo> & inputDevices)438 void UnwantedInteractionBlocker::onInputDevicesChanged(
439         const std::vector<InputDeviceInfo>& inputDevices) {
440     std::scoped_lock lock(mLock);
441     if (!mEnablePalmRejection) {
442         // Palm rejection is disabled. Don't create any palm rejector objects.
443         return;
444     }
445 
446     // Let's see which of the existing devices didn't change, so that we can keep them
447     // and prevent event stream disruption
448     std::set<int32_t /*deviceId*/> devicesToKeep;
449     for (const InputDeviceInfo& device : inputDevices) {
450         std::optional<AndroidPalmFilterDeviceInfo> info = createPalmFilterDeviceInfo(device);
451         if (!info) {
452             continue;
453         }
454 
455         auto [it, emplaced] = mPalmRejectors.try_emplace(device.getId(), *info);
456         if (!emplaced && *info != it->second.getPalmFilterDeviceInfo()) {
457             // Re-create the PalmRejector because the device info has changed.
458             mPalmRejectors.erase(it);
459             mPalmRejectors.emplace(device.getId(), *info);
460         }
461         devicesToKeep.insert(device.getId());
462     }
463     // Delete all devices that we don't need to keep
464     std::erase_if(mPalmRejectors, [&devicesToKeep](const auto& item) {
465         auto const& [deviceId, _] = item;
466         return devicesToKeep.find(deviceId) == devicesToKeep.end();
467     });
468     mPreferStylusOverTouchBlocker.notifyInputDevicesChanged(inputDevices);
469 }
470 
dump(std::string & dump)471 void UnwantedInteractionBlocker::dump(std::string& dump) {
472     std::scoped_lock lock(mLock);
473     dump += "UnwantedInteractionBlocker:\n";
474     dump += "  mPreferStylusOverTouchBlocker:\n";
475     dump += addLinePrefix(mPreferStylusOverTouchBlocker.dump(), "    ");
476     dump += StringPrintf("  mEnablePalmRejection: %s\n",
477                          std::to_string(mEnablePalmRejection).c_str());
478     dump += StringPrintf("  isPalmRejectionEnabled (flag value): %s\n",
479                          std::to_string(isPalmRejectionEnabled()).c_str());
480     dump += mPalmRejectors.empty() ? "  mPalmRejectors: None\n" : "  mPalmRejectors:\n";
481     for (const auto& [deviceId, palmRejector] : mPalmRejectors) {
482         dump += StringPrintf("    deviceId = %" PRId32 ":\n", deviceId);
483         dump += addLinePrefix(palmRejector.dump(), "      ");
484     }
485 }
486 
monitor()487 void UnwantedInteractionBlocker::monitor() {
488     std::scoped_lock lock(mLock);
489 }
490 
~UnwantedInteractionBlocker()491 UnwantedInteractionBlocker::~UnwantedInteractionBlocker() {}
492 
update(const NotifyMotionArgs & args)493 void SlotState::update(const NotifyMotionArgs& args) {
494     for (size_t i = 0; i < args.getPointerCount(); i++) {
495         const int32_t pointerId = args.pointerProperties[i].id;
496         const int32_t resolvedAction = resolveActionForPointer(i, args.action);
497         processPointerId(pointerId, resolvedAction);
498     }
499 }
500 
findUnusedSlot() const501 size_t SlotState::findUnusedSlot() const {
502     size_t unusedSlot = 0;
503     // Since the collection is ordered, we can rely on the in-order traversal
504     for (const auto& [slot, trackingId] : mPointerIdsBySlot) {
505         if (unusedSlot != slot) {
506             break;
507         }
508         unusedSlot++;
509     }
510     return unusedSlot;
511 }
512 
processPointerId(int pointerId,int32_t actionMasked)513 void SlotState::processPointerId(int pointerId, int32_t actionMasked) {
514     switch (MotionEvent::getActionMasked(actionMasked)) {
515         case AMOTION_EVENT_ACTION_DOWN:
516         case AMOTION_EVENT_ACTION_POINTER_DOWN:
517         case AMOTION_EVENT_ACTION_HOVER_ENTER: {
518             // New pointer going down
519             size_t newSlot = findUnusedSlot();
520             mPointerIdsBySlot[newSlot] = pointerId;
521             mSlotsByPointerId[pointerId] = newSlot;
522             return;
523         }
524         case AMOTION_EVENT_ACTION_MOVE:
525         case AMOTION_EVENT_ACTION_HOVER_MOVE: {
526             return;
527         }
528         case AMOTION_EVENT_ACTION_CANCEL:
529         case AMOTION_EVENT_ACTION_POINTER_UP:
530         case AMOTION_EVENT_ACTION_UP:
531         case AMOTION_EVENT_ACTION_HOVER_EXIT: {
532             auto it = mSlotsByPointerId.find(pointerId);
533             LOG_ALWAYS_FATAL_IF(it == mSlotsByPointerId.end());
534             size_t slot = it->second;
535             // Erase this pointer from both collections
536             mPointerIdsBySlot.erase(slot);
537             mSlotsByPointerId.erase(pointerId);
538             return;
539         }
540     }
541     LOG_ALWAYS_FATAL("Unhandled action : %s", MotionEvent::actionToString(actionMasked).c_str());
542     return;
543 }
544 
getSlotForPointerId(int32_t pointerId) const545 std::optional<size_t> SlotState::getSlotForPointerId(int32_t pointerId) const {
546     auto it = mSlotsByPointerId.find(pointerId);
547     if (it == mSlotsByPointerId.end()) {
548         return std::nullopt;
549     }
550     return it->second;
551 }
552 
dump() const553 std::string SlotState::dump() const {
554     std::string out = "mSlotsByPointerId:\n";
555     out += addLinePrefix(dumpMap(mSlotsByPointerId), "  ") + "\n";
556     out += "mPointerIdsBySlot:\n";
557     out += addLinePrefix(dumpMap(mPointerIdsBySlot), "  ") + "\n";
558     return out;
559 }
560 
561 class AndroidPalmRejectionModel : public ::ui::OneDeviceTrainNeuralStylusPalmDetectionFilterModel {
562 public:
AndroidPalmRejectionModel()563     AndroidPalmRejectionModel()
564           : ::ui::OneDeviceTrainNeuralStylusPalmDetectionFilterModel(/*default version*/ "",
565                                                                      std::vector<float>()) {
566         config_.resample_period = ::ui::kResamplePeriod;
567     }
568 };
569 
PalmRejector(const AndroidPalmFilterDeviceInfo & info,std::unique_ptr<::ui::PalmDetectionFilter> filter)570 PalmRejector::PalmRejector(const AndroidPalmFilterDeviceInfo& info,
571                            std::unique_ptr<::ui::PalmDetectionFilter> filter)
572       : mSharedPalmState(std::make_unique<::ui::SharedPalmDetectionFilterState>()),
573         mDeviceInfo(info),
574         mPalmDetectionFilter(std::move(filter)) {
575     if (mPalmDetectionFilter != nullptr) {
576         // This path is used for testing. Non-testing invocations should let this constructor
577         // create a real PalmDetectionFilter
578         return;
579     }
580     std::unique_ptr<::ui::NeuralStylusPalmDetectionFilterModel> model =
581             std::make_unique<AndroidPalmRejectionModel>();
582     mPalmDetectionFilter = std::make_unique<PalmFilterImplementation>(mDeviceInfo, std::move(model),
583                                                                       mSharedPalmState.get());
584 }
585 
getTouches(const NotifyMotionArgs & args,const AndroidPalmFilterDeviceInfo & deviceInfo,const SlotState & oldSlotState,const SlotState & newSlotState)586 std::vector<::ui::InProgressTouchEvdev> getTouches(const NotifyMotionArgs& args,
587                                                    const AndroidPalmFilterDeviceInfo& deviceInfo,
588                                                    const SlotState& oldSlotState,
589                                                    const SlotState& newSlotState) {
590     std::vector<::ui::InProgressTouchEvdev> touches;
591 
592     for (size_t i = 0; i < args.getPointerCount(); i++) {
593         const int32_t pointerId = args.pointerProperties[i].id;
594         touches.emplace_back(::ui::InProgressTouchEvdev());
595         touches.back().major = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_TOUCH_MAJOR);
596         touches.back().minor = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_TOUCH_MINOR);
597         // The field 'tool_type' is not used for palm rejection
598 
599         // Whether there is new information for the touch.
600         touches.back().altered = true;
601 
602         // Whether the touch was cancelled. Touch events should be ignored till a
603         // new touch is initiated.
604         touches.back().was_cancelled = false;
605 
606         // Whether the touch is going to be canceled.
607         touches.back().cancelled = false;
608 
609         // Whether the touch is delayed at first appearance. Will not be reported yet.
610         touches.back().delayed = false;
611 
612         // Whether the touch was delayed before.
613         touches.back().was_delayed = false;
614 
615         // Whether the touch is held until end or no longer held.
616         touches.back().held = false;
617 
618         // Whether this touch was held before being sent.
619         touches.back().was_held = false;
620 
621         const int32_t resolvedAction = resolveActionForPointer(i, args.action);
622         const bool isDown = resolvedAction == AMOTION_EVENT_ACTION_POINTER_DOWN ||
623                 resolvedAction == AMOTION_EVENT_ACTION_DOWN;
624         touches.back().was_touching = !isDown;
625 
626         const bool isUpOrCancel = resolvedAction == AMOTION_EVENT_ACTION_CANCEL ||
627                 resolvedAction == AMOTION_EVENT_ACTION_UP ||
628                 resolvedAction == AMOTION_EVENT_ACTION_POINTER_UP;
629 
630         touches.back().x = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_X);
631         touches.back().y = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_Y);
632 
633         std::optional<size_t> slot = newSlotState.getSlotForPointerId(pointerId);
634         if (!slot) {
635             slot = oldSlotState.getSlotForPointerId(pointerId);
636         }
637         LOG_ALWAYS_FATAL_IF(!slot, "Could not find slot for pointer %d", pointerId);
638         touches.back().slot = *slot;
639         touches.back().tracking_id = (!isUpOrCancel) ? pointerId : -1;
640         touches.back().touching = !isUpOrCancel;
641 
642         // The fields 'radius_x' and 'radius_x' are not used for palm rejection
643         touches.back().pressure = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_PRESSURE);
644         touches.back().tool_code = getLinuxToolCode(args.pointerProperties[i].toolType);
645         // The field 'orientation' is not used for palm rejection
646         // The fields 'tilt_x' and 'tilt_y' are not used for palm rejection
647         // The field 'reported_tool_type' is not used for palm rejection
648         touches.back().stylus_button = false;
649     }
650     return touches;
651 }
652 
detectPalmPointers(const NotifyMotionArgs & args)653 std::set<int32_t> PalmRejector::detectPalmPointers(const NotifyMotionArgs& args) {
654     std::bitset<::ui::kNumTouchEvdevSlots> slotsToHold;
655     std::bitset<::ui::kNumTouchEvdevSlots> slotsToSuppress;
656 
657     // Store the slot state before we call getTouches and update it. This way, we can find
658     // the slots that have been removed due to the incoming event.
659     SlotState oldSlotState = mSlotState;
660     mSlotState.update(args);
661 
662     std::vector<::ui::InProgressTouchEvdev> touches =
663             getTouches(args, mDeviceInfo, oldSlotState, mSlotState);
664     ::base::TimeTicks chromeTimestamp = toChromeTimestamp(args.eventTime);
665 
666     if (DEBUG_MODEL) {
667         std::stringstream touchesStream;
668         for (const ::ui::InProgressTouchEvdev& touch : touches) {
669             touchesStream << touch.tracking_id << " : " << touch << "\n";
670         }
671         ALOGD("Filter: touches = %s", touchesStream.str().c_str());
672     }
673 
674     mPalmDetectionFilter->Filter(touches, chromeTimestamp, &slotsToHold, &slotsToSuppress);
675 
676     ALOGD_IF(DEBUG_MODEL, "Response: slotsToHold = %s, slotsToSuppress = %s",
677              slotsToHold.to_string().c_str(), slotsToSuppress.to_string().c_str());
678 
679     // Now that we know which slots should be suppressed, let's convert those to pointer id's.
680     std::set<int32_t> newSuppressedIds;
681     for (size_t i = 0; i < args.getPointerCount(); i++) {
682         const int32_t pointerId = args.pointerProperties[i].id;
683         std::optional<size_t> slot = oldSlotState.getSlotForPointerId(pointerId);
684         if (!slot) {
685             slot = mSlotState.getSlotForPointerId(pointerId);
686             LOG_ALWAYS_FATAL_IF(!slot, "Could not find slot for pointer id %" PRId32, pointerId);
687         }
688         if (slotsToSuppress.test(*slot)) {
689             newSuppressedIds.insert(pointerId);
690         }
691     }
692     return newSuppressedIds;
693 }
694 
processMotion(const NotifyMotionArgs & args)695 std::vector<NotifyMotionArgs> PalmRejector::processMotion(const NotifyMotionArgs& args) {
696     if (mPalmDetectionFilter == nullptr) {
697         return {args};
698     }
699     const bool skipThisEvent = args.action == AMOTION_EVENT_ACTION_HOVER_ENTER ||
700             args.action == AMOTION_EVENT_ACTION_HOVER_MOVE ||
701             args.action == AMOTION_EVENT_ACTION_HOVER_EXIT ||
702             args.action == AMOTION_EVENT_ACTION_BUTTON_PRESS ||
703             args.action == AMOTION_EVENT_ACTION_BUTTON_RELEASE ||
704             args.action == AMOTION_EVENT_ACTION_SCROLL;
705     if (skipThisEvent) {
706         // Lets not process hover events, button events, or scroll for now.
707         return {args};
708     }
709     if (args.action == AMOTION_EVENT_ACTION_DOWN) {
710         mSuppressedPointerIds.clear();
711     }
712 
713     std::set<int32_t> oldSuppressedIds;
714     std::swap(oldSuppressedIds, mSuppressedPointerIds);
715 
716     std::optional<NotifyMotionArgs> touchOnlyArgs = removeStylusPointerIds(args);
717     if (touchOnlyArgs) {
718         mSuppressedPointerIds = detectPalmPointers(*touchOnlyArgs);
719     } else {
720         // This is a stylus-only event.
721         // We can skip this event and just keep the suppressed pointer ids the same as before.
722         mSuppressedPointerIds = oldSuppressedIds;
723     }
724 
725     std::vector<NotifyMotionArgs> argsWithoutUnwantedPointers =
726             cancelSuppressedPointers(args, oldSuppressedIds, mSuppressedPointerIds);
727     for (const NotifyMotionArgs& checkArgs : argsWithoutUnwantedPointers) {
728         LOG_ALWAYS_FATAL_IF(checkArgs.action == ACTION_UNKNOWN, "%s", checkArgs.dump().c_str());
729     }
730 
731     // Only log if new pointers are getting rejected. That means mSuppressedPointerIds is not a
732     // subset of oldSuppressedIds.
733     if (!std::includes(oldSuppressedIds.begin(), oldSuppressedIds.end(),
734                        mSuppressedPointerIds.begin(), mSuppressedPointerIds.end())) {
735         ALOGI("Palm detected, removing pointer ids %s after %" PRId64 "ms from %s",
736               dumpSet(mSuppressedPointerIds).c_str(), ns2ms(args.eventTime - args.downTime),
737               args.dump().c_str());
738     }
739 
740     return argsWithoutUnwantedPointers;
741 }
742 
getPalmFilterDeviceInfo() const743 const AndroidPalmFilterDeviceInfo& PalmRejector::getPalmFilterDeviceInfo() const {
744     return mDeviceInfo;
745 }
746 
dump() const747 std::string PalmRejector::dump() const {
748     std::string out;
749     out += "mDeviceInfo:\n";
750     std::stringstream deviceInfo;
751     deviceInfo << mDeviceInfo << ", touch_major_res=" << mDeviceInfo.touch_major_res
752                << ", touch_minor_res=" << mDeviceInfo.touch_minor_res << "\n";
753     out += addLinePrefix(deviceInfo.str(), "  ");
754     out += "mSlotState:\n";
755     out += addLinePrefix(mSlotState.dump(), "  ");
756     out += "mSuppressedPointerIds: ";
757     out += dumpSet(mSuppressedPointerIds) + "\n";
758     std::stringstream state;
759     state << *mSharedPalmState;
760     out += "mSharedPalmState: " + state.str() + "\n";
761     std::stringstream filter;
762     filter << static_cast<const PalmFilterImplementation&>(*mPalmDetectionFilter);
763     out += "mPalmDetectionFilter:\n";
764     out += addLinePrefix(filter.str(), "  ") + "\n";
765     return out;
766 }
767 
768 } // namespace android
769