Skip to content

Commit

Permalink
use more generic udp copy code
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-jq-b committed Nov 6, 2023
1 parent a3a7d6f commit c22c6c0
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 108 deletions.
14 changes: 8 additions & 6 deletions g3proxy/src/serve/socks_proxy/task/udp_connect/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,19 +316,21 @@ impl SocksProxyUdpConnectTask {
async fn run_relay<'a, R>(
&'a mut self,
mut clt_tcp_r: R,
clt_r: Box<dyn UdpCopyClientRecv + Unpin + Send>,
clt_w: Box<dyn UdpCopyClientSend + Unpin + Send>,
ups_r: Box<dyn UdpCopyRemoteRecv + Unpin + Send>,
ups_w: Box<dyn UdpCopyRemoteSend + Unpin + Send>,
mut clt_r: Box<dyn UdpCopyClientRecv + Unpin + Send>,
mut clt_w: Box<dyn UdpCopyClientSend + Unpin + Send>,
mut ups_r: Box<dyn UdpCopyRemoteRecv + Unpin + Send>,
mut ups_w: Box<dyn UdpCopyRemoteSend + Unpin + Send>,
escape_logger: &'a Logger,
) -> ServerTaskResult<()>
where
R: AsyncRead + Unpin,
{
let task_id = &self.task_notes.id;

let mut c_to_r = UdpCopyClientToRemote::new(clt_r, ups_w, self.ctx.server_config.udp_relay);
let mut r_to_c = UdpCopyRemoteToClient::new(clt_w, ups_r, self.ctx.server_config.udp_relay);
let mut c_to_r =
UdpCopyClientToRemote::new(&mut *clt_r, &mut *ups_w, self.ctx.server_config.udp_relay);
let mut r_to_c =
UdpCopyRemoteToClient::new(&mut *clt_w, &mut *ups_r, self.ctx.server_config.udp_relay);

let idle_duration = self.ctx.server_config.task_idle_check_duration;
let mut idle_interval =
Expand Down
263 changes: 168 additions & 95 deletions lib/g3-io-ext/src/udp/copy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,90 @@ pub enum UdpCopyError {
RemoteError(#[from] UdpCopyRemoteError),
}

pub struct UdpCopyClientToRemote<C: ?Sized, R: ?Sized> {
client: Box<C>,
remote: Box<R>,
trait UdpCopyRecv {
fn poll_recv_packet(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, usize), UdpCopyError>>;
}

struct ClientRecv<'a, T: UdpCopyClientRecv + ?Sized>(&'a mut T);

impl<'a, T: UdpCopyClientRecv + ?Sized> UdpCopyRecv for ClientRecv<'a, T> {
fn poll_recv_packet(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, usize), UdpCopyError>> {
self.0
.poll_recv_packet(cx, buf)
.map_err(UdpCopyError::ClientError)
}
}

struct RemoteRecv<'a, T: UdpCopyRemoteRecv + ?Sized>(&'a mut T);

impl<'a, T: UdpCopyRemoteRecv + ?Sized> UdpCopyRecv for RemoteRecv<'a, T> {
fn poll_recv_packet(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, usize), UdpCopyError>> {
self.0
.poll_recv_packet(cx, buf)
.map_err(UdpCopyError::RemoteError)
}
}

trait UdpCopySend {
fn poll_send_packet(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, UdpCopyError>>;
}

struct ClientSend<'a, T: UdpCopyClientSend + ?Sized>(&'a mut T);

impl<'a, T: UdpCopyClientSend + ?Sized> UdpCopySend for ClientSend<'a, T> {
fn poll_send_packet(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, UdpCopyError>> {
self.0
.poll_send_packet(cx, buf)
.map_err(UdpCopyError::ClientError)
}
}

struct RemoteSend<'a, T: UdpCopyRemoteSend + ?Sized>(&'a mut T);

impl<'a, T: UdpCopyRemoteSend + ?Sized> UdpCopySend for RemoteSend<'a, T> {
fn poll_send_packet(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, UdpCopyError>> {
self.0
.poll_send_packet(cx, buf)
.map_err(UdpCopyError::RemoteError)
}
}

struct UdpCopyBuffer {
config: LimitedUdpRelayConfig,
packet: UdpCopyPacket,
total: u64,
active: bool,
to_send: bool,
}

impl<C, R> UdpCopyClientToRemote<C, R>
where
C: UdpCopyClientRecv + ?Sized,
R: UdpCopyRemoteSend + ?Sized,
{
pub fn new(client: Box<C>, remote: Box<R>, config: LimitedUdpRelayConfig) -> Self {
let packet = UdpCopyPacket::new(client.max_hdr_len(), config.packet_size);
UdpCopyClientToRemote {
client,
remote,
impl UdpCopyBuffer {
fn new(max_hdr_size: usize, config: LimitedUdpRelayConfig) -> Self {
let packet = UdpCopyPacket::new(max_hdr_size, config.packet_size);
UdpCopyBuffer {
config,
packet,
total: 0,
Expand All @@ -81,47 +145,38 @@ where
}
}

pub fn is_idle(&self) -> bool {
!self.active
}

pub fn reset_active(&mut self) {
self.active = false;
}
}

impl<C, R> Future for UdpCopyClientToRemote<C, R>
where
C: UdpCopyClientRecv + Unpin + ?Sized,
R: UdpCopyRemoteSend + Unpin + ?Sized,
{
type Output = Result<u64, UdpCopyError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll_copy<R, S>(
&mut self,
cx: &mut Context<'_>,
mut receiver: R,
mut sender: S,
) -> Poll<Result<u64, UdpCopyError>>
where
R: UdpCopyRecv,
S: UdpCopySend,
{
let mut copy_this_round = 0usize;
loop {
let me = &mut *self;
if !me.to_send {
let (off, nr) =
ready!(Pin::new(&mut *me.client).poll_recv_packet(cx, &mut me.packet.buf))?;
if !self.to_send {
let (off, nr) = ready!(receiver.poll_recv_packet(cx, &mut self.packet.buf))?;
if nr == 0 {
break;
}
me.packet.buf_data_off = off;
me.packet.buf_data_end = nr;
me.to_send = true;
me.active = true;
self.packet.buf_data_off = off;
self.packet.buf_data_end = nr;
self.to_send = true;
self.active = true;
}

if me.to_send {
let nw = ready!(Pin::new(&mut *me.remote).poll_send_packet(
if self.to_send {
let nw = ready!(sender.poll_send_packet(
cx,
&me.packet.buf[me.packet.buf_data_off..me.packet.buf_data_end],
&self.packet.buf[self.packet.buf_data_off..self.packet.buf_data_end],
))?;
copy_this_round += nw;
me.total += nw as u64;
me.to_send = false;
me.active = true;
self.total += nw as u64;
self.to_send = false;
self.active = true;
}

if copy_this_round >= self.config.yield_size {
Expand All @@ -131,84 +186,102 @@ where
}
Poll::Ready(Ok(self.total))
}

fn is_idle(&self) -> bool {
!self.active
}

fn reset_active(&mut self) {
self.active = false;
}
}

pub struct UdpCopyRemoteToClient<C: ?Sized, R: ?Sized> {
client: Box<C>,
remote: Box<R>,
config: LimitedUdpRelayConfig,
packet: UdpCopyPacket,
total: u64,
active: bool,
to_send: bool,
pub struct UdpCopyClientToRemote<'a, C: ?Sized, R: ?Sized> {
client: &'a mut C,
remote: &'a mut R,
buffer: UdpCopyBuffer,
}

impl<C, R> UdpCopyRemoteToClient<C, R>
impl<'a, C, R> UdpCopyClientToRemote<'a, C, R>
where
C: UdpCopyClientRecv + ?Sized,
R: UdpCopyRemoteSend + ?Sized,
{
pub fn new(client: &'a mut C, remote: &'a mut R, config: LimitedUdpRelayConfig) -> Self {
let buffer = UdpCopyBuffer::new(client.max_hdr_len(), config);
UdpCopyClientToRemote {
client,
remote,
buffer,
}
}

#[inline]
pub fn is_idle(&self) -> bool {
self.buffer.is_idle()
}

#[inline]
pub fn reset_active(&mut self) {
self.buffer.reset_active()
}
}

impl<'a, C, R> Future for UdpCopyClientToRemote<'a, C, R>
where
C: UdpCopyClientRecv + Unpin + ?Sized,
R: UdpCopyRemoteSend + Unpin + ?Sized,
{
type Output = Result<u64, UdpCopyError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = &mut *self;
me.buffer
.poll_copy(cx, ClientRecv(me.client), RemoteSend(me.remote))
}
}

pub struct UdpCopyRemoteToClient<'a, C: ?Sized, R: ?Sized> {
client: &'a mut C,
remote: &'a mut R,
buffer: UdpCopyBuffer,
}

impl<'a, C, R> UdpCopyRemoteToClient<'a, C, R>
where
C: UdpCopyClientSend + ?Sized,
R: UdpCopyRemoteRecv + ?Sized,
{
pub fn new(client: Box<C>, remote: Box<R>, config: LimitedUdpRelayConfig) -> Self {
let packet = UdpCopyPacket::new(remote.max_hdr_len(), config.packet_size);
pub fn new(client: &'a mut C, remote: &'a mut R, config: LimitedUdpRelayConfig) -> Self {
let buffer = UdpCopyBuffer::new(remote.max_hdr_len(), config);
UdpCopyRemoteToClient {
client,
remote,
config,
packet,
total: 0,
active: false,
to_send: false,
buffer,
}
}

#[inline]
pub fn is_idle(&self) -> bool {
!self.active
self.buffer.is_idle()
}

#[inline]
pub fn reset_active(&mut self) {
self.active = false;
self.buffer.reset_active()
}
}

impl<C, R> Future for UdpCopyRemoteToClient<C, R>
impl<'a, C, R> Future for UdpCopyRemoteToClient<'a, C, R>
where
C: UdpCopyClientSend + Unpin + ?Sized,
R: UdpCopyRemoteRecv + Unpin + ?Sized,
{
type Output = Result<u64, UdpCopyError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut copy_this_round = 0usize;
loop {
let me = &mut *self;
if !me.to_send {
let (off, nr) =
ready!(Pin::new(&mut *me.remote).poll_recv_packet(cx, &mut me.packet.buf))?;
if nr == 0 {
break;
}
me.packet.buf_data_off = off;
me.packet.buf_data_end = nr;
me.to_send = true;
me.active = true;
}

if me.to_send {
let nw = ready!(Pin::new(&mut *me.client).poll_send_packet(
cx,
&me.packet.buf[me.packet.buf_data_off..me.packet.buf_data_end]
))?;
copy_this_round += nw;
me.total += nw as u64;
me.to_send = false;
me.active = true;
}

if copy_this_round > self.config.yield_size {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
Poll::Ready(Ok(self.total))
let me = &mut *self;
me.buffer
.poll_copy(cx, RemoteRecv(&mut *me.remote), ClientSend(&mut *me.client))
}
}
12 changes: 5 additions & 7 deletions lib/g3-io-ext/src/udp/relay/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ where
loop {
let me = &mut *self;
if !me.to_send {
let (off, nr, to) =
ready!(Pin::new(&mut *me.client).poll_recv_packet(cx, &mut me.packet.buf))?;
let (off, nr, to) = ready!(me.client.poll_recv_packet(cx, &mut me.packet.buf))?;
if nr == 0 {
break;
}
Expand All @@ -119,7 +118,7 @@ where
}

if me.to_send {
let nw = ready!(Pin::new(&mut *me.remote).poll_send_packet(
let nw = ready!(me.remote.poll_send_packet(
cx,
&me.packet.buf[me.packet.buf_data_off..me.packet.buf_data_end],
&me.packet.to
Expand Down Expand Up @@ -189,9 +188,8 @@ where
loop {
let me = &mut *self;
if !me.to_send {
let (off, nr, to) =
ready!(Pin::new(&mut *me.remote).poll_recv_packet(cx, &mut me.packet.buf))
.map_err(|e| UdpRelayError::RemoteError(None, e))?;
let (off, nr, to) = ready!(me.remote.poll_recv_packet(cx, &mut me.packet.buf))
.map_err(|e| UdpRelayError::RemoteError(None, e))?;
if nr == 0 {
break;
}
Expand All @@ -203,7 +201,7 @@ where
}

if me.to_send {
let nw = ready!(Pin::new(&mut *me.client).poll_send_packet(
let nw = ready!(me.client.poll_send_packet(
cx,
&me.packet.buf[me.packet.buf_data_off..me.packet.buf_data_end],
&me.packet.to
Expand Down

0 comments on commit c22c6c0

Please sign in to comment.