diff --git a/src/tunnel/client/cnx_pool.rs b/src/tunnel/client/cnx_pool.rs index 0cef719..67a3644 100644 --- a/src/tunnel/client/cnx_pool.rs +++ b/src/tunnel/client/cnx_pool.rs @@ -4,6 +4,7 @@ use crate::tunnel::client::l4_transport_stream::TransportStream; use crate::tunnel::client::WsClientConfig; use async_trait::async_trait; use bb8::ManageConnection; +use bytes::Bytes; use std::ops::Deref; use std::sync::Arc; use tracing::instrument; @@ -58,9 +59,9 @@ impl ManageConnection for WsConnection { if self.remote_addr.tls().is_some() { let tls_stream = tls::connect(self, tcp_stream).await?; - Ok(Some(TransportStream::Tls(tls_stream))) + Ok(Some(TransportStream::from_client_tls(tls_stream, Bytes::default()))) } else { - Ok(Some(TransportStream::Plain(tcp_stream))) + Ok(Some(TransportStream::from_tcp(tcp_stream, Bytes::default()))) } } diff --git a/src/tunnel/client/l4_transport_stream.rs b/src/tunnel/client/l4_transport_stream.rs index bbf55e1..879416d 100644 --- a/src/tunnel/client/l4_transport_stream.rs +++ b/src/tunnel/client/l4_transport_stream.rs @@ -1,46 +1,160 @@ +use bytes::{Buf, Bytes}; +use std::cmp; use std::io::{Error, IoSlice}; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; -use tokio_rustls::client::TlsStream; -pub enum TransportStream { - Plain(TcpStream), - Tls(TlsStream), +pub struct TransportStream { + read: TransportReadHalf, + write: TransportWriteHalf, +} + +impl TransportStream { + pub fn from_tcp(tcp: TcpStream, read_buf: Bytes) -> Self { + let (read, write) = tcp.into_split(); + Self { + read: TransportReadHalf::Plain(read, read_buf), + write: TransportWriteHalf::Plain(write), + } + } + + pub fn from_client_tls(tls: tokio_rustls::client::TlsStream, read_buf: Bytes) -> Self { + let (read, write) = tokio::io::split(tls); + Self { + read: TransportReadHalf::Tls(read, read_buf), + write: TransportWriteHalf::Tls(write), + } + } + + pub fn from_server_tls(tls: tokio_rustls::server::TlsStream, read_buf: Bytes) -> Self { + let (read, write) = tokio::io::split(tls); + Self { + read: TransportReadHalf::TlsSrv(read, read_buf), + write: TransportWriteHalf::TlsSrv(write), + } + } + + pub fn from(self, read_buf: Bytes) -> Self { + let mut read = self.read; + *read.read_buf_mut() = read_buf; + Self { + read, + write: self.write, + } + } + + pub fn into_split(self) -> (TransportReadHalf, TransportWriteHalf) { + (self.read, self.write) + } +} + +pub enum TransportReadHalf { + Plain(OwnedReadHalf, Bytes), + Tls(ReadHalf>, Bytes), + TlsSrv(ReadHalf>, Bytes), +} + +impl TransportReadHalf { + fn read_buf_mut(&mut self) -> &mut Bytes { + match self { + Self::Plain(_, buf) => buf, + Self::Tls(_, buf) => buf, + Self::TlsSrv(_, buf) => buf, + } + } +} + +pub enum TransportWriteHalf { + Plain(OwnedWriteHalf), + Tls(WriteHalf>), + TlsSrv(WriteHalf>), } impl AsyncRead for TransportStream { fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf), - Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf), - } + unsafe { self.map_unchecked_mut(|s| &mut s.read).poll_read(cx, buf) } } } impl AsyncWrite for TransportStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_write(cx, buf) } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_flush(cx) } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_shutdown(cx) } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_write_vectored(cx, bufs) } + } + + fn is_write_vectored(&self) -> bool { + self.write.is_write_vectored() + } +} + +impl AsyncRead for TransportReadHalf { + #[inline] + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + let this = self.get_mut(); + + let read_buf = this.read_buf_mut(); + if !read_buf.is_empty() { + let copy_len = cmp::min(read_buf.len(), buf.remaining()); + buf.put_slice(&read_buf[..copy_len]); + read_buf.advance(copy_len); + return Poll::Ready(Ok(())); + } + + match this { + Self::Plain(cnx, _) => Pin::new(cnx).poll_read(cx, buf), + Self::Tls(cnx, _) => Pin::new(cnx).poll_read(cx, buf), + Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TransportWriteHalf { + #[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { match self.get_mut() { Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf), Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf), + Self::TlsSrv(cnx) => Pin::new(cnx).poll_write(cx, buf), } } + #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx), Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx), + Self::TlsSrv(cnx) => Pin::new(cnx).poll_flush(cx), } } + #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx), Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx), + Self::TlsSrv(cnx) => Pin::new(cnx).poll_shutdown(cx), } } + #[inline] fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -49,13 +163,16 @@ impl AsyncWrite for TransportStream { match self.get_mut() { Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), + Self::TlsSrv(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), } } + #[inline] fn is_write_vectored(&self) -> bool { match &self { Self::Plain(cnx) => cnx.is_write_vectored(), Self::Tls(cnx) => cnx.is_write_vectored(), + Self::TlsSrv(cnx) => cnx.is_write_vectored(), } } } diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs index 1c749ae..5599c9b 100644 --- a/src/tunnel/server/handler_websocket.rs +++ b/src/tunnel/server/handler_websocket.rs @@ -2,8 +2,9 @@ use crate::restrictions::types::RestrictionsRules; use crate::tunnel::server::utils::{bad_request, inject_cookie}; use crate::tunnel::server::WsServer; use crate::tunnel::transport; -use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; +use crate::tunnel::transport::websocket::mk_websocket_tunnel; use bytes::Bytes; +use fastwebsockets::Role; use http_body_util::combinators::BoxBody; use http_body_util::Either; use hyper::body::Incoming; @@ -46,31 +47,26 @@ pub(super) async fn ws_server_upgrade( tokio::spawn( async move { let (ws_rx, ws_tx) = match fut.await { - Ok(mut ws) => { - ws.set_auto_pong(false); - ws.set_auto_close(false); - ws.set_auto_apply_mask(mask_frame); - ws.split(tokio::io::split) - } + Ok(ws) => mk_websocket_tunnel(ws, Role::Server, mask_frame)?, Err(err) => { error!("Error during http upgrade request: {:?}", err); - return; + return Err(anyhow::Error::from(err)); } }; let (close_tx, close_rx) = oneshot::channel::<()>(); - let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx); tokio::task::spawn( transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()), ); let _ = transport::io::propagate_local_to_remote( local_rx, - WebsocketTunnelWrite::new(ws_tx, pending_ops), + ws_tx, close_tx, server.config.websocket_ping_frequency, ) .await; + Ok(()) } .instrument(Span::current()), ); diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index a6d29c5..7d15c5a 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -1,11 +1,12 @@ use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; +use crate::tunnel::client::l4_transport_stream::{TransportReadHalf, TransportStream, TransportWriteHalf}; use crate::tunnel::client::WsClient; use crate::tunnel::transport::headers_from_file; use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX}; use crate::tunnel::RemoteAddr; use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; -use fastwebsockets::{CloseCode, Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; +use fastwebsockets::{CloseCode, Frame, OpCode, Payload, Role, WebSocket, WebSocketRead, WebSocketWrite}; use http_body_util::Empty; use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE}; use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; @@ -21,14 +22,16 @@ use std::ops::DerefMut; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; -use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::Notify; +use tokio_rustls::server::TlsStream; use tracing::trace; use uuid::Uuid; pub struct WebsocketTunnelWrite { - inner: WebSocketWrite>>, + inner: WebSocketWrite, buf: BytesMut, pending_operations: Receiver>, pending_ops_notify: Arc, @@ -37,7 +40,7 @@ pub struct WebsocketTunnelWrite { impl WebsocketTunnelWrite { pub fn new( - ws: WebSocketWrite>>, + ws: WebSocketWrite, (pending_operations, notify): (Receiver>, Arc), ) -> Self { Self { @@ -146,13 +149,13 @@ impl TunnelWrite for WebsocketTunnelWrite { } pub struct WebsocketTunnelRead { - inner: WebSocketRead>>, + inner: WebSocketRead, pending_operations: Sender>, notify_pending_ops: Arc, } impl WebsocketTunnelRead { - pub fn new(ws: WebSocketRead>>) -> (Self, (Receiver>, Arc)) { + pub fn new(ws: WebSocketRead) -> (Self, (Receiver>, Arc)) { let (tx, rx) = tokio::sync::mpsc::channel(10); let notify = Arc::new(Notify::new()); ( @@ -278,16 +281,52 @@ pub async fn connect( })?; debug!("with HTTP upgrade request {:?}", req); let transport = pooled_cnx.deref_mut().take().unwrap(); - let (mut ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport) + let (ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport) .await .with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?; - ws.set_auto_apply_mask(client_cfg.websocket_mask_frame); - ws.set_auto_close(false); - ws.set_auto_pong(false); + let (ws_rx, ws_tx) = mk_websocket_tunnel(ws, Role::Client, client_cfg.websocket_mask_frame)?; + Ok((ws_rx, ws_tx, response.into_parts().0)) +} - let (ws_rx, ws_tx) = ws.split(tokio::io::split); +pub fn mk_websocket_tunnel( + ws: WebSocket>, + role: Role, + mask_frame: bool, +) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite)> { + let mut ws = match role { + Role::Client => { + let stream = ws + .into_inner() + .into_inner() + .downcast::>() + .map_err(|_| anyhow!("cannot downcast websocket client stream"))?; + let transport = TransportStream::from(stream.io.into_inner(), stream.read_buf); + WebSocket::after_handshake(transport, role) + } + Role::Server => { + let upgraded = ws.into_inner().into_inner(); + match upgraded.downcast::>>() { + Ok(stream) => { + let transport = TransportStream::from_server_tls(stream.io.into_inner(), stream.read_buf); + WebSocket::after_handshake(transport, role) + } + Err(upgraded) => { + let stream = upgraded + .downcast::>() + .map_err(|_| anyhow!("cannot downcast websocket server stream"))?; + let transport = TransportStream::from_tcp(stream.io.into_inner(), stream.read_buf); + WebSocket::after_handshake(transport, role) + } + } + } + }; + + ws.set_auto_pong(false); + ws.set_auto_close(false); + ws.set_auto_apply_mask(mask_frame); + let (ws_rx, ws_tx) = ws.split(|x| x.into_split()); let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx); - Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops), response.into_parts().0)) + Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops))) }