Skip to content

Commit

Permalink
use more generic udp relay code
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-jq-b committed Nov 6, 2023
1 parent c22c6c0 commit 9270aa9
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 108 deletions.
12 changes: 6 additions & 6 deletions g3proxy/src/serve/socks_proxy/task/udp_associate/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ impl SocksProxyUdpAssociateTask {
async fn run_relay<'a, R>(
&'a mut self,
mut clt_tcp_r: R,
clt_r: Box<dyn UdpRelayClientRecv + Unpin + Send>,
clt_w: Box<dyn UdpRelayClientSend + Unpin + Send>,
ups_r: Box<dyn UdpRelayRemoteRecv + Unpin + Send>,
ups_w: Box<dyn UdpRelayRemoteSend + Unpin + Send>,
mut clt_r: Box<dyn UdpRelayClientRecv + Unpin + Send>,
mut clt_w: Box<dyn UdpRelayClientSend + Unpin + Send>,
mut ups_r: Box<dyn UdpRelayRemoteRecv + Unpin + Send>,
mut ups_w: Box<dyn UdpRelayRemoteSend + Unpin + Send>,
escape_logger: &'a Logger,
) -> ServerTaskResult<()>
where
Expand All @@ -278,9 +278,9 @@ impl SocksProxyUdpAssociateTask {
let task_id = &self.task_notes.id;

let mut c_to_r =
UdpRelayClientToRemote::new(clt_r, ups_w, self.ctx.server_config.udp_relay);
UdpRelayClientToRemote::new(&mut *clt_r, &mut *ups_w, self.ctx.server_config.udp_relay);
let mut r_to_c =
UdpRelayRemoteToClient::new(clt_w, ups_r, self.ctx.server_config.udp_relay);
UdpRelayRemoteToClient::new(&mut *clt_w, &mut *ups_r, self.ctx.server_config.udp_relay);

let idle_duration = self.ctx.server_config.task_idle_check_duration;
let mut idle_interval =
Expand Down
278 changes: 176 additions & 102 deletions lib/g3-io-ext/src/udp/relay/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct UdpRelayPacket {
buf: Box<[u8]>,
buf_data_off: usize,
buf_data_end: usize,
to: UpstreamAddr,
ups: UpstreamAddr,
}

impl UdpRelayPacket {
Expand All @@ -44,7 +44,7 @@ impl UdpRelayPacket {
buf: vec![0; buf_size].into_boxed_slice(),
buf_data_off: 0,
buf_data_end: 0,
to: UpstreamAddr::empty(),
ups: UpstreamAddr::empty(),
}
}
}
Expand All @@ -57,26 +57,93 @@ pub enum UdpRelayError {
RemoteError(Option<UpstreamAddr>, UdpRelayRemoteError),
}

pub struct UdpRelayClientToRemote<C: ?Sized, R: ?Sized> {
client: Box<C>,
remote: Box<R>,
trait UdpRelayRecv {
fn poll_recv_packet(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, usize, UpstreamAddr), UdpRelayError>>;
}

struct ClientRecv<'a, T: UdpRelayClientRecv + ?Sized>(&'a mut T);

impl<'a, T: UdpRelayClientRecv + ?Sized> UdpRelayRecv for ClientRecv<'a, T> {
fn poll_recv_packet(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, usize, UpstreamAddr), UdpRelayError>> {
self.0
.poll_recv_packet(cx, buf)
.map_err(UdpRelayError::ClientError)
}
}

struct RemoteRecv<'a, T: UdpRelayRemoteRecv + ?Sized>(&'a mut T);

impl<'a, T: UdpRelayRemoteRecv + ?Sized> UdpRelayRecv for RemoteRecv<'a, T> {
fn poll_recv_packet(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, usize, UpstreamAddr), UdpRelayError>> {
self.0
.poll_recv_packet(cx, buf)
.map_err(|e| UdpRelayError::RemoteError(None, e))
}
}

trait UdpRelaySend {
fn poll_send_packet(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
ups: &UpstreamAddr,
) -> Poll<Result<usize, UdpRelayError>>;
}

struct ClientSend<'a, T: UdpRelayClientSend + ?Sized>(&'a mut T);

