rtic-sync: introduce loom compat layer and apply it to channel

This commit is contained in:
datdenkikniet 2025-03-16 12:46:23 +01:00 committed by Emil Fresk
parent d76252d767
commit b5db435501
7 changed files with 299 additions and 87 deletions

View file

@ -1,7 +1,7 @@
//! An async aware MPSC channel that can be used on no-alloc systems.
use crate::unsafecell::UnsafeCell;
use core::{
cell::UnsafeCell,
future::poll_fn,
mem::MaybeUninit,
pin::Pin,
@ -48,11 +48,21 @@ unsafe impl<T, const N: usize> Send for Channel<T, N> {}
unsafe impl<T, const N: usize> Sync for Channel<T, N> {}
struct UnsafeAccess<'a, const N: usize> {
freeq: &'a mut Deque<u8, N>,
readyq: &'a mut Deque<u8, N>,
receiver_dropped: &'a mut bool,
num_senders: &'a mut usize,
macro_rules! cs_access {
($name:ident, $type:ty) => {
/// Access the value mutably.
///
/// SAFETY: this function must not be called recursively within `f`.
unsafe fn $name<F, R>(&self, _cs: critical_section::CriticalSection, f: F) -> R
where
F: FnOnce(&mut $type) -> R,
{
self.$name.with_mut(|v| {
let v = unsafe { &mut *v };
f(v)
})
}
};
}
impl<T, const N: usize> Default for Channel<T, N> {
@ -65,6 +75,7 @@ impl<T, const N: usize> Channel<T, N> {
const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries");
/// Create a new channel.
#[cfg(not(loom))]
pub const fn new() -> Self {
Self {
freeq: UnsafeCell::new(Deque::new()),
@ -77,37 +88,49 @@ impl<T, const N: usize> Channel<T, N> {
}
}
/// Create a new channel.
#[cfg(loom)]
pub fn new() -> Self {
Self {
freeq: UnsafeCell::new(Deque::new()),
readyq: UnsafeCell::new(Deque::new()),
receiver_waker: WakerRegistration::new(),
slots: core::array::from_fn(|_| UnsafeCell::new(MaybeUninit::uninit())),
wait_queue: WaitQueue::new(),
receiver_dropped: UnsafeCell::new(false),
num_senders: UnsafeCell::new(0),
}
}
/// Split the queue into a `Sender`/`Receiver` pair.
pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) {
// SAFETY: we have exclusive access to `self`.
let freeq = self.freeq.get_mut();
let freeq = unsafe { freeq.deref() };
// Fill free queue
for idx in 0..N as u8 {
assert!(!self.freeq.get_mut().is_full());
assert!(!freeq.is_full());
// SAFETY: This safe as the loop goes from 0 to the capacity of the underlying queue.
unsafe {
self.freeq.get_mut().push_back_unchecked(idx);
freeq.push_back_unchecked(idx);
}
}
assert!(self.freeq.get_mut().is_full());
assert!(freeq.is_full());
// There is now 1 sender
*self.num_senders.get_mut() = 1;
// SAFETY: we have exclusive access to `self`.
unsafe { *self.num_senders.get_mut().deref() = 1 };
(Sender(self), Receiver(self))
}
fn access<'a>(&'a self, _cs: critical_section::CriticalSection) -> UnsafeAccess<'a, N> {
// SAFETY: This is safe as are in a critical section.
unsafe {
UnsafeAccess {
freeq: &mut *self.freeq.get(),
readyq: &mut *self.readyq.get(),
receiver_dropped: &mut *self.receiver_dropped.get(),
num_senders: &mut *self.num_senders.get(),
}
}
}
cs_access!(freeq, Deque<u8, N>);
cs_access!(readyq, Deque<u8, N>);
cs_access!(receiver_dropped, bool);
cs_access!(num_senders, usize);
/// Return free slot `slot` to the channel.
///
@ -127,8 +150,14 @@ impl<T, const N: usize> Channel<T, N> {
unsafe { freeq_slot.replace(Some(slot), cs) };
wait_head.wake();
} else {
assert!(!self.access(cs).freeq.is_full());
unsafe { self.access(cs).freeq.push_back_unchecked(slot) }
// SAFETY: `self.freeq` is not called recursively.
unsafe {
self.freeq(cs, |freeq| {
assert!(!freeq.is_full());
// SAFETY: `freeq` is not full.
freeq.push_back_unchecked(slot);
});
}
}
})
}
@ -136,6 +165,7 @@ impl<T, const N: usize> Channel<T, N> {
/// Creates a split channel with `'static` lifetime.
#[macro_export]
#[cfg(not(loom))]
macro_rules! make_channel {
($type:ty, $size:expr) => {{
static mut CHANNEL: $crate::channel::Channel<$type, $size> =
@ -285,16 +315,21 @@ impl<T, const N: usize> Sender<'_, T, N> {
fn send_footer(&mut self, idx: u8, val: T) {
// Write the value to the slots, note; this memcpy is not under a critical section.
unsafe {
ptr::write(
self.0.slots.get_unchecked(idx as usize).get() as *mut T,
val,
)
let first_element = self.0.slots.get_unchecked(idx as usize).get_mut();
let ptr = first_element.deref().as_mut_ptr();
ptr::write(ptr, val)
}
// Write the value into the ready queue.
critical_section::with(|cs| {
assert!(!self.0.access(cs).readyq.is_full());
unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) }
// SAFETY: `self.0.readyq` is not called recursively.
unsafe {
self.0.readyq(cs, |readyq| {
assert!(!readyq.is_full());
// SAFETY: ready is not full.
readyq.push_back_unchecked(idx);
});
}
});
fence(Ordering::SeqCst);
@ -315,12 +350,16 @@ impl<T, const N: usize> Sender<'_, T, N> {
return Err(TrySendError::NoReceiver(val));
}
let idx =
if let Some(idx) = critical_section::with(|cs| self.0.access(cs).freeq.pop_front()) {
idx
} else {
return Err(TrySendError::Full(val));
};
let free_slot = critical_section::with(|cs| unsafe {
// SAFETY: `self.0.freeq` is not called recursively.
self.0.freeq(cs, |q| q.pop_front())
});
let idx = if let Some(idx) = free_slot {
idx
} else {
return Err(TrySendError::Full(val));
};
self.send_footer(idx, val);
@ -368,7 +407,8 @@ impl<T, const N: usize> Sender<'_, T, N> {
}
let wq_empty = self.0.wait_queue.is_empty();
let freeq_empty = self.0.access(cs).freeq.is_empty();
// SAFETY: `self.0.freeq` is not called recursively.
let freeq_empty = unsafe { self.0.freeq(cs, |q| q.is_empty()) };
// SAFETY: This pointer is only dereferenced here and on drop of the future
// which happens outside this `poll_fn`'s stack frame.
@ -416,9 +456,15 @@ impl<T, const N: usize> Sender<'_, T, N> {
}
// We are not in the wait queue, no one else is waiting, and there is a free slot available.
else {
assert!(!self.0.access(cs).freeq.is_empty());
let slot = unsafe { self.0.access(cs).freeq.pop_back_unchecked() };
Poll::Ready(Ok(slot))
// SAFETY: `self.0.freeq` is not called recursively.
unsafe {
self.0.freeq(cs, |freeq| {
assert!(!freeq.is_empty());
// SAFETY: `freeq` is non-empty
let slot = freeq.pop_back_unchecked();
Poll::Ready(Ok(slot))
})
}
}
})
})
@ -438,17 +484,26 @@ impl<T, const N: usize> Sender<'_, T, N> {
/// Returns true if there is no `Receiver`s.
pub fn is_closed(&self) -> bool {
critical_section::with(|cs| *self.0.access(cs).receiver_dropped)
critical_section::with(|cs| unsafe {
// SAFETY: `self.0.receiver_dropped` is not called recursively.
self.0.receiver_dropped(cs, |v| *v)
})
}
/// Is the queue full.
pub fn is_full(&self) -> bool {
critical_section::with(|cs| self.0.access(cs).freeq.is_empty())
critical_section::with(|cs| unsafe {
// SAFETY: `self.0.freeq` is not called recursively.
self.0.freeq(cs, |v| v.is_empty())
})
}
/// Is the queue empty.
pub fn is_empty(&self) -> bool {
critical_section::with(|cs| self.0.access(cs).freeq.is_full())
critical_section::with(|cs| unsafe {
// SAFETY: `self.0.freeq` is not called recursively.
self.0.freeq(cs, |v| v.is_full())
})
}
}
@ -456,9 +511,13 @@ impl<T, const N: usize> Drop for Sender<'_, T, N> {
fn drop(&mut self) {
// Count down the reference counter
let num_senders = critical_section::with(|cs| {
*self.0.access(cs).num_senders -= 1;
*self.0.access(cs).num_senders
unsafe {
// SAFETY: `self.0.num_senders` is not called recursively.
self.0.num_senders(cs, |s| {
*s -= 1;
*s
})
}
});
// If there are no senders, wake the receiver to do error handling.
@ -471,7 +530,10 @@ impl<T, const N: usize> Drop for Sender<'_, T, N> {
impl<T, const N: usize> Clone for Sender<'_, T, N> {
fn clone(&self) -> Self {
// Count up the reference counter
critical_section::with(|cs| *self.0.access(cs).num_senders += 1);
critical_section::with(|cs| unsafe {
// SAFETY: `self.0.num_senders` is not called recursively.
self.0.num_senders(cs, |v| *v += 1);
});
Self(self.0)
}
@ -511,11 +573,18 @@ impl<T, const N: usize> Receiver<'_, T, N> {
/// Receives a value if there is one in the channel, non-blocking.
pub fn try_recv(&mut self) -> Result<T, ReceiveError> {
// Try to get a ready slot.
let ready_slot = critical_section::with(|cs| self.0.access(cs).readyq.pop_front());
let ready_slot = critical_section::with(|cs| unsafe {
// SAFETY: `self.0.readyq` is not called recursively.
self.0.readyq(cs, |q| q.pop_front())
});
if let Some(rs) = ready_slot {
// Read the value from the slots, note; this memcpy is not under a critical section.
let r = unsafe { ptr::read(self.0.slots.get_unchecked(rs as usize).get() as *const T) };
let r = unsafe {
let first_element = self.0.slots.get_unchecked(rs as usize).get_mut();
let ptr = first_element.deref().as_ptr();
ptr::read(ptr)
};
// Return the index to the free queue after we've read the value.
// SAFETY: `rs` comes directly from `readyq`.
@ -556,24 +625,36 @@ impl<T, const N: usize> Receiver<'_, T, N> {
/// Returns true if there are no `Sender`s.
pub fn is_closed(&self) -> bool {
critical_section::with(|cs| *self.0.access(cs).num_senders == 0)
critical_section::with(|cs| unsafe {
// SAFETY: `self.0.num_senders` is not called recursively.
self.0.num_senders(cs, |v| *v == 0)
})
}
/// Is the queue full.
pub fn is_full(&self) -> bool {
critical_section::with(|cs| self.0.access(cs).readyq.is_full())
critical_section::with(|cs| unsafe {
// SAFETY: `self.0.readyq` is not called recursively.
self.0.readyq(cs, |v| v.is_full())
})
}
/// Is the queue empty.
pub fn is_empty(&self) -> bool {
critical_section::with(|cs| self.0.access(cs).readyq.is_empty())
critical_section::with(|cs| unsafe {
// SAFETY: `self.0.readyq` is not called recursively.
self.0.readyq(cs, |v| v.is_empty())
})
}
}
impl<T, const N: usize> Drop for Receiver<'_, T, N> {
fn drop(&mut self) {
// Mark the receiver as dropped and wake all waiters
critical_section::with(|cs| *self.0.access(cs).receiver_dropped = true);
critical_section::with(|cs| unsafe {
// SAFETY: `self.0.receiver_dropped` is not called recursively.
self.0.receiver_dropped(cs, |v| *v = true);
});
while let Some((waker, _)) = self.0.wait_queue.pop() {
waker.wake();
@ -582,6 +663,7 @@ impl<T, const N: usize> Drop for Receiver<'_, T, N> {
}
#[cfg(test)]
#[cfg(not(loom))]
mod tests {
use cassette::Cassette;
@ -666,35 +748,6 @@ mod tests {
assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11)));
}
#[tokio::test]
async fn stress_channel() {
const NUM_RUNS: usize = 1_000;
const QUEUE_SIZE: usize = 10;
let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
let mut v = std::vec::Vec::new();
for i in 0..NUM_RUNS {
let mut s = s.clone();
v.push(tokio::spawn(async move {
s.send(i as _).await.unwrap();
}));
}
let mut map = std::collections::BTreeSet::new();
for _ in 0..NUM_RUNS {
map.insert(r.recv().await.unwrap());
}
assert_eq!(map.len(), NUM_RUNS);
for v in v {
v.await.unwrap();
}
}
fn make() {
let _ = make_channel!(u32, 10);
}
@ -715,7 +768,7 @@ mod tests {
where
F: FnOnce(&mut Deque<u8, N>) -> R,
{
critical_section::with(|cs| f(channel.access(cs).freeq))
critical_section::with(|cs| unsafe { channel.freeq(cs, f) })
}
#[test]
@ -750,3 +803,36 @@ mod tests {
drop((tx, rx));
}
}
#[cfg(not(loom))]
#[cfg(test)]
mod tokio_tests {
#[tokio::test]
async fn stress_channel() {
const NUM_RUNS: usize = 1_000;
const QUEUE_SIZE: usize = 10;
let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
let mut v = std::vec::Vec::new();
for i in 0..NUM_RUNS {
let mut s = s.clone();
v.push(tokio::spawn(async move {
s.send(i as _).await.unwrap();
}));
}
let mut map = std::collections::BTreeSet::new();
for _ in 0..NUM_RUNS {
map.insert(r.recv().await.unwrap());
}
assert_eq!(map.len(), NUM_RUNS);
for v in v {
v.await.unwrap();
}
}
}