From 3d4d096cf025c8b45e8e335b0839b1766d2b3869 Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Wed, 30 Aug 2023 09:37:55 +0330 Subject: [PATCH 1/6] instrument the Readiness future with budgeting --- tokio/src/runtime/io/scheduled_io.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tokio/src/runtime/io/scheduled_io.rs b/tokio/src/runtime/io/scheduled_io.rs index ddce4b3ae4b..141bfa68381 100644 --- a/tokio/src/runtime/io/scheduled_io.rs +++ b/tokio/src/runtime/io/scheduled_io.rs @@ -463,6 +463,8 @@ impl Future for Readiness<'_> { (&me.scheduled_io, &mut me.state, &me.waiter) }; + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + loop { match *state { State::Init => { @@ -479,6 +481,7 @@ impl Future for Readiness<'_> { // Currently ready! let tick = TICK.unpack(curr) as u8; *state = State::Done; + coop.made_progress(); return Poll::Ready(ReadyEvent { tick, ready, @@ -503,6 +506,7 @@ impl Future for Readiness<'_> { // Currently ready! let tick = TICK.unpack(curr) as u8; *state = State::Done; + coop.made_progress(); return Poll::Ready(ReadyEvent { tick, ready, @@ -572,6 +576,7 @@ impl Future for Readiness<'_> { let curr_ready = Ready::from_usize(READINESS.unpack(curr)); let ready = curr_ready.intersection(w.interest); + coop.made_progress(); return Poll::Ready(ReadyEvent { tick, ready, From 2d4b02c859d576afe3a90492761b72e162d7b1a7 Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Thu, 31 Aug 2023 15:36:13 +0330 Subject: [PATCH 2/6] add test to assert UnixDatagram cooperates --- tokio/tests/uds_datagram.rs | 44 +++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tokio/tests/uds_datagram.rs b/tokio/tests/uds_datagram.rs index ad22a0b99dd..9c4f64fc3ed 100644 --- a/tokio/tests/uds_datagram.rs +++ b/tokio/tests/uds_datagram.rs @@ -411,3 +411,47 @@ async fn poll_ready() -> io::Result<()> { Ok(()) } + +#[tokio::test(flavor = "current_thread")] +async fn coop_uds() -> io::Result<()> { + use std::sync::atomic::{AtomicU64, Ordering}; + use std::time::{Duration, Instant}; + + const HELLO: &[u8] = b"hello world"; + const DURATION: Duration = Duration::from_secs(3); + + let dir = tempfile::tempdir().unwrap(); + let server_path = dir.path().join("server.sock"); + + let client = std::os::unix::net::UnixDatagram::unbound().unwrap(); + let server = UnixDatagram::bind(&server_path).unwrap(); + + let counter = Arc::new(AtomicU64::new(0)); + + let counter_jh = tokio::spawn({ + let counter = counter.clone(); + + async move { + loop { + tokio::time::sleep(Duration::from_millis(250)).await; + counter.fetch_add(1, Ordering::Relaxed); + } + } + }); + + let mut buf = [0; HELLO.len()]; + let start = Instant::now(); + while Instant::now().duration_since(start) < DURATION { + let _ = client.send_to(HELLO, &server_path); + let _ = server.recv(&mut buf[..]).await.unwrap(); + } + + counter_jh.abort(); + let _ = counter_jh.await; + + let expected = ((DURATION.as_secs() * 4) as f64 * 0.9) as u64; + let counter = counter.load(Ordering::Relaxed); + assert!(counter >= expected); + + Ok(()) +} From b55690e36957c90cb9e6971257994565d067ab29 Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Thu, 31 Aug 2023 16:26:58 +0330 Subject: [PATCH 3/6] reduce the test duration and relax its assertion --- tokio/tests/uds_datagram.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tokio/tests/uds_datagram.rs b/tokio/tests/uds_datagram.rs index 9c4f64fc3ed..ba3c13a79c4 100644 --- a/tokio/tests/uds_datagram.rs +++ b/tokio/tests/uds_datagram.rs @@ -418,7 +418,7 @@ async fn coop_uds() -> io::Result<()> { use std::time::{Duration, Instant}; const HELLO: &[u8] = b"hello world"; - const DURATION: Duration = Duration::from_secs(3); + const DURATION: Duration = Duration::from_secs(1); let dir = tempfile::tempdir().unwrap(); let server_path = dir.path().join("server.sock"); @@ -449,7 +449,7 @@ async fn coop_uds() -> io::Result<()> { counter_jh.abort(); let _ = counter_jh.await; - let expected = ((DURATION.as_secs() * 4) as f64 * 0.9) as u64; + let expected = ((DURATION.as_secs() * 4) as f64 * 0.5) as u64; let counter = counter.load(Ordering::Relaxed); assert!(counter >= expected); From f37a7b18027707bc30fa13c6460371d158f46a63 Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Sun, 3 Sep 2023 02:21:43 +0330 Subject: [PATCH 4/6] remove assumption about constantly being ready --- tokio/tests/tcp_stream.rs | 14 +------------- tokio/tests/uds_stream.rs | 14 +------------- 2 files changed, 2 insertions(+), 26 deletions(-) diff --git a/tokio/tests/tcp_stream.rs b/tokio/tests/tcp_stream.rs index 725a60169ea..3dd9d55f319 100644 --- a/tokio/tests/tcp_stream.rs +++ b/tokio/tests/tcp_stream.rs @@ -5,7 +5,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest}; use tokio::net::{TcpListener, TcpStream}; use tokio::try_join; use tokio_test::task; -use tokio_test::{assert_ok, assert_pending, assert_ready_ok}; +use tokio_test::{assert_ok, assert_pending}; use std::io; use std::task::Poll; @@ -57,10 +57,6 @@ async fn try_read_write() { // Fill the write buffer using non-vectored I/O loop { - // Still ready - let mut writable = task::spawn(client.writable()); - assert_ready_ok!(writable.poll()); - match client.try_write(DATA) { Ok(n) => written.extend(&DATA[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -98,10 +94,6 @@ async fn try_read_write() { // Fill the write buffer using vectored I/O let data_bufs: Vec<_> = DATA.chunks(10).map(io::IoSlice::new).collect(); loop { - // Still ready - let mut writable = task::spawn(client.writable()); - assert_ready_ok!(writable.poll()); - match client.try_write_vectored(&data_bufs) { Ok(n) => written.extend(&DATA[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -313,10 +305,6 @@ async fn try_read_buf() { // Fill the write buffer loop { - // Still ready - let mut writable = task::spawn(client.writable()); - assert_ready_ok!(writable.poll()); - match client.try_write(DATA) { Ok(n) => written.extend(&DATA[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { diff --git a/tokio/tests/uds_stream.rs b/tokio/tests/uds_stream.rs index b8c4e6a8eed..924aa1261e9 100644 --- a/tokio/tests/uds_stream.rs +++ b/tokio/tests/uds_stream.rs @@ -7,7 +7,7 @@ use std::task::Poll; use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest}; use tokio::net::{UnixListener, UnixStream}; -use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; +use tokio_test::{assert_ok, assert_pending, task}; use futures::future::{poll_fn, try_join}; @@ -92,10 +92,6 @@ async fn try_read_write() -> std::io::Result<()> { // Fill the write buffer using non-vectored I/O loop { - // Still ready - let mut writable = task::spawn(client.writable()); - assert_ready_ok!(writable.poll()); - match client.try_write(msg) { Ok(n) => written.extend(&msg[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -133,10 +129,6 @@ async fn try_read_write() -> std::io::Result<()> { // Fill the write buffer using vectored I/O let msg_bufs: Vec<_> = msg.chunks(3).map(io::IoSlice::new).collect(); loop { - // Still ready - let mut writable = task::spawn(client.writable()); - assert_ready_ok!(writable.poll()); - match client.try_write_vectored(&msg_bufs) { Ok(n) => written.extend(&msg[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -329,10 +321,6 @@ async fn try_read_buf() -> std::io::Result<()> { // Fill the write buffer loop { - // Still ready - let mut writable = task::spawn(client.writable()); - assert_ready_ok!(writable.poll()); - match client.try_write(msg) { Ok(n) => written.extend(&msg[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { From c995a75697e2d8c17c33dab6a625ed78e9bd2e0a Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Sun, 3 Sep 2023 02:27:35 +0330 Subject: [PATCH 5/6] use u32 instead of u64 --- tokio/tests/uds_datagram.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tokio/tests/uds_datagram.rs b/tokio/tests/uds_datagram.rs index ba3c13a79c4..5d7e7ddd81d 100644 --- a/tokio/tests/uds_datagram.rs +++ b/tokio/tests/uds_datagram.rs @@ -414,7 +414,7 @@ async fn poll_ready() -> io::Result<()> { #[tokio::test(flavor = "current_thread")] async fn coop_uds() -> io::Result<()> { - use std::sync::atomic::{AtomicU64, Ordering}; + use std::sync::atomic::{AtomicU32, Ordering}; use std::time::{Duration, Instant}; const HELLO: &[u8] = b"hello world"; @@ -426,7 +426,7 @@ async fn coop_uds() -> io::Result<()> { let client = std::os::unix::net::UnixDatagram::unbound().unwrap(); let server = UnixDatagram::bind(&server_path).unwrap(); - let counter = Arc::new(AtomicU64::new(0)); + let counter = Arc::new(AtomicU32::new(0)); let counter_jh = tokio::spawn({ let counter = counter.clone(); @@ -449,7 +449,7 @@ async fn coop_uds() -> io::Result<()> { counter_jh.abort(); let _ = counter_jh.await; - let expected = ((DURATION.as_secs() * 4) as f64 * 0.5) as u64; + let expected = ((DURATION.as_secs() * 4) as f64 * 0.5) as u32; let counter = counter.load(Ordering::Relaxed); assert!(counter >= expected); From b507e6393d436a5b1dee0b3928ca012a7ff92a56 Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Mon, 2 Oct 2023 00:00:18 +0330 Subject: [PATCH 6/6] consume budget when async_io makes progress --- tokio/src/runtime/coop.rs | 11 +++++ tokio/src/runtime/io/registration.rs | 6 ++- tokio/src/runtime/io/scheduled_io.rs | 68 ++++++++++++++++++++++------ tokio/tests/tcp_stream.rs | 14 +++++- tokio/tests/uds_stream.rs | 14 +++++- 5 files changed, 95 insertions(+), 18 deletions(-) diff --git a/tokio/src/runtime/coop.rs b/tokio/src/runtime/coop.rs index 15a4d98c08d..2d44bc69f07 100644 --- a/tokio/src/runtime/coop.rs +++ b/tokio/src/runtime/coop.rs @@ -196,6 +196,17 @@ cfg_coop! { }).unwrap_or(Poll::Ready(RestoreOnPending(Cell::new(Budget::unconstrained())))) } + #[inline] + pub(crate) fn try_decrement() { + let _ = context::budget(|cell| { + let mut budget = cell.get(); + + budget.decrement(); + + cell.set(budget); + }); + } + cfg_rt! { cfg_metrics! { #[inline(always)] diff --git a/tokio/src/runtime/io/registration.rs b/tokio/src/runtime/io/registration.rs index 759589863eb..c9437ab1974 100644 --- a/tokio/src/runtime/io/registration.rs +++ b/tokio/src/runtime/io/registration.rs @@ -223,7 +223,11 @@ impl Registration { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.clear_readiness(event); } - x => return x, + x => { + crate::runtime::coop::try_decrement(); + + return x; + } } } } diff --git a/tokio/src/runtime/io/scheduled_io.rs b/tokio/src/runtime/io/scheduled_io.rs index 141bfa68381..01056d4f3ec 100644 --- a/tokio/src/runtime/io/scheduled_io.rs +++ b/tokio/src/runtime/io/scheduled_io.rs @@ -159,6 +159,8 @@ struct Readiness<'a> { /// Entry in the waiter `LinkedList`. waiter: UnsafeCell, + + is_waiter_registered: bool, } enum State { @@ -429,6 +431,7 @@ impl ScheduledIo { interest, _p: PhantomPinned, }), + is_waiter_registered: false, } } } @@ -458,12 +461,45 @@ impl Future for Readiness<'_> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { use std::sync::atomic::Ordering::SeqCst; - let (scheduled_io, state, waiter) = unsafe { + let (scheduled_io, state, waiter, is_waiter_registered) = unsafe { let me = self.get_unchecked_mut(); - (&me.scheduled_io, &mut me.state, &me.waiter) + ( + &me.scheduled_io, + &mut me.state, + &me.waiter, + &mut me.is_waiter_registered, + ) }; - let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + if !crate::runtime::coop::has_budget_remaining() { + // Wasn't ready, take the lock (and check again while locked). + let mut waiters = scheduled_io.waiters.lock(); + + let w = unsafe { &mut *waiter.get() }; + + if *is_waiter_registered { + // Update the waker, if necessary. + if !w.waker.as_ref().unwrap().will_wake(cx.waker()) { + w.waker = Some(cx.waker().clone()); + } + } else { + // Safety: called while locked + w.waker = Some(cx.waker().clone()); + + // Insert the waiter into the linked list + // + // safety: pointers from `UnsafeCell` are never null. + waiters + .list + .push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + + *is_waiter_registered = true; + } + + drop(waiters); + + return Poll::Pending; + } loop { match *state { @@ -481,7 +517,6 @@ impl Future for Readiness<'_> { // Currently ready! let tick = TICK.unpack(curr) as u8; *state = State::Done; - coop.made_progress(); return Poll::Ready(ReadyEvent { tick, ready, @@ -506,7 +541,6 @@ impl Future for Readiness<'_> { // Currently ready! let tick = TICK.unpack(curr) as u8; *state = State::Done; - coop.made_progress(); return Poll::Ready(ReadyEvent { tick, ready, @@ -516,17 +550,22 @@ impl Future for Readiness<'_> { // Not ready even after locked, insert into list... - // Safety: called while locked - unsafe { - (*waiter.get()).waker = Some(cx.waker().clone()); + if !*is_waiter_registered { + // Safety: called while locked + unsafe { + (*waiter.get()).waker = Some(cx.waker().clone()); + } + + // Insert the waiter into the linked list + // + // safety: pointers from `UnsafeCell` are never null. + waiters + .list + .push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + + *is_waiter_registered = true; } - // Insert the waiter into the linked list - // - // safety: pointers from `UnsafeCell` are never null. - waiters - .list - .push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); *state = State::Waiting; } State::Waiting => { @@ -576,7 +615,6 @@ impl Future for Readiness<'_> { let curr_ready = Ready::from_usize(READINESS.unpack(curr)); let ready = curr_ready.intersection(w.interest); - coop.made_progress(); return Poll::Ready(ReadyEvent { tick, ready, diff --git a/tokio/tests/tcp_stream.rs b/tokio/tests/tcp_stream.rs index 3dd9d55f319..725a60169ea 100644 --- a/tokio/tests/tcp_stream.rs +++ b/tokio/tests/tcp_stream.rs @@ -5,7 +5,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest}; use tokio::net::{TcpListener, TcpStream}; use tokio::try_join; use tokio_test::task; -use tokio_test::{assert_ok, assert_pending}; +use tokio_test::{assert_ok, assert_pending, assert_ready_ok}; use std::io; use std::task::Poll; @@ -57,6 +57,10 @@ async fn try_read_write() { // Fill the write buffer using non-vectored I/O loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + match client.try_write(DATA) { Ok(n) => written.extend(&DATA[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -94,6 +98,10 @@ async fn try_read_write() { // Fill the write buffer using vectored I/O let data_bufs: Vec<_> = DATA.chunks(10).map(io::IoSlice::new).collect(); loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + match client.try_write_vectored(&data_bufs) { Ok(n) => written.extend(&DATA[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -305,6 +313,10 @@ async fn try_read_buf() { // Fill the write buffer loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + match client.try_write(DATA) { Ok(n) => written.extend(&DATA[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { diff --git a/tokio/tests/uds_stream.rs b/tokio/tests/uds_stream.rs index 924aa1261e9..b8c4e6a8eed 100644 --- a/tokio/tests/uds_stream.rs +++ b/tokio/tests/uds_stream.rs @@ -7,7 +7,7 @@ use std::task::Poll; use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest}; use tokio::net::{UnixListener, UnixStream}; -use tokio_test::{assert_ok, assert_pending, task}; +use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; use futures::future::{poll_fn, try_join}; @@ -92,6 +92,10 @@ async fn try_read_write() -> std::io::Result<()> { // Fill the write buffer using non-vectored I/O loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + match client.try_write(msg) { Ok(n) => written.extend(&msg[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -129,6 +133,10 @@ async fn try_read_write() -> std::io::Result<()> { // Fill the write buffer using vectored I/O let msg_bufs: Vec<_> = msg.chunks(3).map(io::IoSlice::new).collect(); loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + match client.try_write_vectored(&msg_bufs) { Ok(n) => written.extend(&msg[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -321,6 +329,10 @@ async fn try_read_buf() -> std::io::Result<()> { // Fill the write buffer loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + match client.try_write(msg) { Ok(n) => written.extend(&msg[..n]), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {