From c9890229cec17bc4be07baf30d1ab2b6ae9db3fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Tue, 19 Nov 2024 15:04:52 +0100 Subject: [PATCH] Use a separate lock in esp-println --- esp-println/src/defmt.rs | 16 +++++++-- esp-println/src/lib.rs | 76 +++++++++++++++++++++++++--------------- 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/esp-println/src/defmt.rs b/esp-println/src/defmt.rs index b790f859164..509eacb7e10 100644 --- a/esp-println/src/defmt.rs +++ b/esp-println/src/defmt.rs @@ -4,7 +4,7 @@ #[cfg(feature = "critical-section")] use critical_section::RestoreState; -use super::PrinterImpl; +use super::{LockToken, PrinterImpl}; /// Global logger lock. #[cfg(feature = "critical-section")] @@ -74,7 +74,12 @@ unsafe impl defmt::Logger for Logger { } unsafe fn flush() { - PrinterImpl::flush(); + let token = unsafe { + // Safety: the implementation ensures this is only called in a critical + // section. + LockToken::conjure() + }; + PrinterImpl::flush(token); } unsafe fn write(bytes: &[u8]) { @@ -85,5 +90,10 @@ unsafe impl defmt::Logger for Logger { } fn do_write(bytes: &[u8]) { - PrinterImpl::write_bytes_assume_cs(bytes) + let token = unsafe { + // Safety: the above implementation ensures this is only called in a critical + // section. + LockToken::conjure() + }; + PrinterImpl::write_bytes_in_cs(bytes, token) } diff --git a/esp-println/src/lib.rs b/esp-println/src/lib.rs index 7609c244129..efa71327678 100644 --- a/esp-println/src/lib.rs +++ b/esp-println/src/lib.rs @@ -87,9 +87,9 @@ impl core::fmt::Write for Printer { impl Printer { /// Writes a byte slice to the configured output. pub fn write_bytes(bytes: &[u8]) { - with(|| { - PrinterImpl::write_bytes_assume_cs(bytes); - PrinterImpl::flush(); + with(|token| { + PrinterImpl::write_bytes_in_cs(bytes, token); + PrinterImpl::flush(token); }) } } @@ -114,10 +114,10 @@ type PrinterImpl = auto_printer::Printer; ) ))] mod auto_printer { - use super::with; use crate::{ serial_jtag_printer::Printer as PrinterSerialJtag, uart_printer::Printer as PrinterUart, + LockToken, }; pub struct Printer; @@ -145,27 +145,19 @@ mod auto_printer { unsafe { (USB_DEVICE_INT_RAW.read_volatile() & SOF_INT_MASK) != 0 } } - pub fn write_bytes_assume_cs(bytes: &[u8]) { + pub fn write_bytes_in_cs(bytes: &[u8], token: LockToken<'_>) { if Self::use_jtag() { - with(|| { - PrinterSerialJtag::write_bytes_assume_cs(bytes); - }) + PrinterSerialJtag::write_bytes_in_cs(bytes, token); } else { - with(|| { - PrinterUart::write_bytes_assume_cs(bytes); - }) + PrinterUart::write_bytes_in_cs(bytes, token); } } - pub fn flush() { + pub fn flush(token: LockToken<'_>) { if Self::use_jtag() { - with(|| { - PrinterSerialJtag::flush(); - }) + PrinterSerialJtag::flush(token); } else { - with(|| { - PrinterUart::flush(); - }) + PrinterUart::flush(token); } } } @@ -198,6 +190,8 @@ mod auto_printer { ))] mod serial_jtag_printer { use portable_atomic::{AtomicBool, Ordering}; + + use super::LockToken; pub struct Printer; #[cfg(feature = "esp32c3")] @@ -256,7 +250,7 @@ mod serial_jtag_printer { } impl Printer { - pub fn write_bytes_assume_cs(bytes: &[u8]) { + pub fn write_bytes_in_cs(bytes: &[u8], _token: LockToken<'_>) { if fifo_full() { // The FIFO is full. Let's see if we can progress. @@ -289,7 +283,7 @@ mod serial_jtag_printer { } } - pub fn flush() { + pub fn flush(_token: LockToken<'_>) { fifo_flush(); } } @@ -297,11 +291,12 @@ mod serial_jtag_printer { #[cfg(all(any(feature = "uart", feature = "auto"), feature = "esp32"))] mod uart_printer { + use super::LockToken; const UART_TX_ONE_CHAR: usize = 0x4000_9200; pub struct Printer; impl Printer { - pub fn write_bytes_assume_cs(bytes: &[u8]) { + pub fn write_bytes_in_cs(bytes: &[u8], _token: LockToken<'_>) { for &b in bytes { unsafe { let uart_tx_one_char: unsafe extern "C" fn(u8) -> i32 = @@ -311,15 +306,16 @@ mod uart_printer { } } - pub fn flush() {} + pub fn flush(_token: LockToken<'_>) {} } } #[cfg(all(any(feature = "uart", feature = "auto"), feature = "esp32s2"))] mod uart_printer { + use super::LockToken; pub struct Printer; impl Printer { - pub fn write_bytes_assume_cs(bytes: &[u8]) { + pub fn write_bytes_in_cs(bytes: &[u8], _token: LockToken<'_>) { // On ESP32-S2 the UART_TX_ONE_CHAR ROM-function seems to have some issues. for chunk in bytes.chunks(64) { for &b in chunk { @@ -338,7 +334,7 @@ mod uart_printer { } } - pub fn flush() {} + pub fn flush(_token: LockToken<'_>) {} } } @@ -347,6 +343,7 @@ mod uart_printer { not(any(feature = "esp32", feature = "esp32s2")) ))] mod uart_printer { + use super::LockToken; trait Functions { const TX_ONE_CHAR: usize; const CHUNK_SIZE: usize = 32; @@ -459,7 +456,7 @@ mod uart_printer { pub struct Printer; impl Printer { - pub fn write_bytes_assume_cs(bytes: &[u8]) { + pub fn write_bytes_in_cs(bytes: &[u8], _token: LockToken<'_>) { for chunk in bytes.chunks(Device::CHUNK_SIZE) { for &b in chunk { Device::tx_byte(b); @@ -469,15 +466,36 @@ mod uart_printer { } } - pub fn flush() {} + pub fn flush(_token: LockToken<'_>) {} + } +} + +#[cfg(not(feature = "critical-section"))] +type LockInner<'a> = PhantomData<&'a ()>; +#[cfg(feature = "critical-section")] +type LockInner<'a> = critical_section::CriticalSection<'a>; + +#[derive(Clone, Copy)] +struct LockToken<'a>(LockInner<'a>); + +impl<'a> LockToken<'a> { + #[allow(unused)] + unsafe fn conjure() -> Self { + #[cfg(feature = "critical-section")] + let inner = critical_section::CriticalSection::new(); + #[cfg(not(feature = "critical-section"))] + let inner = PhantomData; + + LockToken(inner) } } +/// Runs the callback in a critical section, if enabled. #[inline] -fn with(f: impl FnOnce() -> R) -> R { +fn with(f: impl FnOnce(LockToken) -> R) -> R { #[cfg(feature = "critical-section")] - return critical_section::with(|_| f()); + return critical_section::with(|cs| f(LockToken(cs))); #[cfg(not(feature = "critical-section"))] - f() + f(unsafe { LockToken::conjure() }) }