1 use std::{cell::RefCell, collections::HashMap, rc::Rc, time::Duration};
2 
3 use async_trait::async_trait;
4 use log::{trace, warn};
5 use tokio::{sync::oneshot, time::timeout};
6 
7 use crate::{
8     gatt::{
9         ids::{AttHandle, ConnectionId, ServerId, TransactionId, TransportIndex},
10         GattCallbacks,
11     },
12     packets::AttErrorCode,
13 };
14 
15 use super::{
16     AttributeBackingType, GattWriteRequestType, GattWriteType, RawGattDatastore,
17     TransactionDecision,
18 };
19 
20 struct PendingTransaction {
21     response: oneshot::Sender<Result<Vec<u8>, AttErrorCode>>,
22 }
23 
24 #[derive(Debug)]
25 struct PendingTransactionWatcher {
26     conn_id: ConnectionId,
27     trans_id: TransactionId,
28     rx: oneshot::Receiver<Result<Vec<u8>, AttErrorCode>>,
29 }
30 
31 /// This struct converts the asynchronus read/write operations of GattDatastore
32 /// into the callback-based interface expected by JNI
33 pub struct CallbackTransactionManager {
34     callbacks: Rc<dyn GattCallbacks>,
35     pending_transactions: RefCell<PendingTransactionsState>,
36 }
37 
38 struct PendingTransactionsState {
39     pending_transactions: HashMap<(ConnectionId, TransactionId), PendingTransaction>,
40     next_transaction_id: u32,
41 }
42 
43 /// We expect all responses to be provided within this timeout
44 /// It should be less than 30s, as that is the ATT timeout that causes
45 /// the client to disconnect.
46 const TIMEOUT: Duration = Duration::from_secs(15);
47 
48 /// The cause of a failure to dispatch a call to send_response()
49 #[derive(Debug, PartialEq, Eq)]
50 pub enum CallbackResponseError {
51     /// The TransactionId supplied was invalid for the specified connection
52     NonExistentTransaction(TransactionId),
53     /// The TransactionId was valid but has since terminated
54     ListenerHungUp(TransactionId),
55 }
56 
57 impl CallbackTransactionManager {
58     /// Constructor, wrapping a GattCallbacks instance with the GattDatastore
59     /// interface
new(callbacks: Rc<dyn GattCallbacks>) -> Self60     pub fn new(callbacks: Rc<dyn GattCallbacks>) -> Self {
61         Self {
62             callbacks,
63             pending_transactions: RefCell::new(PendingTransactionsState {
64                 pending_transactions: HashMap::new(),
65                 next_transaction_id: 1,
66             }),
67         }
68     }
69 
70     /// Invoked from server implementations in response to read/write requests
send_response( &self, conn_id: ConnectionId, trans_id: TransactionId, value: Result<Vec<u8>, AttErrorCode>, ) -> Result<(), CallbackResponseError>71     pub fn send_response(
72         &self,
73         conn_id: ConnectionId,
74         trans_id: TransactionId,
75         value: Result<Vec<u8>, AttErrorCode>,
76     ) -> Result<(), CallbackResponseError> {
77         let mut pending = self.pending_transactions.borrow_mut();
78         if let Some(transaction) = pending.pending_transactions.remove(&(conn_id, trans_id)) {
79             if transaction.response.send(value).is_err() {
80                 Err(CallbackResponseError::ListenerHungUp(trans_id))
81             } else {
82                 trace!("got expected response for transaction {trans_id:?}");
83                 Ok(())
84             }
85         } else {
86             Err(CallbackResponseError::NonExistentTransaction(trans_id))
87         }
88     }
89 
90     /// Get an impl GattDatastore tied to a particular server
get_datastore(self: &Rc<Self>, server_id: ServerId) -> impl RawGattDatastore91     pub fn get_datastore(self: &Rc<Self>, server_id: ServerId) -> impl RawGattDatastore {
92         GattDatastoreImpl { callback_transaction_manager: self.clone(), server_id }
93     }
94 }
95 
96 impl PendingTransactionsState {
alloc_transaction_id(&mut self) -> TransactionId97     fn alloc_transaction_id(&mut self) -> TransactionId {
98         let trans_id = TransactionId(self.next_transaction_id);
99         self.next_transaction_id = self.next_transaction_id.wrapping_add(1);
100         trans_id
101     }
102 
start_new_transaction(&mut self, conn_id: ConnectionId) -> PendingTransactionWatcher103     fn start_new_transaction(&mut self, conn_id: ConnectionId) -> PendingTransactionWatcher {
104         let trans_id = self.alloc_transaction_id();
105         let (tx, rx) = oneshot::channel();
106         self.pending_transactions.insert((conn_id, trans_id), PendingTransaction { response: tx });
107         PendingTransactionWatcher { conn_id, trans_id, rx }
108     }
109 }
110 
111 impl PendingTransactionWatcher {
112     /// Wait for the transaction to resolve, or to hit the timeout. If the
113     /// timeout is reached, clean up state related to transaction watching.
wait(self, manager: &CallbackTransactionManager) -> Result<Vec<u8>, AttErrorCode>114     async fn wait(self, manager: &CallbackTransactionManager) -> Result<Vec<u8>, AttErrorCode> {
115         if let Ok(Ok(result)) = timeout(TIMEOUT, self.rx).await {
116             result
117         } else {
118             manager
119                 .pending_transactions
120                 .borrow_mut()
121                 .pending_transactions
122                 .remove(&(self.conn_id, self.trans_id));
123             warn!("no response received from Java after timeout - returning UNLIKELY_ERROR");
124             Err(AttErrorCode::UNLIKELY_ERROR)
125         }
126     }
127 }
128 
129 struct GattDatastoreImpl {
130     callback_transaction_manager: Rc<CallbackTransactionManager>,
131     server_id: ServerId,
132 }
133 
134 #[async_trait(?Send)]
135 impl RawGattDatastore for GattDatastoreImpl {
read( &self, tcb_idx: TransportIndex, handle: AttHandle, offset: u32, attr_type: AttributeBackingType, ) -> Result<Vec<u8>, AttErrorCode>136     async fn read(
137         &self,
138         tcb_idx: TransportIndex,
139         handle: AttHandle,
140         offset: u32,
141         attr_type: AttributeBackingType,
142     ) -> Result<Vec<u8>, AttErrorCode> {
143         let conn_id = ConnectionId::new(tcb_idx, self.server_id);
144 
145         let pending_transaction = self
146             .callback_transaction_manager
147             .pending_transactions
148             .borrow_mut()
149             .start_new_transaction(conn_id);
150         let trans_id = pending_transaction.trans_id;
151 
152         self.callback_transaction_manager.callbacks.on_server_read(
153             ConnectionId::new(tcb_idx, self.server_id),
154             trans_id,
155             handle,
156             attr_type,
157             offset,
158         );
159 
160         pending_transaction.wait(&self.callback_transaction_manager).await
161     }
162 
write( &self, tcb_idx: TransportIndex, handle: AttHandle, attr_type: AttributeBackingType, write_type: GattWriteRequestType, data: &[u8], ) -> Result<(), AttErrorCode>163     async fn write(
164         &self,
165         tcb_idx: TransportIndex,
166         handle: AttHandle,
167         attr_type: AttributeBackingType,
168         write_type: GattWriteRequestType,
169         data: &[u8],
170     ) -> Result<(), AttErrorCode> {
171         let conn_id = ConnectionId::new(tcb_idx, self.server_id);
172 
173         let pending_transaction = self
174             .callback_transaction_manager
175             .pending_transactions
176             .borrow_mut()
177             .start_new_transaction(conn_id);
178         let trans_id = pending_transaction.trans_id;
179 
180         self.callback_transaction_manager.callbacks.on_server_write(
181             conn_id,
182             trans_id,
183             handle,
184             attr_type,
185             GattWriteType::Request(write_type),
186             data,
187         );
188 
189         // the data passed back is irrelevant for write requests
190         pending_transaction.wait(&self.callback_transaction_manager).await.map(|_| ())
191     }
192 
write_no_response( &self, tcb_idx: TransportIndex, handle: AttHandle, attr_type: AttributeBackingType, data: &[u8], )193     fn write_no_response(
194         &self,
195         tcb_idx: TransportIndex,
196         handle: AttHandle,
197         attr_type: AttributeBackingType,
198         data: &[u8],
199     ) {
200         let conn_id = ConnectionId::new(tcb_idx, self.server_id);
201 
202         let trans_id = self
203             .callback_transaction_manager
204             .pending_transactions
205             .borrow_mut()
206             .alloc_transaction_id();
207         self.callback_transaction_manager.callbacks.on_server_write(
208             conn_id,
209             trans_id,
210             handle,
211             attr_type,
212             GattWriteType::Command,
213             data,
214         );
215     }
216 
execute( &self, tcb_idx: TransportIndex, decision: TransactionDecision, ) -> Result<(), AttErrorCode>217     async fn execute(
218         &self,
219         tcb_idx: TransportIndex,
220         decision: TransactionDecision,
221     ) -> Result<(), AttErrorCode> {
222         let conn_id = ConnectionId::new(tcb_idx, self.server_id);
223 
224         let pending_transaction = self
225             .callback_transaction_manager
226             .pending_transactions
227             .borrow_mut()
228             .start_new_transaction(conn_id);
229         let trans_id = pending_transaction.trans_id;
230 
231         self.callback_transaction_manager.callbacks.on_execute(conn_id, trans_id, decision);
232 
233         // the data passed back is irrelevant for execute requests
234         pending_transaction.wait(&self.callback_transaction_manager).await.map(|_| ())
235     }
236 }
237