diff --git a/rtic-sync/Cargo.toml b/rtic-sync/Cargo.toml index 60d8be2a839..cb54eef8862 100644 --- a/rtic-sync/Cargo.toml +++ b/rtic-sync/Cargo.toml @@ -25,15 +25,23 @@ portable-atomic = { version = "1", default-features = false } embedded-hal = { version = "1.0.0" } embedded-hal-async = { version = "1.0.0" } embedded-hal-bus = { version = "0.2.0", features = ["async"] } - defmt-03 = { package = "defmt", version = "0.3", optional = true } [dev-dependencies] cassette = "0.3.0" static_cell = "2.1.0" -tokio = { version = "1", features = ["rt", "macros", "time"] } + +[target.'cfg(not(loom))'.dev-dependencies] +tokio = { version = "1", features = ["rt", "macros", "time"], default-features = false } [features] default = [] testing = ["critical-section/std", "rtic-common/testing"] defmt-03 = ["dep:defmt-03", "embedded-hal/defmt-03", "embedded-hal-async/defmt-03", "embedded-hal-bus/defmt-03"] + +[lints.rust] +unexpected_cfgs = { level = "allow", check-cfg = ['cfg(loom)'] } + +[target.'cfg(loom)'.dependencies] +loom = { version = "0.7.2", features = [ "futures" ] } +critical-section = { version = "1", features = [ "restore-state-bool" ] } diff --git a/rtic-sync/src/arbiter.rs b/rtic-sync/src/arbiter.rs index 768e2000c98..60559dffab8 100644 --- a/rtic-sync/src/arbiter.rs +++ b/rtic-sync/src/arbiter.rs @@ -381,6 +381,7 @@ pub mod i2c { } } +#[cfg(not(loom))] #[cfg(test)] mod tests { use super::*; diff --git a/rtic-sync/src/channel.rs b/rtic-sync/src/channel.rs index 0bd2cd26926..9c2111fd098 100644 --- a/rtic-sync/src/channel.rs +++ b/rtic-sync/src/channel.rs @@ -1,7 +1,7 @@ //! An async aware MPSC channel that can be used on no-alloc systems. +use crate::unsafecell::UnsafeCell; use core::{ - cell::UnsafeCell, future::poll_fn, mem::MaybeUninit, pin::Pin, @@ -48,11 +48,21 @@ unsafe impl Send for Channel {} unsafe impl Sync for Channel {} -struct UnsafeAccess<'a, const N: usize> { - freeq: &'a mut Deque, - readyq: &'a mut Deque, - receiver_dropped: &'a mut bool, - num_senders: &'a mut usize, +macro_rules! cs_access { + ($name:ident, $type:ty) => { + /// Access the value mutably. + /// + /// SAFETY: this function must not be called recursively within `f`. + unsafe fn $name(&self, _cs: critical_section::CriticalSection, f: F) -> R + where + F: FnOnce(&mut $type) -> R, + { + self.$name.with_mut(|v| { + let v = unsafe { &mut *v }; + f(v) + }) + } + }; } impl Default for Channel { @@ -65,6 +75,7 @@ impl Channel { const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries"); /// Create a new channel. + #[cfg(not(loom))] pub const fn new() -> Self { Self { freeq: UnsafeCell::new(Deque::new()), @@ -77,37 +88,49 @@ impl Channel { } } + /// Create a new channel. + #[cfg(loom)] + pub fn new() -> Self { + Self { + freeq: UnsafeCell::new(Deque::new()), + readyq: UnsafeCell::new(Deque::new()), + receiver_waker: WakerRegistration::new(), + slots: core::array::from_fn(|_| UnsafeCell::new(MaybeUninit::uninit())), + wait_queue: WaitQueue::new(), + receiver_dropped: UnsafeCell::new(false), + num_senders: UnsafeCell::new(0), + } + } + /// Split the queue into a `Sender`/`Receiver` pair. pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) { + // SAFETY: we have exclusive access to `self`. + let freeq = self.freeq.get_mut(); + let freeq = unsafe { freeq.deref() }; + // Fill free queue for idx in 0..N as u8 { - assert!(!self.freeq.get_mut().is_full()); + assert!(!freeq.is_full()); // SAFETY: This safe as the loop goes from 0 to the capacity of the underlying queue. unsafe { - self.freeq.get_mut().push_back_unchecked(idx); + freeq.push_back_unchecked(idx); } } - assert!(self.freeq.get_mut().is_full()); + assert!(freeq.is_full()); // There is now 1 sender - *self.num_senders.get_mut() = 1; + // SAFETY: we have exclusive access to `self`. + unsafe { *self.num_senders.get_mut().deref() = 1 }; (Sender(self), Receiver(self)) } - fn access<'a>(&'a self, _cs: critical_section::CriticalSection) -> UnsafeAccess<'a, N> { - // SAFETY: This is safe as are in a critical section. - unsafe { - UnsafeAccess { - freeq: &mut *self.freeq.get(), - readyq: &mut *self.readyq.get(), - receiver_dropped: &mut *self.receiver_dropped.get(), - num_senders: &mut *self.num_senders.get(), - } - } - } + cs_access!(freeq, Deque); + cs_access!(readyq, Deque); + cs_access!(receiver_dropped, bool); + cs_access!(num_senders, usize); /// Return free slot `slot` to the channel. /// @@ -127,8 +150,14 @@ impl Channel { unsafe { freeq_slot.replace(Some(slot), cs) }; wait_head.wake(); } else { - assert!(!self.access(cs).freeq.is_full()); - unsafe { self.access(cs).freeq.push_back_unchecked(slot) } + // SAFETY: `self.freeq` is not called recursively. + unsafe { + self.freeq(cs, |freeq| { + assert!(!freeq.is_full()); + // SAFETY: `freeq` is not full. + freeq.push_back_unchecked(slot); + }); + } } }) } @@ -136,6 +165,7 @@ impl Channel { /// Creates a split channel with `'static` lifetime. #[macro_export] +#[cfg(not(loom))] macro_rules! make_channel { ($type:ty, $size:expr) => {{ static mut CHANNEL: $crate::channel::Channel<$type, $size> = @@ -285,16 +315,21 @@ impl Sender<'_, T, N> { fn send_footer(&mut self, idx: u8, val: T) { // Write the value to the slots, note; this memcpy is not under a critical section. unsafe { - ptr::write( - self.0.slots.get_unchecked(idx as usize).get() as *mut T, - val, - ) + let first_element = self.0.slots.get_unchecked(idx as usize).get_mut(); + let ptr = first_element.deref().as_mut_ptr(); + ptr::write(ptr, val) } // Write the value into the ready queue. critical_section::with(|cs| { - assert!(!self.0.access(cs).readyq.is_full()); - unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) } + // SAFETY: `self.0.readyq` is not called recursively. + unsafe { + self.0.readyq(cs, |readyq| { + assert!(!readyq.is_full()); + // SAFETY: ready is not full. + readyq.push_back_unchecked(idx); + }); + } }); fence(Ordering::SeqCst); @@ -315,12 +350,16 @@ impl Sender<'_, T, N> { return Err(TrySendError::NoReceiver(val)); } - let idx = - if let Some(idx) = critical_section::with(|cs| self.0.access(cs).freeq.pop_front()) { - idx - } else { - return Err(TrySendError::Full(val)); - }; + let free_slot = critical_section::with(|cs| unsafe { + // SAFETY: `self.0.freeq` is not called recursively. + self.0.freeq(cs, |q| q.pop_front()) + }); + + let idx = if let Some(idx) = free_slot { + idx + } else { + return Err(TrySendError::Full(val)); + }; self.send_footer(idx, val); @@ -368,7 +407,8 @@ impl Sender<'_, T, N> { } let wq_empty = self.0.wait_queue.is_empty(); - let freeq_empty = self.0.access(cs).freeq.is_empty(); + // SAFETY: `self.0.freeq` is not called recursively. + let freeq_empty = unsafe { self.0.freeq(cs, |q| q.is_empty()) }; // SAFETY: This pointer is only dereferenced here and on drop of the future // which happens outside this `poll_fn`'s stack frame. @@ -416,9 +456,15 @@ impl Sender<'_, T, N> { } // We are not in the wait queue, no one else is waiting, and there is a free slot available. else { - assert!(!self.0.access(cs).freeq.is_empty()); - let slot = unsafe { self.0.access(cs).freeq.pop_back_unchecked() }; - Poll::Ready(Ok(slot)) + // SAFETY: `self.0.freeq` is not called recursively. + unsafe { + self.0.freeq(cs, |freeq| { + assert!(!freeq.is_empty()); + // SAFETY: `freeq` is non-empty + let slot = freeq.pop_back_unchecked(); + Poll::Ready(Ok(slot)) + }) + } } }) }) @@ -438,17 +484,26 @@ impl Sender<'_, T, N> { /// Returns true if there is no `Receiver`s. pub fn is_closed(&self) -> bool { - critical_section::with(|cs| *self.0.access(cs).receiver_dropped) + critical_section::with(|cs| unsafe { + // SAFETY: `self.0.receiver_dropped` is not called recursively. + self.0.receiver_dropped(cs, |v| *v) + }) } /// Is the queue full. pub fn is_full(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).freeq.is_empty()) + critical_section::with(|cs| unsafe { + // SAFETY: `self.0.freeq` is not called recursively. + self.0.freeq(cs, |v| v.is_empty()) + }) } /// Is the queue empty. pub fn is_empty(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).freeq.is_full()) + critical_section::with(|cs| unsafe { + // SAFETY: `self.0.freeq` is not called recursively. + self.0.freeq(cs, |v| v.is_full()) + }) } } @@ -456,9 +511,13 @@ impl Drop for Sender<'_, T, N> { fn drop(&mut self) { // Count down the reference counter let num_senders = critical_section::with(|cs| { - *self.0.access(cs).num_senders -= 1; - - *self.0.access(cs).num_senders + unsafe { + // SAFETY: `self.0.num_senders` is not called recursively. + self.0.num_senders(cs, |s| { + *s -= 1; + *s + }) + } }); // If there are no senders, wake the receiver to do error handling. @@ -471,7 +530,10 @@ impl Drop for Sender<'_, T, N> { impl Clone for Sender<'_, T, N> { fn clone(&self) -> Self { // Count up the reference counter - critical_section::with(|cs| *self.0.access(cs).num_senders += 1); + critical_section::with(|cs| unsafe { + // SAFETY: `self.0.num_senders` is not called recursively. + self.0.num_senders(cs, |v| *v += 1); + }); Self(self.0) } @@ -511,11 +573,18 @@ impl Receiver<'_, T, N> { /// Receives a value if there is one in the channel, non-blocking. pub fn try_recv(&mut self) -> Result { // Try to get a ready slot. - let ready_slot = critical_section::with(|cs| self.0.access(cs).readyq.pop_front()); + let ready_slot = critical_section::with(|cs| unsafe { + // SAFETY: `self.0.readyq` is not called recursively. + self.0.readyq(cs, |q| q.pop_front()) + }); if let Some(rs) = ready_slot { // Read the value from the slots, note; this memcpy is not under a critical section. - let r = unsafe { ptr::read(self.0.slots.get_unchecked(rs as usize).get() as *const T) }; + let r = unsafe { + let first_element = self.0.slots.get_unchecked(rs as usize).get_mut(); + let ptr = first_element.deref().as_ptr(); + ptr::read(ptr) + }; // Return the index to the free queue after we've read the value. // SAFETY: `rs` comes directly from `readyq`. @@ -556,24 +625,36 @@ impl Receiver<'_, T, N> { /// Returns true if there are no `Sender`s. pub fn is_closed(&self) -> bool { - critical_section::with(|cs| *self.0.access(cs).num_senders == 0) + critical_section::with(|cs| unsafe { + // SAFETY: `self.0.num_senders` is not called recursively. + self.0.num_senders(cs, |v| *v == 0) + }) } /// Is the queue full. pub fn is_full(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).readyq.is_full()) + critical_section::with(|cs| unsafe { + // SAFETY: `self.0.readyq` is not called recursively. + self.0.readyq(cs, |v| v.is_full()) + }) } /// Is the queue empty. pub fn is_empty(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).readyq.is_empty()) + critical_section::with(|cs| unsafe { + // SAFETY: `self.0.readyq` is not called recursively. + self.0.readyq(cs, |v| v.is_empty()) + }) } } impl Drop for Receiver<'_, T, N> { fn drop(&mut self) { // Mark the receiver as dropped and wake all waiters - critical_section::with(|cs| *self.0.access(cs).receiver_dropped = true); + critical_section::with(|cs| unsafe { + // SAFETY: `self.0.receiver_dropped` is not called recursively. + self.0.receiver_dropped(cs, |v| *v = true); + }); while let Some((waker, _)) = self.0.wait_queue.pop() { waker.wake(); @@ -582,6 +663,7 @@ impl Drop for Receiver<'_, T, N> { } #[cfg(test)] +#[cfg(not(loom))] mod tests { use cassette::Cassette; @@ -666,35 +748,6 @@ mod tests { assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11))); } - #[tokio::test] - async fn stress_channel() { - const NUM_RUNS: usize = 1_000; - const QUEUE_SIZE: usize = 10; - - let (s, mut r) = make_channel!(u32, QUEUE_SIZE); - let mut v = std::vec::Vec::new(); - - for i in 0..NUM_RUNS { - let mut s = s.clone(); - - v.push(tokio::spawn(async move { - s.send(i as _).await.unwrap(); - })); - } - - let mut map = std::collections::BTreeSet::new(); - - for _ in 0..NUM_RUNS { - map.insert(r.recv().await.unwrap()); - } - - assert_eq!(map.len(), NUM_RUNS); - - for v in v { - v.await.unwrap(); - } - } - fn make() { let _ = make_channel!(u32, 10); } @@ -715,7 +768,7 @@ mod tests { where F: FnOnce(&mut Deque) -> R, { - critical_section::with(|cs| f(channel.access(cs).freeq)) + critical_section::with(|cs| unsafe { channel.freeq(cs, f) }) } #[test] @@ -750,3 +803,36 @@ mod tests { drop((tx, rx)); } } + +#[cfg(not(loom))] +#[cfg(test)] +mod tokio_tests { + #[tokio::test] + async fn stress_channel() { + const NUM_RUNS: usize = 1_000; + const QUEUE_SIZE: usize = 10; + + let (s, mut r) = make_channel!(u32, QUEUE_SIZE); + let mut v = std::vec::Vec::new(); + + for i in 0..NUM_RUNS { + let mut s = s.clone(); + + v.push(tokio::spawn(async move { + s.send(i as _).await.unwrap(); + })); + } + + let mut map = std::collections::BTreeSet::new(); + + for _ in 0..NUM_RUNS { + map.insert(r.recv().await.unwrap()); + } + + assert_eq!(map.len(), NUM_RUNS); + + for v in v { + v.await.unwrap(); + } + } +} diff --git a/rtic-sync/src/lib.rs b/rtic-sync/src/lib.rs index f8845888ed5..c2f323f0448 100644 --- a/rtic-sync/src/lib.rs +++ b/rtic-sync/src/lib.rs @@ -1,6 +1,6 @@ //! Synchronization primitives for asynchronous contexts. -#![no_std] +#![cfg_attr(not(loom), no_std)] #![deny(missing_docs)] #[cfg(feature = "defmt-03")] @@ -11,6 +11,11 @@ pub mod channel; pub use portable_atomic; pub mod signal; +mod unsafecell; + #[cfg(test)] #[macro_use] extern crate std; + +#[cfg(loom)] +mod loom_cs; diff --git a/rtic-sync/src/loom_cs.rs b/rtic-sync/src/loom_cs.rs new file mode 100644 index 00000000000..3291f52ff9d --- /dev/null +++ b/rtic-sync/src/loom_cs.rs @@ -0,0 +1,69 @@ +//! A loom-based implementation of CriticalSection, effectively copied from the critical_section::std module. + +use core::cell::RefCell; +use core::mem::MaybeUninit; + +use loom::cell::Cell; +use loom::sync::{Mutex, MutexGuard}; + +loom::lazy_static! { + static ref GLOBAL_MUTEX: Mutex<()> = Mutex::new(()); + // This is initialized if a thread has acquired the CS, uninitialized otherwise. + static ref GLOBAL_GUARD: RefCell>> = RefCell::new(MaybeUninit::uninit()); +} + +loom::thread_local!(static IS_LOCKED: Cell = Cell::new(false)); + +struct StdCriticalSection; +critical_section::set_impl!(StdCriticalSection); + +unsafe impl critical_section::Impl for StdCriticalSection { + unsafe fn acquire() -> bool { + // Allow reentrancy by checking thread local state + IS_LOCKED.with(|l| { + if l.get() { + // CS already acquired in the current thread. + return true; + } + + // Note: it is fine to set this flag *before* acquiring the mutex because it's thread local. + // No other thread can see its value, there's no potential for races. + // This way, we hold the mutex for slightly less time. + l.set(true); + + // Not acquired in the current thread, acquire it. + let guard = match GLOBAL_MUTEX.lock() { + Ok(guard) => guard, + Err(err) => { + // Ignore poison on the global mutex in case a panic occurred + // while the mutex was held. + err.into_inner() + } + }; + GLOBAL_GUARD.borrow_mut().write(guard); + + false + }) + } + + unsafe fn release(nested_cs: bool) { + if !nested_cs { + // SAFETY: As per the acquire/release safety contract, release can only be called + // if the critical section is acquired in the current thread, + // in which case we know the GLOBAL_GUARD is initialized. + // + // We have to `assume_init_read` then drop instead of `assume_init_drop` because: + // - drop requires exclusive access (&mut) to the contents + // - mutex guard drop first unlocks the mutex, then returns. In between those, there's a brief + // moment where the mutex is unlocked but a `&mut` to the contents exists. + // - During this moment, another thread can go and use GLOBAL_GUARD, causing `&mut` aliasing. + #[allow(let_underscore_lock)] + let _ = GLOBAL_GUARD.borrow_mut().assume_init_read(); + + // Note: it is fine to clear this flag *after* releasing the mutex because it's thread local. + // No other thread can see its value, there's no potential for races. + // This way, we hold the mutex for slightly less time. + IS_LOCKED.with(|l| l.set(false)); + } + } +} diff --git a/rtic-sync/src/signal.rs b/rtic-sync/src/signal.rs index f3c8ceb3ed5..d43e9d5da37 100644 --- a/rtic-sync/src/signal.rs +++ b/rtic-sync/src/signal.rs @@ -168,10 +168,10 @@ macro_rules! make_signal { } #[cfg(test)] +#[cfg(not(loom))] mod tests { - use static_cell::StaticCell; - use super::*; + use static_cell::StaticCell; #[test] fn empty() { diff --git a/rtic-sync/src/unsafecell.rs b/rtic-sync/src/unsafecell.rs new file mode 100644 index 00000000000..e1774f8fa16 --- /dev/null +++ b/rtic-sync/src/unsafecell.rs @@ -0,0 +1,43 @@ +//! Compat layer for [`core::cell::UnsafeCell`] and `loom::cell::UnsafeCell`. + +#[cfg(loom)] +pub use loom::cell::UnsafeCell; + +#[cfg(not(loom))] +pub use core::UnsafeCell; + +#[cfg(not(loom))] +mod core { + /// An [`core::cell::UnsafeCell`] wrapper that provides compatibility with + /// loom's UnsafeCell. + #[derive(Debug)] + pub struct UnsafeCell(core::cell::UnsafeCell); + + impl UnsafeCell { + /// Create a new `UnsafeCell`. + pub const fn new(data: T) -> UnsafeCell { + UnsafeCell(core::cell::UnsafeCell::new(data)) + } + + /// Access the contents of the `UnsafeCell` through a mut pointer. + pub fn get_mut(&self) -> MutPtr { + MutPtr(self.0.get()) + } + + pub unsafe fn with_mut(&self, f: F) -> R + where + F: FnOnce(*mut T) -> R, + { + f(self.0.get()) + } + } + + pub struct MutPtr(*mut T); + + impl MutPtr { + #[allow(clippy::mut_from_ref)] + pub unsafe fn deref(&self) -> &mut T { + &mut *self.0 + } + } +}