1 use std::collections::HashMap; 2 use std::sync::{Arc, Mutex}; 3 use tokio::sync::oneshot; 4 use tokio::time::Duration; 5 6 /// Helper for managing an async topshim function. It takes care of calling the function, preparing 7 /// the channel, waiting for the callback, and returning it in a Result. 8 /// 9 /// `R` is the type of the return. 10 pub(crate) struct AsyncHelper<R> { 11 // Name of the method that this struct helps. Useful for logging. 12 method_name: String, 13 14 // Keeps track of call_id. Increment each time and wrap to 0 when u32 max is reached. 15 last_call_id: u32, 16 17 // Keep pending calls' ids and senders. 18 senders: Arc<Mutex<HashMap<u32, oneshot::Sender<R>>>>, 19 } 20 21 pub(crate) type CallbackSender<R> = Arc<Mutex<Box<(dyn Fn(u32, R) + Send)>>>; 22 23 impl<R: 'static + Send> AsyncHelper<R> { new(method_name: &str) -> Self24 pub(crate) fn new(method_name: &str) -> Self { 25 Self { 26 method_name: String::from(method_name), 27 last_call_id: 0, 28 senders: Arc::new(Mutex::new(HashMap::new())), 29 } 30 } 31 32 /// Calls a topshim method that expects the async return to be delivered via a callback. call_method<F>(&mut self, f: F, timeout_ms: Option<u64>) -> Result<R, ()> where F: Fn(u32),33 pub(crate) async fn call_method<F>(&mut self, f: F, timeout_ms: Option<u64>) -> Result<R, ()> 34 where 35 F: Fn(u32), 36 { 37 // Create a oneshot channel to be used by the callback to notify us that the return is 38 // available. 39 let (tx, rx) = oneshot::channel(); 40 41 // Use a unique method call ID so that we know which callback is corresponding to which 42 // method call. The actual value of the ID does not matter as long as it's always unique, 43 // so a simple increment (and wraps back to 0) is good enough. 44 self.last_call_id = self.last_call_id.wrapping_add(1); 45 46 // Keep track of the sender belonging to this call id. 47 self.senders.lock().unwrap().insert(self.last_call_id, tx); 48 49 // Call the method. `f` is freely defined by the user of this utility. This must be an 50 // operation that expects a callback that will trigger sending of the return via the 51 // oneshot channel. 52 f(self.last_call_id); 53 54 if let Some(timeout_ms) = timeout_ms { 55 let senders = self.senders.clone(); 56 let call_id = self.last_call_id; 57 tokio::spawn(async move { 58 tokio::time::sleep(Duration::from_millis(timeout_ms)).await; 59 60 // If the timer expires first before a callback is triggered, we remove the sender 61 // which will invalidate the channel which in turn will notify the receiver of 62 // an error. 63 // If the callback gets triggered first, this does nothing since the entry has been 64 // removed when sending the response. 65 senders.lock().unwrap().remove(&call_id); 66 }); 67 } 68 69 // Wait for the callback and return when available. 70 rx.await.map_err(|_| ()) 71 } 72 73 /// Returns a function to be invoked when callback is triggered. get_callback_sender(&self) -> CallbackSender<R>74 pub(crate) fn get_callback_sender(&self) -> CallbackSender<R> { 75 let senders = self.senders.clone(); 76 let method_name = self.method_name.clone(); 77 return Arc::new(Mutex::new(Box::new(move |call_id, ret| { 78 if let Some(sender) = senders.lock().unwrap().remove(&call_id) { 79 sender.send(ret).ok(); 80 } else { 81 log::warn!("AsyncHelper {}: Sender no longer exists.", method_name); 82 } 83 }))); 84 } 85 } 86