Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
fix: get rid lockup in task locals
Browse files Browse the repository at this point in the history
  • Loading branch information
Gavin-Niederman committed Dec 9, 2023
1 parent 54ee3d2 commit 72348f1
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 150 deletions.
82 changes: 82 additions & 0 deletions pros/src/task/local.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use core::{cell::RefCell, ptr::NonNull, sync::atomic::AtomicU32};

use alloc::{boxed::Box, collections::BTreeMap};

use spin::Once;

use super::current;

static mut INDEX: AtomicU32 = AtomicU32::new(0);

// Unsafe because you can change the thread local storage while it is being read.
// This requires you to leak val so that you can be sure it lives the entire task.
unsafe fn task_local_storage_set<T>(task: pros_sys::task_t, val: &'static T, index: u32) {
// Yes, we transmute val. This is the intended use of this function.
pros_sys::vTaskSetThreadLocalStoragePointer(task, index as _, (val as *const T).cast());
}

// Unsafe because we can't check if the type is the same as the one that was set.
unsafe fn task_local_storage_get<T>(task: pros_sys::task_t, index: u32) -> Option<&'static T> {
let val = pros_sys::pvTaskGetThreadLocalStoragePointer(task, index as _);
val.cast::<T>().as_ref()
}

struct ThreadLocalStorage {
pub data: BTreeMap<usize, NonNull<()>>,
}

pub struct LocalKey<T: 'static> {
index: Once<usize>,
init: fn() -> T,
}

impl<T: 'static> LocalKey<T> {
pub const fn new(init: fn() -> T) -> Self {
Self {
index: Once::new(),
init,
}
}

pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&'static T) -> R,
{
let index = *self.index.call_once(|| unsafe {
INDEX.fetch_add(1, core::sync::atomic::Ordering::SeqCst) as usize
});
let current = current();

// Get the thread local storage for this task.
// Creating it if it doesn't exist.
let storage = unsafe {
task_local_storage_get(current.task, 0).unwrap_or_else(|| {
let storage = Box::leak(Box::new(RefCell::new(ThreadLocalStorage {
data: BTreeMap::new(),
})));
task_local_storage_set(current.task, storage, 0);
storage
})
};

{
if let Some(val) = storage.borrow_mut().data.get(&index) {
return f(unsafe { val.cast().as_ref() });
}
}

let val = Box::leak(Box::new((self.init)()));
storage.borrow_mut().data.insert(index, NonNull::new((val as *mut T).cast::<()>()).unwrap());
f(val)
}
}

