1 use std::collections::HashMap;
2 use std::sync::{Arc, Mutex};
3 use tokio::sync::oneshot;
4 use tokio::time::Duration;
5 
6 /// Helper for managing an async topshim function. It takes care of calling the function, preparing
7 /// the channel, waiting for the callback, and returning it in a Result.
8 ///
9 /// `R` is the type of the return.
10 pub(crate) struct AsyncHelper<R> {
11     // Name of the method that this struct helps. Useful for logging.
12     method_name: String,
13 
14     // Keeps track of call_id. Increment each time and wrap to 0 when u32 max is reached.
15     last_call_id: u32,
16 
17     // Keep pending calls' ids and senders.
18     senders: Arc<Mutex<HashMap<u32, oneshot::Sender<R>>>>,
19 }
20 
21 pub(crate) type CallbackSender<R> = Arc<Mutex<Box<(dyn Fn(u32, R) + Send)>>>;
22 
23 impl<R: 'static + Send> AsyncHelper<R> {
new(method_name: &str) -> Self24     pub(crate) fn new(method_name: &str) -> Self {
25         Self {
26             method_name: String::from(method_name),
27             last_call_id: 0,
28             senders: Arc::new(Mutex::new(HashMap::new())),
29         }
30     }
31 
32     /// Calls a topshim method that expects the async return to be delivered via a callback.
call_method<F>(&mut self, f: F, timeout_ms: Option<u64>) -> Result<R, ()> where F: Fn(u32),33     pub(crate) async fn call_method<F>(&mut self, f: F, timeout_ms: Option<u64>) -> Result<R, ()>
34     where
35         F: Fn(u32),
36     {
37         // Create a oneshot channel to be used by the callback to notify us that the return is
38         // available.
39         let (tx, rx) = oneshot::channel();
40 
41         // Use a unique method call ID so that we know which callback is corresponding to which
42         // method call. The actual value of the ID does not matter as long as it's always unique,
43         // so a simple increment (and wraps back to 0) is good enough.
44         self.last_call_id = self.last_call_id.wrapping_add(1);
45 
46         // Keep track of the sender belonging to this call id.
47         self.senders.lock().unwrap().insert(self.last_call_id, tx);
48 
49         // Call the method. `f` is freely defined by the user of this utility. This must be an
50         // operation that expects a callback that will trigger sending of the return via the
51         // oneshot channel.
52         f(self.last_call_id);
53 
54         if let Some(timeout_ms) = timeout_ms {
55             let senders = self.senders.clone();
56             let call_id = self.last_call_id;
57             tokio::spawn(async move {
58                 tokio::time::sleep(Duration::from_millis(timeout_ms)).await;
59 
60                 // If the timer expires first before a callback is triggered, we remove the sender
61                 // which will invalidate the channel which in turn will notify the receiver of
62                 // an error.
63                 // If the callback gets triggered first, this does nothing since the entry has been
64                 // removed when sending the response.
65                 senders.lock().unwrap().remove(&call_id);
66             });
67         }
68 
69         // Wait for the callback and return when available.
70         rx.await.map_err(|_| ())
71     }
72 
73     /// Returns a function to be invoked when callback is triggered.
get_callback_sender(&self) -> CallbackSender<R>74     pub(crate) fn get_callback_sender(&self) -> CallbackSender<R> {
75         let senders = self.senders.clone();
76         let method_name = self.method_name.clone();
77         return Arc::new(Mutex::new(Box::new(move |call_id, ret| {
78             if let Some(sender) = senders.lock().unwrap().remove(&call_id) {
79                 sender.send(ret).ok();
80             } else {
81                 log::warn!("AsyncHelper {}: Sender no longer exists.", method_name);
82             }
83         })));
84     }
85 }
86