diff --git a/macros/src/codegen/monotonic.rs b/macros/src/codegen/monotonic.rs index 685502edbb..8931dbbdbe 100644 --- a/macros/src/codegen/monotonic.rs +++ b/macros/src/codegen/monotonic.rs @@ -60,7 +60,6 @@ pub fn codegen(app: &App, _analysis: &Analysis, _extra: &Extra) -> TokenStream2 #[doc = #doc] #[allow(non_snake_case)] pub mod #m { - /// Read the current time from this monotonic pub fn now() -> ::Instant { rtic::export::interrupt::free(|_| { @@ -73,39 +72,13 @@ pub fn codegen(app: &App, _analysis: &Analysis, _extra: &Extra) -> TokenStream2 }) } - fn enqueue_waker( - instant: ::Instant, - waker: core::task::Waker - ) -> Result { - unsafe { - rtic::export::interrupt::free(|_| { - let marker = super::super::#tq_marker.get().read(); - super::super::#tq_marker.get_mut().write(marker.wrapping_add(1)); - - let nr = rtic::export::WakerNotReady { - waker, - instant, - marker, - }; - - let tq = &mut *super::super::#tq.get_mut(); - - tq.enqueue_waker( - nr, - || #enable_interrupt, - || #pend, - (&mut *super::super::#m_ident.get_mut()).as_mut()).map(|_| marker) - }) - } - } - /// Delay #[inline(always)] #[allow(non_snake_case)] pub fn delay(duration: ::Duration) -> DelayFuture { let until = now() + duration; - DelayFuture { until, tq_marker: None } + DelayFuture { until, waker_storage: None } } /// Delay future. @@ -113,11 +86,22 @@ pub fn codegen(app: &App, _analysis: &Analysis, _extra: &Extra) -> TokenStream2 #[allow(non_camel_case_types)] pub struct DelayFuture { until: ::Instant, - tq_marker: Option, + waker_storage: Option>>, + } + + impl Drop for DelayFuture { + fn drop(&mut self) { + if let Some(waker_storage) = &mut self.waker_storage { + rtic::export::interrupt::free(|_| unsafe { + let tq = &mut *super::super::#tq.get_mut(); + tq.cancel_waker_marker(waker_storage.val.marker); + }); + } + } } impl core::future::Future for DelayFuture { - type Output = Result<(), ()>; + type Output = (); fn poll( mut self: core::pin::Pin<&mut Self>, @@ -125,22 +109,33 @@ pub fn codegen(app: &App, _analysis: &Analysis, _extra: &Extra) -> TokenStream2 ) -> core::task::Poll { let mut s = self.as_mut(); let now = now(); + let until = s.until; + let is_ws_none = s.waker_storage.is_none(); - if now >= s.until { - core::task::Poll::Ready(Ok(())) - } else { - if s.tq_marker.is_some() { - core::task::Poll::Pending - } else { - match enqueue_waker(s.until, cx.waker().clone()) { - Ok(marker) => { - s.tq_marker = Some(marker); - core::task::Poll::Pending - }, - Err(()) => core::task::Poll::Ready(Err(())), - } - } + if now >= until { + return core::task::Poll::Ready(()); + } else if is_ws_none { + rtic::export::interrupt::free(|_| unsafe { + let marker = super::super::#tq_marker.get().read(); + super::super::#tq_marker.get_mut().write(marker.wrapping_add(1)); + + let nr = s.waker_storage.insert(rtic::export::IntrusiveNode::new(rtic::export::WakerNotReady { + waker: cx.waker().clone(), + instant: until, + marker, + })); + + let tq = &mut *super::super::#tq.get_mut(); + + tq.enqueue_waker( + core::mem::transmute(nr), // Transmute the reference to static + || #enable_interrupt, + || #pend, + (&mut *super::super::#m_ident.get_mut()).as_mut()); + }); } + + core::task::Poll::Pending } } @@ -150,7 +145,18 @@ pub fn codegen(app: &App, _analysis: &Analysis, _extra: &Extra) -> TokenStream2 pub struct TimeoutFuture { future: F, until: ::Instant, - tq_marker: Option, + waker_storage: Option>>, + } + + impl Drop for TimeoutFuture { + fn drop(&mut self) { + if let Some(waker_storage) = &mut self.waker_storage { + rtic::export::interrupt::free(|_| unsafe { + let tq = &mut *super::super::#tq.get_mut(); + tq.cancel_waker_marker(waker_storage.val.marker); + }); + } + } } /// Timeout after @@ -164,7 +170,7 @@ pub fn codegen(app: &App, _analysis: &Analysis, _extra: &Extra) -> TokenStream2 TimeoutFuture { future, until, - tq_marker: None, + waker_storage: None, } } @@ -178,7 +184,7 @@ pub fn codegen(app: &App, _analysis: &Analysis, _extra: &Extra) -> TokenStream2 TimeoutFuture { future, until: instant, - tq_marker: None, + waker_storage: None, } } @@ -186,46 +192,58 @@ pub fn codegen(app: &App, _analysis: &Analysis, _extra: &Extra) -> TokenStream2 where F: core::future::Future, { - type Output = Result, ()>; + type Output = Result; fn poll( self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_> ) -> core::task::Poll { - let now = now(); - // SAFETY: We don't move the underlying pinned value. let mut s = unsafe { self.get_unchecked_mut() }; let future = unsafe { core::pin::Pin::new_unchecked(&mut s.future) }; + let now = now(); + let until = s.until; + let is_ws_none = s.waker_storage.is_none(); match future.poll(cx) { core::task::Poll::Ready(r) => { - if let Some(marker) = s.tq_marker { + if let Some(waker_storage) = &mut s.waker_storage { rtic::export::interrupt::free(|_| unsafe { let tq = &mut *super::super::#tq.get_mut(); - tq.cancel_waker_marker(marker); + tq.cancel_waker_marker(waker_storage.val.marker); }); } - core::task::Poll::Ready(Ok(Ok(r))) + return core::task::Poll::Ready(Ok(r)); } core::task::Poll::Pending => { - if now >= s.until { + if now >= until { // Timeout - core::task::Poll::Ready(Ok(Err(super::TimeoutError))) - } else if s.tq_marker.is_none() { - match enqueue_waker(s.until, cx.waker().clone()) { - Ok(marker) => { - s.tq_marker = Some(marker); - core::task::Poll::Pending - }, - Err(()) => core::task::Poll::Ready(Err(())), // TQ full - } - } else { - core::task::Poll::Pending + return core::task::Poll::Ready(Err(super::TimeoutError)); + } else if is_ws_none { + rtic::export::interrupt::free(|_| unsafe { + let marker = super::super::#tq_marker.get().read(); + super::super::#tq_marker.get_mut().write(marker.wrapping_add(1)); + + let nr = s.waker_storage.insert(rtic::export::IntrusiveNode::new(rtic::export::WakerNotReady { + waker: cx.waker().clone(), + instant: until, + marker, + })); + + let tq = &mut *super::super::#tq.get_mut(); + + tq.enqueue_waker( + core::mem::transmute(nr), // Transmute the reference to static + || #enable_interrupt, + || #pend, + (&mut *super::super::#m_ident.get_mut()).as_mut()); + }); } } } + + core::task::Poll::Pending } } } diff --git a/macros/src/codegen/timer_queue.rs b/macros/src/codegen/timer_queue.rs index 513f78af8d..db6a9e3d28 100644 --- a/macros/src/codegen/timer_queue.rs +++ b/macros/src/codegen/timer_queue.rs @@ -67,13 +67,7 @@ pub fn codegen(app: &App, analysis: &Analysis, _extra: &Extra) -> Vec); + let tq_ty = quote!(rtic::export::TimerQueue<#mono_type, #t, #n_task>); // For future use // let doc = format!(" RTIC internal: {}:{}", file!(), line!()); @@ -84,7 +78,7 @@ pub fn codegen(app: &App, analysis: &Analysis, _extra: &Extra) -> Vec = rtic::RacyCell::new( rtic::export::TimerQueue { task_queue: rtic::export::SortedLinkedList::new_u16(), - waker_queue: rtic::export::SortedLinkedList::new_u16(), + waker_queue: rtic::export::IntrusiveSortedLinkedList::new(), } ); )); @@ -148,7 +142,7 @@ pub fn codegen(app: &App, analysis: &Analysis, _extra: &Extra) -> Vec Vec waker.wake(), - rtic::export::TaskOrWaker::Task((task, index)) => { - match task { - #(#arms)* - } - } + match task { + #(#arms)* } } diff --git a/src/export.rs b/src/export.rs index 9ef721f986..08101cdd49 100644 --- a/src/export.rs +++ b/src/export.rs @@ -1,11 +1,13 @@ #![allow(clippy::inline_always)] +pub use crate::{ + sll::{IntrusiveSortedLinkedList, Node as IntrusiveNode}, + tq::{TaskNotReady, TimerQueue, WakerNotReady}, +}; +pub use bare_metal::CriticalSection; use core::{ cell::Cell, sync::atomic::{AtomicBool, Ordering}, }; - -pub use crate::tq::{TaskNotReady, TaskOrWaker, TimerQueue, WakerNotReady}; -pub use bare_metal::CriticalSection; pub use cortex_m::{ asm::nop, asm::wfi, diff --git a/src/lib.rs b/src/lib.rs index 0c0d0cc7dc..1db6a2869e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,8 @@ pub mod mutex { #[doc(hidden)] pub mod export; #[doc(hidden)] +pub mod sll; +#[doc(hidden)] mod tq; /// Sets the given `interrupt` as pending diff --git a/src/sll.rs b/src/sll.rs new file mode 100644 index 0000000000..43b53c1749 --- /dev/null +++ b/src/sll.rs @@ -0,0 +1,421 @@ +//! An intrusive sorted priority linked list, designed for use in `Future`s in RTIC. +use core::cmp::Ordering; +use core::fmt; +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use core::ptr::NonNull; + +/// Marker for Min sorted [`IntrusiveSortedLinkedList`]. +pub struct Min; + +/// Marker for Max sorted [`IntrusiveSortedLinkedList`]. +pub struct Max; + +/// The linked list kind: min-list or max-list +pub trait Kind: private::Sealed { + #[doc(hidden)] + fn ordering() -> Ordering; +} + +impl Kind for Min { + fn ordering() -> Ordering { + Ordering::Less + } +} + +impl Kind for Max { + fn ordering() -> Ordering { + Ordering::Greater + } +} + +/// Sealed traits +mod private { + pub trait Sealed {} +} + +impl private::Sealed for Max {} +impl private::Sealed for Min {} + +/// A node in the [`IntrusiveSortedLinkedList`]. +pub struct Node { + pub val: T, + next: Option>>, +} + +impl Node { + pub fn new(val: T) -> Self { + Self { val, next: None } + } +} + +/// The linked list. +pub struct IntrusiveSortedLinkedList<'a, T, K> { + head: Option>>, + _kind: PhantomData, + _lt: PhantomData<&'a ()>, +} + +impl<'a, T, K> fmt::Debug for IntrusiveSortedLinkedList<'a, T, K> +where + T: Ord + core::fmt::Debug, + K: Kind, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut l = f.debug_list(); + let mut current = self.head; + + while let Some(head) = current { + let head = unsafe { head.as_ref() }; + current = head.next; + + l.entry(&head.val); + } + + l.finish() + } +} + +impl<'a, T, K> IntrusiveSortedLinkedList<'a, T, K> +where + T: Ord, + K: Kind, +{ + pub const fn new() -> Self { + Self { + head: None, + _kind: PhantomData, + _lt: PhantomData, + } + } + + // Push to the list. + pub fn push(&mut self, new: &'a mut Node) { + unsafe { + if let Some(head) = self.head { + if head.as_ref().val.cmp(&new.val) != K::ordering() { + // This is newer than head, replace head + new.next = self.head; + self.head = Some(NonNull::new_unchecked(new)); + } else { + // It's not head, search the list for the correct placement + let mut current = head; + + while let Some(next) = current.as_ref().next { + if next.as_ref().val.cmp(&new.val) != K::ordering() { + break; + } + + current = next; + } + + new.next = current.as_ref().next; + current.as_mut().next = Some(NonNull::new_unchecked(new)); + } + } else { + // List is empty, place at head + self.head = Some(NonNull::new_unchecked(new)) + } + } + } + + /// Get an iterator over the sorted list. + pub fn iter(&self) -> Iter<'_, T, K> { + Iter { + _list: self, + index: self.head, + } + } + + /// Find an element in the list that can be changed and resorted. + pub fn find_mut(&mut self, mut f: F) -> Option> + where + F: FnMut(&T) -> bool, + { + let head = self.head?; + + // Special-case, first element + if f(&unsafe { head.as_ref() }.val) { + return Some(FindMut { + is_head: true, + prev_index: None, + index: self.head, + list: self, + maybe_changed: false, + }); + } + + let mut current = head; + + while let Some(next) = unsafe { current.as_ref() }.next { + if f(&unsafe { next.as_ref() }.val) { + return Some(FindMut { + is_head: false, + prev_index: Some(current), + index: Some(next), + list: self, + maybe_changed: false, + }); + } + + current = next; + } + + None + } + + /// Peek at the first element. + pub fn peek(&self) -> Option<&T> { + self.head.map(|head| unsafe { &head.as_ref().val }) + } + + /// Pops the first element in the list. + /// + /// Complexity is worst-case `O(1)`. + pub fn pop(&mut self) -> Option<&'a Node> { + if let Some(head) = self.head { + let v = unsafe { head.as_ref() }; + self.head = v.next; + Some(v) + } else { + None + } + } + + /// Checks if the linked list is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.head.is_none() + } +} + +/// Iterator for the linked list. +pub struct Iter<'a, T, K> +where + T: Ord, + K: Kind, +{ + _list: &'a IntrusiveSortedLinkedList<'a, T, K>, + index: Option>>, +} + +impl<'a, T, K> Iterator for Iter<'a, T, K> +where + T: Ord, + K: Kind, +{ + type Item = &'a T; + + fn next(&mut self) -> Option { + let index = self.index?; + + let node = unsafe { index.as_ref() }; + self.index = node.next; + + Some(&node.val) + } +} + +/// Comes from [`IntrusiveSortedLinkedList::find_mut`]. +pub struct FindMut<'a, 'b, T, K> +where + T: Ord + 'b, + K: Kind, +{ + list: &'a mut IntrusiveSortedLinkedList<'b, T, K>, + is_head: bool, + prev_index: Option>>, + index: Option>>, + maybe_changed: bool, +} + +impl<'a, 'b, T, K> FindMut<'a, 'b, T, K> +where + T: Ord, + K: Kind, +{ + unsafe fn pop_internal(&mut self) -> &'b mut Node { + if self.is_head { + // If it is the head element, we can do a normal pop + let mut head = self.list.head.unwrap_unchecked(); + let v = head.as_mut(); + self.list.head = v.next; + v + } else { + // Somewhere in the list + let mut prev = self.prev_index.unwrap_unchecked(); + let mut curr = self.index.unwrap_unchecked(); + + // Re-point the previous index + prev.as_mut().next = curr.as_ref().next; + + curr.as_mut() + } + } + + /// This will pop the element from the list. + /// + /// Complexity is worst-case `O(1)`. + #[inline] + pub fn pop(mut self) -> &'b mut Node { + unsafe { self.pop_internal() } + } + + /// This will resort the element into the correct position in the list if needed. The resorting + /// will only happen if the element has been accessed mutably. + /// + /// Same as calling `drop`. + /// + /// Complexity is worst-case `O(N)`. + #[inline] + pub fn finish(self) { + drop(self) + } +} + +impl<'b, T, K> Drop for FindMut<'_, 'b, T, K> +where + T: Ord + 'b, + K: Kind, +{ + fn drop(&mut self) { + // Only resort the list if the element has changed + if self.maybe_changed { + unsafe { + let val = self.pop_internal(); + self.list.push(val); + } + } + } +} + +impl Deref for FindMut<'_, '_, T, K> +where + T: Ord, + K: Kind, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &self.index.unwrap_unchecked().as_ref().val } + } +} + +impl DerefMut for FindMut<'_, '_, T, K> +where + T: Ord, + K: Kind, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.maybe_changed = true; + unsafe { &mut self.index.unwrap_unchecked().as_mut().val } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn const_new() { + static mut _V1: IntrusiveSortedLinkedList = IntrusiveSortedLinkedList::new(); + } + + #[test] + fn test_peek() { + let mut ll: IntrusiveSortedLinkedList = IntrusiveSortedLinkedList::new(); + + let mut a = Node { val: 1, next: None }; + ll.push(&mut a); + assert_eq!(ll.peek().unwrap(), &1); + + let mut a = Node { val: 2, next: None }; + ll.push(&mut a); + assert_eq!(ll.peek().unwrap(), &2); + + let mut a = Node { val: 3, next: None }; + ll.push(&mut a); + assert_eq!(ll.peek().unwrap(), &3); + + let mut ll: IntrusiveSortedLinkedList = IntrusiveSortedLinkedList::new(); + + let mut a = Node { val: 2, next: None }; + ll.push(&mut a); + assert_eq!(ll.peek().unwrap(), &2); + + let mut a = Node { val: 1, next: None }; + ll.push(&mut a); + assert_eq!(ll.peek().unwrap(), &1); + + let mut a = Node { val: 3, next: None }; + ll.push(&mut a); + assert_eq!(ll.peek().unwrap(), &1); + } + + #[test] + fn test_empty() { + let ll: IntrusiveSortedLinkedList = IntrusiveSortedLinkedList::new(); + + assert!(ll.is_empty()) + } + + #[test] + fn test_updating() { + let mut ll: IntrusiveSortedLinkedList = IntrusiveSortedLinkedList::new(); + + let mut a = Node { val: 1, next: None }; + ll.push(&mut a); + + let mut a = Node { val: 2, next: None }; + ll.push(&mut a); + + let mut a = Node { val: 3, next: None }; + ll.push(&mut a); + + let mut find = ll.find_mut(|v| *v == 2).unwrap(); + + *find += 1000; + find.finish(); + + assert_eq!(ll.peek().unwrap(), &1002); + + let mut find = ll.find_mut(|v| *v == 3).unwrap(); + + *find += 1000; + find.finish(); + + assert_eq!(ll.peek().unwrap(), &1003); + + // Remove largest element + ll.find_mut(|v| *v == 1003).unwrap().pop(); + + assert_eq!(ll.peek().unwrap(), &1002); + } + + #[test] + fn test_updating_1() { + let mut ll: IntrusiveSortedLinkedList = IntrusiveSortedLinkedList::new(); + + let mut a = Node { val: 1, next: None }; + ll.push(&mut a); + + let v = ll.pop().unwrap(); + + assert_eq!(v.val, 1); + } + + #[test] + fn test_updating_2() { + let mut ll: IntrusiveSortedLinkedList = IntrusiveSortedLinkedList::new(); + + let mut a = Node { val: 1, next: None }; + ll.push(&mut a); + + let mut find = ll.find_mut(|v| *v == 1).unwrap(); + + *find += 1000; + find.finish(); + + assert_eq!(ll.peek().unwrap(), &1001); + } +} diff --git a/src/tq.rs b/src/tq.rs index 90542e7308..ed4016eced 100644 --- a/src/tq.rs +++ b/src/tq.rs @@ -1,20 +1,23 @@ -use crate::Monotonic; +use crate::{ + sll::{IntrusiveSortedLinkedList, Min as IsslMin, Node as IntrusiveNode}, + Monotonic, +}; use core::cmp::Ordering; use core::task::Waker; -use heapless::sorted_linked_list::{LinkedIndexU16, Min, SortedLinkedList}; +use heapless::sorted_linked_list::{LinkedIndexU16, Min as SllMin, SortedLinkedList}; -pub struct TimerQueue +pub struct TimerQueue<'a, Mono, Task, const N_TASK: usize> where Mono: Monotonic, Task: Copy, { - pub task_queue: SortedLinkedList, LinkedIndexU16, Min, N_TASK>, - pub waker_queue: SortedLinkedList, LinkedIndexU16, Min, N_WAKER>, + pub task_queue: SortedLinkedList, LinkedIndexU16, SllMin, N_TASK>, + pub waker_queue: IntrusiveSortedLinkedList<'a, WakerNotReady, IsslMin>, } -impl TimerQueue +impl<'a, Mono, Task, const N_TASK: usize> TimerQueue<'a, Mono, Task, N_TASK> where - Mono: Monotonic, + Mono: Monotonic + 'a, Task: Copy, { fn check_if_enable( @@ -70,17 +73,16 @@ where #[inline] pub fn enqueue_waker( &mut self, - nr: WakerNotReady, + nr: &'a mut IntrusiveNode>, enable_interrupt: F1, pend_handler: F2, mono: Option<&mut Mono>, - ) -> Result<(), ()> - where + ) where F1: FnOnce(), F2: FnOnce(), { - self.check_if_enable(nr.instant, enable_interrupt, pend_handler, mono); - self.waker_queue.push(nr).map_err(|_| ()) + self.check_if_enable(nr.val.instant, enable_interrupt, pend_handler, mono); + self.waker_queue.push(nr); } /// Check if all the timer queue is empty. @@ -133,12 +135,12 @@ where &mut self, instant: Mono::Instant, mono: &mut Mono, - ) -> Option> { + ) -> Option<(Task, u8)> { let now = mono.now(); if instant <= now { // task became ready let nr = unsafe { self.task_queue.pop_unchecked() }; - Some(TaskOrWaker::Task((nr.task, nr.index))) + Some((nr.task, nr.index)) } else { // Set compare mono.set_compare(instant); @@ -149,23 +151,18 @@ where // guard against this. if instant <= now { let nr = unsafe { self.task_queue.pop_unchecked() }; - Some(TaskOrWaker::Task((nr.task, nr.index))) + Some((nr.task, nr.index)) } else { None } } } - fn dequeue_waker_queue( - &mut self, - instant: Mono::Instant, - mono: &mut Mono, - ) -> Option> { + fn dequeue_waker_queue(&mut self, instant: Mono::Instant, mono: &mut Mono) { let now = mono.now(); if instant <= now { - // task became ready - let nr = unsafe { self.waker_queue.pop_unchecked() }; - Some(TaskOrWaker::Waker(nr.waker)) + // Task became ready, wake the waker + self.waker_queue.pop().map(|v| v.val.waker.wake_by_ref()); } else { // Set compare mono.set_compare(instant); @@ -175,16 +172,13 @@ where // read of now to the set of the compare, the time can overflow. This is to // guard against this. if instant <= now { - let nr = unsafe { self.waker_queue.pop_unchecked() }; - Some(TaskOrWaker::Waker(nr.waker)) - } else { - None + self.waker_queue.pop().map(|v| v.val.waker.wake_by_ref()); } } } /// Dequeue a task from the ``TimerQueue`` - pub fn dequeue(&mut self, disable_interrupt: F, mono: &mut Mono) -> Option> + pub fn dequeue(&mut self, disable_interrupt: F, mono: &mut Mono) -> Option<(Task, u8)> where F: FnOnce(), { @@ -228,16 +222,12 @@ where if dequeue_task { self.dequeue_task_queue(instant, mono) } else { - self.dequeue_waker_queue(instant, mono) + self.dequeue_waker_queue(instant, mono); + None } } } -pub enum TaskOrWaker { - Task((Task, u8)), - Waker(Waker), -} - pub struct TaskNotReady where Task: Copy,