From cf646adb7383c071c706a618355e9d85ac28097b Mon Sep 17 00:00:00 2001 From: Emil Fresk Date: Thu, 26 Jan 2023 21:29:52 +0100 Subject: [PATCH] Fixes in MPSC linked list and dropper handling --- rtic-channel/Cargo.toml | 15 ++ rtic-channel/src/lib.rs | 380 +++++++++++++++++++++++++++++++++ rtic-channel/src/wait_queue.rs | 278 ++++++++++++++++++++++++ rtic-time/src/lib.rs | 1 - 4 files changed, 673 insertions(+), 1 deletion(-) create mode 100644 rtic-channel/Cargo.toml create mode 100644 rtic-channel/src/lib.rs create mode 100644 rtic-channel/src/wait_queue.rs diff --git a/rtic-channel/Cargo.toml b/rtic-channel/Cargo.toml new file mode 100644 index 0000000000..89623524e5 --- /dev/null +++ b/rtic-channel/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "rtic-channel" +version = "1.0.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +heapless = "0.7" +critical-section = "1" + + +[features] +default = [] +testing = ["critical-section/std"] diff --git a/rtic-channel/src/lib.rs b/rtic-channel/src/lib.rs new file mode 100644 index 0000000000..a7098ee245 --- /dev/null +++ b/rtic-channel/src/lib.rs @@ -0,0 +1,380 @@ +//! Crate + +#![no_std] +#![deny(missing_docs)] + +use core::{ + cell::UnsafeCell, + future::poll_fn, + mem::MaybeUninit, + ptr, + task::{Poll, Waker}, +}; +use heapless::Deque; +use wait_queue::WaitQueue; +use waker_registration::CriticalSectionWakerRegistration as WakerRegistration; + +mod wait_queue; +mod waker_registration; + +/// An MPSC channel for use in no-alloc systems. `N` sets the size of the queue. +/// +/// This channel uses critical sections, however there are extremely small and all `memcpy` +/// operations of `T` are done without critical sections. +pub struct Channel { + // Here are all indexes that are not used in `slots` and ready to be allocated. + freeq: UnsafeCell>, + // Here are wakers and indexes to slots that are ready to be dequeued by the receiver. + readyq: UnsafeCell>, + // Waker for the receiver. + receiver_waker: WakerRegistration, + // Storage for N `T`s, so we don't memcpy around a lot of `T`s. + slots: [UnsafeCell>; N], + // If there is no room in the queue a `Sender`s can wait for there to be place in the queue. + wait_queue: WaitQueue, + // Keep track of the receiver. + receiver_dropped: UnsafeCell, + // Keep track of the number of senders. + num_senders: UnsafeCell, +} + +struct UnsafeAccess<'a, const N: usize> { + freeq: &'a mut Deque, + readyq: &'a mut Deque, + receiver_dropped: &'a mut bool, + num_senders: &'a mut usize, +} + +impl Channel { + const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries"); + + const INIT_SLOTS: UnsafeCell> = UnsafeCell::new(MaybeUninit::uninit()); + + /// Create a new channel. + pub const fn new() -> Self { + Self { + freeq: UnsafeCell::new(Deque::new()), + readyq: UnsafeCell::new(Deque::new()), + receiver_waker: WakerRegistration::new(), + slots: [Self::INIT_SLOTS; N], + wait_queue: WaitQueue::new(), + receiver_dropped: UnsafeCell::new(false), + num_senders: UnsafeCell::new(0), + } + } + + /// Split the queue into a `Sender`/`Receiver` pair. + pub fn split<'a>(&'a mut self) -> (Sender<'a, T, N>, Receiver<'a, T, N>) { + // Fill free queue + for idx in 0..(N - 1) as u8 { + debug_assert!(!self.freeq.get_mut().is_full()); + + // SAFETY: This safe as the loop goes from 0 to the capacity of the underlying queue. + unsafe { + self.freeq.get_mut().push_back_unchecked(idx); + } + } + + debug_assert!(self.freeq.get_mut().is_full()); + + // There is now 1 sender + *self.num_senders.get_mut() = 1; + + (Sender(self), Receiver(self)) + } + + fn access<'a>(&'a self, _cs: critical_section::CriticalSection) -> UnsafeAccess<'a, N> { + // SAFETY: This is safe as are in a critical section. + unsafe { + UnsafeAccess { + freeq: &mut *self.freeq.get(), + readyq: &mut *self.readyq.get(), + receiver_dropped: &mut *self.receiver_dropped.get(), + num_senders: &mut *self.num_senders.get(), + } + } + } +} + +/// Creates a split channel with `'static` lifetime. +#[macro_export] +macro_rules! make_channel { + ($type:path, $size:expr) => {{ + static mut CHANNEL: Channel<$type, $size> = Channel::new(); + + // SAFETY: This is safe as we hide the static mut from others to access it. + // Only this point is where the mutable access happens. + unsafe { CHANNEL.split() } + }}; +} + +// -------- Sender + +/// Error state for when the receiver has been dropped. +pub struct NoReceiver(pub T); + +/// A `Sender` can send to the channel and can be cloned. +pub struct Sender<'a, T, const N: usize>(&'a Channel); + +unsafe impl<'a, T, const N: usize> Send for Sender<'a, T, N> {} + +impl<'a, T, const N: usize> Sender<'a, T, N> { + #[inline(always)] + fn send_footer(&mut self, idx: u8, val: T) { + // Write the value to the slots, note; this memcpy is not under a critical section. + unsafe { + ptr::write( + self.0.slots.get_unchecked(idx as usize).get() as *mut T, + val, + ) + } + + // Write the value into the ready queue. + critical_section::with(|cs| unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) }); + + // If there is a receiver waker, wake it. + self.0.receiver_waker.wake(); + } + + /// Try to send a value, non-blocking. If the channel is full this will return an error. + /// Note; this does not check if the channel is closed. + pub fn try_send(&mut self, val: T) -> Result<(), T> { + // If the wait queue is not empty, we can't try to push into the queue. + if !self.0.wait_queue.is_empty() { + return Err(val); + } + + let idx = + if let Some(idx) = critical_section::with(|cs| self.0.access(cs).freeq.pop_front()) { + idx + } else { + return Err(val); + }; + + self.send_footer(idx, val); + + Ok(()) + } + + /// Send a value. If there is no place left in the queue this will wait until there is. + /// If the receiver does not exist this will return an error. + pub async fn send(&mut self, val: T) -> Result<(), NoReceiver> { + if self.is_closed() {} + + let mut __hidden_link: Option> = None; + + // Make this future `Drop`-safe + let link_ptr = &mut __hidden_link as *mut Option>; + let dropper = OnDrop::new(|| { + // SAFETY: We only run this closure and dereference the pointer if we have + // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference + // of this pointer is in the `poll_fn`. + if let Some(link) = unsafe { &mut *link_ptr } { + link.remove_from_list(&self.0.wait_queue); + } + }); + + let idx = poll_fn(|cx| { + if self.is_closed() { + return Poll::Ready(Err(())); + } + + // Do all this in one critical section, else there can be race conditions + let queue_idx = critical_section::with(|cs| { + if !self.0.wait_queue.is_empty() || self.0.access(cs).freeq.is_empty() { + // SAFETY: This pointer is only dereferenced here and on drop of the future. + let link = unsafe { &mut *link_ptr }; + if link.is_none() { + // Place the link in the wait queue on first run. + let link_ref = link.insert(wait_queue::Link::new(cx.waker().clone())); + self.0.wait_queue.push(link_ref); + } + + return None; + } + + // Get index as the queue is guaranteed not empty and the wait queue is empty + let idx = unsafe { self.0.access(cs).freeq.pop_front_unchecked() }; + + Some(idx) + }); + + if let Some(idx) = queue_idx { + // Return the index + Poll::Ready(Ok(idx)) + } else { + return Poll::Pending; + } + }) + .await; + + // Make sure the link is removed from the queue. + drop(dropper); + + if let Ok(idx) = idx { + self.send_footer(idx, val); + + Ok(()) + } else { + Err(NoReceiver(val)) + } + } + + /// Returns true if there is no `Receiver`s. + pub fn is_closed(&self) -> bool { + critical_section::with(|cs| *self.0.access(cs).receiver_dropped) + } + + /// Is the queue full. + pub fn is_full(&self) -> bool { + critical_section::with(|cs| self.0.access(cs).freeq.is_empty()) + } + + /// Is the queue empty. + pub fn is_empty(&self) -> bool { + critical_section::with(|cs| self.0.access(cs).freeq.is_full()) + } +} + +impl<'a, T, const N: usize> Drop for Sender<'a, T, N> { + fn drop(&mut self) { + // Count down the reference counter + let num_senders = critical_section::with(|cs| { + *self.0.access(cs).num_senders -= 1; + + *self.0.access(cs).num_senders + }); + + // If there are no senders, wake the receiver to do error handling. + if num_senders == 0 { + self.0.receiver_waker.wake(); + } + } +} + +impl<'a, T, const N: usize> Clone for Sender<'a, T, N> { + fn clone(&self) -> Self { + // Count up the reference counter + critical_section::with(|cs| *self.0.access(cs).num_senders += 1); + + Self(self.0) + } +} + +// -------- Receiver + +/// A receiver of the channel. There can only be one receiver at any time. +pub struct Receiver<'a, T, const N: usize>(&'a Channel); + +/// Error state for when all senders has been dropped. +pub struct NoSender; + +impl<'a, T, const N: usize> Receiver<'a, T, N> { + /// Receives a value if there is one in the channel, non-blocking. + /// Note; this does not check if the channel is closed. + pub fn try_recv(&mut self) -> Option { + // Try to get a ready slot. + let ready_slot = + critical_section::with(|cs| self.0.access(cs).readyq.pop_front().map(|rs| rs)); + + if let Some(rs) = ready_slot { + // Read the value from the slots, note; this memcpy is not under a critical section. + let r = unsafe { ptr::read(self.0.slots.get_unchecked(rs as usize).get() as *const T) }; + + // Return the index to the free queue after we've read the value. + critical_section::with(|cs| unsafe { self.0.access(cs).freeq.push_back_unchecked(rs) }); + + // If someone is waiting in the WaiterQueue, wake the first one up. + if let Some(wait_head) = self.0.wait_queue.pop() { + wait_head.wake(); + } + + Some(r) + } else { + None + } + } + + /// Receives a value, waiting if the queue is empty. + /// If all senders are dropped this will error with `NoSender`. + pub async fn recv(&mut self) -> Result { + // There was nothing in the queue, setup the waiting. + poll_fn(|cx| { + // Register waker. + // TODO: Should it happen here or after the if? This might cause a spurious wake. + self.0.receiver_waker.register(cx.waker()); + + // Try to dequeue. + if let Some(val) = self.try_recv() { + return Poll::Ready(Ok(val)); + } + + // If the queue is empty and there is no sender, return the error. + if self.is_closed() { + return Poll::Ready(Err(NoSender)); + } + + Poll::Pending + }) + .await + } + + /// Returns true if there are no `Sender`s. + pub fn is_closed(&self) -> bool { + critical_section::with(|cs| *self.0.access(cs).num_senders == 0) + } + + /// Is the queue full. + pub fn is_full(&self) -> bool { + critical_section::with(|cs| self.0.access(cs).readyq.is_empty()) + } + + /// Is the queue empty. + pub fn is_empty(&self) -> bool { + critical_section::with(|cs| self.0.access(cs).readyq.is_empty()) + } +} + +impl<'a, T, const N: usize> Drop for Receiver<'a, T, N> { + fn drop(&mut self) { + // Mark the receiver as dropped and wake all waiters + critical_section::with(|cs| *self.0.access(cs).receiver_dropped = true); + + while let Some(waker) = self.0.wait_queue.pop() { + waker.wake(); + } + } +} + +struct OnDrop { + f: core::mem::MaybeUninit, +} + +impl OnDrop { + pub fn new(f: F) -> Self { + Self { + f: core::mem::MaybeUninit::new(f), + } + } + + #[allow(unused)] + pub fn defuse(self) { + core::mem::forget(self) + } +} + +impl Drop for OnDrop { + fn drop(&mut self) { + unsafe { self.f.as_ptr().read()() } + } +} + +#[cfg(test)] +#[macro_use] +extern crate std; + +#[cfg(test)] +mod tests { + #[test] + fn channel() {} +} diff --git a/rtic-channel/src/wait_queue.rs b/rtic-channel/src/wait_queue.rs new file mode 100644 index 0000000000..90d762bdf3 --- /dev/null +++ b/rtic-channel/src/wait_queue.rs @@ -0,0 +1,278 @@ +//! ... + +use core::cell::UnsafeCell; +use core::marker::PhantomPinned; +use core::ptr::null_mut; +use core::sync::atomic::{AtomicPtr, Ordering}; +use core::task::Waker; +use critical_section as cs; + +pub type WaitQueue = LinkedList; + +struct MyLinkPtr(UnsafeCell<*mut Link>); + +impl MyLinkPtr { + #[inline(always)] + fn new(val: *mut Link) -> Self { + Self(UnsafeCell::new(val)) + } + + /// SAFETY: Only use this in a critical section, and don't forget them barriers. + #[inline(always)] + unsafe fn load_relaxed(&self) -> *mut Link { + unsafe { *self.0.get() } + } + + /// SAFETY: Only use this in a critical section, and don't forget them barriers. + #[inline(always)] + unsafe fn store_relaxed(&self, val: *mut Link) { + unsafe { self.0.get().write(val) } + } +} + +/// A FIFO linked list for a wait queue. +pub struct LinkedList { + head: AtomicPtr>, // UnsafeCell<*mut Link> + tail: AtomicPtr>, +} + +impl LinkedList { + /// Create a new linked list. + pub const fn new() -> Self { + Self { + head: AtomicPtr::new(null_mut()), + tail: AtomicPtr::new(null_mut()), + } + } +} + +impl LinkedList { + const R: Ordering = Ordering::Relaxed; + + /// Pop the first element in the queue. + pub fn pop(&self) -> Option { + cs::with(|_| { + // Make sure all previous writes are visible + core::sync::atomic::fence(Ordering::SeqCst); + + let head = self.head.load(Self::R); + + // SAFETY: `as_ref` is safe as `insert` requires a valid reference to a link + if let Some(head_ref) = unsafe { head.as_ref() } { + // Move head to the next element + self.head.store(head_ref.next.load(Self::R), Self::R); + + // We read the value at head + let head_val = head_ref.val.clone(); + + let tail = self.tail.load(Self::R); + if head == tail { + // The queue is empty + self.tail.store(null_mut(), Self::R); + } + + if let Some(next_ref) = unsafe { head_ref.next.load(Self::R).as_ref() } { + next_ref.prev.store(null_mut(), Self::R); + } + + // Clear the pointers in the node. + head_ref.next.store(null_mut(), Self::R); + head_ref.prev.store(null_mut(), Self::R); + + return Some(head_val); + } + + None + }) + } + + /// Put an element at the back of the queue. + pub fn push(&self, link: &mut Link) { + cs::with(|_| { + // Make sure all previous writes are visible + core::sync::atomic::fence(Ordering::SeqCst); + + let tail = self.tail.load(Self::R); + + if let Some(tail_ref) = unsafe { tail.as_ref() } { + // Queue is not empty + link.prev.store(tail, Self::R); + self.tail.store(link, Self::R); + tail_ref.next.store(link, Self::R); + } else { + // Queue is empty + self.tail.store(link, Self::R); + self.head.store(link, Self::R); + } + }); + } + + /// Check if the queue is empty. + pub fn is_empty(&self) -> bool { + self.head.load(Self::R).is_null() + } +} + +/// A link in the linked list. +pub struct Link { + pub(crate) val: T, + next: AtomicPtr>, + prev: AtomicPtr>, + _up: PhantomPinned, +} + +impl Link { + const R: Ordering = Ordering::Relaxed; + + /// Create a new link. + pub const fn new(val: T) -> Self { + Self { + val, + next: AtomicPtr::new(null_mut()), + prev: AtomicPtr::new(null_mut()), + _up: PhantomPinned, + } + } + + pub fn remove_from_list(&mut self, list: &LinkedList) { + cs::with(|_| { + // Make sure all previous writes are visible + core::sync::atomic::fence(Ordering::SeqCst); + + let prev = self.prev.load(Self::R); + let next = self.next.load(Self::R); + + match unsafe { (prev.as_ref(), next.as_ref()) } { + (None, None) => { + // Not in the list or alone in the list, check if list head == node address + let sp = self as *const _; + + if sp == list.head.load(Ordering::Relaxed) { + list.head.store(null_mut(), Self::R); + list.tail.store(null_mut(), Self::R); + } + } + (None, Some(next_ref)) => { + // First in the list + next_ref.prev.store(null_mut(), Self::R); + list.head.store(next, Self::R); + } + (Some(prev_ref), None) => { + // Last in the list + prev_ref.next.store(null_mut(), Self::R); + list.tail.store(prev, Self::R); + } + (Some(prev_ref), Some(next_ref)) => { + // Somewhere in the list + + // Connect the `prev.next` and `next.prev` with each other to remove the node + prev_ref.next.store(next, Self::R); + next_ref.prev.store(prev, Self::R); + } + } + }) + } +} + +#[cfg(test)] +impl LinkedList { + fn print(&self) { + cs::with(|_| { + // Make sure all previous writes are visible + core::sync::atomic::fence(Ordering::SeqCst); + + let mut head = self.head.load(Self::R); + let tail = self.tail.load(Self::R); + + println!( + "List - h = 0x{:x}, t = 0x{:x}", + head as usize, tail as usize + ); + + let mut i = 0; + + // SAFETY: `as_ref` is safe as `insert` requires a valid reference to a link + while let Some(head_ref) = unsafe { head.as_ref() } { + println!( + " {}: {:?}, s = 0x{:x}, n = 0x{:x}, p = 0x{:x}", + i, + head_ref.val, + head as usize, + head_ref.next.load(Ordering::Relaxed) as usize, + head_ref.prev.load(Ordering::Relaxed) as usize + ); + + head = head_ref.next.load(Self::R); + + i += 1; + } + }); + } +} + +#[cfg(test)] +impl Link { + fn print(&self) { + cs::with(|_| { + // Make sure all previous writes are visible + core::sync::atomic::fence(Ordering::SeqCst); + + println!("Link:"); + + println!( + " val = {:?}, n = 0x{:x}, p = 0x{:x}", + self.val, + self.next.load(Ordering::Relaxed) as usize, + self.prev.load(Ordering::Relaxed) as usize + ); + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn linked_list() { + let mut wq = LinkedList::::new(); + + let mut i1 = Link::new(10); + let mut i2 = Link::new(11); + let mut i3 = Link::new(12); + let mut i4 = Link::new(13); + let mut i5 = Link::new(14); + + wq.push(&mut i1); + wq.push(&mut i2); + wq.push(&mut i3); + wq.push(&mut i4); + wq.push(&mut i5); + + wq.print(); + + wq.pop(); + i1.print(); + + wq.print(); + + i4.remove_from_list(&wq); + + wq.print(); + + // i1.remove_from_list(&wq); + // wq.print(); + + println!("i2"); + i2.remove_from_list(&wq); + wq.print(); + + println!("i3"); + i3.remove_from_list(&wq); + wq.print(); + + println!("i5"); + i5.remove_from_list(&wq); + wq.print(); + } +} diff --git a/rtic-time/src/lib.rs b/rtic-time/src/lib.rs index 34f93622aa..78ece1df20 100644 --- a/rtic-time/src/lib.rs +++ b/rtic-time/src/lib.rs @@ -1,7 +1,6 @@ //! Crate #![no_std] -#![no_main] #![deny(missing_docs)] #![allow(incomplete_features)] #![feature(async_fn_in_trait)]