diff --git a/g3proxy/CHANGELOG b/g3proxy/CHANGELOG index 7e3bc8893..645b52eb1 100644 --- a/g3proxy/CHANGELOG +++ b/g3proxy/CHANGELOG @@ -1,6 +1,7 @@ v1.7.29: - Feature: add OpenSSL based NativeTlsPort server + - Feature: allow to chain TcpStream after ports server v1.7.28: - BUG FIX: fix server ingress net acl check for traffic from extra tls ports diff --git a/g3proxy/src/serve/tcp_stream/server.rs b/g3proxy/src/serve/tcp_stream/server.rs index 4913f85f6..6fefcdab2 100644 --- a/g3proxy/src/serve/tcp_stream/server.rs +++ b/g3proxy/src/serve/tcp_stream/server.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use anyhow::{anyhow, Context}; use async_trait::async_trait; use slog::Logger; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::sync::broadcast; use tokio_openssl::SslStream; @@ -30,7 +31,7 @@ use g3_daemon::server::ClientConnectionInfo; use g3_types::acl::{AclAction, AclNetworkRule}; use g3_types::collection::{SelectivePickPolicy, SelectiveVec, SelectiveVecBuilder}; use g3_types::metrics::MetricsName; -use g3_types::net::{OpensslClientConfig, WeightedUpstreamAddr}; +use g3_types::net::{OpensslClientConfig, UpstreamAddr, WeightedUpstreamAddr}; use super::common::CommonTaskContext; use super::stats::TcpStreamServerStats; @@ -155,12 +156,11 @@ impl TcpStreamServer { false } - async fn run_task( + fn get_ctx_and_upstream( &self, - stream: TcpStream, cc_info: ClientConnectionInfo, run_ctx: ServerRunContext, - ) { + ) -> (CommonTaskContext, &UpstreamAddr) { let client_ip = cc_info.client_ip(); let ctx = CommonTaskContext { server_config: Arc::clone(&self.config), @@ -193,8 +193,34 @@ impl TcpStreamServer { } }; - TcpStreamTask::new(ctx, upstream.inner()) - .into_running(stream) + (ctx, upstream.inner()) + } + + async fn run_task_with_tcp( + &self, + stream: TcpStream, + cc_info: ClientConnectionInfo, + run_ctx: ServerRunContext, + ) { + let (ctx, upstream) = self.get_ctx_and_upstream(cc_info, run_ctx); + + TcpStreamTask::new(ctx, upstream) + .tcp_into_running(stream) + .await; + } + + async fn run_task_with_stream( + &self, + stream: T, + cc_info: ClientConnectionInfo, + run_ctx: ServerRunContext, + ) where + T: AsyncRead + AsyncWrite + Send + Sync + 'static, + { + let (ctx, upstream) = self.get_ctx_and_upstream(cc_info, run_ctx); + + TcpStreamTask::new(ctx, upstream) + .stream_into_running(stream) .await; } } @@ -311,22 +337,36 @@ impl Server for TcpStreamServer { return; } - self.run_task(stream, cc_info, ctx).await + self.run_task_with_tcp(stream, cc_info, ctx).await } async fn run_rustls_task( &self, - _stream: TlsStream, - _cc_info: ClientConnectionInfo, - _ctx: ServerRunContext, + stream: TlsStream, + cc_info: ClientConnectionInfo, + ctx: ServerRunContext, ) { + let client_addr = cc_info.client_addr(); + self.server_stats.add_conn(client_addr); + if self.drop_early(client_addr) { + return; + } + + self.run_task_with_stream(stream, cc_info, ctx).await } async fn run_openssl_task( &self, - _stream: SslStream, - _cc_info: ClientConnectionInfo, - _ctx: ServerRunContext, + stream: SslStream, + cc_info: ClientConnectionInfo, + ctx: ServerRunContext, ) { + let client_addr = cc_info.client_addr(); + self.server_stats.add_conn(client_addr); + if self.drop_early(client_addr) { + return; + } + + self.run_task_with_stream(stream, cc_info, ctx).await } } diff --git a/g3proxy/src/serve/tcp_stream/task.rs b/g3proxy/src/serve/tcp_stream/task.rs index c0995a8f6..361ee5448 100644 --- a/g3proxy/src/serve/tcp_stream/task.rs +++ b/g3proxy/src/serve/tcp_stream/task.rs @@ -64,9 +64,26 @@ impl TcpStreamTask { } } - pub(super) async fn into_running(mut self, stream: TcpStream) { + pub(super) async fn tcp_into_running(self, stream: TcpStream) { + let (clt_r, clt_w) = self.split_tcp_clt(stream); + self.into_running(clt_r, clt_w).await; + } + + pub(super) async fn stream_into_running(self, stream: T) + where + T: AsyncRead + AsyncWrite + Send + Sync + 'static, + { + let (clt_r, clt_w) = self.split_stream_clt(stream); + self.into_running(clt_r, clt_w).await; + } + + async fn into_running(mut self, clt_r: CR, clt_w: CW) + where + CR: AsyncRead + Send + Sync + Unpin + 'static, + CW: AsyncWrite + Send + Sync + Unpin + 'static, + { self.pre_start(); - match self.run(stream).await { + match self.run(clt_r, clt_w).await { Ok(_) => self .get_log_context() .log(&self.ctx.task_logger, &ServerTaskError::Finished), @@ -91,7 +108,11 @@ impl TcpStreamTask { self.ctx.server_stats.dec_alive_task(); } - async fn run(&mut self, clt_stream: TcpStream) -> ServerTaskResult<()> { + async fn run(&mut self, clt_r: CR, clt_w: CW) -> ServerTaskResult<()> + where + CR: AsyncRead + Send + Sync + Unpin + 'static, + CW: AsyncWrite + Send + Sync + Unpin + 'static, + { // set client side socket options self.ctx .cc_info @@ -138,35 +159,39 @@ impl TcpStreamTask { }; self.task_notes.stage = ServerTaskStage::Connected; - self.run_connected(clt_stream, ups_r, ups_w).await + self.run_connected(clt_r, clt_w, ups_r, ups_w).await } - async fn run_connected( + async fn run_connected( &mut self, - clt_stream: TcpStream, - ups_r: R, - ups_w: W, + clt_r: CR, + clt_w: CW, + ups_r: UR, + ups_w: UW, ) -> ServerTaskResult<()> where - R: AsyncRead + Send + Sync + Unpin + 'static, - W: AsyncWrite + Send + Sync + Unpin + 'static, + CR: AsyncRead + Send + Sync + Unpin + 'static, + CW: AsyncWrite + Send + Sync + Unpin + 'static, + UR: AsyncRead + Send + Sync + Unpin + 'static, + UW: AsyncWrite + Send + Sync + Unpin + 'static, { self.task_notes.mark_relaying(); - self.relay(clt_stream, ups_r, ups_w).await + self.relay(clt_r, clt_w, ups_r, ups_w).await } - async fn relay( + async fn relay( &mut self, - clt_stream: TcpStream, - ups_r: R, - ups_w: W, + clt_r: CR, + clt_w: CW, + ups_r: UR, + ups_w: UW, ) -> ServerTaskResult<()> where - R: AsyncRead + Send + Sync + Unpin + 'static, - W: AsyncWrite + Send + Sync + Unpin + 'static, + CR: AsyncRead + Send + Sync + Unpin + 'static, + CW: AsyncWrite + Send + Sync + Unpin + 'static, + UR: AsyncRead + Send + Sync + Unpin + 'static, + UW: AsyncWrite + Send + Sync + Unpin + 'static, { - let (clt_r, clt_w) = self.split_clt(clt_stream); - if let Some(audit_handle) = self.ctx.audit_handle.take() { let ctx = StreamInspectContext::new( audit_handle, @@ -199,7 +224,7 @@ impl TcpStreamTask { } } - fn split_clt( + fn split_tcp_clt( &self, clt_stream: TcpStream, ) -> ( @@ -207,7 +232,32 @@ impl TcpStreamTask { LimitedWriter, ) { let (clt_r, clt_w) = clt_stream.into_split(); + self.setup_limit_and_stats(clt_r, clt_w) + } + + fn split_stream_clt( + &self, + clt_stream: T, + ) -> ( + LimitedReader, + LimitedWriter, + ) + where + T: AsyncRead + AsyncWrite, + { + let (clt_r, clt_w) = tokio::io::split(clt_stream); + self.setup_limit_and_stats(clt_r, clt_w) + } + fn setup_limit_and_stats( + &self, + clt_r: CR, + clt_w: CW, + ) -> (LimitedReader, LimitedWriter) + where + CR: AsyncRead, + CW: AsyncWrite, + { let (clt_r_stats, clt_w_stats) = TcpStreamTaskCltWrapperStats::new_pair(&self.ctx.server_stats, &self.task_stats); let clt_speed_limit = &self.ctx.server_config.tcp_sock_speed_limit;