From 51204cba42e035ffeeb1ca08a3a595bbebf3f3bc Mon Sep 17 00:00:00 2001 From: Eivind Alexander Bergem Date: Thu, 21 Apr 2022 14:22:49 +0200 Subject: [PATCH 1/8] Re-factored executor and tasks for better ergonomics - Added `TaskHandle` as a non-generic reference to `Task`. - Use array to initialize executor with tasks, avoiding macro. - Wrap `Executor` and `Scheduler` in handlers requiring static references. This way we don't need unsafe when calling from the handler. --- hyperloop-macros/src/lib.rs | 69 +------- hyperloop-priority-queue/src/lib.rs | 10 +- hyperloop/Cargo.toml | 2 +- hyperloop/src/executor.rs | 239 ++++++++++++++++++++++++---- hyperloop/src/interrupt.rs | 4 +- hyperloop/src/lib.rs | 1 + hyperloop/src/notify.rs | 31 ++-- hyperloop/src/task.rs | 195 +++++------------------ hyperloop/src/timer.rs | 124 ++++++++------- 9 files changed, 338 insertions(+), 337 deletions(-) diff --git a/hyperloop-macros/src/lib.rs b/hyperloop-macros/src/lib.rs index d2ef941..936ee9c 100644 --- a/hyperloop-macros/src/lib.rs +++ b/hyperloop-macros/src/lib.rs @@ -5,12 +5,10 @@ use darling::FromMeta; use proc_macro::{self, TokenStream}; use quote::{format_ident, quote}; use syn::{ - parse::Parse, - parse_quote, punctuated::{Pair, Punctuated}, spanned::Spanned, token::Comma, - FnArg, Ident, Pat, Stmt, Token, + FnArg, Ident, Pat, }; #[derive(Debug, FromMeta)] @@ -69,7 +67,7 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream { let result = quote! { #(#attrs)* - #visibility fn #name(#args) -> Option<&'static mut crate::task::Task<#future_type>> { + #visibility fn #name(#args) -> Option { type F = #future_type; fn wrapper(#args) -> impl FnOnce() -> F { @@ -84,7 +82,7 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream { unsafe { if let None = TASK { TASK = Some(Task::new(wrapper(#arg_values), #priority)); - Some(TASK.as_mut().unwrap()) + Some(TASK.as_mut().unwrap().get_handle()) } else { None } @@ -93,64 +91,3 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream { }; result.into() } - -struct Args { - args: Punctuated, -} - -impl Parse for Args { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - match Punctuated::::parse_terminated(&input) { - Ok(args) => Ok(Self { args }), - Err(err) => Err(err), - } - } -} - -struct Statements { - data: Vec, -} - -impl quote::ToTokens for Statements { - fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - for stmt in self.data.iter() { - stmt.to_tokens(tokens); - } - } -} - -#[proc_macro] -pub fn executor_from_tasks(tokens: TokenStream) -> TokenStream { - let args = syn::parse_macro_input!(tokens as Args).args; - - let n_tasks = args.len(); - - let tasks = Statements { - data: args - .pairs() - .map(|pair| { - let task = pair.into_value(); - let stmt: Stmt = parse_quote!( - #task.add_to_executor(executor.get_sender()).unwrap(); - ); - stmt - }) - .collect(), - }; - - let result = quote! { - { - static mut EXECUTOR: Option> = None; - - let executor = unsafe { - EXECUTOR.get_or_insert(Executor::new()) - }; - - #tasks - - executor - } - }; - - result.into() -} diff --git a/hyperloop-priority-queue/src/lib.rs b/hyperloop-priority-queue/src/lib.rs index 6f81c94..c111904 100644 --- a/hyperloop-priority-queue/src/lib.rs +++ b/hyperloop-priority-queue/src/lib.rs @@ -326,8 +326,8 @@ where } } - pub fn get_sender(&self) -> PrioritySender { - let queue: &'static Self = unsafe { &*(self as *const Self) }; + pub unsafe fn get_sender(&self) -> PrioritySender { + let queue: &'static Self = &*(self as *const Self); PrioritySender { slots: &queue.slots, @@ -540,7 +540,7 @@ mod tests { #[test] fn stack() { let mut heap: PriorityQueue = PriorityQueue::new(); - let sender = heap.get_sender(); + let sender = unsafe { heap.get_sender() }; for i in 0..10 { sender.stack_push(i).unwrap(); @@ -568,7 +568,7 @@ mod tests { #[test] fn channel() { let mut queue: PriorityQueue = PriorityQueue::new(); - let sender = queue.get_sender(); + let sender = unsafe { queue.get_sender() }; for i in 0..10 { sender.send(i).unwrap(); @@ -619,7 +619,7 @@ mod tests { let n_items = n_threads * n_items_per_thread; for i in 0..n_threads { - let sender = queue.get_sender(); + let sender = unsafe { queue.get_sender() }; let handler = thread::spawn(move || { for j in 0..n_items_per_thread { loop { diff --git a/hyperloop/Cargo.toml b/hyperloop/Cargo.toml index 05e0477..06ff151 100644 --- a/hyperloop/Cargo.toml +++ b/hyperloop/Cargo.toml @@ -2,7 +2,7 @@ name = "hyperloop" version = "0.1.0" authors = ["Eivind Alexander Bergem "] -edition = "2018" +edition = "2021" [dependencies] futures = {version = "0.3.15", default-features = false} diff --git a/hyperloop/src/executor.rs b/hyperloop/src/executor.rs index 095d500..d79ee6f 100644 --- a/hyperloop/src/executor.rs +++ b/hyperloop/src/executor.rs @@ -1,25 +1,23 @@ use core::cmp::Ordering; +use core::sync::atomic::AtomicBool; +use core::task::{Poll, RawWaker, RawWakerVTable, Waker}; use crate::priority_queue::{Max, PriorityQueue, PrioritySender}; -use crate::task::PollTask; +use crate::task::TaskHandle; pub(crate) type Priority = u8; +type TaskId = u16; -#[derive(Debug, Clone, Copy)] pub struct Ticket { - task: *const dyn PollTask, + task: TaskId, priority: Priority, } impl Ticket { - pub(crate) fn new(task: *const dyn PollTask, priority: Priority) -> Self { + pub(crate) fn new(task: TaskId, priority: Priority) -> Self { Self { task, priority } } - - unsafe fn get_task(&self) -> &dyn PollTask { - &*self.task - } } impl PartialEq for Ticket { @@ -44,17 +42,134 @@ impl Ord for Ticket { pub(crate) type TaskSender = PrioritySender; +const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake, drop); + +unsafe fn clone(ptr: *const ()) -> RawWaker { + RawWaker::new(ptr, &VTABLE) +} + +unsafe fn wake(ptr: *const ()) { + let task = &*(ptr as *const ExecutorTask); + task.wake(); +} + +unsafe fn drop(_ptr: *const ()) {} + +struct ExecutorTask { + task: TaskHandle, + task_id: TaskId, + priority: Priority, + sender: Option, + pending_wake: AtomicBool, +} + +impl ExecutorTask { + fn new( + task: TaskHandle, + task_id: TaskId, + priority: Priority, + sender: Option, + ) -> Self { + Self { + task, + task_id, + priority, + sender, + pending_wake: AtomicBool::new(false), + } + } + + fn set_sender(&mut self, sender: TaskSender) { + self.sender = Some(sender); + } + + fn get_waker(&self) -> Waker { + let ptr: *const () = (self as *const ExecutorTask).cast(); + let vtable = &VTABLE; + + unsafe { Waker::from_raw(RawWaker::new(ptr, vtable)) } + } + + fn send_ticket(&self, ticket: Ticket) -> Result<(), ()> { + let sender = unsafe { self.sender.as_ref().unwrap_unchecked() }; + + sender.send(ticket).map_err(|_| ()) + } + + fn wake(&self) { + if let Ok(_) = self.pending_wake.compare_exchange( + false, + true, + atomig::Ordering::Acquire, + atomig::Ordering::Acquire, + ) { + let ticket = Ticket::new(self.task_id, self.priority); + + self.send_ticket(ticket).unwrap_or_else(|_| unreachable!()); + } + } + + fn clear_pending_wake_flag(&self) { + let _ = self.pending_wake.compare_exchange( + true, + false, + atomig::Ordering::Acquire, + atomig::Ordering::Acquire, + ); + } + + fn poll(&mut self, waker: Waker) -> Poll<()> { + self.task.poll(waker) + } +} + pub struct Executor { + tasks: [ExecutorTask; N], queue: PriorityQueue, } impl Executor { - pub fn new() -> Self { + pub fn new(tasks: [TaskHandle; N]) -> Self { + let mut i = 0; + let tasks = tasks.map(|task| { + let priority = task.priority; + let task = ExecutorTask::new(task, i, priority, None); + i += 1; + task + }); + Self { + tasks, queue: PriorityQueue::new(), } } + unsafe fn get_task(&mut self, task_id: TaskId) -> &mut ExecutorTask { + let index = task_id as usize; + + let task = &mut self.tasks[index]; + task.clear_pending_wake_flag(); + + task + } + + unsafe fn init(&mut self) { + for i in 0..N { + let sender = self.queue.get_sender(); + let task = self.get_task(i as u16); + + task.set_sender(sender); + task.wake(); + } + } + + unsafe fn poll_task(&mut self, task_id: TaskId) { + let task = self.get_task(task_id); + let waker = task.get_waker(); + + let _ = task.poll(waker); + } + /// Poll all tasks in the queue /// /// # Safety @@ -64,29 +179,46 @@ impl Executor { /// pointers to the tasks stored in the executor. The pointers can /// be dereferenced at any time and will be dangling if the /// exeutor is moved or dropped. - pub unsafe fn poll_tasks(&mut self) { + unsafe fn poll_tasks(&mut self) { while let Some(ticket) = self.queue.pop() { - let _ = ticket.get_task().poll(); + self.poll_task(ticket.task); } } - pub fn get_sender(&self) -> TaskSender { - self.queue.get_sender() + pub fn get_handle(&'static mut self) -> ExecutorHandle { + ExecutorHandle::new(self) + } +} + +pub struct ExecutorHandle { + executor: &'static mut Executor, +} + +impl ExecutorHandle { + pub fn new(executor: &'static mut Executor) -> Self { + unsafe { executor.init() }; + Self { executor } + } + + /// Poll all tasks in the queue + pub fn poll_tasks(&mut self) { + unsafe { self.executor.poll_tasks() } } } #[cfg(test)] mod tests { use crossbeam_queue::ArrayQueue; - use hyperloop_macros::{executor_from_tasks, task}; + use hyperloop_macros::task; + use std::boxed::Box; use std::sync::Arc; use super::*; + use crate::notify::Notification; use crate::task::Task; #[test] fn test_executor() { - let mut executor = Executor::<10>::new(); let queue = Arc::new(ArrayQueue::new(10)); let test_future = |queue, value| { @@ -99,19 +231,19 @@ mod tests { } }; - let task1 = Task::new(test_future(queue.clone(), 1), 1); - let task2 = Task::new(test_future(queue.clone(), 2), 3); - let task3 = Task::new(test_future(queue.clone(), 3), 2); - let task4 = Task::new(test_future(queue.clone(), 4), 4); + let task1 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 1), 1))); + let task2 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 2), 3))); + let task3 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 3), 2))); + let task4 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 4), 4))); - task1.add_to_executor(executor.get_sender()).unwrap(); - task2.add_to_executor(executor.get_sender()).unwrap(); - task3.add_to_executor(executor.get_sender()).unwrap(); - task4.add_to_executor(executor.get_sender()).unwrap(); + let mut executor = ExecutorHandle::new(Box::leak(Box::new(Executor::new([ + task1.get_handle(), + task2.get_handle(), + task3.get_handle(), + task4.get_handle(), + ])))); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop().unwrap(), 4); assert_eq!(queue.pop().unwrap(), 2); @@ -119,6 +251,55 @@ mod tests { assert_eq!(queue.pop().unwrap(), 1); } + #[test] + fn test_pending_wake() { + let queue = Arc::new(ArrayQueue::new(10)); + let notify = Box::leak(Box::new(Notification::new())); + + let test_future = |queue, notify| { + move || { + async fn future(queue: Arc>, notify: &'static Notification) { + for i in 0..10 { + queue.push(i).unwrap(); + notify.wait().await; + } + } + + future(queue, notify) + } + }; + + let task = Box::leak(Box::new(Task::new(test_future(queue.clone(), notify), 1))); + + let mut executor = + ExecutorHandle::new(Box::leak(Box::new(Executor::new([task.get_handle()])))); + + executor.poll_tasks(); + + assert_eq!(queue.pop().unwrap(), 0); + assert!(queue.pop().is_none()); + + notify.notify(); + + executor.poll_tasks(); + + assert_eq!(queue.pop().unwrap(), 1); + assert!(queue.pop().is_none()); + + executor.poll_tasks(); + assert!(queue.pop().is_none()); + + let waker = unsafe { executor.executor.get_task(0).get_waker() }; + + waker.wake(); + + notify.notify(); + executor.poll_tasks(); + + assert_eq!(queue.pop().unwrap(), 2); + assert!(queue.pop().is_none()); + } + #[test] fn macros() { #[task(priority = 1)] @@ -136,11 +317,9 @@ mod tests { let task1 = test_task1(queue.clone()).unwrap(); let task2 = test_task2(queue.clone()).unwrap(); - let executor = executor_from_tasks!(task1, task2); + let mut executor = ExecutorHandle::new(Box::leak(Box::new(Executor::new([task1, task2])))); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop().unwrap(), 2); assert_eq!(queue.pop().unwrap(), 1); diff --git a/hyperloop/src/interrupt.rs b/hyperloop/src/interrupt.rs index 6468c15..a1cc9f3 100644 --- a/hyperloop/src/interrupt.rs +++ b/hyperloop/src/interrupt.rs @@ -29,9 +29,11 @@ impl YieldFuture { impl Future for YieldFuture { type Output = (); - fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if !self.done { self.done = true; + cx.waker().wake_by_ref(); + Poll::Pending } else { Poll::Ready(()) diff --git a/hyperloop/src/lib.rs b/hyperloop/src/lib.rs index 0d1ab5a..f55f9d5 100644 --- a/hyperloop/src/lib.rs +++ b/hyperloop/src/lib.rs @@ -2,6 +2,7 @@ #![feature(const_fn_trait_bound)] #![feature(type_alias_impl_trait)] #![feature(once_cell)] +#![feature(option_result_unwrap_unchecked)] mod common; diff --git a/hyperloop/src/notify.rs b/hyperloop/src/notify.rs index 98575bd..8444ddf 100644 --- a/hyperloop/src/notify.rs +++ b/hyperloop/src/notify.rs @@ -64,7 +64,10 @@ mod tests { use std::boxed::Box; use std::sync::Arc; - use crate::{executor::Executor, task::Task}; + use crate::{ + executor::{Executor, ExecutorHandle}, + task::Task, + }; use super::*; @@ -72,7 +75,6 @@ mod tests { fn notify() { let notification = Box::leak(Box::new(Notification::new())); - let mut executor = Executor::<10>::new(); let queue = Arc::new(ArrayQueue::new(10)); let wait = |receiver, queue| { @@ -88,43 +90,34 @@ mod tests { } }; - let task1 = Task::new(wait(notification, queue.clone()), 1); + let task = Box::leak(Box::new(Task::new(wait(notification, queue.clone()), 1))); - task1.add_to_executor(executor.get_sender()).unwrap(); + let mut executor = + ExecutorHandle::new(Box::leak(Box::new(Executor::new([task.get_handle()])))); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), None); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), None); notification.notify(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), None); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), None); notification.notify(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(3)); assert_eq!(queue.pop(), None); diff --git a/hyperloop/src/task.rs b/hyperloop/src/task.rs index 7217c7e..bfffb08 100644 --- a/hyperloop/src/task.rs +++ b/hyperloop/src/task.rs @@ -1,37 +1,11 @@ use core::{ - lazy::OnceCell, pin::Pin, - task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, + task::{Context, Poll, Waker}, }; -use atomig::{Atom, Atomic, Ordering}; use futures::Future; -use crate::executor::{Priority, TaskSender, Ticket}; - -unsafe fn clone + 'static>(ptr: *const ()) -> RawWaker { - let task = &*(ptr as *const Task); - - RawWaker::new(ptr, &task.vtable) -} - -unsafe fn wake + 'static>(ptr: *const ()) { - let task = &*(ptr as *const Task); - task.wake(); -} - -unsafe fn drop(_ptr: *const ()) {} - -pub(crate) trait PollTask { - unsafe fn poll(&self) -> Poll<()>; -} - -#[repr(u8)] -#[derive(Copy, Clone, Eq, PartialEq, Debug, Atom)] -enum TaskState { - NotQueued, - Queued, -} +use crate::executor::Priority; pub struct Task where @@ -39,9 +13,6 @@ where { future: F, priority: Priority, - sender: OnceCell, - vtable: RawWakerVTable, - state: Atomic, } impl Task @@ -52,165 +23,81 @@ where Self { future: future_fn(), priority, - sender: OnceCell::new(), - vtable: RawWakerVTable::new(clone::, wake::, wake::, drop), - state: Atomic::new(TaskState::NotQueued), - } - } - - fn update_state(&self, old: TaskState, new: TaskState) -> bool { - if let Ok(_) = self - .state - .compare_exchange(old, new, Ordering::Relaxed, Ordering::Relaxed) - { - true - } else { - false - } - } - - #[cfg(test)] - fn get_state(&self) -> TaskState { - self.state.load(Ordering::Relaxed) - } - - unsafe fn as_static(&self) -> &'static Self { - &*(self as *const Self) - } - - unsafe fn as_mut(&self) -> &mut Self { - &mut *((self as *const Self) as *mut Self) - } - - unsafe fn get_waker(&self) -> Waker { - let ptr: *const () = (self as *const Task).cast(); - let vtable = &self.as_static().vtable; - - Waker::from_raw(RawWaker::new(ptr, vtable)) - } - - pub fn wake(&self) { - self.schedule().unwrap(); - } - - pub fn add_to_executor(&self, sender: TaskSender) -> Result<(), ()> { - self.set_sender(sender)?; - self.schedule() - } - - fn set_sender(&self, sender: TaskSender) -> Result<(), ()> { - match self.sender.set(sender) { - Ok(_) => Ok(()), - Err(_) => Err(()), } } - fn send_ticket(&self, ticket: Ticket) -> Result<(), ()> { - if let Some(sender) = self.sender.get() { - if let Ok(_) = sender.send(ticket) { - return Ok(()); - } - } - - Err(()) + pub fn get_handle(&'static mut self) -> TaskHandle { + TaskHandle::new(self) } +} - fn schedule(&self) -> Result<(), ()> { - if self.update_state(TaskState::NotQueued, TaskState::Queued) { - let ticket = Ticket::new(self as *const Self, self.priority); +pub struct TaskHandle { + future: *mut (), + poll: fn(*mut (), &mut Context<'_>) -> Poll<()>, + pub priority: Priority, +} - match self.send_ticket(ticket) { - Ok(_) => Ok(()), - Err(_) => { - assert!(self.update_state(TaskState::Queued, TaskState::NotQueued)); - Err(()) - } +impl TaskHandle { + pub fn new>(task: &'static mut Task) -> Self { + unsafe { + Self { + future: core::mem::transmute::<_, _>(&mut task.future), + poll: core::mem::transmute::< + fn(Pin<&mut F>, &mut Context<'_>) -> Poll, + fn(*mut (), &mut Context<'_>) -> Poll<()>, + >(F::poll), + priority: task.priority, } - } else { - Ok(()) } } -} -impl PollTask for Task -where - F: Future + 'static, -{ - unsafe fn poll(&self) -> Poll<()> { - let waker = self.get_waker(); + pub fn poll(&mut self, waker: Waker) -> Poll<()> { let mut cx = Context::from_waker(&waker); - let future = Pin::new_unchecked(&mut self.as_mut().future); - assert!(self.update_state(TaskState::Queued, TaskState::NotQueued)); - let result = future.poll(&mut cx); + let poll = self.poll; - result + poll(self.future, &mut cx) } } #[cfg(test)] mod tests { - use crate::{ - interrupt::yield_now, - priority_queue::{Max, PriorityQueue}, - }; + use crossbeam_queue::ArrayQueue; + use std::boxed::Box; + use std::sync::Arc; + + use crate::{common::tests::MockWaker, interrupt::yield_now}; use super::*; #[test] fn task() { - let mut queue: PriorityQueue = PriorityQueue::new(); + let queue = Arc::new(ArrayQueue::new(10)); - let test_future = || { + let test_future = |queue| { || { - async fn future() { - loop { - yield_now().await + async fn future(queue: Arc>) { + for i in 0.. { + queue.push(i).unwrap(); + yield_now().await; } } - future() + future(queue) } }; - let task = Task::new(test_future(), 1); - - task.set_sender(queue.get_sender()).unwrap(); + let mut task = Box::leak(Box::new(Task::new(test_future(queue.clone()), 1))).get_handle(); + let waker: Waker = Arc::new(MockWaker::new()).into(); - assert_eq!(task.get_state(), TaskState::NotQueued); + assert_eq!(task.poll(waker.clone()), Poll::Pending); - task.schedule().unwrap(); - - assert_eq!(task.get_state(), TaskState::Queued); - - assert!(queue.pop().is_some()); - assert!(queue.pop().is_none()); - - task.schedule().unwrap(); - - assert!(queue.pop().is_none()); - - unsafe { - assert_eq!(task.poll(), Poll::Pending); - } - - assert_eq!(task.get_state(), TaskState::NotQueued); - - task.wake(); - task.wake(); - task.wake(); - - assert_eq!(task.get_state(), TaskState::Queued); - - assert!(queue.pop().is_some()); + assert_eq!(queue.pop().unwrap(), 0); assert!(queue.pop().is_none()); - task.wake(); - task.wake(); - task.wake(); - - assert_eq!(task.get_state(), TaskState::Queued); + assert_eq!(task.poll(waker.clone()), Poll::Pending); + assert_eq!(queue.pop().unwrap(), 1); assert!(queue.pop().is_none()); } } diff --git a/hyperloop/src/timer.rs b/hyperloop/src/timer.rs index ec9c8fb..ddf5b06 100644 --- a/hyperloop/src/timer.rs +++ b/hyperloop/src/timer.rs @@ -302,7 +302,7 @@ impl Scheduler { } } - pub fn get_timer(&self) -> Timer { + unsafe fn get_timer(&self) -> Timer { Timer::new(self.rate, self.counter.clone(), self.queue.get_sender()) } @@ -316,7 +316,7 @@ impl Scheduler { None } - pub async fn task(&mut self) { + async fn task(&mut self) { let mut timer = TimerFuture::new(self.counter.clone()); loop { @@ -327,6 +327,28 @@ impl Scheduler { } } } + + pub fn get_handle(&'static mut self) -> SchedulerHandle { + SchedulerHandle::new(self) + } +} + +pub struct SchedulerHandle { + scheduler: &'static mut Scheduler, +} + +impl SchedulerHandle { + pub fn new(scheduler: &'static mut Scheduler) -> Self { + Self { scheduler } + } + + pub fn get_timer(&self) -> Timer { + unsafe { self.scheduler.get_timer() } + } + + pub fn into_task(self) -> impl Future { + self.scheduler.task() + } } #[cfg(test)] @@ -377,9 +399,11 @@ mod tests { fn delay() { let counter = Box::leak(Box::new(TickCounter::new())); let token = counter.get_token(); - let scheduler: &'static mut Scheduler<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))); - let sender = scheduler.queue.get_sender(); + let scheduler: SchedulerHandle<10> = + Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))).get_handle(); + + let queue = &mut scheduler.scheduler.queue; + let sender = unsafe { queue.get_sender() }; let mockwaker = Arc::new(MockWaker::new()); let waker: Waker = mockwaker.clone().into(); let mut cx = Context::from_waker(&waker); @@ -406,17 +430,17 @@ mod tests { assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); - if let Some(ticket) = scheduler.queue.pop() { + if let Some(ticket) = queue.pop() { assert_eq!(ticket.expires, 1); ticket.waker.wake(); assert_eq!(mockwaker.woke.load(Ordering::Relaxed), true) } - if let Some(ticket) = scheduler.queue.pop() { + if let Some(ticket) = queue.pop() { assert_eq!(ticket.expires, 15); } - if let Some(ticket) = scheduler.queue.pop() { + if let Some(ticket) = queue.pop() { assert_eq!(ticket.expires, 20); } } @@ -425,11 +449,9 @@ mod tests { fn timer() { let counter = Box::leak(Box::new(TickCounter::new())); let token = counter.get_token(); - let scheduler: &'static mut Scheduler<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))); + let scheduler: SchedulerHandle<10> = + Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))).get_handle(); let timer = scheduler.get_timer(); - let mut executor = Box::new(Executor::<10>::new()); - let queue = Arc::new(ArrayQueue::new(10)); log_init(); @@ -463,15 +485,18 @@ mod tests { } }; - let task1 = Task::new(move || scheduler.task(), 1); - let task2 = Task::new(test_future(queue.clone(), timer.clone()), 1); + let queue = Arc::new(ArrayQueue::new(10)); - task1.add_to_executor(executor.get_sender()).unwrap(); - task2.add_to_executor(executor.get_sender()).unwrap(); + let task1 = Box::leak(Box::new(Task::new(move || scheduler.into_task(), 1))).get_handle(); + let task2 = Box::leak(Box::new(Task::new( + test_future(queue.clone(), timer.clone()), + 1, + ))) + .get_handle(); - unsafe { - executor.poll_tasks(); - } + let mut executor = Box::leak(Box::new(Executor::new([task1, task2]))).get_handle(); + + executor.poll_tasks(); assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), None); @@ -479,17 +504,13 @@ mod tests { unsafe { counter.tick(); } - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), None); counter.wake(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), None); @@ -499,9 +520,7 @@ mod tests { unsafe { counter.tick(); } - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(3)); assert_eq!(queue.pop(), None); @@ -512,9 +531,7 @@ mod tests { unsafe { counter.tick(); } - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(4)); assert_eq!(queue.pop(), None); @@ -525,9 +542,7 @@ mod tests { unsafe { counter.tick(); } - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(5)); assert_eq!(queue.pop(), None); @@ -536,9 +551,7 @@ mod tests { unsafe { counter.tick(); } - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), None); } @@ -546,9 +559,7 @@ mod tests { unsafe { counter.tick(); } - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(6)); assert_eq!(queue.pop(), None); @@ -558,11 +569,9 @@ mod tests { fn timeout() { let counter = Box::leak(Box::new(TickCounter::new())); let token = counter.get_token(); - let scheduler: &'static mut Scheduler<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))); + let scheduler: SchedulerHandle<10> = + Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))).get_handle(); let timer = scheduler.get_timer(); - let mut executor = Executor::<10>::new(); - let queue = Arc::new(ArrayQueue::new(10)); log_init(); @@ -595,15 +604,15 @@ mod tests { } }; - let task1 = Task::new(move || scheduler.task(), 1); - let task2 = Task::new(waiting_future(queue.clone(), timer), 1); + let queue = Arc::new(ArrayQueue::new(10)); + + let task1 = Box::leak(Box::new(Task::new(move || scheduler.into_task(), 1))).get_handle(); + let task2 = + Box::leak(Box::new(Task::new(waiting_future(queue.clone(), timer), 1))).get_handle(); - task1.add_to_executor(executor.get_sender()).unwrap(); - task2.add_to_executor(executor.get_sender()).unwrap(); + let mut executor = Box::leak(Box::new(Executor::new([task1, task2]))).get_handle(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), None); @@ -616,9 +625,7 @@ mod tests { counter.wake(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), None); @@ -629,10 +636,7 @@ mod tests { } counter.wake(); - unsafe { - executor.poll_tasks(); - } - + executor.poll_tasks(); assert_eq!(queue.pop(), None); } @@ -641,9 +645,7 @@ mod tests { } counter.wake(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(3)); assert_eq!(queue.pop(), None); From 85285a82874b261210a58d3efd7761ebf8a1d972 Mon Sep 17 00:00:00 2001 From: Eivind Alexander Bergem Date: Fri, 22 Apr 2022 08:38:43 +0200 Subject: [PATCH 2/8] Removed timer task. - Removed async timer task to avoid having N + 1 tasks. - Moved wake logic into interrupt handler. --- hyperloop/src/timer.rs | 296 +++++++++++++---------------------------- 1 file changed, 95 insertions(+), 201 deletions(-) diff --git a/hyperloop/src/timer.rs b/hyperloop/src/timer.rs index ddf5b06..3fc99ba 100644 --- a/hyperloop/src/timer.rs +++ b/hyperloop/src/timer.rs @@ -10,66 +10,84 @@ use embedded_time::{ }; use core::future::Future; -use futures::{task::AtomicWaker, Stream, StreamExt}; use log::error; use crate::priority_queue::{Min, PeekMut, PriorityQueue, PrioritySender}; type Tick = u64; -pub struct TickCounter { - count: Tick, - waker: AtomicWaker, +pub struct Scheduler { + rate: Hertz, + counter: Tick, + queue: PriorityQueue, } -impl TickCounter { - pub const fn new() -> Self { +impl Scheduler { + pub fn new(rate: Hertz) -> Self { Self { - count: 0, - waker: AtomicWaker::new(), + rate, + counter: 0, + queue: PriorityQueue::new(), } } - /// Increment tick count - /// - /// # Safety - /// - /// Updating the tick value is not atomic on 32-bit systems, so it - /// would be possible to get an invalid reading if reading during - /// a write. For this reason, this function should only be called - /// from a high priority interrupt handler. pub unsafe fn increment(&mut self) { - self.count += 1; + self.counter += 1; } - pub fn wake(&self) { - self.waker.wake(); - } + fn next_waker(&mut self) -> Option { + if let Some(ticket) = self.queue.peek_mut().as_mut() { + if self.counter > ticket.expires { + return Some(PeekMut::pop(ticket).waker); + } + } - pub unsafe fn tick(&mut self) { - self.increment(); - self.wake(); + None } - pub fn get_token(&self) -> TickCounterToken { - TickCounterToken { - counter: unsafe { &*(self as *const Self) }, + fn wake_tasks(&mut self) { + while let Some(waker) = self.next_waker() { + waker.wake(); } } + + pub fn split(&'static mut self) -> (Ticker, Timer) { + let counter: &'static Tick = unsafe { &*(&self.counter as *const Tick) }; + let sender = unsafe { self.queue.get_sender() }; + let timer = Timer::new(self.rate, TickReader::new(counter), sender); + let ticker = Ticker::new(self); + + (ticker, timer) + } } #[derive(Clone)] -pub struct TickCounterToken { - counter: &'static TickCounter, +pub struct TickReader { + counter: &'static Tick, } -impl TickCounterToken { - pub fn register_waker(&self, waker: &Waker) { - self.counter.waker.register(waker); +impl TickReader { + fn new(counter: &'static Tick) -> Self { + Self { counter } } - pub fn get_count(&self) -> Tick { - self.counter.count + fn get_count(&self) -> Tick { + *self.counter + } +} + +pub struct Ticker { + scheduler: &'static mut Scheduler, +} + +impl Ticker { + pub fn new(scheduler: &'static mut Scheduler) -> Self { + Self { scheduler } + } + + pub unsafe fn tick(&mut self) { + self.scheduler.increment(); + self.scheduler.wake_tasks(); } } @@ -107,13 +125,13 @@ impl Ord for Ticket { struct DelayFuture { sender: PrioritySender, - counter: TickCounterToken, + counter: TickReader, expires: Tick, started: bool, } impl DelayFuture { - fn new(sender: PrioritySender, counter: TickCounterToken, expires: Tick) -> Self { + fn new(sender: PrioritySender, counter: TickReader, expires: Tick) -> Self { Self { sender, counter, @@ -162,12 +180,7 @@ impl TimeoutFuture where F: Future, { - fn new( - future: F, - sender: PrioritySender, - counter: TickCounterToken, - expires: Tick, - ) -> Self { + fn new(future: F, sender: PrioritySender, counter: TickReader, expires: Tick) -> Self { Self { future, delay: DelayFuture::new(sender, counter, expires), @@ -205,12 +218,12 @@ where #[derive(Clone)] pub struct Timer { rate: Hertz, - counter: TickCounterToken, + counter: TickReader, sender: PrioritySender, } impl Timer { - pub fn new(rate: Hertz, counter: TickCounterToken, sender: PrioritySender) -> Self { + pub fn new(rate: Hertz, counter: TickReader, sender: PrioritySender) -> Self { Self { rate, counter, @@ -255,102 +268,6 @@ impl Timer { } } -struct TimerFuture { - counter: TickCounterToken, - expires: Option, -} - -impl TimerFuture { - fn new(counter: TickCounterToken) -> Self { - Self { - counter, - expires: None, - } - } -} - -impl Stream for TimerFuture { - type Item = (); - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Some(expires) = self.expires { - if self.counter.get_count() >= expires { - self.expires = None; - return Poll::Ready(Some(())); - } - } else { - self.expires = Some(self.counter.get_count() + 1_u64); - } - - self.counter.register_waker(cx.waker()); - Poll::Pending - } -} - -pub struct Scheduler { - rate: Hertz, - counter: TickCounterToken, - queue: PriorityQueue, -} - -impl Scheduler { - pub fn new(rate: Hertz, counter: TickCounterToken) -> Self { - Self { - rate, - counter, - queue: PriorityQueue::new(), - } - } - - unsafe fn get_timer(&self) -> Timer { - Timer::new(self.rate, self.counter.clone(), self.queue.get_sender()) - } - - fn next_waker(&mut self) -> Option { - if let Some(ticket) = self.queue.peek_mut().as_mut() { - if self.counter.get_count() > ticket.expires { - return Some(PeekMut::pop(ticket).waker); - } - } - - None - } - - async fn task(&mut self) { - let mut timer = TimerFuture::new(self.counter.clone()); - - loop { - if let Some(waker) = self.next_waker() { - waker.wake(); - } else { - timer.next().await.unwrap() - } - } - } - - pub fn get_handle(&'static mut self) -> SchedulerHandle { - SchedulerHandle::new(self) - } -} - -pub struct SchedulerHandle { - scheduler: &'static mut Scheduler, -} - -impl SchedulerHandle { - pub fn new(scheduler: &'static mut Scheduler) -> Self { - Self { scheduler } - } - - pub fn get_timer(&self) -> Timer { - unsafe { self.scheduler.get_timer() } - } - - pub fn into_task(self) -> impl Future { - self.scheduler.task() - } -} - #[cfg(test)] mod tests { use core::sync::atomic::Ordering; @@ -369,46 +286,31 @@ mod tests { #[test] fn state() { - let counter = Box::leak(Box::new(TickCounter::new())); - let token = counter.get_token(); - - assert_eq!(token.get_count(), 0); - - unsafe { counter.increment() }; - assert_eq!(token.get_count(), 1); + let (mut ticker, timer) = Box::leak(Box::new(Scheduler::<0>::new(1000.Hz()))).split(); + let counter = timer.counter; - let mockwaker = Arc::new(MockWaker::new()); - let waker: Waker = mockwaker.clone().into(); - - token.register_waker(&waker); - counter.wake(); + assert_eq!(counter.get_count(), 0); - assert_eq!(mockwaker.woke.load(Ordering::Relaxed), true); - - mockwaker.woke.store(false, Ordering::Relaxed); - token.register_waker(&waker); + unsafe { ticker.tick() }; + assert_eq!(counter.get_count(), 1); unsafe { - counter.tick(); + ticker.tick(); } - assert_eq!(token.get_count(), 2); - assert_eq!(mockwaker.woke.load(Ordering::Relaxed), true); + + assert_eq!(counter.get_count(), 2); } #[test] fn delay() { - let counter = Box::leak(Box::new(TickCounter::new())); - let token = counter.get_token(); - let scheduler: SchedulerHandle<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))).get_handle(); + let (ticker, timer) = Box::leak(Box::new(Scheduler::<10>::new(1000.Hz()))).split(); + let counter = timer.counter.clone(); - let queue = &mut scheduler.scheduler.queue; - let sender = unsafe { queue.get_sender() }; let mockwaker = Arc::new(MockWaker::new()); let waker: Waker = mockwaker.clone().into(); let mut cx = Context::from_waker(&waker); - let mut future = DelayFuture::new(sender.clone(), token.clone(), 1); + let mut future = DelayFuture::new(timer.sender.clone(), counter.clone(), 1); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); @@ -416,17 +318,19 @@ mod tests { assert_eq!(future.started, true); unsafe { - counter.tick(); - counter.tick(); + ticker.scheduler.increment(); + ticker.scheduler.increment(); } assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Ready(())); - let mut future = DelayFuture::new(sender.clone(), token.clone(), 20); + let queue = &mut ticker.scheduler.queue; + + let mut future = DelayFuture::new(timer.sender.clone(), counter.clone(), 20); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); - let mut future = DelayFuture::new(sender.clone(), token.clone(), 15); + let mut future = DelayFuture::new(timer.sender.clone(), counter.clone(), 15); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); @@ -447,11 +351,7 @@ mod tests { #[test] fn timer() { - let counter = Box::leak(Box::new(TickCounter::new())); - let token = counter.get_token(); - let scheduler: SchedulerHandle<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))).get_handle(); - let timer = scheduler.get_timer(); + let (mut ticker, timer) = Box::leak(Box::new(Scheduler::<10>::new(1000.Hz()))).split(); log_init(); @@ -487,14 +387,13 @@ mod tests { let queue = Arc::new(ArrayQueue::new(10)); - let task1 = Box::leak(Box::new(Task::new(move || scheduler.into_task(), 1))).get_handle(); - let task2 = Box::leak(Box::new(Task::new( + let task = Box::leak(Box::new(Task::new( test_future(queue.clone(), timer.clone()), 1, ))) .get_handle(); - let mut executor = Box::leak(Box::new(Executor::new([task1, task2]))).get_handle(); + let mut executor = Box::leak(Box::new(Executor::new([task]))).get_handle(); executor.poll_tasks(); @@ -502,23 +401,23 @@ mod tests { assert_eq!(queue.pop(), None); unsafe { - counter.tick(); + ticker.tick(); } executor.poll_tasks(); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), None); - counter.wake(); + ticker.scheduler.wake_tasks(); executor.poll_tasks(); assert_eq!(queue.pop(), None); unsafe { - counter.tick(); + ticker.tick(); } unsafe { - counter.tick(); + ticker.tick(); } executor.poll_tasks(); @@ -526,10 +425,10 @@ mod tests { assert_eq!(queue.pop(), None); unsafe { - counter.tick(); + ticker.tick(); } unsafe { - counter.tick(); + ticker.tick(); } executor.poll_tasks(); @@ -537,10 +436,10 @@ mod tests { assert_eq!(queue.pop(), None); unsafe { - counter.tick(); + ticker.tick(); } unsafe { - counter.tick(); + ticker.tick(); } executor.poll_tasks(); @@ -549,7 +448,7 @@ mod tests { for _ in 0..10 { unsafe { - counter.tick(); + ticker.tick(); } executor.poll_tasks(); @@ -557,7 +456,7 @@ mod tests { } unsafe { - counter.tick(); + ticker.tick(); } executor.poll_tasks(); @@ -567,11 +466,7 @@ mod tests { #[test] fn timeout() { - let counter = Box::leak(Box::new(TickCounter::new())); - let token = counter.get_token(); - let scheduler: SchedulerHandle<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))).get_handle(); - let timer = scheduler.get_timer(); + let (ticker, timer) = Box::leak(Box::new(Scheduler::<10>::new(1000.Hz()))).split(); log_init(); @@ -606,11 +501,10 @@ mod tests { let queue = Arc::new(ArrayQueue::new(10)); - let task1 = Box::leak(Box::new(Task::new(move || scheduler.into_task(), 1))).get_handle(); - let task2 = + let task = Box::leak(Box::new(Task::new(waiting_future(queue.clone(), timer), 1))).get_handle(); - let mut executor = Box::leak(Box::new(Executor::new([task1, task2]))).get_handle(); + let mut executor = Box::leak(Box::new(Executor::new([task]))).get_handle(); executor.poll_tasks(); @@ -619,11 +513,11 @@ mod tests { for _ in 0..101 { unsafe { - counter.increment(); + ticker.scheduler.increment(); } } - counter.wake(); + ticker.scheduler.wake_tasks(); executor.poll_tasks(); @@ -632,18 +526,18 @@ mod tests { for _ in 0..1000 { unsafe { - counter.increment(); + ticker.scheduler.increment(); } - counter.wake(); + ticker.scheduler.wake_tasks(); executor.poll_tasks(); assert_eq!(queue.pop(), None); } unsafe { - counter.increment(); + ticker.scheduler.increment(); } - counter.wake(); + ticker.scheduler.wake_tasks(); executor.poll_tasks(); From 7e730d7b2a977f014e5d44124e9e5cce3a0e2276 Mon Sep 17 00:00:00 2001 From: Eivind Alexander Bergem Date: Fri, 22 Apr 2022 09:50:15 +0200 Subject: [PATCH 3/8] Fixed miri errors in priority queue. - Use UnsafeCell for interior mutability. - Added special handling when popping root from heap of size 1 to avoid swapping item with itself. --- hyperloop-priority-queue/src/lib.rs | 51 ++++++++++++++++------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/hyperloop-priority-queue/src/lib.rs b/hyperloop-priority-queue/src/lib.rs index c111904..b886ce8 100644 --- a/hyperloop-priority-queue/src/lib.rs +++ b/hyperloop-priority-queue/src/lib.rs @@ -1,6 +1,6 @@ #![no_std] -use core::{marker::PhantomData, ops::Deref, sync::atomic::Ordering}; +use core::{cell::UnsafeCell, marker::PhantomData, mem, ops::Deref, sync::atomic::Ordering}; #[cfg(not(loom))] use core::sync::atomic::AtomicUsize; @@ -102,7 +102,7 @@ where } fn item(&self) -> &T { - self.heap.slots[self.pos].as_ref().unwrap() + unsafe { self.heap.slot_mut(self.pos).as_ref().unwrap() } } unsafe fn slot_mut(&self) -> &mut Option { @@ -113,10 +113,7 @@ where let slot = unsafe { self.slot_mut() }; let other_slot = unsafe { other.slot_mut() }; - let item = slot.take(); - *slot = other_slot.take(); - *other_slot = item; - + mem::swap(slot, other_slot); other } @@ -187,7 +184,7 @@ impl AtomicStackPosition { fn compare_exchange(&self, current: usize, new: usize) -> Result { self.atomic - .compare_exchange_weak(current, new, Ordering::Release, Ordering::Relaxed) + .compare_exchange_weak(current, new, Ordering::AcqRel, Ordering::Relaxed) } } @@ -195,7 +192,7 @@ pub struct PrioritySender where T: 'static, { - slots: &'static [Option], + slots: &'static [UnsafeCell>], available: &'static AtomicUsize, stack_pos: &'static AtomicStackPosition, } @@ -210,9 +207,12 @@ impl Clone for PrioritySender { } } +unsafe impl Send for PrioritySender {} +unsafe impl Sync for PrioritySender {} + impl PrioritySender { unsafe fn slot_mut(&self, index: usize) -> &mut Option { - &mut *((&self.slots[index] as *const Option) as *mut Option) + &mut *self.slots[index].get() } fn stack_push(&self, item: T) -> Result<(), T> { @@ -289,13 +289,13 @@ where impl<'a, T, K, const N: usize> Deref for PeekMut<'a, T, K, N> where - T: PartialOrd, - K: Kind, + T: PartialOrd + 'static, + K: Kind + 'static, { type Target = T; fn deref(&self) -> &Self::Target { - self.queue.slots[0].as_ref().unwrap() + unsafe { self.queue.slot_mut(0).as_ref().unwrap() } } } @@ -304,7 +304,7 @@ where T: PartialOrd, K: Kind, { - slots: [Option; N], + slots: [UnsafeCell>; N], available: AtomicUsize, stack_pos: AtomicStackPosition, heap_size: usize, @@ -318,7 +318,7 @@ where { pub fn new() -> Self { Self { - slots: [(); N].map(|_| None), + slots: [(); N].map(|_| UnsafeCell::new(None)), available: AtomicUsize::new(N), stack_pos: AtomicStackPosition::new(N), heap_size: 0, @@ -337,7 +337,7 @@ where } unsafe fn slot_mut(&self, index: usize) -> &mut Option { - &mut *((&self.slots[index] as *const Option) as *mut Option) + &mut *self.slots[index].get() } fn get_node(&self, index: usize) -> Node { @@ -364,7 +364,7 @@ where break Err(()); } else { let new = current.popped(); - let item = self.slots[current.pos()].take(); + let item = unsafe { self.slot_mut(current.pos()).take() }; if let Ok(_) = self .stack_pos @@ -372,7 +372,9 @@ where { break Ok(item); } else { - self.slots[current.pos()] = item; + unsafe { + *self.slot_mut(current.pos()) = item; + } } } } @@ -407,7 +409,9 @@ where let index = self.heap_size; if index < N { - self.slots[index] = Some(item); + unsafe { + *self.slot_mut(index) = Some(item); + } self.heap_size += 1; @@ -431,7 +435,8 @@ where } fn take_root(&mut self) -> Option { - if let Some(item) = self.slots[0].take() { + if self.heap_size > 1 { + let item = unsafe { self.slot_mut(0).take() }.unwrap(); { let root = self.get_root(); let last = self.get_last(); @@ -440,6 +445,9 @@ where self.heap_size -= 1; Some(item) + } else if self.heap_size == 1 { + self.heap_size -= 1; + Some(unsafe { self.slot_mut(0).take() }.unwrap()) } else { None } @@ -501,9 +509,7 @@ where #[cfg(not(loom))] #[cfg(test)] mod tests { - use std::thread; - - use std::vec::Vec; + use std::{thread, vec::Vec}; use super::*; @@ -608,6 +614,7 @@ mod tests { } #[test] + #[cfg_attr(miri, ignore)] fn channel_thread() { const N: usize = 1000; let mut queue: PriorityQueue = PriorityQueue::new(); From 9eb15f2abe264ce5a5cc04a65cb64f7d2eb988b0 Mon Sep 17 00:00:00 2001 From: Eivind Alexander Bergem Date: Mon, 25 Apr 2022 12:42:32 +0200 Subject: [PATCH 4/8] Use non-specific nightly. --- hyperloop/src/lib.rs | 2 -- rust-toolchain.toml | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/hyperloop/src/lib.rs b/hyperloop/src/lib.rs index f55f9d5..9bcdae2 100644 --- a/hyperloop/src/lib.rs +++ b/hyperloop/src/lib.rs @@ -1,8 +1,6 @@ #![no_std] -#![feature(const_fn_trait_bound)] #![feature(type_alias_impl_trait)] #![feature(once_cell)] -#![feature(option_result_unwrap_unchecked)] mod common; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 8315f48..e9d4cf2 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "nightly-2021-10-16" +channel = "nightly" components = [ "rustfmt", "clippy" ] From 949f99396a4fc0e16837f03acd43722815ccaa70 Mon Sep 17 00:00:00 2001 From: Eivind Alexander Bergem Date: Mon, 25 Apr 2022 13:07:04 +0200 Subject: [PATCH 5/8] Use `UnsafeCell` for timer counter. --- hyperloop/src/timer.rs | 148 ++++++++++++++++++----------------------- 1 file changed, 63 insertions(+), 85 deletions(-) diff --git a/hyperloop/src/timer.rs b/hyperloop/src/timer.rs index 3fc99ba..347be4d 100644 --- a/hyperloop/src/timer.rs +++ b/hyperloop/src/timer.rs @@ -1,4 +1,5 @@ use core::{ + cell::UnsafeCell, pin::Pin, task::{Context, Poll, Waker}, }; @@ -18,7 +19,7 @@ type Tick = u64; pub struct Scheduler { rate: Hertz, - counter: Tick, + counter: UnsafeCell, queue: PriorityQueue, } @@ -26,18 +27,18 @@ impl Scheduler { pub fn new(rate: Hertz) -> Self { Self { rate, - counter: 0, + counter: UnsafeCell::new(0), queue: PriorityQueue::new(), } } pub unsafe fn increment(&mut self) { - self.counter += 1; + *self.counter.get() += 1; } fn next_waker(&mut self) -> Option { if let Some(ticket) = self.queue.peek_mut().as_mut() { - if self.counter > ticket.expires { + if unsafe { *self.counter.get() } > ticket.expires { return Some(PeekMut::pop(ticket).waker); } } @@ -51,43 +52,17 @@ impl Scheduler { } } - pub fn split(&'static mut self) -> (Ticker, Timer) { - let counter: &'static Tick = unsafe { &*(&self.counter as *const Tick) }; - let sender = unsafe { self.queue.get_sender() }; - let timer = Timer::new(self.rate, TickReader::new(counter), sender); - let ticker = Ticker::new(self); - - (ticker, timer) - } -} - -#[derive(Clone)] -pub struct TickReader { - counter: &'static Tick, -} - -impl TickReader { - fn new(counter: &'static Tick) -> Self { - Self { counter } - } - - fn get_count(&self) -> Tick { - *self.counter + pub unsafe fn tick(&mut self) { + self.increment(); + self.wake_tasks(); } -} -pub struct Ticker { - scheduler: &'static mut Scheduler, -} + pub fn get_timer(&self) -> Timer { + let counter = self.counter.get() as *const _; + let sender = unsafe { self.queue.get_sender() }; + let timer = Timer::new(self.rate, counter, sender); -impl Ticker { - pub fn new(scheduler: &'static mut Scheduler) -> Self { - Self { scheduler } - } - - pub unsafe fn tick(&mut self) { - self.scheduler.increment(); - self.scheduler.wake_tasks(); + timer } } @@ -125,13 +100,13 @@ impl Ord for Ticket { struct DelayFuture { sender: PrioritySender, - counter: TickReader, + counter: *const Tick, expires: Tick, started: bool, } impl DelayFuture { - fn new(sender: PrioritySender, counter: TickReader, expires: Tick) -> Self { + fn new(sender: PrioritySender, counter: *const Tick, expires: Tick) -> Self { Self { sender, counter, @@ -159,7 +134,7 @@ impl Future for DelayFuture { // expiration. This ensures that we wait for no less than // the specified duration, and possibly one tick longer // than desired. - if self.counter.get_count() > self.expires { + if unsafe { *self.counter } > self.expires { Poll::Ready(()) } else { Poll::Pending @@ -180,7 +155,7 @@ impl TimeoutFuture where F: Future, { - fn new(future: F, sender: PrioritySender, counter: TickReader, expires: Tick) -> Self { + fn new(future: F, sender: PrioritySender, counter: *const Tick, expires: Tick) -> Self { Self { future, delay: DelayFuture::new(sender, counter, expires), @@ -218,12 +193,12 @@ where #[derive(Clone)] pub struct Timer { rate: Hertz, - counter: TickReader, + counter: *const Tick, sender: PrioritySender, } impl Timer { - pub fn new(rate: Hertz, counter: TickReader, sender: PrioritySender) -> Self { + pub fn new(rate: Hertz, counter: *const Tick, sender: PrioritySender) -> Self { Self { rate, counter, @@ -236,7 +211,7 @@ impl Timer { } fn get_count(&self) -> Tick { - self.counter.get_count() + unsafe { *self.counter } } fn delay_to_ticks>(&self, duration: D) -> Tick { @@ -253,7 +228,7 @@ impl Timer { pub fn delay(&self, duration: Milliseconds) -> impl Future { DelayFuture::new( self.sender.clone(), - self.counter.clone(), + self.counter, self.delay_to_ticks(duration), ) } @@ -262,7 +237,7 @@ impl Timer { TimeoutFuture::new( future, self.sender.clone(), - self.counter.clone(), + self.counter, self.delay_to_ticks(duration), ) } @@ -286,31 +261,32 @@ mod tests { #[test] fn state() { - let (mut ticker, timer) = Box::leak(Box::new(Scheduler::<0>::new(1000.Hz()))).split(); - let counter = timer.counter; + let scheduler = Box::leak(Box::new(Scheduler::<0>::new(1000.Hz()))); + let timer = scheduler.get_timer(); - assert_eq!(counter.get_count(), 0); + assert_eq!(unsafe { *timer.counter }, 0); - unsafe { ticker.tick() }; - assert_eq!(counter.get_count(), 1); + unsafe { scheduler.tick() }; + assert_eq!(unsafe { *timer.counter }, 1); unsafe { - ticker.tick(); + scheduler.tick(); } - assert_eq!(counter.get_count(), 2); + assert_eq!(unsafe { *timer.counter }, 2); } #[test] fn delay() { - let (ticker, timer) = Box::leak(Box::new(Scheduler::<10>::new(1000.Hz()))).split(); - let counter = timer.counter.clone(); + let scheduler = Box::leak(Box::new(Scheduler::<10>::new(1000.Hz()))); + let timer = scheduler.get_timer(); + let counter = timer.counter; let mockwaker = Arc::new(MockWaker::new()); let waker: Waker = mockwaker.clone().into(); let mut cx = Context::from_waker(&waker); - let mut future = DelayFuture::new(timer.sender.clone(), counter.clone(), 1); + let mut future = DelayFuture::new(timer.sender.clone(), counter, 1); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); @@ -318,40 +294,39 @@ mod tests { assert_eq!(future.started, true); unsafe { - ticker.scheduler.increment(); - ticker.scheduler.increment(); + scheduler.increment(); + scheduler.increment(); } assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Ready(())); - let queue = &mut ticker.scheduler.queue; - - let mut future = DelayFuture::new(timer.sender.clone(), counter.clone(), 20); + let mut future = DelayFuture::new(timer.sender.clone(), counter, 20); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); - let mut future = DelayFuture::new(timer.sender.clone(), counter.clone(), 15); + let mut future = DelayFuture::new(timer.sender.clone(), counter, 15); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); - if let Some(ticket) = queue.pop() { + if let Some(ticket) = scheduler.queue.pop() { assert_eq!(ticket.expires, 1); ticket.waker.wake(); assert_eq!(mockwaker.woke.load(Ordering::Relaxed), true) } - if let Some(ticket) = queue.pop() { + if let Some(ticket) = scheduler.queue.pop() { assert_eq!(ticket.expires, 15); } - if let Some(ticket) = queue.pop() { + if let Some(ticket) = scheduler.queue.pop() { assert_eq!(ticket.expires, 20); } } #[test] fn timer() { - let (mut ticker, timer) = Box::leak(Box::new(Scheduler::<10>::new(1000.Hz()))).split(); + let scheduler = Box::leak(Box::new(Scheduler::new(1000.Hz()))); + let timer = scheduler.get_timer(); log_init(); @@ -393,7 +368,9 @@ mod tests { ))) .get_handle(); - let mut executor = Box::leak(Box::new(Executor::new([task]))).get_handle(); + let mut executor = Box::leak(Box::new(Executor::new([task]))) + .get_handle() + .with_scheduler(scheduler); executor.poll_tasks(); @@ -401,23 +378,23 @@ mod tests { assert_eq!(queue.pop(), None); unsafe { - ticker.tick(); + scheduler.tick(); } executor.poll_tasks(); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), None); - ticker.scheduler.wake_tasks(); + scheduler.wake_tasks(); executor.poll_tasks(); assert_eq!(queue.pop(), None); unsafe { - ticker.tick(); + scheduler.tick(); } unsafe { - ticker.tick(); + scheduler.tick(); } executor.poll_tasks(); @@ -425,10 +402,10 @@ mod tests { assert_eq!(queue.pop(), None); unsafe { - ticker.tick(); + scheduler.tick(); } unsafe { - ticker.tick(); + scheduler.tick(); } executor.poll_tasks(); @@ -436,10 +413,10 @@ mod tests { assert_eq!(queue.pop(), None); unsafe { - ticker.tick(); + scheduler.tick(); } unsafe { - ticker.tick(); + scheduler.tick(); } executor.poll_tasks(); @@ -448,7 +425,7 @@ mod tests { for _ in 0..10 { unsafe { - ticker.tick(); + scheduler.tick(); } executor.poll_tasks(); @@ -456,7 +433,7 @@ mod tests { } unsafe { - ticker.tick(); + scheduler.tick(); } executor.poll_tasks(); @@ -466,7 +443,8 @@ mod tests { #[test] fn timeout() { - let (ticker, timer) = Box::leak(Box::new(Scheduler::<10>::new(1000.Hz()))).split(); + let scheduler = Box::leak(Box::new(Scheduler::<2>::new(1000.Hz()))); + let timer = scheduler.get_timer(); log_init(); @@ -513,11 +491,11 @@ mod tests { for _ in 0..101 { unsafe { - ticker.scheduler.increment(); + scheduler.increment(); } } - ticker.scheduler.wake_tasks(); + scheduler.wake_tasks(); executor.poll_tasks(); @@ -526,18 +504,18 @@ mod tests { for _ in 0..1000 { unsafe { - ticker.scheduler.increment(); + scheduler.increment(); } - ticker.scheduler.wake_tasks(); + scheduler.wake_tasks(); executor.poll_tasks(); assert_eq!(queue.pop(), None); } unsafe { - ticker.scheduler.increment(); + scheduler.increment(); } - ticker.scheduler.wake_tasks(); + scheduler.wake_tasks(); executor.poll_tasks(); From 8ac6d125dbf9bf33f7ee1620939f57beea6442d3 Mon Sep 17 00:00:00 2001 From: Eivind Alexander Bergem Date: Wed, 27 Apr 2022 10:05:36 +0200 Subject: [PATCH 6/8] Use raw pointers in priority queue sender to keep miri happy. --- hyperloop-priority-queue/src/lib.rs | 33 +++++++++++++++-------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/hyperloop-priority-queue/src/lib.rs b/hyperloop-priority-queue/src/lib.rs index b886ce8..b3230c8 100644 --- a/hyperloop-priority-queue/src/lib.rs +++ b/hyperloop-priority-queue/src/lib.rs @@ -192,9 +192,9 @@ pub struct PrioritySender where T: 'static, { - slots: &'static [UnsafeCell>], - available: &'static AtomicUsize, - stack_pos: &'static AtomicStackPosition, + slots: *const [UnsafeCell>], + available: *const AtomicUsize, + stack_pos: *const AtomicStackPosition, } impl Clone for PrioritySender { @@ -212,20 +212,19 @@ unsafe impl Sync for PrioritySender {} impl PrioritySender { unsafe fn slot_mut(&self, index: usize) -> &mut Option { - &mut *self.slots[index].get() + &mut *(*self.slots)[index].get() } fn stack_push(&self, item: T) -> Result<(), T> { + let stack_pos = unsafe { &*self.stack_pos }; + loop { - let current = self.stack_pos.load(); + let current = stack_pos.load(); if current.pos() > 0 { let new = current.reserved(); - if let Ok(_) = self - .stack_pos - .compare_exchange(current.value(), new.value()) - { + if let Ok(_) = stack_pos.compare_exchange(current.value(), new.value()) { let slot = unsafe { self.slot_mut(new.pos()) }; *slot = Some(item); break; @@ -236,10 +235,10 @@ impl PrioritySender { } loop { - let old = self.stack_pos.load(); + let old = stack_pos.load(); let new = old.pushed(); - if let Ok(_) = self.stack_pos.compare_exchange(old.value(), new.value()) { + if let Ok(_) = stack_pos.compare_exchange(old.value(), new.value()) { break; } } @@ -248,13 +247,15 @@ impl PrioritySender { } pub fn send(&self, item: T) -> Result<(), T> { + let available = unsafe { &*self.available }; + loop { - let available = self.available.load(Ordering::Acquire); + let n_available = available.load(Ordering::Acquire); - if available > 0 { - if let Ok(_) = self.available.compare_exchange( - available, - available - 1, + if n_available > 0 { + if let Ok(_) = available.compare_exchange( + n_available, + n_available - 1, Ordering::Release, Ordering::Relaxed, ) { From f8c65d04282301473bf3d6961fb5a237d48536fd Mon Sep 17 00:00:00 2001 From: Eivind Alexander Bergem Date: Wed, 27 Apr 2022 10:06:28 +0200 Subject: [PATCH 7/8] Use UnsafeCell and raw pointers in executor to keep miri happy. --- hyperloop/src/executor.rs | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/hyperloop/src/executor.rs b/hyperloop/src/executor.rs index d79ee6f..d36abe0 100644 --- a/hyperloop/src/executor.rs +++ b/hyperloop/src/executor.rs @@ -1,10 +1,13 @@ +use core::cell::UnsafeCell; use core::cmp::Ordering; +use core::mem; use core::sync::atomic::AtomicBool; use core::task::{Poll, RawWaker, RawWakerVTable, Waker}; use crate::priority_queue::{Max, PriorityQueue, PrioritySender}; use crate::task::TaskHandle; +use crate::timer::Scheduler; pub(crate) type Priority = u8; type TaskId = u16; @@ -124,7 +127,7 @@ impl ExecutorTask { } pub struct Executor { - tasks: [ExecutorTask; N], + tasks: [UnsafeCell; N], queue: PriorityQueue, } @@ -133,7 +136,7 @@ impl Executor { let mut i = 0; let tasks = tasks.map(|task| { let priority = task.priority; - let task = ExecutorTask::new(task, i, priority, None); + let task = UnsafeCell::new(ExecutorTask::new(task, i, priority, None)); i += 1; task }); @@ -147,7 +150,7 @@ impl Executor { unsafe fn get_task(&mut self, task_id: TaskId) -> &mut ExecutorTask { let index = task_id as usize; - let task = &mut self.tasks[index]; + let task = &mut *self.tasks[index].get(); task.clear_pending_wake_flag(); task @@ -186,23 +189,27 @@ impl Executor { } pub fn get_handle(&'static mut self) -> ExecutorHandle { - ExecutorHandle::new(self) + ExecutorHandle::new(unsafe { &mut *mem::transmute::<_, *mut Self>(self) }) } } pub struct ExecutorHandle { - executor: &'static mut Executor, + executor: *mut Executor, } impl ExecutorHandle { - pub fn new(executor: &'static mut Executor) -> Self { - unsafe { executor.init() }; + pub fn new(executor: *mut Executor) -> Self { + unsafe { (*executor).init() }; Self { executor } } /// Poll all tasks in the queue pub fn poll_tasks(&mut self) { - unsafe { self.executor.poll_tasks() } + unsafe { (*self.executor).poll_tasks() } + } + + pub fn with_scheduler(self, _scheduler: &Scheduler) -> Self { + self } } @@ -271,8 +278,7 @@ mod tests { let task = Box::leak(Box::new(Task::new(test_future(queue.clone(), notify), 1))); - let mut executor = - ExecutorHandle::new(Box::leak(Box::new(Executor::new([task.get_handle()])))); + let mut executor = Box::leak(Box::new(Executor::new([task.get_handle()]))).get_handle(); executor.poll_tasks(); @@ -289,7 +295,7 @@ mod tests { executor.poll_tasks(); assert!(queue.pop().is_none()); - let waker = unsafe { executor.executor.get_task(0).get_waker() }; + let waker = unsafe { (*executor.executor).get_task(0).get_waker() }; waker.wake(); From cac7c945cb5379a885687227876d33ed467c97fc Mon Sep 17 00:00:00 2001 From: Eivind Alexander Bergem Date: Wed, 27 Apr 2022 14:26:54 +0200 Subject: [PATCH 8/8] Added macro to create static executor. --- hyperloop-macros/src/lib.rs | 49 ++++++++++++++++++++++++++++++++++++- hyperloop/src/executor.rs | 8 +++--- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/hyperloop-macros/src/lib.rs b/hyperloop-macros/src/lib.rs index 936ee9c..a210a08 100644 --- a/hyperloop-macros/src/lib.rs +++ b/hyperloop-macros/src/lib.rs @@ -5,10 +5,11 @@ use darling::FromMeta; use proc_macro::{self, TokenStream}; use quote::{format_ident, quote}; use syn::{ + parse::Parse, punctuated::{Pair, Punctuated}, spanned::Spanned, token::Comma, - FnArg, Ident, Pat, + Expr, FnArg, Ident, Pat, Stmt, Token, }; #[derive(Debug, FromMeta)] @@ -91,3 +92,49 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream { }; result.into() } + +struct Args { + args: Punctuated, +} + +impl Parse for Args { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + match Punctuated::::parse_terminated(&input) { + Ok(args) => Ok(Self { args }), + Err(err) => Err(err), + } + } +} + +struct Statements { + data: Vec, +} + +impl quote::ToTokens for Statements { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for stmt in self.data.iter() { + stmt.to_tokens(tokens); + } + } +} + +#[proc_macro] +pub fn static_executor(tokens: TokenStream) -> TokenStream { + let args = syn::parse_macro_input!(tokens as Args).args; + + let n_tasks = args.len(); + + let result = quote! { + { + static mut EXECUTOR: Option> = None; + + let executor = unsafe { + EXECUTOR.get_or_insert(Executor::new([#args])) + }; + + executor.get_handle() + } + }; + + result.into() +} diff --git a/hyperloop/src/executor.rs b/hyperloop/src/executor.rs index d36abe0..73f0c7a 100644 --- a/hyperloop/src/executor.rs +++ b/hyperloop/src/executor.rs @@ -216,7 +216,7 @@ impl ExecutorHandle { #[cfg(test)] mod tests { use crossbeam_queue::ArrayQueue; - use hyperloop_macros::task; + use hyperloop_macros::{static_executor, task}; use std::boxed::Box; use std::sync::Arc; @@ -243,12 +243,12 @@ mod tests { let task3 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 3), 2))); let task4 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 4), 4))); - let mut executor = ExecutorHandle::new(Box::leak(Box::new(Executor::new([ + let mut executor = static_executor!( task1.get_handle(), task2.get_handle(), task3.get_handle(), task4.get_handle(), - ])))); + ); executor.poll_tasks(); @@ -278,7 +278,7 @@ mod tests { let task = Box::leak(Box::new(Task::new(test_future(queue.clone(), notify), 1))); - let mut executor = Box::leak(Box::new(Executor::new([task.get_handle()]))).get_handle(); + let mut executor = static_executor!(task.get_handle()); executor.poll_tasks();