1 //! libfmq Rust wrapper
2 
3 /*
4 * Copyright (C) 2024 The Android Open Source Project
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 *      http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18 
19 use fmq_bindgen::{
20     convertDesc, convertGrantor, descFlags, descGrantors, descHandleFDs, descHandleInts,
21     descHandleNumFDs, descHandleNumInts, descNumGrantors, descQuantum, freeDesc,
22     ndk_ScopedFileDescriptor, ErasedMessageQueue, ErasedMessageQueueDesc, GrantorDescriptor,
23     MQDescriptor, MemTransaction, NativeHandle, ParcelFileDescriptor, SynchronizedReadWrite,
24 };
25 
26 use std::ptr::addr_of_mut;
27 
28 use log::error;
29 
30 /// A trait indicating that a type is safe to pass through shared memory.
31 ///
32 /// # Safety
33 ///
34 /// This requires that the type must not contain any capabilities such as file
35 /// descriptors or heap allocations, and that it must be permitted to access
36 /// all bytes of its representation (so it must not contain any padding bytes).
37 ///
38 /// Because being stored in shared memory the allows the type to be accessed
39 /// from different processes, it may also be accessed from different threads in
40 /// the same process. As such, `Share` is a supertrait of `Sync`.
41 pub unsafe trait Share: Sync {}
42 
43 // SAFETY: All types implementing the `zerocopy::AsBytes` trait implement `Share`.
44 unsafe impl<T: zerocopy::AsBytes + zerocopy::FromBytes + Send + Sync> Share for T {}
45 
46 /// An IPC message queue for values of type T.
47 pub struct MessageQueue<T> {
48     inner: ErasedMessageQueue,
49     ty: core::marker::PhantomData<T>,
50 }
51 
52 /** A write completion from the MessageQueue::write() method.
53 
54 This completion mutably borrows the MessageQueue to prevent concurrent writes;
55 these must be forbidden because the underlying AidlMessageQueue only stores the
56 number of outstanding writes, not which have and have not completed, so they
57 must complete in order. */
58 #[must_use]
59 pub struct WriteCompletion<'a, T: Share> {
60     inner: MemTransaction,
61     queue: &'a mut MessageQueue<T>,
62     n_elems: usize,
63     n_written: usize,
64 }
65 
66 impl<'a, T: Share> WriteCompletion<'a, T> {
67     /// Obtain a pointer to the location at which the idx'th item should be
68     /// stored.
69     ///
70     /// The returned pointer is only valid while `self` has not been dropped and
71     /// is invalidated by any call to `self.write`. The pointer should be used
72     /// with `std::ptr::write` or a DMA API to initialize the underlying storage
73     /// before calling `assume_written` to indicate how many elements were
74     /// written.
75     ///
76     /// It is only permitted to access at most `contiguous_count(idx)` items
77     /// via offsets from the returned address.
78     ///
79     /// Calling this method with a greater `idx` may return a pointer to another
80     /// memory region of different size than the first.
ptr(&self, idx: usize) -> *mut T81     pub fn ptr(&self, idx: usize) -> *mut T {
82         if idx >= self.n_elems {
83             panic!(
84                 "indexing out of bound: ReadCompletion for {} elements but idx {} accessed",
85                 self.n_elems, idx
86             )
87         }
88         ptr(&self.inner, idx)
89     }
90 
91     /// Return the number of contiguous elements that may be stored starting at
92     /// the given index in the backing buffer corresponding to the given index.
93     ///
94     /// Intended for use with the `ptr` method.
95     ///
96     /// Returns 0 if `idx` is greater than or equal to the completion's element
97     /// count.
contiguous_count(&self, idx: usize) -> usize98     pub fn contiguous_count(&self, idx: usize) -> usize {
99         contiguous_count(&self.inner, idx, self.n_elems)
100     }
101 
102     /// Returns how many elements still must be written to this WriteCompletion
103     /// before dropping it.
required_elements(&self) -> usize104     pub fn required_elements(&self) -> usize {
105         assert!(self.n_written <= self.n_elems);
106         self.n_elems - self.n_written
107     }
108 
109     /// Write one item to `self`. Fails and returns the item if `self` is full.
write(&mut self, data: T) -> Result<(), T>110     pub fn write(&mut self, data: T) -> Result<(), T> {
111         if self.required_elements() > 0 {
112             // SAFETY: `self.ptr(self.n_written)` is known to be uninitialized.
113             // The dtor of data, if any, will not run because `data` is moved
114             // out of here.
115             unsafe { self.ptr(self.n_written).write(data) };
116             self.n_written += 1;
117             Ok(())
118         } else {
119             Err(data)
120         }
121     }
122 
123     /// Promise to the `WriteCompletion` that `n_newly_written` elements have
124     /// been written with unsafe code or DMA to the pointer returned by the
125     /// `ptr` method.
126     ///
127     /// Panics if `n_newly_written` exceeds the number of elements yet required.
128     ///
129     /// # Safety
130     /// It is UB to call this method except after calling the `ptr` method and
131     /// writing the specified number of values of type T to that location.
assume_written(&mut self, n_newly_written: usize)132     pub unsafe fn assume_written(&mut self, n_newly_written: usize) {
133         assert!(n_newly_written < self.required_elements());
134         self.n_written += n_newly_written;
135     }
136 }
137 
138 impl<'a, T: Share> Drop for WriteCompletion<'a, T> {
drop(&mut self)139     fn drop(&mut self) {
140         if self.n_written < self.n_elems {
141             error!(
142                 "WriteCompletion dropped without writing to all elements ({}/{} written)",
143                 self.n_written, self.n_elems
144             );
145         }
146         let txn = std::mem::take(&mut self.inner);
147         self.queue.commit_write(txn);
148     }
149 }
150 
151 impl<T: Share> MessageQueue<T> {
type_size() -> usize152     const fn type_size() -> usize {
153         std::mem::size_of::<T>()
154     }
155 
156     /// Create a new MessageQueue with capacity for `elems` elements.
new(elems: usize, event_word: bool) -> Self157     pub fn new(elems: usize, event_word: bool) -> Self {
158         Self {
159             // SAFETY: Calling bindgen'd constructor. The only argument that
160             // can't be validated by the implementation is the quantum, which
161             // must equal the element size.
162             inner: unsafe { ErasedMessageQueue::new1(elems, event_word, Self::type_size()) },
163             ty: core::marker::PhantomData,
164         }
165     }
166 
167     /// Create a MessageQueue connected to another existing instance from its
168     /// descriptor.
from_desc(desc: &MQDescriptor<T, SynchronizedReadWrite>, reset_pointers: bool) -> Self169     pub fn from_desc(desc: &MQDescriptor<T, SynchronizedReadWrite>, reset_pointers: bool) -> Self {
170         let mut grantors = desc
171             .grantors
172             .iter()
173             // SAFETY: this just forwards the integers to the GrantorDescriptor
174             // constructor; GrantorDescriptor is POD.
175             .map(|g| unsafe { convertGrantor(g.fdIndex, g.offset, g.extent) })
176             .collect::<Vec<_>>();
177 
178         // SAFETY: These pointer/length pairs come from Vecs that will outlive
179         // this function call, and the call itself copies all data it needs out
180         // of them.
181         let cpp_desc = unsafe {
182             convertDesc(
183                 grantors.as_mut_ptr(),
184                 grantors.len(),
185                 desc.handle.fds.as_ptr().cast(),
186                 desc.handle.fds.len(),
187                 desc.handle.ints.as_ptr(),
188                 desc.handle.ints.len(),
189                 desc.quantum,
190                 desc.flags,
191             )
192         };
193         // SAFETY: Calling bindgen'd constructor which does not store cpp_desc,
194         // but just passes it to the initializer of AidlMQDescriptorShim, which
195         // deep-copies it.
196         let inner = unsafe { ErasedMessageQueue::new(cpp_desc, reset_pointers) };
197         // SAFETY: we must free the desc returned by convertDesc; the pointer
198         // was just returned above so we know it is valid.
199         unsafe { freeDesc(cpp_desc) };
200         Self { inner, ty: core::marker::PhantomData }
201     }
202 
203     /// Obtain a copy of the MessageQueue's descriptor, which may be used to
204     /// access it remotely.
dupe_desc(&mut self) -> MQDescriptor<T, SynchronizedReadWrite>205     pub fn dupe_desc(&mut self) -> MQDescriptor<T, SynchronizedReadWrite> {
206         // SAFETY: dupeDesc may be called on any valid ErasedMessageQueue; it
207         // simply forwards to dupeDesc on the inner AidlMessageQueue and wraps
208         // in a heap allocation.
209         let erased_desc: *mut ErasedMessageQueueDesc = unsafe { self.inner.dupeDesc() };
210         let grantor_to_rust =
211             |g: &fmq_bindgen::aidl_android_hardware_common_fmq_GrantorDescriptor| {
212                 GrantorDescriptor { fdIndex: g.fdIndex, offset: g.offset, extent: g.extent }
213             };
214 
215         let scoped_to_parcel_fd = |fd: &ndk_ScopedFileDescriptor| {
216             use std::os::fd::{BorrowedFd, FromRawFd, OwnedFd};
217             // SAFETY: the fd is already open as an invariant of ndk::ScopedFileDescriptor, so
218             // it will not be -1, as required by BorrowedFd.
219             let borrowed = unsafe { BorrowedFd::borrow_raw(fd._base as i32) };
220             ParcelFileDescriptor::new(match borrowed.try_clone_to_owned() {
221                 Ok(fd) => fd,
222                 Err(e) => {
223                     error!("could not dup NativeHandle fd {}: {}", fd._base, e);
224                     // SAFETY: OwnedFd requires the fd is not -1. If we failed to dup the fd,
225                     // other code downstream will fail, but we can do no better than pass it on.
226                     unsafe { OwnedFd::from_raw_fd(fd._base as i32) }
227                 }
228             })
229         };
230 
231         // First, we create a desc with the wrong type, because we cannot create one whole cloth of
232         // our desired return type unless T implements Default. This Default requirement is
233         // superfluous (T::default() is never called), so we then transmute to our desired type.
234         let desc = MQDescriptor::<(), SynchronizedReadWrite>::default();
235         // SAFETY: This transmute changes only the element type parameter of the MQDescriptor. The
236         // layout of an MQDescriptor does not depend on T as T appears in it only in PhantomData.
237         let mut desc: MQDescriptor<T, SynchronizedReadWrite> = unsafe { std::mem::transmute(desc) };
238         // SAFETY: These slices are created out of the pointer and length pairs exposed by the
239         // individual descFoo accessors, so we know they are valid pointer/lengths and point to
240         // data that will continue to exist as long as the desc does.
241         //
242         // Calls to the descFoo accessors on erased_desc are sound because we know inner.dupeDesc
243         // returns a valid pointer to a new heap-allocated ErasedMessageQueueDesc.
244         let (grantors, fds, ints, quantum, flags) = unsafe {
245             use std::slice::from_raw_parts;
246             let grantors = from_raw_parts(descGrantors(erased_desc), descNumGrantors(erased_desc));
247             let fds = from_raw_parts(descHandleFDs(erased_desc), descHandleNumFDs(erased_desc));
248             let ints = from_raw_parts(descHandleInts(erased_desc), descHandleNumInts(erased_desc));
249             let quantum = descQuantum(erased_desc);
250             let flags = descFlags(erased_desc);
251             (grantors, fds, ints, quantum, flags)
252         };
253         let fds = fds.iter().map(scoped_to_parcel_fd).collect();
254         let ints = ints.to_vec();
255         desc.grantors = grantors.iter().map(grantor_to_rust).collect();
256         desc.handle = NativeHandle { fds, ints };
257         desc.quantum = quantum;
258         desc.flags = flags;
259         // SAFETY: we must free the desc returned by dupeDesc; the pointer was
260         // just returned above so we know it is valid.
261         unsafe { freeDesc(erased_desc) };
262         desc
263     }
264 
265     /// Begin a write transaction. The returned WriteCompletion can be used to
266     /// write into the region allocated for the transaction.
write(&mut self) -> Option<WriteCompletion<T>>267     pub fn write(&mut self) -> Option<WriteCompletion<T>> {
268         self.write_many(1)
269     }
270 
271     /// Begin a write transaction for multiple items. See `MessageQueue::write`.
write_many(&mut self, n: usize) -> Option<WriteCompletion<T>>272     pub fn write_many(&mut self, n: usize) -> Option<WriteCompletion<T>> {
273         let txn = self.begin_write(n)?;
274         Some(WriteCompletion { inner: txn, queue: self, n_elems: n, n_written: 0 })
275     }
276 
commit_write(&mut self, txn: MemTransaction) -> bool277     fn commit_write(&mut self, txn: MemTransaction) -> bool {
278         // SAFETY: simply calls commitWrite with the txn length. The txn must
279         // only use its first MemRegion.
280         unsafe { self.inner.commitWrite(txn.first.length + txn.second.length) }
281     }
282 
begin_write(&self, n: usize) -> Option<MemTransaction>283     fn begin_write(&self, n: usize) -> Option<MemTransaction> {
284         let mut txn: MemTransaction = Default::default();
285         // SAFETY: we pass a raw pointer to txn, which is used only during the
286         // call to beginWrite to write the txn's MemRegion fields, which are raw
287         // pointers and lengths pointing into the queue. The pointer to txn is
288         // not stored.
289         unsafe { self.inner.beginWrite(n, addr_of_mut!(txn)) }.then_some(txn)
290     }
291 }
292 
293 #[inline(always)]
ptr<T: Share>(txn: &MemTransaction, idx: usize) -> *mut T294 fn ptr<T: Share>(txn: &MemTransaction, idx: usize) -> *mut T {
295     let (base, region_idx) = if idx < txn.first.length {
296         (txn.first.address, idx)
297     } else {
298         (txn.second.address, idx - txn.first.length)
299     };
300     let byte_count: usize = region_idx.checked_mul(MessageQueue::<T>::type_size()).unwrap();
301     base.wrapping_byte_offset(byte_count.try_into().unwrap()) as *mut T
302 }
303 
304 #[inline(always)]
contiguous_count(txn: &MemTransaction, idx: usize, n_elems: usize) -> usize305 fn contiguous_count(txn: &MemTransaction, idx: usize, n_elems: usize) -> usize {
306     if idx > n_elems {
307         return 0;
308     }
309     let region_len = if idx < txn.first.length { txn.first.length } else { txn.second.length };
310     region_len - idx
311 }
312 
313 /** A read completion from the MessageQueue::read() method.
314 
315 This completion mutably borrows the MessageQueue to prevent concurrent reads;
316 these must be forbidden because the underlying AidlMessageQueue only stores the
317 number of outstanding reads, not which have and have not completed, so they
318 must complete in order. */
319 #[must_use]
320 pub struct ReadCompletion<'a, T: Share> {
321     inner: MemTransaction,
322     queue: &'a mut MessageQueue<T>,
323     n_elems: usize,
324     n_read: usize,
325 }
326 
327 impl<'a, T: Share> ReadCompletion<'a, T> {
328     /// Obtain a pointer to the location at which the idx'th item is located.
329     ///
330     /// The returned pointer is only valid while `self` has not been dropped and
331     /// is invalidated by any call to `self.read`. The pointer should be used
332     /// with `std::ptr::read` or a DMA API before calling `assume_read` to
333     /// indicate how many elements were read.
334     ///
335     /// It is only permitted to access at most `contiguous_count(idx)` items
336     /// via offsets from the returned address.
337     ///
338     /// Calling this method with a greater `idx` may return a pointer to another
339     /// memory region of different size than the first.
ptr(&self, idx: usize) -> *mut T340     pub fn ptr(&self, idx: usize) -> *mut T {
341         if idx >= self.n_elems {
342             panic!(
343                 "indexing out of bound: ReadCompletion for {} elements but idx {} accessed",
344                 self.n_elems, idx
345             )
346         }
347         ptr(&self.inner, idx)
348     }
349 
350     /// Return the number of contiguous elements located starting at the given
351     /// index in the backing buffer corresponding to the given index.
352     ///
353     /// Intended for use with the `ptr` method.
354     ///
355     /// Returns 0 if `idx` is greater than or equal to the completion's element
356     /// count.
contiguous_count(&self, idx: usize) -> usize357     pub fn contiguous_count(&self, idx: usize) -> usize {
358         contiguous_count(&self.inner, idx, self.n_elems)
359     }
360 
361     /// Returns how many elements still must be read from `self` before dropping
362     /// it.
unread_elements(&self) -> usize363     pub fn unread_elements(&self) -> usize {
364         assert!(self.n_read <= self.n_elems);
365         self.n_elems - self.n_read
366     }
367 
368     /// Read one item from the `self`. Fails and returns `()` if `self` is empty.
read(&mut self) -> Option<T>369     pub fn read(&mut self) -> Option<T> {
370         if self.unread_elements() > 0 {
371             // SAFETY: `self.ptr(self.n_read)`is known to be filled with a valid
372             // instance of type `T`.
373             let data = unsafe { self.ptr(self.n_read).read() };
374             self.n_read += 1;
375             Some(data)
376         } else {
377             None
378         }
379     }
380 
381     /// Promise to the `ReadCompletion` that `n_newly_read` elements have
382     /// been read with unsafe code or DMA from the pointer returned by the
383     /// `ptr` method.
384     ///
385     /// Panics if `n_newly_read` exceeds the number of elements still unread.
386     ///
387     /// Calling this method without actually reading the elements will result
388     /// in them being leaked without destructors (if any) running.
assume_read(&mut self, n_newly_read: usize)389     pub fn assume_read(&mut self, n_newly_read: usize) {
390         assert!(n_newly_read < self.unread_elements());
391         self.n_read += n_newly_read;
392     }
393 }
394 
395 impl<'a, T: Share> Drop for ReadCompletion<'a, T> {
drop(&mut self)396     fn drop(&mut self) {
397         if self.n_read < self.n_elems {
398             error!(
399                 "ReadCompletion dropped without reading all elements ({}/{} read)",
400                 self.n_read, self.n_elems
401             );
402         }
403         let txn = std::mem::take(&mut self.inner);
404         self.queue.commit_read(txn);
405     }
406 }
407 
408 impl<T: Share> MessageQueue<T> {
409     /// Begin a read transaction. The returned `ReadCompletion` can be used to
410     /// write into the region allocated for the transaction.
read(&mut self) -> Option<ReadCompletion<T>>411     pub fn read(&mut self) -> Option<ReadCompletion<T>> {
412         self.read_many(1)
413     }
414 
415     /// Begin a read transaction for multiple items. See `MessageQueue::read`.
read_many(&mut self, n: usize) -> Option<ReadCompletion<T>>416     pub fn read_many(&mut self, n: usize) -> Option<ReadCompletion<T>> {
417         let txn = self.begin_read(n)?;
418         Some(ReadCompletion { inner: txn, queue: self, n_elems: n, n_read: 0 })
419     }
420 
commit_read(&mut self, txn: MemTransaction) -> bool421     fn commit_read(&mut self, txn: MemTransaction) -> bool {
422         // SAFETY: simply calls commitRead with the txn length. The txn must
423         // only use its first MemRegion.
424         unsafe { self.inner.commitRead(txn.first.length + txn.second.length) }
425     }
426 
begin_read(&self, n: usize) -> Option<MemTransaction>427     fn begin_read(&self, n: usize) -> Option<MemTransaction> {
428         let mut txn: MemTransaction = Default::default();
429         // SAFETY: we pass a raw pointer to txn, which is used only during the
430         // call to beginRead to write the txn's MemRegion fields, which are raw
431         // pointers and lengths pointing into the queue. The pointer to txn is
432         // not stored.
433         unsafe { self.inner.beginRead(n, addr_of_mut!(txn)) }.then_some(txn)
434     }
435 }
436