1 use std::{cell::RefCell, collections::HashMap, rc::Rc, time::Duration}; 2 3 use async_trait::async_trait; 4 use log::{trace, warn}; 5 use tokio::{sync::oneshot, time::timeout}; 6 7 use crate::{ 8 gatt::{ 9 ids::{AttHandle, ConnectionId, ServerId, TransactionId, TransportIndex}, 10 GattCallbacks, 11 }, 12 packets::AttErrorCode, 13 }; 14 15 use super::{ 16 AttributeBackingType, GattWriteRequestType, GattWriteType, RawGattDatastore, 17 TransactionDecision, 18 }; 19 20 struct PendingTransaction { 21 response: oneshot::Sender<Result<Vec<u8>, AttErrorCode>>, 22 } 23 24 #[derive(Debug)] 25 struct PendingTransactionWatcher { 26 conn_id: ConnectionId, 27 trans_id: TransactionId, 28 rx: oneshot::Receiver<Result<Vec<u8>, AttErrorCode>>, 29 } 30 31 /// This struct converts the asynchronus read/write operations of GattDatastore 32 /// into the callback-based interface expected by JNI 33 pub struct CallbackTransactionManager { 34 callbacks: Rc<dyn GattCallbacks>, 35 pending_transactions: RefCell<PendingTransactionsState>, 36 } 37 38 struct PendingTransactionsState { 39 pending_transactions: HashMap<(ConnectionId, TransactionId), PendingTransaction>, 40 next_transaction_id: u32, 41 } 42 43 /// We expect all responses to be provided within this timeout 44 /// It should be less than 30s, as that is the ATT timeout that causes 45 /// the client to disconnect. 46 const TIMEOUT: Duration = Duration::from_secs(15); 47 48 /// The cause of a failure to dispatch a call to send_response() 49 #[derive(Debug, PartialEq, Eq)] 50 pub enum CallbackResponseError { 51 /// The TransactionId supplied was invalid for the specified connection 52 NonExistentTransaction(TransactionId), 53 /// The TransactionId was valid but has since terminated 54 ListenerHungUp(TransactionId), 55 } 56 57 impl CallbackTransactionManager { 58 /// Constructor, wrapping a GattCallbacks instance with the GattDatastore 59 /// interface new(callbacks: Rc<dyn GattCallbacks>) -> Self60 pub fn new(callbacks: Rc<dyn GattCallbacks>) -> Self { 61 Self { 62 callbacks, 63 pending_transactions: RefCell::new(PendingTransactionsState { 64 pending_transactions: HashMap::new(), 65 next_transaction_id: 1, 66 }), 67 } 68 } 69 70 /// Invoked from server implementations in response to read/write requests send_response( &self, conn_id: ConnectionId, trans_id: TransactionId, value: Result<Vec<u8>, AttErrorCode>, ) -> Result<(), CallbackResponseError>71 pub fn send_response( 72 &self, 73 conn_id: ConnectionId, 74 trans_id: TransactionId, 75 value: Result<Vec<u8>, AttErrorCode>, 76 ) -> Result<(), CallbackResponseError> { 77 let mut pending = self.pending_transactions.borrow_mut(); 78 if let Some(transaction) = pending.pending_transactions.remove(&(conn_id, trans_id)) { 79 if transaction.response.send(value).is_err() { 80 Err(CallbackResponseError::ListenerHungUp(trans_id)) 81 } else { 82 trace!("got expected response for transaction {trans_id:?}"); 83 Ok(()) 84 } 85 } else { 86 Err(CallbackResponseError::NonExistentTransaction(trans_id)) 87 } 88 } 89 90 /// Get an impl GattDatastore tied to a particular server get_datastore(self: &Rc<Self>, server_id: ServerId) -> impl RawGattDatastore91 pub fn get_datastore(self: &Rc<Self>, server_id: ServerId) -> impl RawGattDatastore { 92 GattDatastoreImpl { callback_transaction_manager: self.clone(), server_id } 93 } 94 } 95 96 impl PendingTransactionsState { alloc_transaction_id(&mut self) -> TransactionId97 fn alloc_transaction_id(&mut self) -> TransactionId { 98 let trans_id = TransactionId(self.next_transaction_id); 99 self.next_transaction_id = self.next_transaction_id.wrapping_add(1); 100 trans_id 101 } 102 start_new_transaction(&mut self, conn_id: ConnectionId) -> PendingTransactionWatcher103 fn start_new_transaction(&mut self, conn_id: ConnectionId) -> PendingTransactionWatcher { 104 let trans_id = self.alloc_transaction_id(); 105 let (tx, rx) = oneshot::channel(); 106 self.pending_transactions.insert((conn_id, trans_id), PendingTransaction { response: tx }); 107 PendingTransactionWatcher { conn_id, trans_id, rx } 108 } 109 } 110 111 impl PendingTransactionWatcher { 112 /// Wait for the transaction to resolve, or to hit the timeout. If the 113 /// timeout is reached, clean up state related to transaction watching. wait(self, manager: &CallbackTransactionManager) -> Result<Vec<u8>, AttErrorCode>114 async fn wait(self, manager: &CallbackTransactionManager) -> Result<Vec<u8>, AttErrorCode> { 115 if let Ok(Ok(result)) = timeout(TIMEOUT, self.rx).await { 116 result 117 } else { 118 manager 119 .pending_transactions 120 .borrow_mut() 121 .pending_transactions 122 .remove(&(self.conn_id, self.trans_id)); 123 warn!("no response received from Java after timeout - returning UNLIKELY_ERROR"); 124 Err(AttErrorCode::UNLIKELY_ERROR) 125 } 126 } 127 } 128 129 struct GattDatastoreImpl { 130 callback_transaction_manager: Rc<CallbackTransactionManager>, 131 server_id: ServerId, 132 } 133 134 #[async_trait(?Send)] 135 impl RawGattDatastore for GattDatastoreImpl { read( &self, tcb_idx: TransportIndex, handle: AttHandle, offset: u32, attr_type: AttributeBackingType, ) -> Result<Vec<u8>, AttErrorCode>136 async fn read( 137 &self, 138 tcb_idx: TransportIndex, 139 handle: AttHandle, 140 offset: u32, 141 attr_type: AttributeBackingType, 142 ) -> Result<Vec<u8>, AttErrorCode> { 143 let conn_id = ConnectionId::new(tcb_idx, self.server_id); 144 145 let pending_transaction = self 146 .callback_transaction_manager 147 .pending_transactions 148 .borrow_mut() 149 .start_new_transaction(conn_id); 150 let trans_id = pending_transaction.trans_id; 151 152 self.callback_transaction_manager.callbacks.on_server_read( 153 ConnectionId::new(tcb_idx, self.server_id), 154 trans_id, 155 handle, 156 attr_type, 157 offset, 158 ); 159 160 pending_transaction.wait(&self.callback_transaction_manager).await 161 } 162 write( &self, tcb_idx: TransportIndex, handle: AttHandle, attr_type: AttributeBackingType, write_type: GattWriteRequestType, data: &[u8], ) -> Result<(), AttErrorCode>163 async fn write( 164 &self, 165 tcb_idx: TransportIndex, 166 handle: AttHandle, 167 attr_type: AttributeBackingType, 168 write_type: GattWriteRequestType, 169 data: &[u8], 170 ) -> Result<(), AttErrorCode> { 171 let conn_id = ConnectionId::new(tcb_idx, self.server_id); 172 173 let pending_transaction = self 174 .callback_transaction_manager 175 .pending_transactions 176 .borrow_mut() 177 .start_new_transaction(conn_id); 178 let trans_id = pending_transaction.trans_id; 179 180 self.callback_transaction_manager.callbacks.on_server_write( 181 conn_id, 182 trans_id, 183 handle, 184 attr_type, 185 GattWriteType::Request(write_type), 186 data, 187 ); 188 189 // the data passed back is irrelevant for write requests 190 pending_transaction.wait(&self.callback_transaction_manager).await.map(|_| ()) 191 } 192 write_no_response( &self, tcb_idx: TransportIndex, handle: AttHandle, attr_type: AttributeBackingType, data: &[u8], )193 fn write_no_response( 194 &self, 195 tcb_idx: TransportIndex, 196 handle: AttHandle, 197 attr_type: AttributeBackingType, 198 data: &[u8], 199 ) { 200 let conn_id = ConnectionId::new(tcb_idx, self.server_id); 201 202 let trans_id = self 203 .callback_transaction_manager 204 .pending_transactions 205 .borrow_mut() 206 .alloc_transaction_id(); 207 self.callback_transaction_manager.callbacks.on_server_write( 208 conn_id, 209 trans_id, 210 handle, 211 attr_type, 212 GattWriteType::Command, 213 data, 214 ); 215 } 216 execute( &self, tcb_idx: TransportIndex, decision: TransactionDecision, ) -> Result<(), AttErrorCode>217 async fn execute( 218 &self, 219 tcb_idx: TransportIndex, 220 decision: TransactionDecision, 221 ) -> Result<(), AttErrorCode> { 222 let conn_id = ConnectionId::new(tcb_idx, self.server_id); 223 224 let pending_transaction = self 225 .callback_transaction_manager 226 .pending_transactions 227 .borrow_mut() 228 .start_new_transaction(conn_id); 229 let trans_id = pending_transaction.trans_id; 230 231 self.callback_transaction_manager.callbacks.on_execute(conn_id, trans_id, decision); 232 233 // the data passed back is irrelevant for execute requests 234 pending_transaction.wait(&self.callback_transaction_manager).await.map(|_| ()) 235 } 236 } 237