From eb2455e95dc742475b29c652b3ef3246a92fcc4a Mon Sep 17 00:00:00 2001 From: Eric Rodrigues Pires Date: Wed, 4 Dec 2024 14:43:03 -0300 Subject: [PATCH] Attempt to remove busywait in ChannelTx --- russh/src/channels/channel_ref.rs | 6 ++--- russh/src/channels/io/tx.rs | 20 +++++++++------- russh/src/channels/mod.rs | 40 ++++++++++++++++++++++++++++--- russh/src/client/encrypted.rs | 2 +- russh/src/client/mod.rs | 6 ++--- russh/src/server/encrypted.rs | 2 +- russh/src/server/session.rs | 6 ++--- 7 files changed, 60 insertions(+), 22 deletions(-) diff --git a/russh/src/channels/channel_ref.rs b/russh/src/channels/channel_ref.rs index d924bb11..ad39ad8b 100644 --- a/russh/src/channels/channel_ref.rs +++ b/russh/src/channels/channel_ref.rs @@ -3,14 +3,14 @@ use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; use tokio::sync::Mutex; -use crate::ChannelMsg; +use crate::{channels::WindowSize, ChannelMsg}; /// A handle to the [`super::Channel`]'s to be able to transmit messages /// to it and update it's `window_size`. #[derive(Debug)] pub struct ChannelRef { pub(super) sender: UnboundedSender, - pub(super) window_size: Arc>, + pub(super) window_size: Arc>, } impl ChannelRef { @@ -21,7 +21,7 @@ impl ChannelRef { } } - pub fn window_size(&self) -> &Arc> { + pub(crate) fn window_size(&self) -> &Arc> { &self.window_size } } diff --git a/russh/src/channels/io/tx.rs b/russh/src/channels/io/tx.rs index 47ce2726..fd1d54e4 100644 --- a/russh/src/channels/io/tx.rs +++ b/russh/src/channels/io/tx.rs @@ -9,8 +9,7 @@ use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{self, OwnedPermit}; use tokio::sync::{Mutex, OwnedMutexGuard}; -use super::ChannelMsg; -use crate::{ChannelId, CryptoVec}; +use crate::{channels::WindowSize, ChannelId, ChannelMsg, CryptoVec}; type BoxedThreadsafeFuture = Pin>>; type OwnedPermitFuture = @@ -21,8 +20,8 @@ pub struct ChannelTx { send_fut: Option>, id: ChannelId, - window_size_fut: Option>>, - window_size: Arc>, + window_size_fut: Option>>, + window_size: Arc>, max_packet_size: u32, ext: Option, } @@ -34,7 +33,7 @@ where pub fn new( sender: mpsc::Sender, id: ChannelId, - window_size: Arc>, + window_size: Arc>, max_packet_size: u32, ext: Option, ) -> Self { @@ -58,11 +57,13 @@ where self.window_size_fut.take(); let writable = (self.max_packet_size) - .min(*window_size) + .min(window_size.get()) .min(buf.len() as u32) as usize; if writable == 0 { - // TODO fix this busywait - cx.waker().wake_by_ref(); + if buf.is_empty() || self.sender.is_closed() { + return Poll::Ready((ChannelMsg::Eof, 0)); + } + window_size.add_waker(cx.waker().clone()); return Poll::Pending; } let mut data = CryptoVec::new_zeroed(writable); @@ -116,6 +117,9 @@ where cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + if self.sender.is_closed() { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } let send_fut = if let Some(x) = self.send_fut.as_mut() { x } else { diff --git a/russh/src/channels/mod.rs b/russh/src/channels/mod.rs index 5b095d9b..d0aa5f4e 100644 --- a/russh/src/channels/mod.rs +++ b/russh/src/channels/mod.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::task::Waker; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::mpsc::{Sender, UnboundedReceiver}; @@ -112,6 +113,35 @@ pub enum ChannelMsg { OpenFailure(ChannelOpenFailure), } +#[derive(Debug, Default)] +pub(crate) struct WindowSize { + window_size: u32, + waker: Option, +} + +impl WindowSize { + pub(crate) fn get(&self) -> u32 { + self.window_size + } + + pub(crate) fn set(&mut self, window_size: u32) { + self.window_size = window_size; + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } + + pub(crate) fn add_waker(&mut self, waker: Waker) { + self.waker = Some(waker); + } +} + +impl std::ops::SubAssign for WindowSize { + fn sub_assign(&mut self, rhs: u32) { + self.window_size -= rhs; + } +} + /// A handle to a session channel. /// /// Allows you to read and write from a channel without borrowing the session @@ -120,7 +150,7 @@ pub struct Channel> { pub(crate) sender: Sender, pub(crate) receiver: UnboundedReceiver, pub(crate) max_packet_size: u32, - pub(crate) window_size: Arc>, + pub(crate) window_size: Arc>, } impl> std::fmt::Debug for Channel { @@ -137,7 +167,10 @@ impl + Send + Sync + 'static> Channel { window_size: u32, ) -> (Self, ChannelRef) { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let window_size = Arc::new(Mutex::new(window_size)); + let window_size = Arc::new(Mutex::new(WindowSize { + window_size, + waker: None, + })); ( Self { @@ -157,7 +190,8 @@ impl + Send + Sync + 'static> Channel { /// Returns the min between the maximum packet size and the /// remaining window size in the channel. pub async fn writable_packet_size(&self) -> usize { - self.max_packet_size.min(*self.window_size.lock().await) as usize + self.max_packet_size + .min(self.window_size.lock().await.window_size) as usize } pub fn id(&self) -> ChannelId { diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index 8fe978e2..85fb2557 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -630,7 +630,7 @@ impl Session { new_size -= enc.flush_pending(channel_num)? as u32; } if let Some(chan) = self.channels.get(&channel_num) { - *chan.window_size().lock().await = new_size; + chan.window_size().lock().await.set(new_size); let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }); } diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 3e4f95cb..4c2cde5d 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -56,7 +56,7 @@ use tokio::sync::mpsc::{ }; use tokio::sync::{oneshot, Mutex}; -use crate::channels::{Channel, ChannelMsg, ChannelRef}; +use crate::channels::{Channel, ChannelMsg, ChannelRef, WindowSize}; use crate::cipher::{self, clear, CipherPair, OpeningKey}; use crate::keys::key::parse_public_key; use crate::session::{ @@ -428,7 +428,7 @@ impl Handle { async fn wait_channel_confirmation( &self, mut receiver: UnboundedReceiver, - window_size_ref: Arc>, + window_size_ref: Arc>, ) -> Result, crate::Error> { loop { match receiver.recv().await { @@ -437,7 +437,7 @@ impl Handle { max_packet_size, window_size, }) => { - *window_size_ref.lock().await = window_size; + window_size_ref.lock().await.set(window_size); return Ok(Channel { id, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index fd616b82..528b15c8 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -763,7 +763,7 @@ impl Session { enc.flush_pending(channel_num)?; } if let Some(chan) = self.channels.get(&channel_num) { - *chan.window_size().lock().await = new_size; + chan.window_size().lock().await.set(new_size); chan.send(ChannelMsg::WindowAdjusted { new_size }) .unwrap_or(()) diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index eb08579c..32848f4a 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -10,7 +10,7 @@ use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver}; use tokio::sync::{oneshot, Mutex}; use super::*; -use crate::channels::{Channel, ChannelMsg, ChannelRef}; +use crate::channels::{Channel, ChannelMsg, ChannelRef, WindowSize}; use crate::kex::EXTENSION_SUPPORT_AS_CLIENT; use crate::msg; @@ -346,7 +346,7 @@ impl Handle { async fn wait_channel_confirmation( &self, mut receiver: UnboundedReceiver, - window_size_ref: Arc>, + window_size_ref: Arc>, ) -> Result, Error> { loop { match receiver.recv().await { @@ -355,7 +355,7 @@ impl Handle { max_packet_size, window_size, }) => { - *window_size_ref.lock().await = window_size; + window_size_ref.lock().await.set(window_size); return Ok(Channel { id,