Skip to content

Commit

Permalink
g3proxy: allow to chain TcpStream after ports server
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-jq-b committed Nov 3, 2023
1 parent 7c6a5d5 commit 2290ea8
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 33 deletions.
1 change: 1 addition & 0 deletions g3proxy/CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

v1.7.29:
- Feature: add OpenSSL based NativeTlsPort server
- Feature: allow to chain TcpStream after ports server

v1.7.28:
- BUG FIX: fix server ingress net acl check for traffic from extra tls ports
Expand Down
66 changes: 53 additions & 13 deletions g3proxy/src/serve/tcp_stream/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::Arc;
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use slog::Logger;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::broadcast;
use tokio_openssl::SslStream;
Expand All @@ -30,7 +31,7 @@ use g3_daemon::server::ClientConnectionInfo;
use g3_types::acl::{AclAction, AclNetworkRule};
use g3_types::collection::{SelectivePickPolicy, SelectiveVec, SelectiveVecBuilder};
use g3_types::metrics::MetricsName;
use g3_types::net::{OpensslClientConfig, WeightedUpstreamAddr};
use g3_types::net::{OpensslClientConfig, UpstreamAddr, WeightedUpstreamAddr};

use super::common::CommonTaskContext;
use super::stats::TcpStreamServerStats;
Expand Down Expand Up @@ -155,12 +156,11 @@ impl TcpStreamServer {
false
}

async fn run_task(
fn get_ctx_and_upstream(
&self,
stream: TcpStream,
cc_info: ClientConnectionInfo,
run_ctx: ServerRunContext,
) {
) -> (CommonTaskContext, &UpstreamAddr) {
let client_ip = cc_info.client_ip();
let ctx = CommonTaskContext {
server_config: Arc::clone(&self.config),
Expand Down Expand Up @@ -193,8 +193,34 @@ impl TcpStreamServer {
}
};

TcpStreamTask::new(ctx, upstream.inner())
.into_running(stream)
(ctx, upstream.inner())
}

async fn run_task_with_tcp(
&self,
stream: TcpStream,
cc_info: ClientConnectionInfo,
run_ctx: ServerRunContext,
) {
let (ctx, upstream) = self.get_ctx_and_upstream(cc_info, run_ctx);

TcpStreamTask::new(ctx, upstream)
.tcp_into_running(stream)
.await;
}

async fn run_task_with_stream<T>(
&self,
stream: T,
cc_info: ClientConnectionInfo,
run_ctx: ServerRunContext,
) where
T: AsyncRead + AsyncWrite + Send + Sync + 'static,
{
let (ctx, upstream) = self.get_ctx_and_upstream(cc_info, run_ctx);

TcpStreamTask::new(ctx, upstream)
.stream_into_running(stream)
.await;
}
}
Expand Down Expand Up @@ -311,22 +337,36 @@ impl Server for TcpStreamServer {
return;
}

self.run_task(stream, cc_info, ctx).await
self.run_task_with_tcp(stream, cc_info, ctx).await
}

async fn run_rustls_task(
&self,
_stream: TlsStream<TcpStream>,
_cc_info: ClientConnectionInfo,
_ctx: ServerRunContext,
stream: TlsStream<TcpStream>,
cc_info: ClientConnectionInfo,
ctx: ServerRunContext,
) {
let client_addr = cc_info.client_addr();
self.server_stats.add_conn(client_addr);
if self.drop_early(client_addr) {
return;
}

self.run_task_with_stream(stream, cc_info, ctx).await
}

async fn run_openssl_task(
&self,
_stream: SslStream<TcpStream>,
_cc_info: ClientConnectionInfo,
_ctx: ServerRunContext,
stream: SslStream<TcpStream>,
cc_info: ClientConnectionInfo,
ctx: ServerRunContext,
) {
let client_addr = cc_info.client_addr();
self.server_stats.add_conn(client_addr);
if self.drop_early(client_addr) {
return;
}

self.run_task_with_stream(stream, cc_info, ctx).await
}
}
90 changes: 70 additions & 20 deletions g3proxy/src/serve/tcp_stream/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,26 @@ impl TcpStreamTask {
}
}

