Skip to content

Commit

Permalink
feat(net4mqtt): add topic address src and dst
Browse files Browse the repository at this point in the history
  • Loading branch information
a-wing committed Oct 6, 2024
1 parent 6428478 commit 18322d7
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 106 deletions.
42 changes: 31 additions & 11 deletions libs/net4mqtt/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ enum Commands {
/// Listen local port mapping as agent's target address
#[arg(short, long, default_value_t = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6666))]
listen: SocketAddr,
/// Agent's target address
#[arg(short, long)]
target: Option<String>,
/// agent id
#[arg(short, long, default_value_t = format!("-"))]
agent_id: String,
Expand All @@ -67,9 +70,9 @@ enum Commands {
/// Mqtt Broker Address (<scheme>://<host>:<port>/<prefix>?client_id=<client_id>)
#[arg(short, long, default_value_t = format!("mqtt://localhost:1883/net4mqtt"))]
mqtt_url: String,
/// Agent's target address
#[arg(short, long, default_value_t = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 7777))]
target: SocketAddr,
/// Default Agent's target address
#[arg(short, long, default_value_t = format!("{}", SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 7777)))]
target: String,
/// Set Current agent id
#[arg(short, long, default_value_t = format!("-"))]
id: String,
Expand Down Expand Up @@ -104,13 +107,22 @@ async fn main() {
debug!("use domain: {:?}", domain);

let listener = TcpListener::bind(listen).await.unwrap();
proxy::local_socks(&mqtt_url, listener, &agent_id, &id, None, None, kcp)
.await
.unwrap();
proxy::local_socks(
&mqtt_url,
listener,
(&agent_id, &id),
Some(domain),
None,
None,
kcp,
)
.await
.unwrap();
}
Commands::Local {
mqtt_url,
listen,
target,
agent_id,
id,
udp,
Expand All @@ -120,14 +132,22 @@ async fn main() {

if udp {
let sock = UdpSocket::bind(listen).await.unwrap();
proxy::local_ports_udp(&mqtt_url, sock, &agent_id, &id, None, None)
proxy::local_ports_udp(&mqtt_url, sock, target, (&agent_id, &id), None, None)
.await
.unwrap();
} else {
let listener = TcpListener::bind(listen).await.unwrap();
proxy::local_ports_tcp(&mqtt_url, listener, &agent_id, &id, None, None, kcp)
.await
.unwrap();
proxy::local_ports_tcp(
&mqtt_url,
listener,
target,
(&agent_id, &id),
None,
None,
kcp,
)
.await
.unwrap();
}
}
Commands::Agent {
Expand All @@ -137,7 +157,7 @@ async fn main() {
} => {
info!("Running as agent, {:?}", target);

proxy::agent(&mqtt_url, target, &id, None, None)
proxy::agent(&mqtt_url, &target, &id, None, None)
.await
.unwrap();
}
Expand Down
76 changes: 43 additions & 33 deletions libs/net4mqtt/src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;

use anyhow::{anyhow, Error, Result};
use kcp::Kcp;
Expand Down Expand Up @@ -171,34 +172,33 @@ async fn up_udp_vnet(
}

async fn up_agent_vclient(
address: SocketAddr,
address: &str,
protocol: &str,
topic: String,
sender: UnboundedSender<(String, Vec<u8>)>,
receiver: UnboundedReceiver<(String, Vec<u8>)>,
) -> Result<(), Error> {
match protocol {
topic::protocol::KCP => {
let socket = TcpStream::connect(address).await.unwrap();
let socket = TcpStream::connect(address).await?;
up_kcp_vnet(socket, topic, sender, receiver).await
}
topic::protocol::TCP => {
let socket = TcpStream::connect(address).await.unwrap();
let socket = TcpStream::connect(address).await?;
up_tcp_vnet(socket, topic, sender, receiver).await
}
topic::protocol::UDP => {
let socket = UdpSocket::bind(SocketAddr::new(
// "0.0.0.0:0"
// "[::]:0"
match address {
match SocketAddr::from_str(address)? {
SocketAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
SocketAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
},
0,
))
.await
.unwrap();
socket.connect(address).await.unwrap();
.await?;
socket.connect(address).await?;
up_udp_vnet(socket, topic, sender, receiver).await
}
e => Err(anyhow!("unknown protocol {}", e)),
Expand Down Expand Up @@ -256,7 +256,7 @@ async fn mqtt_client_init(

pub async fn agent(
mqtt_url: &str,
address: SocketAddr,
address: &str,
agent_id: &str,
xdata: Option<(Vec<u8>, Option<Vec<u8>>)>,
on_xdata: Option<Sender<(String, String, Vec<u8>)>>,
Expand Down Expand Up @@ -288,8 +288,8 @@ pub async fn agent(
result = receiver.recv() => {
match result {
Some((key, data)) => {
let (prefix, agent_id, local_id, _label, protocol, address) = topic::parse(&key);
client.publish(topic::build(prefix, agent_id, local_id, topic::label::O, protocol, address),
let (prefix, agent_id, local_id, _label, protocol, src, dst) = topic::parse(&key);
client.publish(topic::build(prefix, agent_id, local_id, topic::label::O, protocol, src, dst),
QoS::AtMostOnce,
false,
data
Expand All @@ -303,7 +303,7 @@ pub async fn agent(
Ok(notification) => {
if let Some(p) = mqtt_receive(notification) {
let topic = p.topic.clone();
let (_prefix, agent_id, local_id, label, protocol, _address) = topic::parse(&topic);
let (_prefix, agent_id, local_id, label, protocol, _src, dst) = topic::parse(&topic);

match label {
topic::label::X => {
Expand All @@ -318,8 +318,9 @@ pub async fn agent(
let (vnet_tx, vnet_rx) = unbounded_channel::<(String, Vec<u8>)>();
let topic = p.topic.clone();
let protocol = protocol.to_string();
let dst = if dst == topic::NIL { address } else { dst }.to_string();
task::spawn(async move {
if let Err(e) = up_agent_vclient(address, &protocol, topic, sender, vnet_rx).await {
if let Err(e) = up_agent_vclient(&dst, &protocol, topic, sender, vnet_rx).await {
error!("agent vnet error: {:?}", e)
}
});
Expand Down Expand Up @@ -350,12 +351,14 @@ pub async fn agent(
pub async fn local_ports_tcp(
mqtt_url: &str,
listener: TcpListener,
agent_id: &str,
local_id: &str,
target: Option<String>,
id: (&str, &str),
xdata: Option<(Vec<u8>, Option<Vec<u8>>)>,
on_xdata: Option<Sender<(String, String, Vec<u8>)>>,
tcp_over_kcp: bool,
) -> Result<(), Error> {
let (agent_id, local_id) = id;
let target = target.unwrap_or(topic::NIL.to_string());
let mut senders =
LruCache::<String, UnboundedSender<(String, Vec<u8>)>>::with_expiry_duration_and_capacity(
LRU_TIME_TO_LIVE,
Expand Down Expand Up @@ -386,8 +389,8 @@ pub async fn local_ports_tcp(
let protocol = if tcp_over_kcp { topic::protocol::KCP } else { topic::protocol::TCP };

let addr = socket.peer_addr().unwrap().to_string();
let key_send = topic::build(prefix, agent_id, local_id, topic::label::I, protocol, &addr);
let key_recv = topic::build(prefix, agent_id, local_id, topic::label::O, protocol, &addr);
let key_send = topic::build(prefix, agent_id, local_id, topic::label::I, protocol, &addr, &target);
let key_recv = topic::build(prefix, agent_id, local_id, topic::label::O, protocol, &addr, &target);

senders.insert(key_recv, vnet_tx);
task::spawn(async move {
Expand Down Expand Up @@ -417,20 +420,20 @@ pub async fn local_ports_tcp(
Ok(notification) => {
if let Some(p) = mqtt_receive(notification) {
let topic = p.topic.clone();
let (_prefix, agent_id, local_id, label, protocol, _address) = topic::parse(&topic);
let (_prefix, agent_id, local_id, label, protocol, _src, _dst) = topic::parse(&topic);

match (label, protocol) {
(topic::label::X, _) => {
if let Some(s) = on_xdata {
s.send((agent_id.to_string(), local_id.to_string(), p.payload.to_vec())).await.unwrap();
s.send((agent_id.to_string(), local_id.to_string(), p.payload.to_vec())).await?;
}
},
(_, topic::protocol::KCP | topic::protocol::TCP) => {
if let Some(sender) = senders.get(&p.topic) {
if sender.is_closed() {
senders.remove(&p.topic);
} else {
sender.send((p.topic, p.payload.to_vec())).unwrap();
sender.send((p.topic, p.payload.to_vec()))?;
}
}
},
Expand All @@ -453,11 +456,13 @@ pub async fn local_ports_tcp(
pub async fn local_ports_udp(
mqtt_url: &str,
sock: UdpSocket,
agent_id: &str,
local_id: &str,
target: Option<String>,
id: (&str, &str),
xdata: Option<(Vec<u8>, Option<Vec<u8>>)>,
on_xdata: Option<Sender<(String, String, Vec<u8>)>>,
) -> Result<(), Error> {
let (agent_id, local_id) = id;
let target = target.unwrap_or(topic::NIL.to_string());
let (sender, mut receiver) = unbounded_channel::<(String, Vec<u8>)>();

let (url, prefix) = crate::utils::pre_url(mqtt_url.parse::<Url>()?);
Expand All @@ -480,7 +485,7 @@ pub async fn local_ports_udp(
select! {
Ok((len, addr)) = sock.recv_from(&mut buf) => {
sender.send((
topic::build(prefix, agent_id, local_id, topic::label::I, topic::protocol::UDP, &addr.to_string()),
topic::build(prefix, agent_id, local_id, topic::label::I, topic::protocol::UDP, &addr.to_string(), &target),
buf[..len].to_vec())).unwrap();
}
result = receiver.recv() => {
Expand All @@ -501,15 +506,15 @@ pub async fn local_ports_udp(
Ok(notification) => {
if let Some(p) = mqtt_receive(notification) {
let topic = p.topic.clone();
let (_prefix, _agent_id, _local_id, label, protocol, address) = topic::parse(&topic);
let (_prefix, _agent_id, _local_id, label, protocol, src, _dst) = topic::parse(&topic);

match (label, protocol) {
(topic::label::X, _) => {
if let Some(s) = on_xdata {
s.send((agent_id.to_string(), local_id.to_string(), p.payload.to_vec())).await.unwrap();
s.send((agent_id.to_string(), local_id.to_string(), p.payload.to_vec())).await?;
}
},
(_, topic::protocol::UDP) => { let _ = sock.send_to(&p.payload, address).await.unwrap(); },
(_, topic::protocol::UDP) => { let _ = sock.send_to(&p.payload, src).await?; },
(label, protocol) => info!("unknown label: {} and protocol: {}", label, protocol)
}
}
Expand Down Expand Up @@ -543,12 +548,13 @@ use std::sync::Arc;
pub async fn local_socks(
mqtt_url: &str,
listener: TcpListener,
agent_id: &str,
local_id: &str,
id: (&str, &str),
domain: Option<String>,
xdata: Option<(Vec<u8>, Option<Vec<u8>>)>,
on_xdata: Option<Sender<(String, String, Vec<u8>)>>,
tcp_over_kcp: bool,
) -> Result<(), Error> {
let (agent_id, local_id) = id;
let mut senders =
LruCache::<String, UnboundedSender<(String, Vec<u8>)>>::with_expiry_duration_and_capacity(
LRU_TIME_TO_LIVE,
Expand Down Expand Up @@ -576,20 +582,24 @@ pub async fn local_socks(
let on_xdata = on_xdata.clone();
select! {
Ok((conn, _)) = server.accept() => {
match crate::socks::handle(conn).await {
Ok((target, socket)) => {
let agent_id = match target {
match crate::socks::handle(conn, domain.clone()).await {
Ok((id, target, socket)) => {
let agent_id = match id {
Some(id) => id,
None => agent_id.to_string(),
};
let target = match target {
Some(t) => t,
None => topic::NIL.to_string(),
};

let (vnet_tx, vnet_rx) = unbounded_channel::<(String, Vec<u8>)>();

let protocol = if tcp_over_kcp { topic::protocol::KCP } else { topic::protocol::TCP };

let addr = socket.peer_addr().unwrap().to_string();
let key_send = topic::build(prefix, &agent_id, local_id, topic::label::I, protocol, &addr);
let key_recv = topic::build(prefix, &agent_id, local_id, topic::label::O, protocol, &addr);
let key_send = topic::build(prefix, &agent_id, local_id, topic::label::I, protocol, &addr, &target);
let key_recv = topic::build(prefix, &agent_id, local_id, topic::label::O, protocol, &addr, &target);

senders.insert(key_recv, vnet_tx);
task::spawn(async move {
Expand Down Expand Up @@ -623,7 +633,7 @@ pub async fn local_socks(
Ok(notification) => {
if let Some(p) = mqtt_receive(notification) {
let topic = p.topic.clone();
let (_prefix, agent_id, local_id, label, protocol, _address) = topic::parse(&topic);
let (_prefix, agent_id, local_id, label, protocol, _src, _dst) = topic::parse(&topic);

match (label, protocol) {
(topic::label::X, _) => {
Expand Down
29 changes: 21 additions & 8 deletions libs/net4mqtt/src/socks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use tokio::io::AsyncWriteExt;

pub(crate) async fn handle(
conn: IncomingConnection<(), NeedAuthenticate>,
) -> Result<(Option<String>, Connect<Ready>), Error> {
domain: Option<String>,
) -> Result<(Option<String>, Option<String>, Connect<Ready>), Error> {
let conn = match conn.authenticate().await {
Ok((conn, _)) => conn,
Err((err, mut conn)) => {
Expand Down Expand Up @@ -51,12 +52,24 @@ pub(crate) async fn handle(
let _ = conn.close().await;
}
Ok(Command::Connect(connect, addr)) => {
let target = match addr {
Address::DomainAddress(domain, _port) => match std::str::from_utf8(&domain) {
Ok(raw) => Some(crate::kxdns::Kxdns::resolver(raw).to_string()),
Err(_) => None,
},
Address::SocketAddress(_) => None,
let (id, target) = match addr {
Address::DomainAddress(domain_address, _port) => {
match std::str::from_utf8(&domain_address) {
Ok(raw) => {
if let Some(d) = domain {
if raw.ends_with(&d) {
(Some(crate::kxdns::Kxdns::resolver(raw).to_string()), None)
} else {
(None, Some(raw.to_string()))
}
} else {
(None, Some(raw.to_string()))
}
}
Err(_) => (None, None),
}
}
Address::SocketAddress(ip) => (None, Some(ip.to_string())),
};

let replied = connect
Expand All @@ -70,7 +83,7 @@ pub(crate) async fn handle(
return Err(anyhow!(err));
}
};
return Ok((target, conn));
return Ok((id, target, conn));
}
Err((err, mut conn)) => {
let _ = conn.shutdown().await;
Expand Down
Loading

0 comments on commit 18322d7

Please sign in to comment.