impl<'a, T: UdpRelayClientSend + ?Sized> UdpRelaySend for ClientSend<'a, T> {
fn poll_send_packet(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
from: &UpstreamAddr,
) -> Poll<Result<usize, UdpRelayError>> {
self.0
.poll_send_packet(cx, buf, from)
.map_err(UdpRelayError::ClientError)
}
}

struct RemoteSend<'a, T: UdpRelayRemoteSend + ?Sized>(&'a mut T);

impl<'a, T: UdpRelayRemoteSend + ?Sized> UdpRelaySend for RemoteSend<'a, T> {
fn poll_send_packet(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
to: &UpstreamAddr,
) -> Poll<Result<usize, UdpRelayError>> {
self.0
.poll_send_packet(cx, buf, to)
.map_err(|e| UdpRelayError::RemoteError(Some(to.clone()), e))
}
}

struct UdpRelayBuffer {
config: LimitedUdpRelayConfig,
packet: UdpRelayPacket,
total: u64,
active: bool,
to_send: bool,
}

impl<C, R> UdpRelayClientToRemote<C, R>
where
C: UdpRelayClientRecv + ?Sized,
R: UdpRelayRemoteSend + ?Sized,
{
pub fn new(client: Box<C>, remote: Box<R>, config: LimitedUdpRelayConfig) -> Self {
let packet = UdpRelayPacket::new(client.max_hdr_len(), config.packet_size);
UdpRelayClientToRemote {
client,
remote,
impl UdpRelayBuffer {
fn new(max_hdr_size: usize, config: LimitedUdpRelayConfig) -> Self {
let packet = UdpRelayPacket::new(max_hdr_size, config.packet_size);
UdpRelayBuffer {
config,
packet,
total: 0,
Expand All @@ -85,49 +152,40 @@ where
}
}

pub fn is_idle(&self) -> bool {
!self.active
}

pub fn reset_active(&mut self) {
self.active = false;
}
}

impl<C, R> Future for UdpRelayClientToRemote<C, R>
where
C: UdpRelayClientRecv + Unpin + ?Sized,
R: UdpRelayRemoteSend + Unpin + ?Sized,
{
type Output = Result<u64, UdpRelayError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll_relay<R, S>(
&mut self,
cx: &mut Context<'_>,
mut receiver: R,
mut sender: S,
) -> Poll<Result<u64, UdpRelayError>>
where
R: UdpRelayRecv,
S: UdpRelaySend,
{
let mut relay_this_round = 0usize;
loop {
let me = &mut *self;
if !me.to_send {
let (off, nr, to) = ready!(me.client.poll_recv_packet(cx, &mut me.packet.buf))?;
if !self.to_send {
let (off, nr, ups) = ready!(receiver.poll_recv_packet(cx, &mut self.packet.buf))?;
if nr == 0 {
break;
}
me.packet.buf_data_off = off;
me.packet.buf_data_end = nr;
me.packet.to = to;
me.to_send = true;
me.active = true;
self.packet.buf_data_off = off;
self.packet.buf_data_end = nr;
self.packet.ups = ups;
self.to_send = true;
self.active = true;
}

if me.to_send {
let nw = ready!(me.remote.poll_send_packet(
if self.to_send {
let nw = ready!(sender.poll_send_packet(
cx,
&me.packet.buf[me.packet.buf_data_off..me.packet.buf_data_end],
&me.packet.to
))
.map_err(|e| UdpRelayError::RemoteError(Some(me.packet.to.clone()), e))?;
&self.packet.buf[self.packet.buf_data_off..self.packet.buf_data_end],
&self.packet.ups
))?;
relay_this_round += nw;
me.total += nw as u64;
me.to_send = false;
me.active = true;
self.total += nw as u64;
self.to_send = false;
self.active = true;
}

if relay_this_round >= self.config.yield_size {
Expand All @@ -137,86 +195,102 @@ where
}
Poll::Ready(Ok(self.total))
}

fn is_idle(&self) -> bool {
!self.active
}

fn reset_active(&mut self) {
self.active = false;
}
}

pub struct UdpRelayRemoteToClient<C: ?Sized, R: ?Sized> {
client: Box<C>,
remote: Box<R>,
config: LimitedUdpRelayConfig,
packet: UdpRelayPacket,
total: u64,
active: bool,
to_send: bool,
pub struct UdpRelayClientToRemote<'a, C: ?Sized, R: ?Sized> {
client: &'a mut C,
remote: &'a mut R,
buffer: UdpRelayBuffer,
}

impl<'a, C, R> UdpRelayClientToRemote<'a, C, R>
where
C: UdpRelayClientRecv + ?Sized,
R: UdpRelayRemoteSend + ?Sized,
{
pub fn new(client: &'a mut C, remote: &'a mut R, config: LimitedUdpRelayConfig) -> Self {
let buffer = UdpRelayBuffer::new(client.max_hdr_len(), config);
UdpRelayClientToRemote {
client,
remote,
buffer,
}
}

#[inline]
pub fn is_idle(&self) -> bool {
self.buffer.is_idle()
}

#[inline]
pub fn reset_active(&mut self) {
self.buffer.reset_active()
}
}

impl<'a, C, R> Future for UdpRelayClientToRemote<'a, C, R>
where
C: UdpRelayClientRecv + Unpin + ?Sized,
R: UdpRelayRemoteSend + Unpin + ?Sized,
{
type Output = Result<u64, UdpRelayError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = &mut *self;
me.buffer
.poll_relay(cx, ClientRecv(me.client), RemoteSend(me.remote))
}
}

pub struct UdpRelayRemoteToClient<'a, C: ?Sized, R: ?Sized> {
client: &'a mut C,
remote: &'a mut R,
buffer: UdpRelayBuffer,
}

impl<C, R> UdpRelayRemoteToClient<C, R>
impl<'a, C, R> UdpRelayRemoteToClient<'a, C, R>
where
C: UdpRelayClientSend + ?Sized,
R: UdpRelayRemoteRecv + ?Sized,
{
pub fn new(client: Box<C>, remote: Box<R>, config: LimitedUdpRelayConfig) -> Self {
let packet = UdpRelayPacket::new(remote.max_hdr_len(), config.packet_size);
pub fn new(client: &'a mut C, remote: &'a mut R, config: LimitedUdpRelayConfig) -> Self {
let buffer = UdpRelayBuffer::new(remote.max_hdr_len(), config);
UdpRelayRemoteToClient {
client,
remote,
config,
packet,
total: 0,
active: false,
to_send: false,
buffer,
}
}

#[inline]
pub fn is_idle(&self) -> bool {
!self.active
self.buffer.is_idle()
}

#[inline]
pub fn reset_active(&mut self) {
self.active = false;
self.buffer.reset_active()
}
}

impl<C, R> Future for UdpRelayRemoteToClient<C, R>
impl<'a, C, R> Future for UdpRelayRemoteToClient<'a, C, R>
where
C: UdpRelayClientSend + Unpin + ?Sized,
R: UdpRelayRemoteRecv + Unpin + ?Sized,
{
type Output = Result<u64, UdpRelayError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut relay_this_round = 0usize;
loop {
let me = &mut *self;
if !me.to_send {
let (off, nr, to) = ready!(me.remote.poll_recv_packet(cx, &mut me.packet.buf))
.map_err(|e| UdpRelayError::RemoteError(None, e))?;
if nr == 0 {
break;
}
me.packet.buf_data_off = off;
me.packet.buf_data_end = nr;
me.packet.to = to;
me.to_send = true;
me.active = true;
}

if me.to_send {
let nw = ready!(me.client.poll_send_packet(
cx,
&me.packet.buf[me.packet.buf_data_off..me.packet.buf_data_end],
&me.packet.to
))?;
relay_this_round += nw;
me.total += nw as u64;
me.to_send = false;
me.active = true;
}

if relay_this_round > self.config.yield_size {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
Poll::Ready(Ok(self.total))
let me = &mut *self;
me.buffer
.poll_relay(cx, RemoteRecv(me.remote), ClientSend(me.client))
}
}

0 comments on commit 9270aa9

Please sign in to comment.