#[macro_export]
macro_rules! os_task_local {
($($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty = $init:expr;)*) => {
$(
$(#[$attr])*
$vis static $name: $crate::task::local::LocalKey<$t> = $crate::task::local::LocalKey::new(|| $init);
)*
};
}
69 changes: 69 additions & 0 deletions pros/src/task/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
pub mod local;
pub mod task;

pub use task::*;

use core::{future::Future, task::Poll};

use crate::async_runtime::executor::EXECUTOR;

/// Blocks the current task for the given amount of time, if you are in an async function.
/// ## you probably don't want to use this.
/// This function will block the entire task, including the async executor!
/// Instead, you should use [`sleep`].
pub fn delay(duration: core::time::Duration) {
unsafe { pros_sys::delay(duration.as_millis() as u32) }
}

pub struct SleepFuture {
target_millis: u32,
}
impl Future for SleepFuture {
type Output = ();

fn poll(
self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
if self.target_millis < unsafe { pros_sys::millis() } {
Poll::Ready(())
} else {
EXECUTOR.with(|e| {
e.reactor
.borrow_mut()
.sleepers
.push(cx.waker().clone(), self.target_millis)
});
Poll::Pending
}
}
}

pub fn sleep(duration: core::time::Duration) -> SleepFuture {
SleepFuture {
target_millis: unsafe { pros_sys::millis() + duration.as_millis() as u32 },
}
}

/// Returns the task the function was called from.
pub fn current() -> TaskHandle {
unsafe {
let task = pros_sys::task_get_current();
TaskHandle { task }
}
}

/// Gets the first notification in the queue.
/// If there is none, blocks until a notification is received.
/// I am unsure what happens if the thread is unblocked while waiting.
/// returns the value of the notification
pub fn get_notification() -> u32 {
unsafe { pros_sys::task_notify_take(false, pros_sys::TIMEOUT_MAX) }
}

#[doc(hidden)]
pub fn __init_main() {
unsafe {
pros_sys::lcd_initialize();
}
}
154 changes: 4 additions & 150 deletions pros/src/task.rs → pros/src/task/task.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
use core::{cell::RefCell, future::Future, hash::Hash, panic, ptr::NonNull, task::Poll};

use alloc::boxed::Box;
use hashbrown::HashMap;
use slab::Slab;
use core::hash::Hash;

use snafu::Snafu;
use spin::Once;

use crate::{
async_runtime::executor::EXECUTOR,
error::{bail_on, map_errno},
sync::Mutex,
};
use crate::error::{bail_on, map_errno};

/// Creates a task to be run 'asynchronously' (More information at the [FreeRTOS docs](https://www.freertos.org/taskandcr.html)).
/// Takes in a closure that can move variables if needed.
Expand Down Expand Up @@ -54,7 +47,7 @@ fn spawn_inner<F: FnOnce() + Send + 'static>(
/// An owned permission to perform actions on a task.
#[derive(Clone)]
pub struct TaskHandle {
task: pros_sys::task_t,
pub(crate) task: pros_sys::task_t,
}
unsafe impl Send for TaskHandle {}
impl Hash for TaskHandle {
Expand Down Expand Up @@ -260,142 +253,3 @@ map_errno! {
ENOMEM => SpawnError::TCBNotCreated,
}
}

/// Blocks the current task for the given amount of time, if you are in an async function.
/// ## you probably don't want to use this.
/// This function will block the entire task, including the async executor!
/// Instead, you should use [`sleep`].
pub fn delay(duration: core::time::Duration) {
unsafe { pros_sys::delay(duration.as_millis() as u32) }
}

pub struct SleepFuture {
target_millis: u32,
}
impl Future for SleepFuture {
type Output = ();

fn poll(
self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
if self.target_millis < unsafe { pros_sys::millis() } {
Poll::Ready(())
} else {
EXECUTOR.with(|e| {
e.reactor
.borrow_mut()
.sleepers
.push(cx.waker().clone(), self.target_millis)
});
Poll::Pending
}
}
}

pub fn sleep(duration: core::time::Duration) -> SleepFuture {
SleepFuture {
target_millis: unsafe { pros_sys::millis() + duration.as_millis() as u32 },
}
}

/// Returns the task the function was called from.
pub fn current() -> TaskHandle {
unsafe {
let task = pros_sys::task_get_current();
TaskHandle { task }
}
}

/// Gets the first notification in the queue.
/// If there is none, blocks until a notification is received.
/// I am unsure what happens if the thread is unblocked while waiting.
/// returns the value of the notification
pub fn get_notification() -> u32 {
unsafe { pros_sys::task_notify_take(false, pros_sys::TIMEOUT_MAX) }
}

// Unsafe because you can change the thread local storage while it is being read.
// This requires you to leak val so that you can be sure it lives the entire task.
unsafe fn task_local_storage_set<T>(task: pros_sys::task_t, val: &'static T, index: u32) {
// Yes, we transmute val. This is the intended use of this function.
pros_sys::vTaskSetThreadLocalStoragePointer(task, index as _, (val as *const T).cast());
}

// Unsafe because we can't check if the type is the same as the one that was set.
unsafe fn task_local_storage_get<T>(task: pros_sys::task_t, index: u32) -> Option<&'static T> {
let val = pros_sys::pvTaskGetThreadLocalStoragePointer(task, index as _);
val.cast::<T>().as_ref()
}

struct ThreadLocalStorage {
pub data: Slab<NonNull<()>>,
}

pub struct LocalKey<T: 'static> {
index_map: Once<Mutex<HashMap<TaskHandle, usize>>>,
init: fn() -> T,
}

impl<T: 'static> LocalKey<T> {
pub const fn new(init: fn() -> T) -> Self {
Self {
index_map: Once::new(),
init,
}
}

pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&'static T) -> R,
{
self.index_map.call_once(|| Mutex::new(HashMap::new()));

let current = current();

// Get the thread local storage for this task.
// Creating it if it doesn't exist.
let storage = unsafe {
task_local_storage_get(current.task, 0).unwrap_or_else(|| {
let storage = Box::leak(Box::new(RefCell::new(ThreadLocalStorage {
data: Slab::new(),
})));
task_local_storage_set(current.task, storage, 0);
storage
})
};

if let Some(index) = self.index_map.get().unwrap().lock().get(&current) {
let val = unsafe { storage.borrow().data[*index].cast::<T>().as_ref() };
f(val)
} else {
let val = Box::leak(Box::new((self.init)()));
let ptr = NonNull::from(val).cast();
let index = storage.borrow_mut().data.insert(ptr);
self.index_map
.get()
.unwrap()
.lock()
.insert(current.clone(), index);

f(unsafe { ptr.cast().as_ref() })
}
}
}

#[macro_export]
macro_rules! os_task_local {
($($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty = $init:expr;)*) => {
$(
$(#[$attr])*
$vis static $name: $crate::task::LocalKey<$t> = $crate::task::LocalKey::new(|| $init);
)*
};
}

#[doc(hidden)]
pub fn __init_main() {
unsafe {
pros_sys::lcd_initialize();
}
}

0 comments on commit 72348f1

Please sign in to comment.