pub(super) async fn into_running(mut self, stream: TcpStream) {
pub(super) async fn tcp_into_running(self, stream: TcpStream) {
let (clt_r, clt_w) = self.split_tcp_clt(stream);
self.into_running(clt_r, clt_w).await;
}

pub(super) async fn stream_into_running<T>(self, stream: T)
where
T: AsyncRead + AsyncWrite + Send + Sync + 'static,
{
let (clt_r, clt_w) = self.split_stream_clt(stream);
self.into_running(clt_r, clt_w).await;
}

async fn into_running<CR, CW>(mut self, clt_r: CR, clt_w: CW)
where
CR: AsyncRead + Send + Sync + Unpin + 'static,
CW: AsyncWrite + Send + Sync + Unpin + 'static,
{
self.pre_start();
match self.run(stream).await {
match self.run(clt_r, clt_w).await {
Ok(_) => self
.get_log_context()
.log(&self.ctx.task_logger, &ServerTaskError::Finished),
Expand All @@ -91,7 +108,11 @@ impl TcpStreamTask {
self.ctx.server_stats.dec_alive_task();
}

async fn run(&mut self, clt_stream: TcpStream) -> ServerTaskResult<()> {
async fn run<CR, CW>(&mut self, clt_r: CR, clt_w: CW) -> ServerTaskResult<()>
where
CR: AsyncRead + Send + Sync + Unpin + 'static,
CW: AsyncWrite + Send + Sync + Unpin + 'static,
{
// set client side socket options
self.ctx
.cc_info
Expand Down Expand Up @@ -138,35 +159,39 @@ impl TcpStreamTask {
};

self.task_notes.stage = ServerTaskStage::Connected;
self.run_connected(clt_stream, ups_r, ups_w).await
self.run_connected(clt_r, clt_w, ups_r, ups_w).await
}

async fn run_connected<R, W>(
async fn run_connected<CR, CW, UR, UW>(
&mut self,
clt_stream: TcpStream,
ups_r: R,
ups_w: W,
clt_r: CR,
clt_w: CW,
ups_r: UR,
ups_w: UW,
) -> ServerTaskResult<()>
where
R: AsyncRead + Send + Sync + Unpin + 'static,
W: AsyncWrite + Send + Sync + Unpin + 'static,
CR: AsyncRead + Send + Sync + Unpin + 'static,
CW: AsyncWrite + Send + Sync + Unpin + 'static,
UR: AsyncRead + Send + Sync + Unpin + 'static,
UW: AsyncWrite + Send + Sync + Unpin + 'static,
{
self.task_notes.mark_relaying();
self.relay(clt_stream, ups_r, ups_w).await
self.relay(clt_r, clt_w, ups_r, ups_w).await
}

async fn relay<R, W>(
async fn relay<CR, CW, UR, UW>(
&mut self,
clt_stream: TcpStream,
ups_r: R,
ups_w: W,
clt_r: CR,
clt_w: CW,
ups_r: UR,
ups_w: UW,
) -> ServerTaskResult<()>
where
R: AsyncRead + Send + Sync + Unpin + 'static,
W: AsyncWrite + Send + Sync + Unpin + 'static,
CR: AsyncRead + Send + Sync + Unpin + 'static,
CW: AsyncWrite + Send + Sync + Unpin + 'static,
UR: AsyncRead + Send + Sync + Unpin + 'static,
UW: AsyncWrite + Send + Sync + Unpin + 'static,
{
let (clt_r, clt_w) = self.split_clt(clt_stream);

if let Some(audit_handle) = self.ctx.audit_handle.take() {
let ctx = StreamInspectContext::new(
audit_handle,
Expand Down Expand Up @@ -199,15 +224,40 @@ impl TcpStreamTask {
}
}

fn split_clt(
fn split_tcp_clt(
&self,
clt_stream: TcpStream,
) -> (
LimitedReader<impl AsyncRead>,
LimitedWriter<impl AsyncWrite>,
) {
let (clt_r, clt_w) = clt_stream.into_split();
self.setup_limit_and_stats(clt_r, clt_w)
}

fn split_stream_clt<T>(
&self,
clt_stream: T,
) -> (
LimitedReader<impl AsyncRead>,
LimitedWriter<impl AsyncWrite>,
)
where
T: AsyncRead + AsyncWrite,
{
let (clt_r, clt_w) = tokio::io::split(clt_stream);
self.setup_limit_and_stats(clt_r, clt_w)
}

fn setup_limit_and_stats<CR, CW>(
&self,
clt_r: CR,
clt_w: CW,
) -> (LimitedReader<CR>, LimitedWriter<CW>)
where
CR: AsyncRead,
CW: AsyncWrite,
{
let (clt_r_stats, clt_w_stats) =
TcpStreamTaskCltWrapperStats::new_pair(&self.ctx.server_stats, &self.task_stats);
let clt_speed_limit = &self.ctx.server_config.tcp_sock_speed_limit;
Expand Down

0 comments on commit 2290ea8

Please sign in to comment.