diff --git a/hyperloop-macros/src/lib.rs b/hyperloop-macros/src/lib.rs index d2ef941..a210a08 100644 --- a/hyperloop-macros/src/lib.rs +++ b/hyperloop-macros/src/lib.rs @@ -6,11 +6,10 @@ use proc_macro::{self, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse::Parse, - parse_quote, punctuated::{Pair, Punctuated}, spanned::Spanned, token::Comma, - FnArg, Ident, Pat, Stmt, Token, + Expr, FnArg, Ident, Pat, Stmt, Token, }; #[derive(Debug, FromMeta)] @@ -69,7 +68,7 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream { let result = quote! { #(#attrs)* - #visibility fn #name(#args) -> Option<&'static mut crate::task::Task<#future_type>> { + #visibility fn #name(#args) -> Option { type F = #future_type; fn wrapper(#args) -> impl FnOnce() -> F { @@ -84,7 +83,7 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream { unsafe { if let None = TASK { TASK = Some(Task::new(wrapper(#arg_values), #priority)); - Some(TASK.as_mut().unwrap()) + Some(TASK.as_mut().unwrap().get_handle()) } else { None } @@ -95,12 +94,12 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream { } struct Args { - args: Punctuated, + args: Punctuated, } impl Parse for Args { fn parse(input: syn::parse::ParseStream) -> syn::Result { - match Punctuated::::parse_terminated(&input) { + match Punctuated::::parse_terminated(&input) { Ok(args) => Ok(Self { args }), Err(err) => Err(err), } @@ -120,35 +119,20 @@ impl quote::ToTokens for Statements { } #[proc_macro] -pub fn executor_from_tasks(tokens: TokenStream) -> TokenStream { +pub fn static_executor(tokens: TokenStream) -> TokenStream { let args = syn::parse_macro_input!(tokens as Args).args; let n_tasks = args.len(); - let tasks = Statements { - data: args - .pairs() - .map(|pair| { - let task = pair.into_value(); - let stmt: Stmt = parse_quote!( - #task.add_to_executor(executor.get_sender()).unwrap(); - ); - stmt - }) - .collect(), - }; - let result = quote! { { static mut EXECUTOR: Option> = None; let executor = unsafe { - EXECUTOR.get_or_insert(Executor::new()) + EXECUTOR.get_or_insert(Executor::new([#args])) }; - #tasks - - executor + executor.get_handle() } }; diff --git a/hyperloop-priority-queue/src/lib.rs b/hyperloop-priority-queue/src/lib.rs index 6f81c94..b3230c8 100644 --- a/hyperloop-priority-queue/src/lib.rs +++ b/hyperloop-priority-queue/src/lib.rs @@ -1,6 +1,6 @@ #![no_std] -use core::{marker::PhantomData, ops::Deref, sync::atomic::Ordering}; +use core::{cell::UnsafeCell, marker::PhantomData, mem, ops::Deref, sync::atomic::Ordering}; #[cfg(not(loom))] use core::sync::atomic::AtomicUsize; @@ -102,7 +102,7 @@ where } fn item(&self) -> &T { - self.heap.slots[self.pos].as_ref().unwrap() + unsafe { self.heap.slot_mut(self.pos).as_ref().unwrap() } } unsafe fn slot_mut(&self) -> &mut Option { @@ -113,10 +113,7 @@ where let slot = unsafe { self.slot_mut() }; let other_slot = unsafe { other.slot_mut() }; - let item = slot.take(); - *slot = other_slot.take(); - *other_slot = item; - + mem::swap(slot, other_slot); other } @@ -187,7 +184,7 @@ impl AtomicStackPosition { fn compare_exchange(&self, current: usize, new: usize) -> Result { self.atomic - .compare_exchange_weak(current, new, Ordering::Release, Ordering::Relaxed) + .compare_exchange_weak(current, new, Ordering::AcqRel, Ordering::Relaxed) } } @@ -195,9 +192,9 @@ pub struct PrioritySender where T: 'static, { - slots: &'static [Option], - available: &'static AtomicUsize, - stack_pos: &'static AtomicStackPosition, + slots: *const [UnsafeCell>], + available: *const AtomicUsize, + stack_pos: *const AtomicStackPosition, } impl Clone for PrioritySender { @@ -210,22 +207,24 @@ impl Clone for PrioritySender { } } +unsafe impl Send for PrioritySender {} +unsafe impl Sync for PrioritySender {} + impl PrioritySender { unsafe fn slot_mut(&self, index: usize) -> &mut Option { - &mut *((&self.slots[index] as *const Option) as *mut Option) + &mut *(*self.slots)[index].get() } fn stack_push(&self, item: T) -> Result<(), T> { + let stack_pos = unsafe { &*self.stack_pos }; + loop { - let current = self.stack_pos.load(); + let current = stack_pos.load(); if current.pos() > 0 { let new = current.reserved(); - if let Ok(_) = self - .stack_pos - .compare_exchange(current.value(), new.value()) - { + if let Ok(_) = stack_pos.compare_exchange(current.value(), new.value()) { let slot = unsafe { self.slot_mut(new.pos()) }; *slot = Some(item); break; @@ -236,10 +235,10 @@ impl PrioritySender { } loop { - let old = self.stack_pos.load(); + let old = stack_pos.load(); let new = old.pushed(); - if let Ok(_) = self.stack_pos.compare_exchange(old.value(), new.value()) { + if let Ok(_) = stack_pos.compare_exchange(old.value(), new.value()) { break; } } @@ -248,13 +247,15 @@ impl PrioritySender { } pub fn send(&self, item: T) -> Result<(), T> { + let available = unsafe { &*self.available }; + loop { - let available = self.available.load(Ordering::Acquire); + let n_available = available.load(Ordering::Acquire); - if available > 0 { - if let Ok(_) = self.available.compare_exchange( - available, - available - 1, + if n_available > 0 { + if let Ok(_) = available.compare_exchange( + n_available, + n_available - 1, Ordering::Release, Ordering::Relaxed, ) { @@ -289,13 +290,13 @@ where impl<'a, T, K, const N: usize> Deref for PeekMut<'a, T, K, N> where - T: PartialOrd, - K: Kind, + T: PartialOrd + 'static, + K: Kind + 'static, { type Target = T; fn deref(&self) -> &Self::Target { - self.queue.slots[0].as_ref().unwrap() + unsafe { self.queue.slot_mut(0).as_ref().unwrap() } } } @@ -304,7 +305,7 @@ where T: PartialOrd, K: Kind, { - slots: [Option; N], + slots: [UnsafeCell>; N], available: AtomicUsize, stack_pos: AtomicStackPosition, heap_size: usize, @@ -318,7 +319,7 @@ where { pub fn new() -> Self { Self { - slots: [(); N].map(|_| None), + slots: [(); N].map(|_| UnsafeCell::new(None)), available: AtomicUsize::new(N), stack_pos: AtomicStackPosition::new(N), heap_size: 0, @@ -326,8 +327,8 @@ where } } - pub fn get_sender(&self) -> PrioritySender { - let queue: &'static Self = unsafe { &*(self as *const Self) }; + pub unsafe fn get_sender(&self) -> PrioritySender { + let queue: &'static Self = &*(self as *const Self); PrioritySender { slots: &queue.slots, @@ -337,7 +338,7 @@ where } unsafe fn slot_mut(&self, index: usize) -> &mut Option { - &mut *((&self.slots[index] as *const Option) as *mut Option) + &mut *self.slots[index].get() } fn get_node(&self, index: usize) -> Node { @@ -364,7 +365,7 @@ where break Err(()); } else { let new = current.popped(); - let item = self.slots[current.pos()].take(); + let item = unsafe { self.slot_mut(current.pos()).take() }; if let Ok(_) = self .stack_pos @@ -372,7 +373,9 @@ where { break Ok(item); } else { - self.slots[current.pos()] = item; + unsafe { + *self.slot_mut(current.pos()) = item; + } } } } @@ -407,7 +410,9 @@ where let index = self.heap_size; if index < N { - self.slots[index] = Some(item); + unsafe { + *self.slot_mut(index) = Some(item); + } self.heap_size += 1; @@ -431,7 +436,8 @@ where } fn take_root(&mut self) -> Option { - if let Some(item) = self.slots[0].take() { + if self.heap_size > 1 { + let item = unsafe { self.slot_mut(0).take() }.unwrap(); { let root = self.get_root(); let last = self.get_last(); @@ -440,6 +446,9 @@ where self.heap_size -= 1; Some(item) + } else if self.heap_size == 1 { + self.heap_size -= 1; + Some(unsafe { self.slot_mut(0).take() }.unwrap()) } else { None } @@ -501,9 +510,7 @@ where #[cfg(not(loom))] #[cfg(test)] mod tests { - use std::thread; - - use std::vec::Vec; + use std::{thread, vec::Vec}; use super::*; @@ -540,7 +547,7 @@ mod tests { #[test] fn stack() { let mut heap: PriorityQueue = PriorityQueue::new(); - let sender = heap.get_sender(); + let sender = unsafe { heap.get_sender() }; for i in 0..10 { sender.stack_push(i).unwrap(); @@ -568,7 +575,7 @@ mod tests { #[test] fn channel() { let mut queue: PriorityQueue = PriorityQueue::new(); - let sender = queue.get_sender(); + let sender = unsafe { queue.get_sender() }; for i in 0..10 { sender.send(i).unwrap(); @@ -608,6 +615,7 @@ mod tests { } #[test] + #[cfg_attr(miri, ignore)] fn channel_thread() { const N: usize = 1000; let mut queue: PriorityQueue = PriorityQueue::new(); @@ -619,7 +627,7 @@ mod tests { let n_items = n_threads * n_items_per_thread; for i in 0..n_threads { - let sender = queue.get_sender(); + let sender = unsafe { queue.get_sender() }; let handler = thread::spawn(move || { for j in 0..n_items_per_thread { loop { diff --git a/hyperloop/Cargo.toml b/hyperloop/Cargo.toml index 05e0477..06ff151 100644 --- a/hyperloop/Cargo.toml +++ b/hyperloop/Cargo.toml @@ -2,7 +2,7 @@ name = "hyperloop" version = "0.1.0" authors = ["Eivind Alexander Bergem "] -edition = "2018" +edition = "2021" [dependencies] futures = {version = "0.3.15", default-features = false} diff --git a/hyperloop/src/executor.rs b/hyperloop/src/executor.rs index 095d500..73f0c7a 100644 --- a/hyperloop/src/executor.rs +++ b/hyperloop/src/executor.rs @@ -1,25 +1,26 @@ +use core::cell::UnsafeCell; use core::cmp::Ordering; +use core::mem; +use core::sync::atomic::AtomicBool; +use core::task::{Poll, RawWaker, RawWakerVTable, Waker}; use crate::priority_queue::{Max, PriorityQueue, PrioritySender}; -use crate::task::PollTask; +use crate::task::TaskHandle; +use crate::timer::Scheduler; pub(crate) type Priority = u8; +type TaskId = u16; -#[derive(Debug, Clone, Copy)] pub struct Ticket { - task: *const dyn PollTask, + task: TaskId, priority: Priority, } impl Ticket { - pub(crate) fn new(task: *const dyn PollTask, priority: Priority) -> Self { + pub(crate) fn new(task: TaskId, priority: Priority) -> Self { Self { task, priority } } - - unsafe fn get_task(&self) -> &dyn PollTask { - &*self.task - } } impl PartialEq for Ticket { @@ -44,17 +45,134 @@ impl Ord for Ticket { pub(crate) type TaskSender = PrioritySender; +const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake, drop); + +unsafe fn clone(ptr: *const ()) -> RawWaker { + RawWaker::new(ptr, &VTABLE) +} + +unsafe fn wake(ptr: *const ()) { + let task = &*(ptr as *const ExecutorTask); + task.wake(); +} + +unsafe fn drop(_ptr: *const ()) {} + +struct ExecutorTask { + task: TaskHandle, + task_id: TaskId, + priority: Priority, + sender: Option, + pending_wake: AtomicBool, +} + +impl ExecutorTask { + fn new( + task: TaskHandle, + task_id: TaskId, + priority: Priority, + sender: Option, + ) -> Self { + Self { + task, + task_id, + priority, + sender, + pending_wake: AtomicBool::new(false), + } + } + + fn set_sender(&mut self, sender: TaskSender) { + self.sender = Some(sender); + } + + fn get_waker(&self) -> Waker { + let ptr: *const () = (self as *const ExecutorTask).cast(); + let vtable = &VTABLE; + + unsafe { Waker::from_raw(RawWaker::new(ptr, vtable)) } + } + + fn send_ticket(&self, ticket: Ticket) -> Result<(), ()> { + let sender = unsafe { self.sender.as_ref().unwrap_unchecked() }; + + sender.send(ticket).map_err(|_| ()) + } + + fn wake(&self) { + if let Ok(_) = self.pending_wake.compare_exchange( + false, + true, + atomig::Ordering::Acquire, + atomig::Ordering::Acquire, + ) { + let ticket = Ticket::new(self.task_id, self.priority); + + self.send_ticket(ticket).unwrap_or_else(|_| unreachable!()); + } + } + + fn clear_pending_wake_flag(&self) { + let _ = self.pending_wake.compare_exchange( + true, + false, + atomig::Ordering::Acquire, + atomig::Ordering::Acquire, + ); + } + + fn poll(&mut self, waker: Waker) -> Poll<()> { + self.task.poll(waker) + } +} + pub struct Executor { + tasks: [UnsafeCell; N], queue: PriorityQueue, } impl Executor { - pub fn new() -> Self { + pub fn new(tasks: [TaskHandle; N]) -> Self { + let mut i = 0; + let tasks = tasks.map(|task| { + let priority = task.priority; + let task = UnsafeCell::new(ExecutorTask::new(task, i, priority, None)); + i += 1; + task + }); + Self { + tasks, queue: PriorityQueue::new(), } } + unsafe fn get_task(&mut self, task_id: TaskId) -> &mut ExecutorTask { + let index = task_id as usize; + + let task = &mut *self.tasks[index].get(); + task.clear_pending_wake_flag(); + + task + } + + unsafe fn init(&mut self) { + for i in 0..N { + let sender = self.queue.get_sender(); + let task = self.get_task(i as u16); + + task.set_sender(sender); + task.wake(); + } + } + + unsafe fn poll_task(&mut self, task_id: TaskId) { + let task = self.get_task(task_id); + let waker = task.get_waker(); + + let _ = task.poll(waker); + } + /// Poll all tasks in the queue /// /// # Safety @@ -64,29 +182,50 @@ impl Executor { /// pointers to the tasks stored in the executor. The pointers can /// be dereferenced at any time and will be dangling if the /// exeutor is moved or dropped. - pub unsafe fn poll_tasks(&mut self) { + unsafe fn poll_tasks(&mut self) { while let Some(ticket) = self.queue.pop() { - let _ = ticket.get_task().poll(); + self.poll_task(ticket.task); } } - pub fn get_sender(&self) -> TaskSender { - self.queue.get_sender() + pub fn get_handle(&'static mut self) -> ExecutorHandle { + ExecutorHandle::new(unsafe { &mut *mem::transmute::<_, *mut Self>(self) }) + } +} + +pub struct ExecutorHandle { + executor: *mut Executor, +} + +impl ExecutorHandle { + pub fn new(executor: *mut Executor) -> Self { + unsafe { (*executor).init() }; + Self { executor } + } + + /// Poll all tasks in the queue + pub fn poll_tasks(&mut self) { + unsafe { (*self.executor).poll_tasks() } + } + + pub fn with_scheduler(self, _scheduler: &Scheduler) -> Self { + self } } #[cfg(test)] mod tests { use crossbeam_queue::ArrayQueue; - use hyperloop_macros::{executor_from_tasks, task}; + use hyperloop_macros::{static_executor, task}; + use std::boxed::Box; use std::sync::Arc; use super::*; + use crate::notify::Notification; use crate::task::Task; #[test] fn test_executor() { - let mut executor = Executor::<10>::new(); let queue = Arc::new(ArrayQueue::new(10)); let test_future = |queue, value| { @@ -99,19 +238,19 @@ mod tests { } }; - let task1 = Task::new(test_future(queue.clone(), 1), 1); - let task2 = Task::new(test_future(queue.clone(), 2), 3); - let task3 = Task::new(test_future(queue.clone(), 3), 2); - let task4 = Task::new(test_future(queue.clone(), 4), 4); + let task1 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 1), 1))); + let task2 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 2), 3))); + let task3 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 3), 2))); + let task4 = Box::leak(Box::new(Task::new(test_future(queue.clone(), 4), 4))); - task1.add_to_executor(executor.get_sender()).unwrap(); - task2.add_to_executor(executor.get_sender()).unwrap(); - task3.add_to_executor(executor.get_sender()).unwrap(); - task4.add_to_executor(executor.get_sender()).unwrap(); + let mut executor = static_executor!( + task1.get_handle(), + task2.get_handle(), + task3.get_handle(), + task4.get_handle(), + ); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop().unwrap(), 4); assert_eq!(queue.pop().unwrap(), 2); @@ -119,6 +258,54 @@ mod tests { assert_eq!(queue.pop().unwrap(), 1); } + #[test] + fn test_pending_wake() { + let queue = Arc::new(ArrayQueue::new(10)); + let notify = Box::leak(Box::new(Notification::new())); + + let test_future = |queue, notify| { + move || { + async fn future(queue: Arc>, notify: &'static Notification) { + for i in 0..10 { + queue.push(i).unwrap(); + notify.wait().await; + } + } + + future(queue, notify) + } + }; + + let task = Box::leak(Box::new(Task::new(test_future(queue.clone(), notify), 1))); + + let mut executor = static_executor!(task.get_handle()); + + executor.poll_tasks(); + + assert_eq!(queue.pop().unwrap(), 0); + assert!(queue.pop().is_none()); + + notify.notify(); + + executor.poll_tasks(); + + assert_eq!(queue.pop().unwrap(), 1); + assert!(queue.pop().is_none()); + + executor.poll_tasks(); + assert!(queue.pop().is_none()); + + let waker = unsafe { (*executor.executor).get_task(0).get_waker() }; + + waker.wake(); + + notify.notify(); + executor.poll_tasks(); + + assert_eq!(queue.pop().unwrap(), 2); + assert!(queue.pop().is_none()); + } + #[test] fn macros() { #[task(priority = 1)] @@ -136,11 +323,9 @@ mod tests { let task1 = test_task1(queue.clone()).unwrap(); let task2 = test_task2(queue.clone()).unwrap(); - let executor = executor_from_tasks!(task1, task2); + let mut executor = ExecutorHandle::new(Box::leak(Box::new(Executor::new([task1, task2])))); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop().unwrap(), 2); assert_eq!(queue.pop().unwrap(), 1); diff --git a/hyperloop/src/interrupt.rs b/hyperloop/src/interrupt.rs index 6468c15..a1cc9f3 100644 --- a/hyperloop/src/interrupt.rs +++ b/hyperloop/src/interrupt.rs @@ -29,9 +29,11 @@ impl YieldFuture { impl Future for YieldFuture { type Output = (); - fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if !self.done { self.done = true; + cx.waker().wake_by_ref(); + Poll::Pending } else { Poll::Ready(()) diff --git a/hyperloop/src/lib.rs b/hyperloop/src/lib.rs index 0d1ab5a..9bcdae2 100644 --- a/hyperloop/src/lib.rs +++ b/hyperloop/src/lib.rs @@ -1,5 +1,4 @@ #![no_std] -#![feature(const_fn_trait_bound)] #![feature(type_alias_impl_trait)] #![feature(once_cell)] diff --git a/hyperloop/src/notify.rs b/hyperloop/src/notify.rs index 98575bd..8444ddf 100644 --- a/hyperloop/src/notify.rs +++ b/hyperloop/src/notify.rs @@ -64,7 +64,10 @@ mod tests { use std::boxed::Box; use std::sync::Arc; - use crate::{executor::Executor, task::Task}; + use crate::{ + executor::{Executor, ExecutorHandle}, + task::Task, + }; use super::*; @@ -72,7 +75,6 @@ mod tests { fn notify() { let notification = Box::leak(Box::new(Notification::new())); - let mut executor = Executor::<10>::new(); let queue = Arc::new(ArrayQueue::new(10)); let wait = |receiver, queue| { @@ -88,43 +90,34 @@ mod tests { } }; - let task1 = Task::new(wait(notification, queue.clone()), 1); + let task = Box::leak(Box::new(Task::new(wait(notification, queue.clone()), 1))); - task1.add_to_executor(executor.get_sender()).unwrap(); + let mut executor = + ExecutorHandle::new(Box::leak(Box::new(Executor::new([task.get_handle()])))); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), None); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), None); notification.notify(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), None); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), None); notification.notify(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(3)); assert_eq!(queue.pop(), None); diff --git a/hyperloop/src/task.rs b/hyperloop/src/task.rs index 7217c7e..bfffb08 100644 --- a/hyperloop/src/task.rs +++ b/hyperloop/src/task.rs @@ -1,37 +1,11 @@ use core::{ - lazy::OnceCell, pin::Pin, - task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, + task::{Context, Poll, Waker}, }; -use atomig::{Atom, Atomic, Ordering}; use futures::Future; -use crate::executor::{Priority, TaskSender, Ticket}; - -unsafe fn clone + 'static>(ptr: *const ()) -> RawWaker { - let task = &*(ptr as *const Task); - - RawWaker::new(ptr, &task.vtable) -} - -unsafe fn wake + 'static>(ptr: *const ()) { - let task = &*(ptr as *const Task); - task.wake(); -} - -unsafe fn drop(_ptr: *const ()) {} - -pub(crate) trait PollTask { - unsafe fn poll(&self) -> Poll<()>; -} - -#[repr(u8)] -#[derive(Copy, Clone, Eq, PartialEq, Debug, Atom)] -enum TaskState { - NotQueued, - Queued, -} +use crate::executor::Priority; pub struct Task where @@ -39,9 +13,6 @@ where { future: F, priority: Priority, - sender: OnceCell, - vtable: RawWakerVTable, - state: Atomic, } impl Task @@ -52,165 +23,81 @@ where Self { future: future_fn(), priority, - sender: OnceCell::new(), - vtable: RawWakerVTable::new(clone::, wake::, wake::, drop), - state: Atomic::new(TaskState::NotQueued), - } - } - - fn update_state(&self, old: TaskState, new: TaskState) -> bool { - if let Ok(_) = self - .state - .compare_exchange(old, new, Ordering::Relaxed, Ordering::Relaxed) - { - true - } else { - false - } - } - - #[cfg(test)] - fn get_state(&self) -> TaskState { - self.state.load(Ordering::Relaxed) - } - - unsafe fn as_static(&self) -> &'static Self { - &*(self as *const Self) - } - - unsafe fn as_mut(&self) -> &mut Self { - &mut *((self as *const Self) as *mut Self) - } - - unsafe fn get_waker(&self) -> Waker { - let ptr: *const () = (self as *const Task).cast(); - let vtable = &self.as_static().vtable; - - Waker::from_raw(RawWaker::new(ptr, vtable)) - } - - pub fn wake(&self) { - self.schedule().unwrap(); - } - - pub fn add_to_executor(&self, sender: TaskSender) -> Result<(), ()> { - self.set_sender(sender)?; - self.schedule() - } - - fn set_sender(&self, sender: TaskSender) -> Result<(), ()> { - match self.sender.set(sender) { - Ok(_) => Ok(()), - Err(_) => Err(()), } } - fn send_ticket(&self, ticket: Ticket) -> Result<(), ()> { - if let Some(sender) = self.sender.get() { - if let Ok(_) = sender.send(ticket) { - return Ok(()); - } - } - - Err(()) + pub fn get_handle(&'static mut self) -> TaskHandle { + TaskHandle::new(self) } +} - fn schedule(&self) -> Result<(), ()> { - if self.update_state(TaskState::NotQueued, TaskState::Queued) { - let ticket = Ticket::new(self as *const Self, self.priority); +pub struct TaskHandle { + future: *mut (), + poll: fn(*mut (), &mut Context<'_>) -> Poll<()>, + pub priority: Priority, +} - match self.send_ticket(ticket) { - Ok(_) => Ok(()), - Err(_) => { - assert!(self.update_state(TaskState::Queued, TaskState::NotQueued)); - Err(()) - } +impl TaskHandle { + pub fn new>(task: &'static mut Task) -> Self { + unsafe { + Self { + future: core::mem::transmute::<_, _>(&mut task.future), + poll: core::mem::transmute::< + fn(Pin<&mut F>, &mut Context<'_>) -> Poll, + fn(*mut (), &mut Context<'_>) -> Poll<()>, + >(F::poll), + priority: task.priority, } - } else { - Ok(()) } } -} -impl PollTask for Task -where - F: Future + 'static, -{ - unsafe fn poll(&self) -> Poll<()> { - let waker = self.get_waker(); + pub fn poll(&mut self, waker: Waker) -> Poll<()> { let mut cx = Context::from_waker(&waker); - let future = Pin::new_unchecked(&mut self.as_mut().future); - assert!(self.update_state(TaskState::Queued, TaskState::NotQueued)); - let result = future.poll(&mut cx); + let poll = self.poll; - result + poll(self.future, &mut cx) } } #[cfg(test)] mod tests { - use crate::{ - interrupt::yield_now, - priority_queue::{Max, PriorityQueue}, - }; + use crossbeam_queue::ArrayQueue; + use std::boxed::Box; + use std::sync::Arc; + + use crate::{common::tests::MockWaker, interrupt::yield_now}; use super::*; #[test] fn task() { - let mut queue: PriorityQueue = PriorityQueue::new(); + let queue = Arc::new(ArrayQueue::new(10)); - let test_future = || { + let test_future = |queue| { || { - async fn future() { - loop { - yield_now().await + async fn future(queue: Arc>) { + for i in 0.. { + queue.push(i).unwrap(); + yield_now().await; } } - future() + future(queue) } }; - let task = Task::new(test_future(), 1); - - task.set_sender(queue.get_sender()).unwrap(); + let mut task = Box::leak(Box::new(Task::new(test_future(queue.clone()), 1))).get_handle(); + let waker: Waker = Arc::new(MockWaker::new()).into(); - assert_eq!(task.get_state(), TaskState::NotQueued); + assert_eq!(task.poll(waker.clone()), Poll::Pending); - task.schedule().unwrap(); - - assert_eq!(task.get_state(), TaskState::Queued); - - assert!(queue.pop().is_some()); - assert!(queue.pop().is_none()); - - task.schedule().unwrap(); - - assert!(queue.pop().is_none()); - - unsafe { - assert_eq!(task.poll(), Poll::Pending); - } - - assert_eq!(task.get_state(), TaskState::NotQueued); - - task.wake(); - task.wake(); - task.wake(); - - assert_eq!(task.get_state(), TaskState::Queued); - - assert!(queue.pop().is_some()); + assert_eq!(queue.pop().unwrap(), 0); assert!(queue.pop().is_none()); - task.wake(); - task.wake(); - task.wake(); - - assert_eq!(task.get_state(), TaskState::Queued); + assert_eq!(task.poll(waker.clone()), Poll::Pending); + assert_eq!(queue.pop().unwrap(), 1); assert!(queue.pop().is_none()); } } diff --git a/hyperloop/src/timer.rs b/hyperloop/src/timer.rs index ec9c8fb..347be4d 100644 --- a/hyperloop/src/timer.rs +++ b/hyperloop/src/timer.rs @@ -1,4 +1,5 @@ use core::{ + cell::UnsafeCell, pin::Pin, task::{Context, Poll, Waker}, }; @@ -10,66 +11,58 @@ use embedded_time::{ }; use core::future::Future; -use futures::{task::AtomicWaker, Stream, StreamExt}; use log::error; use crate::priority_queue::{Min, PeekMut, PriorityQueue, PrioritySender}; type Tick = u64; -pub struct TickCounter { - count: Tick, - waker: AtomicWaker, +pub struct Scheduler { + rate: Hertz, + counter: UnsafeCell, + queue: PriorityQueue, } -impl TickCounter { - pub const fn new() -> Self { +impl Scheduler { + pub fn new(rate: Hertz) -> Self { Self { - count: 0, - waker: AtomicWaker::new(), + rate, + counter: UnsafeCell::new(0), + queue: PriorityQueue::new(), } } - /// Increment tick count - /// - /// # Safety - /// - /// Updating the tick value is not atomic on 32-bit systems, so it - /// would be possible to get an invalid reading if reading during - /// a write. For this reason, this function should only be called - /// from a high priority interrupt handler. pub unsafe fn increment(&mut self) { - self.count += 1; + *self.counter.get() += 1; } - pub fn wake(&self) { - self.waker.wake(); - } + fn next_waker(&mut self) -> Option { + if let Some(ticket) = self.queue.peek_mut().as_mut() { + if unsafe { *self.counter.get() } > ticket.expires { + return Some(PeekMut::pop(ticket).waker); + } + } - pub unsafe fn tick(&mut self) { - self.increment(); - self.wake(); + None } - pub fn get_token(&self) -> TickCounterToken { - TickCounterToken { - counter: unsafe { &*(self as *const Self) }, + fn wake_tasks(&mut self) { + while let Some(waker) = self.next_waker() { + waker.wake(); } } -} -#[derive(Clone)] -pub struct TickCounterToken { - counter: &'static TickCounter, -} - -impl TickCounterToken { - pub fn register_waker(&self, waker: &Waker) { - self.counter.waker.register(waker); + pub unsafe fn tick(&mut self) { + self.increment(); + self.wake_tasks(); } - pub fn get_count(&self) -> Tick { - self.counter.count + pub fn get_timer(&self) -> Timer { + let counter = self.counter.get() as *const _; + let sender = unsafe { self.queue.get_sender() }; + let timer = Timer::new(self.rate, counter, sender); + + timer } } @@ -107,13 +100,13 @@ impl Ord for Ticket { struct DelayFuture { sender: PrioritySender, - counter: TickCounterToken, + counter: *const Tick, expires: Tick, started: bool, } impl DelayFuture { - fn new(sender: PrioritySender, counter: TickCounterToken, expires: Tick) -> Self { + fn new(sender: PrioritySender, counter: *const Tick, expires: Tick) -> Self { Self { sender, counter, @@ -141,7 +134,7 @@ impl Future for DelayFuture { // expiration. This ensures that we wait for no less than // the specified duration, and possibly one tick longer // than desired. - if self.counter.get_count() > self.expires { + if unsafe { *self.counter } > self.expires { Poll::Ready(()) } else { Poll::Pending @@ -162,12 +155,7 @@ impl TimeoutFuture where F: Future, { - fn new( - future: F, - sender: PrioritySender, - counter: TickCounterToken, - expires: Tick, - ) -> Self { + fn new(future: F, sender: PrioritySender, counter: *const Tick, expires: Tick) -> Self { Self { future, delay: DelayFuture::new(sender, counter, expires), @@ -205,12 +193,12 @@ where #[derive(Clone)] pub struct Timer { rate: Hertz, - counter: TickCounterToken, + counter: *const Tick, sender: PrioritySender, } impl Timer { - pub fn new(rate: Hertz, counter: TickCounterToken, sender: PrioritySender) -> Self { + pub fn new(rate: Hertz, counter: *const Tick, sender: PrioritySender) -> Self { Self { rate, counter, @@ -223,7 +211,7 @@ impl Timer { } fn get_count(&self) -> Tick { - self.counter.get_count() + unsafe { *self.counter } } fn delay_to_ticks>(&self, duration: D) -> Tick { @@ -240,7 +228,7 @@ impl Timer { pub fn delay(&self, duration: Milliseconds) -> impl Future { DelayFuture::new( self.sender.clone(), - self.counter.clone(), + self.counter, self.delay_to_ticks(duration), ) } @@ -249,86 +237,12 @@ impl Timer { TimeoutFuture::new( future, self.sender.clone(), - self.counter.clone(), + self.counter, self.delay_to_ticks(duration), ) } } -struct TimerFuture { - counter: TickCounterToken, - expires: Option, -} - -impl TimerFuture { - fn new(counter: TickCounterToken) -> Self { - Self { - counter, - expires: None, - } - } -} - -impl Stream for TimerFuture { - type Item = (); - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Some(expires) = self.expires { - if self.counter.get_count() >= expires { - self.expires = None; - return Poll::Ready(Some(())); - } - } else { - self.expires = Some(self.counter.get_count() + 1_u64); - } - - self.counter.register_waker(cx.waker()); - Poll::Pending - } -} - -pub struct Scheduler { - rate: Hertz, - counter: TickCounterToken, - queue: PriorityQueue, -} - -impl Scheduler { - pub fn new(rate: Hertz, counter: TickCounterToken) -> Self { - Self { - rate, - counter, - queue: PriorityQueue::new(), - } - } - - pub fn get_timer(&self) -> Timer { - Timer::new(self.rate, self.counter.clone(), self.queue.get_sender()) - } - - fn next_waker(&mut self) -> Option { - if let Some(ticket) = self.queue.peek_mut().as_mut() { - if self.counter.get_count() > ticket.expires { - return Some(PeekMut::pop(ticket).waker); - } - } - - None - } - - pub async fn task(&mut self) { - let mut timer = TimerFuture::new(self.counter.clone()); - - loop { - if let Some(waker) = self.next_waker() { - waker.wake(); - } else { - timer.next().await.unwrap() - } - } - } -} - #[cfg(test)] mod tests { use core::sync::atomic::Ordering; @@ -347,44 +261,32 @@ mod tests { #[test] fn state() { - let counter = Box::leak(Box::new(TickCounter::new())); - let token = counter.get_token(); - - assert_eq!(token.get_count(), 0); - - unsafe { counter.increment() }; - assert_eq!(token.get_count(), 1); - - let mockwaker = Arc::new(MockWaker::new()); - let waker: Waker = mockwaker.clone().into(); - - token.register_waker(&waker); - counter.wake(); + let scheduler = Box::leak(Box::new(Scheduler::<0>::new(1000.Hz()))); + let timer = scheduler.get_timer(); - assert_eq!(mockwaker.woke.load(Ordering::Relaxed), true); + assert_eq!(unsafe { *timer.counter }, 0); - mockwaker.woke.store(false, Ordering::Relaxed); - token.register_waker(&waker); + unsafe { scheduler.tick() }; + assert_eq!(unsafe { *timer.counter }, 1); unsafe { - counter.tick(); + scheduler.tick(); } - assert_eq!(token.get_count(), 2); - assert_eq!(mockwaker.woke.load(Ordering::Relaxed), true); + + assert_eq!(unsafe { *timer.counter }, 2); } #[test] fn delay() { - let counter = Box::leak(Box::new(TickCounter::new())); - let token = counter.get_token(); - let scheduler: &'static mut Scheduler<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))); - let sender = scheduler.queue.get_sender(); + let scheduler = Box::leak(Box::new(Scheduler::<10>::new(1000.Hz()))); + let timer = scheduler.get_timer(); + let counter = timer.counter; + let mockwaker = Arc::new(MockWaker::new()); let waker: Waker = mockwaker.clone().into(); let mut cx = Context::from_waker(&waker); - let mut future = DelayFuture::new(sender.clone(), token.clone(), 1); + let mut future = DelayFuture::new(timer.sender.clone(), counter, 1); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); @@ -392,17 +294,17 @@ mod tests { assert_eq!(future.started, true); unsafe { - counter.tick(); - counter.tick(); + scheduler.increment(); + scheduler.increment(); } assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Ready(())); - let mut future = DelayFuture::new(sender.clone(), token.clone(), 20); + let mut future = DelayFuture::new(timer.sender.clone(), counter, 20); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); - let mut future = DelayFuture::new(sender.clone(), token.clone(), 15); + let mut future = DelayFuture::new(timer.sender.clone(), counter, 15); assert_eq!(Pin::new(&mut future).poll(&mut cx), Poll::Pending); @@ -423,13 +325,8 @@ mod tests { #[test] fn timer() { - let counter = Box::leak(Box::new(TickCounter::new())); - let token = counter.get_token(); - let scheduler: &'static mut Scheduler<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))); + let scheduler = Box::leak(Box::new(Scheduler::new(1000.Hz()))); let timer = scheduler.get_timer(); - let mut executor = Box::new(Executor::<10>::new()); - let queue = Arc::new(ArrayQueue::new(10)); log_init(); @@ -463,92 +360,82 @@ mod tests { } }; - let task1 = Task::new(move || scheduler.task(), 1); - let task2 = Task::new(test_future(queue.clone(), timer.clone()), 1); + let queue = Arc::new(ArrayQueue::new(10)); - task1.add_to_executor(executor.get_sender()).unwrap(); - task2.add_to_executor(executor.get_sender()).unwrap(); + let task = Box::leak(Box::new(Task::new( + test_future(queue.clone(), timer.clone()), + 1, + ))) + .get_handle(); - unsafe { - executor.poll_tasks(); - } + let mut executor = Box::leak(Box::new(Executor::new([task]))) + .get_handle() + .with_scheduler(scheduler); + + executor.poll_tasks(); assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), None); unsafe { - counter.tick(); - } - unsafe { - executor.poll_tasks(); + scheduler.tick(); } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), None); - counter.wake(); - unsafe { - executor.poll_tasks(); - } + scheduler.wake_tasks(); + executor.poll_tasks(); assert_eq!(queue.pop(), None); unsafe { - counter.tick(); + scheduler.tick(); } unsafe { - counter.tick(); - } - unsafe { - executor.poll_tasks(); + scheduler.tick(); } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(3)); assert_eq!(queue.pop(), None); unsafe { - counter.tick(); - } - unsafe { - counter.tick(); + scheduler.tick(); } unsafe { - executor.poll_tasks(); + scheduler.tick(); } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(4)); assert_eq!(queue.pop(), None); unsafe { - counter.tick(); - } - unsafe { - counter.tick(); + scheduler.tick(); } unsafe { - executor.poll_tasks(); + scheduler.tick(); } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(5)); assert_eq!(queue.pop(), None); for _ in 0..10 { unsafe { - counter.tick(); - } - unsafe { - executor.poll_tasks(); + scheduler.tick(); } + executor.poll_tasks(); assert_eq!(queue.pop(), None); } unsafe { - counter.tick(); - } - unsafe { - executor.poll_tasks(); + scheduler.tick(); } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(6)); assert_eq!(queue.pop(), None); @@ -556,13 +443,8 @@ mod tests { #[test] fn timeout() { - let counter = Box::leak(Box::new(TickCounter::new())); - let token = counter.get_token(); - let scheduler: &'static mut Scheduler<10> = - Box::leak(Box::new(Scheduler::new(1000.Hz(), token.clone()))); + let scheduler = Box::leak(Box::new(Scheduler::<2>::new(1000.Hz()))); let timer = scheduler.get_timer(); - let mut executor = Executor::<10>::new(); - let queue = Arc::new(ArrayQueue::new(10)); log_init(); @@ -595,55 +477,47 @@ mod tests { } }; - let task1 = Task::new(move || scheduler.task(), 1); - let task2 = Task::new(waiting_future(queue.clone(), timer), 1); + let queue = Arc::new(ArrayQueue::new(10)); - task1.add_to_executor(executor.get_sender()).unwrap(); - task2.add_to_executor(executor.get_sender()).unwrap(); + let task = + Box::leak(Box::new(Task::new(waiting_future(queue.clone(), timer), 1))).get_handle(); - unsafe { - executor.poll_tasks(); - } + let mut executor = Box::leak(Box::new(Executor::new([task]))).get_handle(); + + executor.poll_tasks(); assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), None); for _ in 0..101 { unsafe { - counter.increment(); + scheduler.increment(); } } - counter.wake(); + scheduler.wake_tasks(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), None); for _ in 0..1000 { unsafe { - counter.increment(); - } - counter.wake(); - - unsafe { - executor.poll_tasks(); + scheduler.increment(); } + scheduler.wake_tasks(); + executor.poll_tasks(); assert_eq!(queue.pop(), None); } unsafe { - counter.increment(); + scheduler.increment(); } - counter.wake(); + scheduler.wake_tasks(); - unsafe { - executor.poll_tasks(); - } + executor.poll_tasks(); assert_eq!(queue.pop(), Some(3)); assert_eq!(queue.pop(), None); diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 8315f48..e9d4cf2 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "nightly-2021-10-16" +channel = "nightly" components = [ "rustfmt", "clippy" ]