diff --git a/rtic-channel/Cargo.toml b/rtic-channel/Cargo.toml index 89623524e5..5d4cbd0e08 100644 --- a/rtic-channel/Cargo.toml +++ b/rtic-channel/Cargo.toml @@ -9,6 +9,9 @@ edition = "2021" heapless = "0.7" critical-section = "1" +[dev-dependencies] +tokio = { version = "1", features = ["rt", "macros", "time"] } + [features] default = [] diff --git a/rtic-channel/src/lib.rs b/rtic-channel/src/lib.rs index b6a317fb84..1077b5a6e2 100644 --- a/rtic-channel/src/lib.rs +++ b/rtic-channel/src/lib.rs @@ -10,6 +10,7 @@ use core::{ mem::MaybeUninit, pin::Pin, ptr, + sync::atomic::{fence, Ordering}, task::{Poll, Waker}, }; use heapless::Deque; @@ -40,6 +41,9 @@ pub struct Channel { num_senders: UnsafeCell, } +unsafe impl Send for Channel {} +unsafe impl Sync for Channel {} + struct UnsafeAccess<'a, const N: usize> { freeq: &'a mut Deque, readyq: &'a mut Deque, @@ -129,6 +133,21 @@ pub struct Sender<'a, T, const N: usize>(&'a Channel); unsafe impl<'a, T, const N: usize> Send for Sender<'a, T, N> {} +/// This is needed to make the async closure in `send` accept that we "share" +/// the link possible between threads. +#[derive(Clone)] +struct LinkPtr(*mut Option>); + +impl LinkPtr { + /// This will dereference the pointer stored within and give out an `&mut`. + unsafe fn get(&self) -> &mut Option> { + &mut *self.0 + } +} + +unsafe impl Send for LinkPtr {} +unsafe impl Sync for LinkPtr {} + impl<'a, T, const N: usize> core::fmt::Debug for Sender<'a, T, N> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "Sender") @@ -147,7 +166,12 @@ impl<'a, T, const N: usize> Sender<'a, T, N> { } // Write the value into the ready queue. - critical_section::with(|cs| unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) }); + critical_section::with(|cs| { + debug_assert!(!self.0.access(cs).readyq.is_full()); + unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) } + }); + + fence(Ordering::SeqCst); // If there is a receiver waker, wake it. self.0.receiver_waker.wake(); @@ -176,18 +200,17 @@ impl<'a, T, const N: usize> Sender<'a, T, N> { /// Send a value. If there is no place left in the queue this will wait until there is. /// If the receiver does not exist this will return an error. pub async fn send(&mut self, val: T) -> Result<(), NoReceiver> { - if self.is_closed() {} - let mut link_ptr: Option> = None; // Make this future `Drop`-safe, also shadow the original definition so we can't abuse it. - let link_ptr = &mut link_ptr as *mut Option>; + let link_ptr = LinkPtr(&mut link_ptr as *mut Option>); + let link_ptr2 = link_ptr.clone(); let dropper = OnDrop::new(|| { // SAFETY: We only run this closure and dereference the pointer if we have // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference // of this pointer is in the `poll_fn`. - if let Some(link) = unsafe { &mut *link_ptr } { + if let Some(link) = unsafe { link_ptr2.get() } { link.remove_from_list(&self.0.wait_queue); } }); @@ -199,11 +222,19 @@ impl<'a, T, const N: usize> Sender<'a, T, N> { // Do all this in one critical section, else there can be race conditions let queue_idx = critical_section::with(|cs| { - if !self.0.wait_queue.is_empty() || self.0.access(cs).freeq.is_empty() { + let wq_empty = self.0.wait_queue.is_empty(); + let fq_empty = self.0.access(cs).freeq.is_empty(); + if !wq_empty || fq_empty { // SAFETY: This pointer is only dereferenced here and on drop of the future // which happens outside this `poll_fn`'s stack frame. - let link = unsafe { &mut *link_ptr }; - if link.is_none() { + let link = unsafe { link_ptr.get() }; + if let Some(link) = link { + if !link.is_poped() { + return None; + } else { + // Fall through to dequeue + } + } else { // Place the link in the wait queue on first run. let link_ref = link.insert(wait_queue::Link::new(cx.waker().clone())); @@ -212,11 +243,12 @@ impl<'a, T, const N: usize> Sender<'a, T, N> { self.0 .wait_queue .push(unsafe { Pin::new_unchecked(link_ref) }); - } - return None; + return None; + } } + debug_assert!(!self.0.access(cs).freeq.is_empty()); // Get index as the queue is guaranteed not empty and the wait queue is empty let idx = unsafe { self.0.access(cs).freeq.pop_front_unchecked() }; @@ -319,7 +351,12 @@ impl<'a, T, const N: usize> Receiver<'a, T, N> { let r = unsafe { ptr::read(self.0.slots.get_unchecked(rs as usize).get() as *const T) }; // Return the index to the free queue after we've read the value. - critical_section::with(|cs| unsafe { self.0.access(cs).freeq.push_back_unchecked(rs) }); + critical_section::with(|cs| { + debug_assert!(!self.0.access(cs).freeq.is_full()); + unsafe { self.0.access(cs).freeq.push_back_unchecked(rs) } + }); + + fence(Ordering::SeqCst); // If someone is waiting in the WaiterQueue, wake the first one up. if let Some(wait_head) = self.0.wait_queue.pop() { @@ -363,7 +400,7 @@ impl<'a, T, const N: usize> Receiver<'a, T, N> { /// Is the queue full. pub fn is_full(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).readyq.is_empty()) + critical_section::with(|cs| self.0.access(cs).readyq.is_full()) } /// Is the queue empty. @@ -412,6 +449,113 @@ extern crate std; #[cfg(test)] mod tests { + use super::*; + #[test] - fn channel() {} + fn empty() { + let (mut s, mut r) = make_channel!(u32, 10); + + assert!(s.is_empty()); + assert!(r.is_empty()); + + s.try_send(1).unwrap(); + + assert!(!s.is_empty()); + assert!(!r.is_empty()); + + r.try_recv().unwrap(); + + assert!(s.is_empty()); + assert!(r.is_empty()); + } + + #[test] + fn full() { + let (mut s, mut r) = make_channel!(u32, 3); + + for _ in 0..3 { + assert!(!s.is_full()); + assert!(!r.is_full()); + + s.try_send(1).unwrap(); + } + + assert!(s.is_full()); + assert!(r.is_full()); + + for _ in 0..3 { + r.try_recv().unwrap(); + + assert!(!s.is_full()); + assert!(!r.is_full()); + } + } + + #[test] + fn send_recieve() { + let (mut s, mut r) = make_channel!(u32, 10); + + for i in 0..10 { + s.try_send(i).unwrap(); + } + + assert_eq!(s.try_send(11), Err(11)); + + for i in 0..10 { + assert_eq!(r.try_recv().unwrap(), i); + } + + assert_eq!(r.try_recv(), None); + } + + #[test] + fn closed_recv() { + let (s, mut r) = make_channel!(u32, 10); + + drop(s); + + assert!(r.is_closed()); + + assert_eq!(r.try_recv(), None); + } + + #[test] + fn closed_sender() { + let (mut s, r) = make_channel!(u32, 10); + + drop(r); + + assert!(s.is_closed()); + + assert_eq!(s.try_send(11), Ok(())); + } + + #[tokio::test] + async fn stress_channel() { + const NUM_RUNS: usize = 1_000; + const QUEUE_SIZE: usize = 10; + + let (s, mut r) = make_channel!(u32, QUEUE_SIZE); + let mut v = std::vec::Vec::new(); + + for i in 0..NUM_RUNS { + let mut s = s.clone(); + + v.push(tokio::spawn(async move { + s.send(i as _).await.unwrap(); + })); + } + + let mut map = std::collections::BTreeSet::new(); + + for _ in 0..NUM_RUNS { + map.insert(r.recv().await.unwrap()); + } + + assert_eq!(map.len(), NUM_RUNS); + + for v in v { + v.await.unwrap(); + } + } } diff --git a/rtic-channel/src/wait_queue.rs b/rtic-channel/src/wait_queue.rs index ba05e6bb75..e6d5a8b97e 100644 --- a/rtic-channel/src/wait_queue.rs +++ b/rtic-channel/src/wait_queue.rs @@ -3,7 +3,7 @@ use core::marker::PhantomPinned; use core::pin::Pin; use core::ptr::null_mut; -use core::sync::atomic::{AtomicPtr, Ordering}; +use core::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; use core::task::Waker; use critical_section as cs; @@ -57,6 +57,7 @@ impl LinkedList { // Clear the pointers in the node. head_ref.next.store(null_mut(), Self::R); head_ref.prev.store(null_mut(), Self::R); + head_ref.is_poped.store(true, Self::R); return Some(head_val); } @@ -100,9 +101,12 @@ pub struct Link { pub(crate) val: T, next: AtomicPtr>, prev: AtomicPtr>, + is_poped: AtomicBool, _up: PhantomPinned, } +unsafe impl Send for Link {} + impl Link { const R: Ordering = Ordering::Relaxed; @@ -112,10 +116,15 @@ impl Link { val, next: AtomicPtr::new(null_mut()), prev: AtomicPtr::new(null_mut()), + is_poped: AtomicBool::new(false), _up: PhantomPinned, } } + pub fn is_poped(&self) -> bool { + self.is_poped.load(Self::R) + } + pub fn remove_from_list(&mut self, list: &LinkedList) { cs::with(|_| { // Make sure all previous writes are visible @@ -123,6 +132,7 @@ impl Link { let prev = self.prev.load(Self::R); let next = self.next.load(Self::R); + self.is_poped.store(true, Self::R); match unsafe { (prev.as_ref(), next.as_ref()) } { (None, None) => { @@ -217,7 +227,7 @@ mod tests { #[test] fn linked_list() { - let mut wq = LinkedList::::new(); + let wq = LinkedList::::new(); let mut i1 = Link::new(10); let mut i2 = Link::new(11);