1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! Input filter thread implementation in rust.
18 //! Using IInputFilter.aidl interface to create ever looping thread with JNI support, rest of
19 //! thread handling is done from rust side.
20 //!
21 //! NOTE: Tried using rust provided threading infrastructure but that uses std::thread which doesn't
22 //! have JNI support and can't call into Java policy that we use currently. libutils provided
23 //! Thread.h also recommends against using std::thread and using the provided infrastructure that
24 //! already provides way of attaching JniEnv to the created thread. So, we are using an AIDL
25 //! interface to expose the InputThread infrastructure to rust.
26 
27 use crate::input_filter::InputFilterThreadCreator;
28 use binder::{BinderFeatures, Interface, Strong};
29 use com_android_server_inputflinger::aidl::com::android::server::inputflinger::IInputThread::{
30     IInputThread, IInputThreadCallback::BnInputThreadCallback,
31     IInputThreadCallback::IInputThreadCallback,
32 };
33 use log::{debug, error};
34 use nix::{sys::time::TimeValLike, time::clock_gettime, time::ClockId};
35 use std::sync::{Arc, RwLock, RwLockWriteGuard};
36 
37 /// Interface to receive callback from Input filter thread
38 pub trait ThreadCallback {
39     /// Calls back after the requested timeout expires.
40     /// {@see InputFilterThread.request_timeout_at_time(...)}
41     ///
42     /// NOTE: In case of multiple requests, the timeout request which is earliest in time, will be
43     /// fulfilled and notified to all the listeners. It's up to the listeners to re-request another
44     /// timeout in the future.
notify_timeout_expired(&self, when_nanos: i64)45     fn notify_timeout_expired(&self, when_nanos: i64);
46     /// Unique name for the listener, which will be used to uniquely identify the listener.
name(&self) -> &str47     fn name(&self) -> &str;
48 }
49 
50 #[derive(Clone)]
51 pub struct InputFilterThread {
52     thread_creator: InputFilterThreadCreator,
53     thread_callback_handler: ThreadCallbackHandler,
54     inner: Arc<RwLock<InputFilterThreadInner>>,
55     looper: Arc<RwLock<Looper>>,
56 }
57 
58 struct InputFilterThreadInner {
59     next_timeout: i64,
60     is_finishing: bool,
61 }
62 
63 struct Looper {
64     cpp_thread: Option<Strong<dyn IInputThread>>,
65 }
66 
67 impl InputFilterThread {
68     /// Create a new InputFilterThread instance.
69     /// NOTE: This will create a new thread. Clone the existing instance to reuse the same thread.
new(thread_creator: InputFilterThreadCreator) -> InputFilterThread70     pub fn new(thread_creator: InputFilterThreadCreator) -> InputFilterThread {
71         Self {
72             thread_creator,
73             thread_callback_handler: ThreadCallbackHandler::new(),
74             inner: Arc::new(RwLock::new(InputFilterThreadInner {
75                 next_timeout: i64::MAX,
76                 is_finishing: false,
77             })),
78             looper: Arc::new(RwLock::new(Looper { cpp_thread: None })),
79         }
80     }
81 
82     /// Listener requesting a timeout in future will receive a callback at or before the requested
83     /// time on the input filter thread.
84     /// {@see ThreadCallback.notify_timeout_expired(...)}
request_timeout_at_time(&self, when_nanos: i64)85     pub fn request_timeout_at_time(&self, when_nanos: i64) {
86         let mut need_wake = false;
87         {
88             // acquire filter lock
89             let filter_thread = &mut self.filter_thread();
90             if when_nanos < filter_thread.next_timeout {
91                 filter_thread.next_timeout = when_nanos;
92                 need_wake = true;
93             }
94         } // release filter lock
95         if need_wake {
96             self.wake();
97         }
98     }
99 
100     /// Registers a callback listener.
101     ///
102     /// NOTE: If a listener with the same name already exists when registering using
103     /// {@see InputFilterThread.register_thread_callback(...)}, we will ignore the listener. You
104     /// must clear any previously registered listeners using
105     /// {@see InputFilterThread.unregister_thread_callback(...) before registering the new listener.
106     ///
107     /// NOTE: Also, registering a callback will start the looper if not already started.
register_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>)108     pub fn register_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>) {
109         self.thread_callback_handler.register_thread_callback(callback);
110         self.start();
111     }
112 
113     /// Unregisters a callback listener.
114     ///
115     /// NOTE: Unregistering a callback will stop the looper if not other callback registered.
unregister_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>)116     pub fn unregister_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>) {
117         self.thread_callback_handler.unregister_thread_callback(callback);
118         // Stop the thread if no registered callbacks exist. We will recreate the thread when new
119         // callbacks are registered.
120         let has_callbacks = self.thread_callback_handler.has_callbacks();
121         if !has_callbacks {
122             self.stop();
123         }
124     }
125 
start(&self)126     fn start(&self) {
127         debug!("InputFilterThread: start thread");
128         {
129             // acquire looper lock
130             let looper = &mut self.looper();
131             if looper.cpp_thread.is_none() {
132                 looper.cpp_thread = Some(self.thread_creator.create(
133                     &BnInputThreadCallback::new_binder(self.clone(), BinderFeatures::default()),
134                 ));
135             }
136         } // release looper lock
137         self.set_finishing(false);
138     }
139 
stop(&self)140     fn stop(&self) {
141         debug!("InputFilterThread: stop thread");
142         self.set_finishing(true);
143         self.wake();
144         {
145             // acquire looper lock
146             let looper = &mut self.looper();
147             if let Some(cpp_thread) = &looper.cpp_thread {
148                 let _ = cpp_thread.finish();
149             }
150             // Clear all references
151             looper.cpp_thread = None;
152         } // release looper lock
153     }
154 
set_finishing(&self, is_finishing: bool)155     fn set_finishing(&self, is_finishing: bool) {
156         let filter_thread = &mut self.filter_thread();
157         filter_thread.is_finishing = is_finishing;
158     }
159 
loop_once(&self, now: i64)160     fn loop_once(&self, now: i64) {
161         let mut wake_up_time = i64::MAX;
162         let mut timeout_expired = false;
163         {
164             // acquire thread lock
165             let filter_thread = &mut self.filter_thread();
166             if filter_thread.is_finishing {
167                 // Thread is finishing so don't block processing on it and let it loop.
168                 return;
169             }
170             if filter_thread.next_timeout != i64::MAX {
171                 if filter_thread.next_timeout <= now {
172                     timeout_expired = true;
173                     filter_thread.next_timeout = i64::MAX;
174                 } else {
175                     wake_up_time = filter_thread.next_timeout;
176                 }
177             }
178         } // release thread lock
179         if timeout_expired {
180             self.thread_callback_handler.notify_timeout_expired(now);
181         }
182         self.sleep_until(wake_up_time);
183     }
184 
filter_thread(&self) -> RwLockWriteGuard<'_, InputFilterThreadInner>185     fn filter_thread(&self) -> RwLockWriteGuard<'_, InputFilterThreadInner> {
186         self.inner.write().unwrap()
187     }
188 
sleep_until(&self, when_nanos: i64)189     fn sleep_until(&self, when_nanos: i64) {
190         let looper = self.looper.read().unwrap();
191         if let Some(cpp_thread) = &looper.cpp_thread {
192             let _ = cpp_thread.sleepUntil(when_nanos);
193         }
194     }
195 
wake(&self)196     fn wake(&self) {
197         let looper = self.looper.read().unwrap();
198         if let Some(cpp_thread) = &looper.cpp_thread {
199             let _ = cpp_thread.wake();
200         }
201     }
202 
looper(&self) -> RwLockWriteGuard<'_, Looper>203     fn looper(&self) -> RwLockWriteGuard<'_, Looper> {
204         self.looper.write().unwrap()
205     }
206 }
207 
208 impl Interface for InputFilterThread {}
209 
210 impl IInputThreadCallback for InputFilterThread {
loopOnce(&self) -> binder::Result<()>211     fn loopOnce(&self) -> binder::Result<()> {
212         self.loop_once(clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds());
213         Result::Ok(())
214     }
215 }
216 
217 #[derive(Default, Clone)]
218 struct ThreadCallbackHandler(Arc<RwLock<ThreadCallbackHandlerInner>>);
219 
220 #[derive(Default)]
221 struct ThreadCallbackHandlerInner {
222     callbacks: Vec<Box<dyn ThreadCallback + Send + Sync>>,
223 }
224 
225 impl ThreadCallbackHandler {
new() -> Self226     fn new() -> Self {
227         Default::default()
228     }
229 
has_callbacks(&self) -> bool230     fn has_callbacks(&self) -> bool {
231         !&self.0.read().unwrap().callbacks.is_empty()
232     }
233 
register_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>)234     fn register_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>) {
235         let callbacks = &mut self.0.write().unwrap().callbacks;
236         if callbacks.iter().any(|x| x.name() == callback.name()) {
237             error!(
238                 "InputFilterThread: register_thread_callback, callback {:?} already exists!",
239                 callback.name()
240             );
241             return;
242         }
243         debug!(
244             "InputFilterThread: register_thread_callback, callback {:?} added!",
245             callback.name()
246         );
247         callbacks.push(callback);
248     }
249 
unregister_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>)250     fn unregister_thread_callback(&self, callback: Box<dyn ThreadCallback + Send + Sync>) {
251         let callbacks = &mut self.0.write().unwrap().callbacks;
252         if let Some(index) = callbacks.iter().position(|x| x.name() == callback.name()) {
253             callbacks.remove(index);
254             debug!(
255                 "InputFilterThread: unregister_thread_callback, callback {:?} removed!",
256                 callback.name()
257             );
258             return;
259         }
260         error!(
261             "InputFilterThread: unregister_thread_callback, callback {:?} doesn't exist",
262             callback.name()
263         );
264     }
265 
notify_timeout_expired(&self, when_nanos: i64)266     fn notify_timeout_expired(&self, when_nanos: i64) {
267         let callbacks = &self.0.read().unwrap().callbacks;
268         for callback in callbacks.iter() {
269             callback.notify_timeout_expired(when_nanos);
270         }
271     }
272 }
273 
274 #[cfg(test)]
275 mod tests {
276     use crate::input_filter::{test_callbacks::TestCallbacks, InputFilterThreadCreator};
277     use crate::input_filter_thread::{test_thread_callback::TestThreadCallback, InputFilterThread};
278     use binder::Strong;
279     use nix::{sys::time::TimeValLike, time::clock_gettime, time::ClockId};
280     use std::sync::{Arc, RwLock};
281     use std::time::Duration;
282 
283     #[test]
test_register_callback_creates_cpp_thread()284     fn test_register_callback_creates_cpp_thread() {
285         let test_callbacks = TestCallbacks::new();
286         let test_thread = get_thread(test_callbacks.clone());
287         let test_thread_callback = TestThreadCallback::new();
288         test_thread.register_thread_callback(Box::new(test_thread_callback));
289         assert!(test_callbacks.is_thread_running());
290     }
291 
292     #[test]
test_unregister_callback_finishes_cpp_thread()293     fn test_unregister_callback_finishes_cpp_thread() {
294         let test_callbacks = TestCallbacks::new();
295         let test_thread = get_thread(test_callbacks.clone());
296         let test_thread_callback = TestThreadCallback::new();
297         test_thread.register_thread_callback(Box::new(test_thread_callback.clone()));
298         test_thread.unregister_thread_callback(Box::new(test_thread_callback));
299         assert!(!test_callbacks.is_thread_running());
300     }
301 
302     #[test]
test_notify_timeout_called_after_timeout_expired()303     fn test_notify_timeout_called_after_timeout_expired() {
304         let test_callbacks = TestCallbacks::new();
305         let test_thread = get_thread(test_callbacks.clone());
306         let test_thread_callback = TestThreadCallback::new();
307         test_thread.register_thread_callback(Box::new(test_thread_callback.clone()));
308 
309         let now = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_milliseconds();
310         test_thread.request_timeout_at_time((now + 10) * 1000000);
311 
312         std::thread::sleep(Duration::from_millis(100));
313         assert!(test_thread_callback.is_notify_timeout_called());
314     }
315 
316     #[test]
test_notify_timeout_not_called_before_timeout_expired()317     fn test_notify_timeout_not_called_before_timeout_expired() {
318         let test_callbacks = TestCallbacks::new();
319         let test_thread = get_thread(test_callbacks.clone());
320         let test_thread_callback = TestThreadCallback::new();
321         test_thread.register_thread_callback(Box::new(test_thread_callback.clone()));
322 
323         let now = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_milliseconds();
324         test_thread.request_timeout_at_time((now + 100) * 1000000);
325 
326         std::thread::sleep(Duration::from_millis(10));
327         assert!(!test_thread_callback.is_notify_timeout_called());
328     }
329 
get_thread(callbacks: TestCallbacks) -> InputFilterThread330     fn get_thread(callbacks: TestCallbacks) -> InputFilterThread {
331         InputFilterThread::new(InputFilterThreadCreator::new(Arc::new(RwLock::new(Strong::new(
332             Box::new(callbacks),
333         )))))
334     }
335 }
336 
337 #[cfg(test)]
338 pub mod test_thread_callback {
339     use crate::input_filter_thread::ThreadCallback;
340     use std::sync::{Arc, RwLock, RwLockWriteGuard};
341 
342     #[derive(Default)]
343     struct TestThreadCallbackInner {
344         is_notify_timeout_called: bool,
345     }
346 
347     #[derive(Default, Clone)]
348     pub struct TestThreadCallback(Arc<RwLock<TestThreadCallbackInner>>);
349 
350     impl TestThreadCallback {
new() -> Self351         pub fn new() -> Self {
352             Default::default()
353         }
354 
inner(&self) -> RwLockWriteGuard<'_, TestThreadCallbackInner>355         fn inner(&self) -> RwLockWriteGuard<'_, TestThreadCallbackInner> {
356             self.0.write().unwrap()
357         }
358 
is_notify_timeout_called(&self) -> bool359         pub fn is_notify_timeout_called(&self) -> bool {
360             self.0.read().unwrap().is_notify_timeout_called
361         }
362     }
363 
364     impl ThreadCallback for TestThreadCallback {
notify_timeout_expired(&self, _when_nanos: i64)365         fn notify_timeout_expired(&self, _when_nanos: i64) {
366             self.inner().is_notify_timeout_called = true;
367         }
name(&self) -> &str368         fn name(&self) -> &str {
369             "TestThreadCallback"
370         }
371     }
372 }
373