From 3400b2538dff73d524409c2f6b5ab9d8f863122f Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Wed, 3 Jan 2024 23:49:40 +0700 Subject: [PATCH 1/8] feat: virtual socket --- Cargo.toml | 1 + packages/integration_tests/src/lib.rs | 1 + .../integration_tests/src/virtual_socket.rs | 123 +++++++ packages/runner/Cargo.toml | 4 +- packages/runner/src/lib.rs | 3 + packages/services/virtual_socket/Cargo.toml | 20 ++ .../services/virtual_socket/src/behavior.rs | 82 +++++ .../services/virtual_socket/src/handler.rs | 44 +++ packages/services/virtual_socket/src/lib.rs | 11 + packages/services/virtual_socket/src/msg.rs | 25 ++ packages/services/virtual_socket/src/sdk.rs | 25 ++ packages/services/virtual_socket/src/state.rs | 315 ++++++++++++++++++ .../virtual_socket/src/state/connector.rs | 28 ++ .../virtual_socket/src/state/listener.rs | 20 ++ .../virtual_socket/src/state/socket.rs | 120 +++++++ 15 files changed, 821 insertions(+), 1 deletion(-) create mode 100644 packages/integration_tests/src/virtual_socket.rs create mode 100644 packages/services/virtual_socket/Cargo.toml create mode 100644 packages/services/virtual_socket/src/behavior.rs create mode 100644 packages/services/virtual_socket/src/handler.rs create mode 100644 packages/services/virtual_socket/src/lib.rs create mode 100644 packages/services/virtual_socket/src/msg.rs create mode 100644 packages/services/virtual_socket/src/sdk.rs create mode 100644 packages/services/virtual_socket/src/state.rs create mode 100644 packages/services/virtual_socket/src/state/connector.rs create mode 100644 packages/services/virtual_socket/src/state/listener.rs create mode 100644 packages/services/virtual_socket/src/state/socket.rs diff --git a/Cargo.toml b/Cargo.toml index 6dcf69a7..10c23bfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ members = [ "packages/services/tun_tap", "packages/services/pub_sub", "packages/services/rpc", + "packages/services/virtual_socket", "packages/transports/vnet", "packages/transports/tcp", "packages/transports/udp", diff --git a/packages/integration_tests/src/lib.rs b/packages/integration_tests/src/lib.rs index c1e12fc0..a301724b 100644 --- a/packages/integration_tests/src/lib.rs +++ b/packages/integration_tests/src/lib.rs @@ -1,3 +1,4 @@ mod key_value; mod pubsub; mod rpc; +mod virtual_socket; diff --git a/packages/integration_tests/src/virtual_socket.rs b/packages/integration_tests/src/virtual_socket.rs new file mode 100644 index 00000000..b42aff9c --- /dev/null +++ b/packages/integration_tests/src/virtual_socket.rs @@ -0,0 +1,123 @@ +#[cfg(test)] +mod test { + use std::{collections::HashMap, sync::Arc, time::Duration}; + + use async_std::task::JoinHandle; + use atm0s_sdn::{ + convert_enum, KeyValueBehaviorEvent, KeyValueHandlerEvent, KeyValueSdkEvent, LayersSpreadRouterSyncBehavior, LayersSpreadRouterSyncBehaviorEvent, LayersSpreadRouterSyncHandlerEvent, + ManualBehavior, ManualBehaviorConf, ManualBehaviorEvent, ManualHandlerEvent, NetworkPlane, NetworkPlaneConfig, NodeAddr, NodeAddrBuilder, NodeId, SharedRouter, SystemTimer, + VirtualSocketBehavior, VirtualSocketSdk, + }; + use atm0s_sdn_transport_vnet::VnetEarth; + + #[derive(convert_enum::From, convert_enum::TryInto)] + enum BE { + KeyValue(KeyValueBehaviorEvent), + RouterSync(LayersSpreadRouterSyncBehaviorEvent), + Manual(ManualBehaviorEvent), + } + + #[derive(convert_enum::From, convert_enum::TryInto)] + enum HE { + KeyValue(KeyValueHandlerEvent), + RouterSync(LayersSpreadRouterSyncHandlerEvent), + Manual(ManualHandlerEvent), + } + + #[derive(convert_enum::From, convert_enum::TryInto)] + enum SE { + KeyValue(KeyValueSdkEvent), + } + + async fn run_node(vnet: Arc, node_id: NodeId, seeds: Vec) -> (VirtualSocketSdk, NodeAddr, JoinHandle<()>) { + log::info!("Run node {} connect to {:?}", node_id, seeds); + let node_addr = Arc::new(NodeAddrBuilder::new(node_id)); + let transport = Box::new(atm0s_sdn_transport_vnet::VnetTransport::new(vnet, node_addr.addr())); + let timer = Arc::new(SystemTimer()); + + let router = SharedRouter::new(node_id); + let manual = ManualBehavior::::new(ManualBehaviorConf { + node_id, + node_addr: node_addr.addr(), + seeds, + local_tags: vec![], + connect_tags: vec![], + }); + + let (virtual_socket_behaviour, virtual_socket_sdk) = VirtualSocketBehavior::new(node_id); + let router_sync_behaviour = LayersSpreadRouterSyncBehavior::new(router.clone()); + + let mut plane = NetworkPlane::::new(NetworkPlaneConfig { + node_id, + tick_ms: 100, + behaviors: vec![Box::new(virtual_socket_behaviour), Box::new(router_sync_behaviour), Box::new(manual)], + transport, + timer, + router: Arc::new(router.clone()), + }); + + let join = async_std::task::spawn(async move { + plane.started(); + while let Ok(_) = plane.recv().await {} + plane.stopped(); + }); + + (virtual_socket_sdk, node_addr.addr(), join) + } + + #[async_std::test] + async fn local_socket() { + let node_id = 1; + let vnet = Arc::new(VnetEarth::default()); + let (sdk, _addr, join) = run_node(vnet.clone(), node_id, vec![]).await; + async_std::task::sleep(Duration::from_millis(300)).await; + + let mut listener = sdk.listen("DEMO"); + let connector = sdk.connector(); + async_std::task::spawn(async move { + let mut socket = connector + .connect_to(node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) + .await + .expect("Should connect"); + socket.write(vec![1, 2, 3]).await.expect("Should write"); + }); + + if let Some(mut socket) = listener.recv().await { + assert_eq!(socket.meta(), &HashMap::from([("k1".to_string(), "k2".to_string())])); + assert_eq!(socket.read().await.expect("Should read"), vec![1, 2, 3]); + assert_eq!(socket.read().await, None); + } + + join.cancel().await; + } + + #[async_std::test] + async fn remote_socket() { + let vnet = Arc::new(VnetEarth::default()); + + let node_id1 = 1; + let node_id2 = 2; + let (sdk1, addr1, join1) = run_node(vnet.clone(), node_id1, vec![]).await; + let (sdk2, _addr2, join2) = run_node(vnet.clone(), node_id2, vec![addr1]).await; + async_std::task::sleep(Duration::from_millis(300)).await; + + let mut listener1 = sdk1.listen("DEMO"); + let connector2 = sdk2.connector(); + async_std::task::spawn(async move { + let mut socket = connector2 + .connect_to(node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) + .await + .expect("Should connect"); + socket.write(vec![1, 2, 3]).await.expect("Should write"); + }); + + if let Some(mut socket) = listener1.recv().await { + assert_eq!(socket.meta(), &HashMap::from([("k1".to_string(), "k2".to_string())])); + assert_eq!(socket.read().await.expect("Should read"), vec![1, 2, 3]); + assert_eq!(socket.read().await, None); + } + + join1.cancel().await; + join2.cancel().await; + } +} diff --git a/packages/runner/Cargo.toml b/packages/runner/Cargo.toml index a8d0f9a8..9c6e3339 100644 --- a/packages/runner/Cargo.toml +++ b/packages/runner/Cargo.toml @@ -25,6 +25,7 @@ atm0s-sdn-manual-discovery = { path = "../services/manual_discovery", version = atm0s-sdn-key-value = { path = "../services/key_value", version = "0.1.7", optional = true } atm0s-sdn-pub-sub = { path = "../services/pub_sub", version = "0.1.6", optional = true } atm0s-sdn-rpc = { path = "../services/rpc", version = "0.1.3", optional = true } +atm0s-sdn-virtual-socket = { path = "../services/virtual_socket", version = "0.1.0", optional = true } async-trait = { workspace = true } futures-util = "0.3" @@ -39,4 +40,5 @@ pub-sub = ["atm0s-sdn-pub-sub"] spread-router = ["atm0s-sdn-layers-spread-router", "atm0s-sdn-layers-spread-router-sync"] manual-discovery = ["atm0s-sdn-manual-discovery"] rpc = ["atm0s-sdn-rpc"] -all = ["transport-udp", "transport-tcp", "transport-compose", "key-value", "pub-sub", "spread-router", "manual-discovery", "rpc"] +virtual-socket = ["atm0s-sdn-virtual-socket"] +all = ["transport-udp", "transport-tcp", "transport-compose", "key-value", "pub-sub", "spread-router", "manual-discovery", "rpc", "virtual-socket"] diff --git a/packages/runner/src/lib.rs b/packages/runner/src/lib.rs index c06a3bd0..90818671 100644 --- a/packages/runner/src/lib.rs +++ b/packages/runner/src/lib.rs @@ -33,6 +33,9 @@ pub use atm0s_sdn_pub_sub::{ #[cfg(feature = "rpc")] pub use atm0s_sdn_rpc::{RpcBehavior, RpcBox, RpcEmitter, RpcError, RpcHandler, RpcIdGenerate, RpcMsg, RpcMsgParam, RpcQueue, RpcRequest}; +#[cfg(feature = "virtual-socket")] +pub use atm0s_sdn_virtual_socket::{VirtualSocketBehavior, VirtualSocketSdk}; + #[cfg(feature = "transport-tcp")] pub use atm0s_sdn_transport_tcp::TcpTransport; #[cfg(feature = "transport-udp")] diff --git a/packages/services/virtual_socket/Cargo.toml b/packages/services/virtual_socket/Cargo.toml new file mode 100644 index 00000000..481a94db --- /dev/null +++ b/packages/services/virtual_socket/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "atm0s-sdn-virtual-socket" +version = "0.1.0" +edition = "2021" +description = "Virtual Socket service in atm0s-sdn" +license = "MIT" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +atm0s-sdn-identity = { path = "../../core/identity", version = "0.2.0" } +atm0s-sdn-router = { path = "../../core/router", version = "0.1.4" } +atm0s-sdn-utils = { path = "../../core/utils", version = "0.1.1" } +atm0s-sdn-network = { path = "../../network", version = "0.3.0" } +log = { workspace = true } +futures = "0.3" +async-trait = { workspace = true } +async-std = { workspace = true } +parking_lot = { workspace = true } +serde = { workspace = true } \ No newline at end of file diff --git a/packages/services/virtual_socket/src/behavior.rs b/packages/services/virtual_socket/src/behavior.rs new file mode 100644 index 00000000..eeb80f05 --- /dev/null +++ b/packages/services/virtual_socket/src/behavior.rs @@ -0,0 +1,82 @@ +use std::sync::Arc; + +use atm0s_sdn_identity::{ConnId, NodeId}; +use atm0s_sdn_network::{ + behaviour::{BehaviorContext, ConnectionHandler, NetworkBehavior, NetworkBehaviorAction}, + msg::TransportMsg, + transport::{ConnectionEvent, ConnectionRejectReason, ConnectionSender, OutgoingConnectionError}, +}; +use parking_lot::RwLock; + +use crate::{ + handler::VirtualSocketHandler, + state::{process_incoming_data, State}, + VirtualSocketSdk, VIRTUAL_SOCKET_SERVICE_ID, +}; + +pub struct VirtualSocketBehavior { + node_id: NodeId, + state: Arc>, +} + +impl VirtualSocketBehavior { + pub fn new(node_id: NodeId) -> (Self, VirtualSocketSdk) { + log::info!("[VirtualSocketBehavior] create new on node: {}", node_id); + let state = Arc::new(RwLock::new(State::default())); + (Self { node_id, state: state.clone() }, VirtualSocketSdk::new(state)) + } +} + +impl NetworkBehavior for VirtualSocketBehavior { + fn service_id(&self) -> u8 { + VIRTUAL_SOCKET_SERVICE_ID + } + + fn on_started(&mut self, ctx: &BehaviorContext, _now_ms: u64) { + self.state.write().set_awaker(ctx.awaker.clone()); + } + + fn on_tick(&mut self, _ctx: &BehaviorContext, now_ms: u64, _interval_ms: u64) { + self.state.write().on_tick(now_ms); + } + + fn on_awake(&mut self, _ctx: &BehaviorContext, _now_ms: u64) {} + + fn on_sdk_msg(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _from_service: u8, _event: SE) {} + + fn on_local_msg(&mut self, _ctx: &BehaviorContext, now_ms: u64, msg: TransportMsg) { + process_incoming_data(now_ms, &self.state, ConnectionEvent::Msg(msg)); + } + + fn check_incoming_connection(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node: NodeId, _conn_id: ConnId) -> Result<(), ConnectionRejectReason> { + Ok(()) + } + + fn check_outgoing_connection(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node: NodeId, _conn_id: ConnId) -> Result<(), ConnectionRejectReason> { + Ok(()) + } + + fn on_incoming_connection_connected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _conn: Arc) -> Option>> { + Some(Box::new(VirtualSocketHandler { state: self.state.clone() })) + } + + fn on_outgoing_connection_connected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _conn: Arc) -> Option>> { + Some(Box::new(VirtualSocketHandler { state: self.state.clone() })) + } + + fn on_incoming_connection_disconnected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, _conn_id: ConnId) {} + + fn on_outgoing_connection_disconnected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, _conn_id: ConnId) {} + + fn on_outgoing_connection_error(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, _conn_id: ConnId, _err: &OutgoingConnectionError) {} + + fn on_handler_event(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, _conn_id: ConnId, _event: BE) {} + + fn on_stopped(&mut self, _ctx: &BehaviorContext, _now_ms: u64) {} + + fn pop_action(&mut self) -> Option> { + let (socket_id, out) = self.state.write().pop_outgoing()?; + let msg = out.into_transport_msg(self.node_id, socket_id.node_id(), socket_id.client_id()); + Some(NetworkBehaviorAction::ToNet(msg)) + } +} diff --git a/packages/services/virtual_socket/src/handler.rs b/packages/services/virtual_socket/src/handler.rs new file mode 100644 index 00000000..ef11e4b4 --- /dev/null +++ b/packages/services/virtual_socket/src/handler.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use atm0s_sdn_identity::{ConnId, NodeId}; +use atm0s_sdn_network::{ + behaviour::{ConnectionContext, ConnectionHandler, ConnectionHandlerAction}, + transport::ConnectionEvent, +}; +use parking_lot::RwLock; + +use crate::state::{process_incoming_data, State}; + +pub struct VirtualSocketHandler { + pub(crate) state: Arc>, +} + +impl ConnectionHandler for VirtualSocketHandler { + /// Called when the connection is opened. + fn on_opened(&mut self, _ctx: &ConnectionContext, _now_ms: u64) {} + + /// Called on each tick of the connection. + fn on_tick(&mut self, _ctx: &ConnectionContext, _now_ms: u64, _interval_ms: u64) {} + + /// Called when the connection is awake. + fn on_awake(&mut self, _ctx: &ConnectionContext, _now_ms: u64) {} + + /// Called when an event occurs on the connection. + fn on_event(&mut self, _ctx: &ConnectionContext, now_ms: u64, event: ConnectionEvent) { + process_incoming_data(now_ms, &self.state, event); + } + + /// Called when an event occurs on another handler. + fn on_other_handler_event(&mut self, _ctx: &ConnectionContext, _now_ms: u64, _from_node: NodeId, _from_conn: ConnId, _event: HE) {} + + /// Called when an event occurs on the behavior. + fn on_behavior_event(&mut self, _ctx: &ConnectionContext, _now_ms: u64, _event: HE) {} + + /// Called when the connection is closed. + fn on_closed(&mut self, _ctx: &ConnectionContext, _now_ms: u64) {} + + /// Pops the next action to be taken by the connection handler. + fn pop_action(&mut self) -> Option> { + None + } +} diff --git a/packages/services/virtual_socket/src/lib.rs b/packages/services/virtual_socket/src/lib.rs new file mode 100644 index 00000000..cb80af23 --- /dev/null +++ b/packages/services/virtual_socket/src/lib.rs @@ -0,0 +1,11 @@ +pub(crate) const VIRTUAL_SOCKET_SERVICE_ID: u8 = 6; + +mod behavior; +mod handler; +mod msg; +mod sdk; +pub(crate) mod state; + +pub use behavior::VirtualSocketBehavior; +pub use sdk::VirtualSocketSdk; +pub use state::{socket::VirtualSocket, VirtualSocketConnectResult}; diff --git a/packages/services/virtual_socket/src/msg.rs b/packages/services/virtual_socket/src/msg.rs new file mode 100644 index 00000000..6243e8d8 --- /dev/null +++ b/packages/services/virtual_socket/src/msg.rs @@ -0,0 +1,25 @@ +use atm0s_sdn_identity::NodeId; +use serde::{Deserialize, Serialize}; + +use std::collections::HashMap; + +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum VirtualSocketControlMsg { + ConnectRequest(String, HashMap), + ConnectReponse(bool), + ConnectingPing, + ConnectingPong, + ConnectionClose(), +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct SocketId(pub NodeId, pub u32); +impl SocketId { + pub fn node_id(&self) -> NodeId { + self.0 + } + + pub fn client_id(&self) -> u32 { + self.1 + } +} diff --git a/packages/services/virtual_socket/src/sdk.rs b/packages/services/virtual_socket/src/sdk.rs new file mode 100644 index 00000000..4934dcad --- /dev/null +++ b/packages/services/virtual_socket/src/sdk.rs @@ -0,0 +1,25 @@ +use std::sync::Arc; + +use parking_lot::RwLock; + +use crate::state::{connector::VirtualSocketConnector, listener::VirtualSocketListener, State}; + +pub struct VirtualSocketSdk { + state: Arc>, +} + +impl VirtualSocketSdk { + pub fn new(state: Arc>) -> Self { + Self { state } + } + + pub fn connector(&self) -> VirtualSocketConnector { + VirtualSocketConnector { state: self.state.clone() } + } + + pub fn listen(&self, id: &str) -> VirtualSocketListener { + log::info!("[VirtualSocketSdk] listen on: {}", id); + let rx = self.state.write().new_listener(id); + VirtualSocketListener { rx, state: self.state.clone() } + } +} diff --git a/packages/services/virtual_socket/src/state.rs b/packages/services/virtual_socket/src/state.rs new file mode 100644 index 00000000..0974b21b --- /dev/null +++ b/packages/services/virtual_socket/src/state.rs @@ -0,0 +1,315 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_std::channel::{Receiver, Sender}; +use atm0s_sdn_identity::NodeId; +use atm0s_sdn_network::transport::ConnectionEvent; +use atm0s_sdn_utils::{awaker::Awaker, error_handle::ErrorUtils, option_handle::OptionUtils, vec_dequeue::VecDeque}; +use parking_lot::RwLock; + +use crate::msg::{SocketId, VirtualSocketControlMsg}; + +use self::socket::{VirtualSocketBuilder, VirtualSocketEvent, CONTROL_CLIENT_META, CONTROL_SERVER_META, DATA_CLIENT_META, DATA_SERVER_META}; + +pub mod connector; +pub mod listener; +pub mod socket; + +const CONNECT_TIMEOUT_MS: u64 = 10000; +const PING_INTERVAL_MS: u64 = 1000; +const PING_TIMEOUT_MS: u64 = 10000; + +enum OutgoingState { + Connecting { started_at: u64, res_tx: Sender }, + Connected { last_ping: u64, pong_time: u64, tx: Sender> }, +} + +struct IncommingState { + ping_time: u64, + tx: Sender>, +} + +pub enum VirtualSocketConnectResult { + Success(VirtualSocketBuilder), + Timeout, + Unreachable, +} + +pub struct State { + client_idx: u32, + last_tick_ms: u64, + listeners: HashMap>, + outgoings: HashMap, + incomings: HashMap, + outgoing_queue: VecDeque<(SocketId, VirtualSocketEvent)>, + awaker: Option>, +} + +impl Default for State { + fn default() -> Self { + Self { + client_idx: 0, + last_tick_ms: 0, + listeners: HashMap::new(), + outgoings: HashMap::new(), + incomings: HashMap::new(), + outgoing_queue: VecDeque::new(), + awaker: None, + } + } +} + +impl State { + pub fn set_awaker(&mut self, awaker: Arc) { + self.awaker = Some(awaker); + } + + pub fn new_listener(&mut self, id: &str) -> Receiver { + log::info!("[VirtualSocketState] new listener: {}", id); + let (tx, rx) = async_std::channel::bounded(10); + self.listeners.insert(id.to_string(), tx); + rx + } + + pub fn new_outgoing(&mut self, dest_node_id: NodeId, dest_listener_id: &str, meta: HashMap) -> Option> { + let client_idx = self.client_idx; + self.client_idx += 1; + log::info!("[VirtualSocketState] new outgoing: {}/{} with meta {:?} => idx {}", dest_node_id, dest_listener_id, meta, client_idx); + let socket_id = SocketId(dest_node_id, client_idx); + + let (tx, rx) = async_std::channel::bounded(1); + self.outgoings.insert( + socket_id.clone(), + OutgoingState::Connecting { + started_at: self.last_tick_ms, + res_tx: tx, + }, + ); + self.outgoing_queue.push_back(( + socket_id, + VirtualSocketEvent::ClientControl(VirtualSocketControlMsg::ConnectRequest(dest_listener_id.to_string(), meta)), + )); + Some(rx) + } + + pub fn on_tick(&mut self, now_ms: u64) { + self.last_tick_ms = now_ms; + + // Remove timed out outgoing connections + let mut to_remove = Vec::new(); + for (socket_id, state) in self.outgoings.iter() { + if let OutgoingState::Connecting { started_at, res_tx: tx } = state { + if now_ms - started_at > CONNECT_TIMEOUT_MS { + log::info!("[VirtualSocketState] outgoing timeout: node {} idx: {}", socket_id.node_id(), socket_id.client_id()); + to_remove.push(socket_id.clone()); + tx.try_send(VirtualSocketConnectResult::Timeout).print_error("Should send timeout to waiting connector"); + } + } + } + for socket_id in to_remove { + self.outgoings.remove(&socket_id).print_none("Should remove timed out outgoing connection"); + } + + // send ping from outgoing sockets + for (socket_id, state) in self.outgoings.iter_mut() { + if let OutgoingState::Connected { last_ping, .. } = state { + if now_ms - *last_ping > PING_INTERVAL_MS { + *last_ping = now_ms; + self.outgoing_queue + .push_back((socket_id.clone(), VirtualSocketEvent::ClientControl(VirtualSocketControlMsg::ConnectingPing))); + } + } + } + + // Remote ping timeout outgoing sockets + let mut to_remove = Vec::new(); + for (socket_id, state) in self.outgoings.iter() { + if let OutgoingState::Connected { pong_time, .. } = state { + if now_ms - *pong_time > PING_TIMEOUT_MS { + log::info!("[VirtualSocketState] outgoing ping timeout: node {} idx: {}", socket_id.node_id(), socket_id.client_id()); + to_remove.push(socket_id.clone()); + } + } + } + for socket_id in to_remove { + self.outgoings.remove(&socket_id).print_none("Should remove timed out outgoing connection"); + } + + // Remove timed out incoming sockets + let mut to_remove = Vec::new(); + for (socket_id, state) in self.incomings.iter() { + if now_ms - state.ping_time > PING_TIMEOUT_MS { + log::info!("[VirtualSocketState] incoming ping timeout: node {} idx: {}", socket_id.node_id(), socket_id.client_id()); + to_remove.push(socket_id.clone()); + } + } + for socket_id in to_remove { + self.incomings.remove(&socket_id).print_none("Should remove timed out incoming connection"); + } + } + + pub fn on_recv_server_data(&self, _now_ms: u64, socket_id: SocketId, data: &[u8]) { + log::debug!("[VirtualSocketState] on_recv_server_data: {:?} {:?}", socket_id, data); + if let Some(OutgoingState::Connected { tx, .. }) = self.outgoings.get(&socket_id) { + tx.try_send(data.to_vec()).ok(); + } + } + + pub fn on_recv_client_data(&self, _now_ms: u64, socket_id: SocketId, data: &[u8]) { + log::debug!("[VirtualSocketState] on_recv_client_data: {:?} {:?}", socket_id, data); + if let Some(state) = self.incomings.get(&socket_id) { + state.tx.try_send(data.to_vec()).ok(); + } + } + + pub fn send_out(&mut self, socket_id: SocketId, event: VirtualSocketEvent) { + log::debug!("[VirtualSocketState] send_out to : {:?} {:?}", socket_id, event); + self.outgoing_queue.push_back((socket_id, event)); + if self.outgoing_queue.len() == 1 { + if let Some(awaker) = self.awaker.as_ref() { + awaker.notify(); + } + } + } + + pub fn on_recv_server_control(&mut self, now_ms: u64, socket_id: SocketId, control: VirtualSocketControlMsg) { + log::debug!("[VirtualSocketState] on_recv_server_control from : {:?} {:?}", socket_id, control); + match control { + VirtualSocketControlMsg::ConnectReponse(success) => { + if let Some(state) = self.outgoings.get_mut(&socket_id) { + if let OutgoingState::Connecting { res_tx, .. } = state { + if success { + let (socket_tx, socket_rx) = async_std::channel::bounded(10); + res_tx + .try_send(VirtualSocketConnectResult::Success(VirtualSocketBuilder::new(true, socket_id, HashMap::new(), socket_rx))) + .print_error("Should send connect response to waiting connector"); + *state = OutgoingState::Connected { + last_ping: now_ms, + pong_time: now_ms, + tx: socket_tx, + }; + } else { + res_tx + .try_send(VirtualSocketConnectResult::Unreachable) + .print_error("Should send connect response to waiting connector"); + self.outgoings.remove(&socket_id).print_none("Should remove failed outgoing connection"); + }; + } else { + log::warn!("[VirtualSocketState] on_recv_server_control socket already connected: {:?}", socket_id); + } + } else { + log::warn!("[VirtualSocketState] on_recv_server_control socket not found: {:?}", socket_id); + } + } + VirtualSocketControlMsg::ConnectionClose() => { + if self.incomings.remove(&socket_id).is_some() { + log::info!("[VirtualSocketState] closed outgoing socket: {:?}", socket_id); + } + } + VirtualSocketControlMsg::ConnectingPong => { + //update pong time to outgoing sockets + if let Some(state) = self.outgoings.get_mut(&socket_id) { + if let OutgoingState::Connected { pong_time, .. } = state { + *pong_time = now_ms; + } + } + } + _ => { + log::warn!("[VirtualSocketState] on_recv_server_control Unknown control message: {:?}", control); + } + } + } + + pub fn on_recv_client_control(&mut self, now_ms: u64, socket_id: SocketId, control: VirtualSocketControlMsg) { + log::debug!("[VirtualSocketState] on_recv_client_control from : {:?} {:?}", socket_id, control); + match control { + VirtualSocketControlMsg::ConnectRequest(listener_id, meta) => { + if self.incomings.contains_key(&socket_id) { + log::warn!("[VirtualSocketState] on_recv_client_control socket already connected: {:?}", socket_id); + return; + } + if let Some(tx) = self.listeners.get(&listener_id) { + let (socket_tx, socket_rx) = async_std::channel::bounded(10); + self.incomings.insert(socket_id.clone(), IncommingState { ping_time: now_ms, tx: socket_tx }); + tx.try_send(VirtualSocketBuilder::new(false, socket_id.clone(), meta, socket_rx)) + .print_error("Should send new virtual socket to listener"); + self.outgoing_queue + .push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectReponse(true)))); + } else { + self.outgoing_queue + .push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectReponse(false)))); + } + } + VirtualSocketControlMsg::ConnectionClose() => { + if self.incomings.remove(&socket_id).is_some() { + log::info!("[VirtualSocketState] closed incoming socket: {:?}", socket_id); + } + } + VirtualSocketControlMsg::ConnectingPing => { + //update ping time to incoming sockets + if let Some(state) = self.incomings.get_mut(&socket_id) { + state.ping_time = now_ms; + } + self.outgoing_queue.push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectingPong))); + } + _ => { + log::warn!("[VirtualSocketState] on_recv_client_control Unknown control message: {:?}", control); + } + } + } + + pub fn close_socket(&mut self, is_client: bool, socket_id: &SocketId) { + if is_client { + if self.outgoings.remove(socket_id).is_some() { + log::debug!("[VirtualSocketState] will close outgoing socket: {:?}", socket_id); + let socket_id: SocketId = socket_id.clone(); + self.outgoing_queue + .push_back((socket_id, VirtualSocketEvent::ClientControl(VirtualSocketControlMsg::ConnectionClose()))); + } + } else { + if self.incomings.remove(socket_id).is_some() { + log::debug!("[VirtualSocketState] will close incomming socket: {:?}", socket_id); + let socket_id: SocketId = socket_id.clone(); + self.outgoing_queue + .push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectionClose()))); + } + } + } + + pub fn pop_outgoing(&mut self) -> Option<(SocketId, VirtualSocketEvent)> { + self.outgoing_queue.pop_front() + } +} + +pub fn process_incoming_data(now_ms: u64, state: &RwLock, event: ConnectionEvent) { + if let ConnectionEvent::Msg(data) = event { + if let Some(from) = data.header.from_node { + match data.header.meta { + DATA_CLIENT_META => { + let socket_id = SocketId(from, data.header.stream_id); + state.read().on_recv_client_data(now_ms, socket_id, data.payload()); + } + DATA_SERVER_META => { + let socket_id = SocketId(from, data.header.stream_id); + state.read().on_recv_server_data(now_ms, socket_id, data.payload()); + } + CONTROL_CLIENT_META => { + if let Ok(control) = data.get_payload_bincode::() { + //is control + if let Some(from) = data.header.from_node { + state.write().on_recv_client_control(now_ms, SocketId(from, data.header.stream_id), control); + } + } + } + CONTROL_SERVER_META => { + if let Ok(control) = data.get_payload_bincode::() { + //is control + if let Some(from) = data.header.from_node { + state.write().on_recv_server_control(now_ms, SocketId(from, data.header.stream_id), control); + } + } + } + _ => {} + } + } + } +} diff --git a/packages/services/virtual_socket/src/state/connector.rs b/packages/services/virtual_socket/src/state/connector.rs new file mode 100644 index 00000000..e4611f86 --- /dev/null +++ b/packages/services/virtual_socket/src/state/connector.rs @@ -0,0 +1,28 @@ +use std::{collections::HashMap, sync::Arc}; + +use atm0s_sdn_identity::NodeId; +use parking_lot::RwLock; + +use super::{socket::VirtualSocket, State, VirtualSocketConnectResult}; + +#[derive(Debug)] +pub enum VirtualSocketConnectorError { + Timeout, + Unreachable, +} + +pub struct VirtualSocketConnector { + pub(crate) state: Arc>, +} + +impl VirtualSocketConnector { + pub async fn connect_to(&self, dest: NodeId, listener: &str, meta: HashMap) -> Result { + let rx = self.state.write().new_outgoing(dest, listener, meta).ok_or(VirtualSocketConnectorError::Unreachable)?; + match rx.recv().await { + Ok(VirtualSocketConnectResult::Success(builder)) => Ok(builder.build(self.state.clone())), + Ok(VirtualSocketConnectResult::Timeout) => Err(VirtualSocketConnectorError::Timeout), + Ok(VirtualSocketConnectResult::Unreachable) => Err(VirtualSocketConnectorError::Unreachable), + Err(_) => Err(VirtualSocketConnectorError::Unreachable), + } + } +} diff --git a/packages/services/virtual_socket/src/state/listener.rs b/packages/services/virtual_socket/src/state/listener.rs new file mode 100644 index 00000000..c4379786 --- /dev/null +++ b/packages/services/virtual_socket/src/state/listener.rs @@ -0,0 +1,20 @@ +use std::sync::Arc; + +use parking_lot::RwLock; + +use super::{ + socket::{VirtualSocket, VirtualSocketBuilder}, + State, +}; + +pub struct VirtualSocketListener { + pub(crate) rx: async_std::channel::Receiver, + pub(crate) state: Arc>, +} + +impl VirtualSocketListener { + pub async fn recv(&mut self) -> Option { + let builder = self.rx.recv().await.ok()?; + Some(builder.build(self.state.clone())) + } +} diff --git a/packages/services/virtual_socket/src/state/socket.rs b/packages/services/virtual_socket/src/state/socket.rs new file mode 100644 index 00000000..818f2b24 --- /dev/null +++ b/packages/services/virtual_socket/src/state/socket.rs @@ -0,0 +1,120 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_std::channel::Receiver; +use atm0s_sdn_identity::NodeId; +use atm0s_sdn_network::msg::{MsgHeader, TransportMsg}; +use atm0s_sdn_router::RouteRule; +use parking_lot::RwLock; + +use crate::{ + msg::{SocketId, VirtualSocketControlMsg}, + VIRTUAL_SOCKET_SERVICE_ID, +}; + +use super::State; + +pub const CONTROL_CLIENT_META: u8 = 0; +pub const CONTROL_SERVER_META: u8 = 1; +pub const DATA_CLIENT_META: u8 = 2; +pub const DATA_SERVER_META: u8 = 3; + +pub struct VirtualSocketBuilder { + is_client: bool, + remote: SocketId, + rx: Receiver>, + meta: HashMap, +} + +impl VirtualSocketBuilder { + pub(crate) fn new(is_client: bool, remote: SocketId, meta: HashMap, rx: Receiver>) -> Self { + Self { is_client, remote, meta, rx } + } + + pub fn build(self, state: Arc>) -> VirtualSocket { + VirtualSocket::new(self.is_client, self.remote, self.meta, self.rx, state) + } +} + +#[derive(Debug)] +pub enum VirtualSocketEvent { + ServerControl(VirtualSocketControlMsg), + ClientControl(VirtualSocketControlMsg), + ServerData(Vec), + ClientData(Vec), +} + +impl VirtualSocketEvent { + pub fn into_transport_msg(self, local_node: NodeId, remote_node: NodeId, client_id: u32) -> TransportMsg { + match self { + VirtualSocketEvent::ServerData(data) => { + let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) + .set_from_node(Some(local_node)) + .set_stream_id(client_id) + .set_meta(DATA_SERVER_META); + TransportMsg::build_raw(header, &data) + } + VirtualSocketEvent::ClientData(data) => { + let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) + .set_from_node(Some(local_node)) + .set_stream_id(client_id) + .set_meta(DATA_CLIENT_META); + TransportMsg::build_raw(header, &data) + } + VirtualSocketEvent::ServerControl(control) => { + let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) + .set_from_node(Some(local_node)) + .set_stream_id(client_id) + .set_meta(CONTROL_SERVER_META); + TransportMsg::from_payload_bincode(header, &control) + } + VirtualSocketEvent::ClientControl(control) => { + let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) + .set_from_node(Some(local_node)) + .set_stream_id(client_id) + .set_meta(CONTROL_CLIENT_META); + TransportMsg::from_payload_bincode(header, &control) + } + } + } +} + +pub struct VirtualSocket { + is_client: bool, + remote: SocketId, + rx: Receiver>, + meta: HashMap, + state: Arc>, +} + +impl VirtualSocket { + pub(crate) fn new(is_client: bool, remote: SocketId, meta: HashMap, rx: Receiver>, state: Arc>) -> Self { + Self { is_client, remote, meta, rx, state } + } + + pub fn remote(&self) -> &SocketId { + &self.remote + } + + pub fn meta(&self) -> &HashMap { + &self.meta + } + + pub async fn read(&mut self) -> Option> { + self.rx.recv().await.ok() + } + + pub async fn write(&mut self, buf: Vec) -> Option<()> { + if self.is_client { + self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ClientData(buf)); + } else { + self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ServerData(buf)); + } + Some(()) + } +} + +impl Drop for VirtualSocket { + fn drop(&mut self) { + self.state.write().close_socket(self.is_client, &self.remote); + } +} From 825eca101d00b7a0b2d2ed6168f1dbfff451b92d Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Fri, 5 Jan 2024 11:39:06 +0700 Subject: [PATCH 2/8] added stream --- .../integration_tests/src/virtual_socket.rs | 72 +++++++++- packages/runner/src/lib.rs | 2 +- packages/services/virtual_socket/Cargo.toml | 3 +- packages/services/virtual_socket/src/lib.rs | 2 +- packages/services/virtual_socket/src/state.rs | 4 + .../virtual_socket/src/state/socket.rs | 81 +++++++++-- .../virtual_socket/src/state/stream.rs | 132 ++++++++++++++++++ 7 files changed, 276 insertions(+), 20 deletions(-) create mode 100644 packages/services/virtual_socket/src/state/stream.rs diff --git a/packages/integration_tests/src/virtual_socket.rs b/packages/integration_tests/src/virtual_socket.rs index b42aff9c..88ed83c6 100644 --- a/packages/integration_tests/src/virtual_socket.rs +++ b/packages/integration_tests/src/virtual_socket.rs @@ -6,7 +6,7 @@ mod test { use atm0s_sdn::{ convert_enum, KeyValueBehaviorEvent, KeyValueHandlerEvent, KeyValueSdkEvent, LayersSpreadRouterSyncBehavior, LayersSpreadRouterSyncBehaviorEvent, LayersSpreadRouterSyncHandlerEvent, ManualBehavior, ManualBehaviorConf, ManualBehaviorEvent, ManualHandlerEvent, NetworkPlane, NetworkPlaneConfig, NodeAddr, NodeAddrBuilder, NodeId, SharedRouter, SystemTimer, - VirtualSocketBehavior, VirtualSocketSdk, + VirtualSocketBehavior, VirtualSocketSdk, VirtualStream, }; use atm0s_sdn_transport_vnet::VnetEarth; @@ -79,7 +79,7 @@ mod test { .connect_to(node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) .await .expect("Should connect"); - socket.write(vec![1, 2, 3]).await.expect("Should write"); + socket.write(&vec![1, 2, 3]).expect("Should write"); }); if let Some(mut socket) = listener.recv().await { @@ -91,6 +91,37 @@ mod test { join.cancel().await; } + #[async_std::test] + async fn local_stream() { + let node_id = 1; + let vnet = Arc::new(VnetEarth::default()); + let (sdk, _addr, join) = run_node(vnet.clone(), node_id, vec![]).await; + async_std::task::sleep(Duration::from_millis(300)).await; + + let mut listener = sdk.listen("DEMO"); + let connector = sdk.connector(); + async_std::task::spawn(async move { + let socket = connector + .connect_to(node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) + .await + .expect("Should connect"); + let mut stream = VirtualStream::new(socket); + assert_eq!(stream.write(&vec![1, 2, 3]).await.expect("Should send"), 3); + async_std::task::sleep(Duration::from_secs(1)).await; + }); + + if let Some(socket) = listener.recv().await { + let mut stream = VirtualStream::new(socket); + assert_eq!(stream.meta(), &HashMap::from([("k1".to_string(), "k2".to_string())])); + let mut buf = vec![0; 1500]; + assert_eq!(stream.read(&mut buf).await.expect("Should read"), 3); + assert_eq!(buf[..3], [1, 2, 3]); + assert_eq!(stream.read(&mut buf).await.expect("Should read"), 0); + } + + join.cancel().await; + } + #[async_std::test] async fn remote_socket() { let vnet = Arc::new(VnetEarth::default()); @@ -108,7 +139,7 @@ mod test { .connect_to(node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) .await .expect("Should connect"); - socket.write(vec![1, 2, 3]).await.expect("Should write"); + socket.write(&vec![1, 2, 3]).expect("Should write"); }); if let Some(mut socket) = listener1.recv().await { @@ -120,4 +151,39 @@ mod test { join1.cancel().await; join2.cancel().await; } + + #[async_std::test] + async fn remote_stream() { + let vnet = Arc::new(VnetEarth::default()); + + let node_id1 = 1; + let node_id2 = 2; + let (sdk1, addr1, join1) = run_node(vnet.clone(), node_id1, vec![]).await; + let (sdk2, _addr2, join2) = run_node(vnet.clone(), node_id2, vec![addr1]).await; + async_std::task::sleep(Duration::from_millis(300)).await; + + let mut listener1 = sdk1.listen("DEMO"); + let connector2 = sdk2.connector(); + async_std::task::spawn(async move { + let socket = connector2 + .connect_to(node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) + .await + .expect("Should connect"); + let mut stream = VirtualStream::new(socket); + assert_eq!(stream.write(&vec![1, 2, 3]).await.expect("Should send"), 3); + async_std::task::sleep(Duration::from_secs(1)).await; + }); + + if let Some(socket) = listener1.recv().await { + let mut stream = VirtualStream::new(socket); + assert_eq!(stream.meta(), &HashMap::from([("k1".to_string(), "k2".to_string())])); + let mut buf = vec![0; 1500]; + assert_eq!(stream.read(&mut buf).await.expect("Should read"), 3); + assert_eq!(buf[..3], [1, 2, 3]); + assert_eq!(stream.read(&mut buf).await.expect("Should read"), 0); + } + + join1.cancel().await; + join2.cancel().await; + } } diff --git a/packages/runner/src/lib.rs b/packages/runner/src/lib.rs index 90818671..eaba4a46 100644 --- a/packages/runner/src/lib.rs +++ b/packages/runner/src/lib.rs @@ -34,7 +34,7 @@ pub use atm0s_sdn_pub_sub::{ pub use atm0s_sdn_rpc::{RpcBehavior, RpcBox, RpcEmitter, RpcError, RpcHandler, RpcIdGenerate, RpcMsg, RpcMsgParam, RpcQueue, RpcRequest}; #[cfg(feature = "virtual-socket")] -pub use atm0s_sdn_virtual_socket::{VirtualSocketBehavior, VirtualSocketSdk}; +pub use atm0s_sdn_virtual_socket::{VirtualSocket, VirtualSocketBehavior, VirtualSocketSdk, VirtualStream}; #[cfg(feature = "transport-tcp")] pub use atm0s_sdn_transport_tcp::TcpTransport; diff --git a/packages/services/virtual_socket/Cargo.toml b/packages/services/virtual_socket/Cargo.toml index 481a94db..99f09484 100644 --- a/packages/services/virtual_socket/Cargo.toml +++ b/packages/services/virtual_socket/Cargo.toml @@ -17,4 +17,5 @@ futures = "0.3" async-trait = { workspace = true } async-std = { workspace = true } parking_lot = { workspace = true } -serde = { workspace = true } \ No newline at end of file +serde = { workspace = true } +kcp = "0.5.3" diff --git a/packages/services/virtual_socket/src/lib.rs b/packages/services/virtual_socket/src/lib.rs index cb80af23..a70aaf99 100644 --- a/packages/services/virtual_socket/src/lib.rs +++ b/packages/services/virtual_socket/src/lib.rs @@ -8,4 +8,4 @@ pub(crate) mod state; pub use behavior::VirtualSocketBehavior; pub use sdk::VirtualSocketSdk; -pub use state::{socket::VirtualSocket, VirtualSocketConnectResult}; +pub use state::{socket::VirtualSocket, stream::VirtualStream, VirtualSocketConnectResult}; diff --git a/packages/services/virtual_socket/src/state.rs b/packages/services/virtual_socket/src/state.rs index 0974b21b..cb8e84f4 100644 --- a/packages/services/virtual_socket/src/state.rs +++ b/packages/services/virtual_socket/src/state.rs @@ -13,6 +13,7 @@ use self::socket::{VirtualSocketBuilder, VirtualSocketEvent, CONTROL_CLIENT_META pub mod connector; pub mod listener; pub mod socket; +pub mod stream; const CONNECT_TIMEOUT_MS: u64 = 10000; const PING_INTERVAL_MS: u64 = 1000; @@ -313,3 +314,6 @@ pub fn process_incoming_data(now_ms: u64, state: &RwLock, event: Connecti } } } + +#[cfg(test)] +mod tests {} diff --git a/packages/services/virtual_socket/src/state/socket.rs b/packages/services/virtual_socket/src/state/socket.rs index 818f2b24..34c9dd4e 100644 --- a/packages/services/virtual_socket/src/state/socket.rs +++ b/packages/services/virtual_socket/src/state/socket.rs @@ -1,3 +1,4 @@ +use std::io::Write; use std::{collections::HashMap, sync::Arc}; use async_std::channel::Receiver; @@ -79,18 +80,52 @@ impl VirtualSocketEvent { } pub struct VirtualSocket { - is_client: bool, - remote: SocketId, - rx: Receiver>, - meta: HashMap, - state: Arc>, + writer: VirtualSocketWriter, + reader: VirtualSocketReader, } impl VirtualSocket { pub(crate) fn new(is_client: bool, remote: SocketId, meta: HashMap, rx: Receiver>, state: Arc>) -> Self { - Self { is_client, remote, meta, rx, state } + Self { + writer: VirtualSocketWriter { + is_client, + remote: remote.clone(), + state: state.clone(), + }, + reader: VirtualSocketReader { is_client, rx, remote, meta, state }, + } + } + + pub fn remote(&self) -> &SocketId { + self.writer.remote() + } + + pub fn meta(&self) -> &HashMap { + self.reader.meta() } + pub async fn read(&mut self) -> Option> { + self.reader.read().await + } + + pub fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.writer.write(buf) + } + + pub fn split(self) -> (VirtualSocketReader, VirtualSocketWriter) { + (self.reader, self.writer) + } +} + +pub struct VirtualSocketReader { + is_client: bool, + rx: Receiver>, + remote: SocketId, + meta: HashMap, + state: Arc>, +} + +impl VirtualSocketReader { pub fn remote(&self) -> &SocketId { &self.remote } @@ -102,19 +137,37 @@ impl VirtualSocket { pub async fn read(&mut self) -> Option> { self.rx.recv().await.ok() } +} + +impl Drop for VirtualSocketReader { + fn drop(&mut self) { + self.state.write().close_socket(self.is_client, &self.remote); + } +} - pub async fn write(&mut self, buf: Vec) -> Option<()> { +pub struct VirtualSocketWriter { + is_client: bool, + remote: SocketId, + state: Arc>, +} + +impl VirtualSocketWriter { + pub fn remote(&self) -> &SocketId { + &self.remote + } +} + +impl Write for VirtualSocketWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { if self.is_client { - self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ClientData(buf)); + self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ClientData(buf.to_vec())); } else { - self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ServerData(buf)); + self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ServerData(buf.to_vec())); } - Some(()) + Ok(buf.len()) } -} -impl Drop for VirtualSocket { - fn drop(&mut self) { - self.state.write().close_socket(self.is_client, &self.remote); + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) } } diff --git a/packages/services/virtual_socket/src/state/stream.rs b/packages/services/virtual_socket/src/state/stream.rs new file mode 100644 index 00000000..c77720d3 --- /dev/null +++ b/packages/services/virtual_socket/src/state/stream.rs @@ -0,0 +1,132 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_std::{channel::Receiver, stream::StreamExt, task::JoinHandle}; +use futures::{select, FutureExt as _}; +use kcp::{Error, Kcp}; +use parking_lot::RwLock; + +use super::socket::{VirtualSocket, VirtualSocketWriter}; + +const MAX_KCP_SEND_QUEUE: usize = 10; + +enum ReadEvent { + Continue, + Close, +} + +pub struct VirtualStream { + meta: HashMap, + kcp: Arc>>, + task: Option>, + write_awake_rx: Receiver<()>, + read_awake_rx: Receiver, +} + +impl VirtualStream { + pub fn new(socket: VirtualSocket) -> Self { + let (mut reader, writer) = socket.split(); + let meta = reader.meta().clone(); + let kcp = Arc::new(RwLock::new(Kcp::new_stream(writer.remote().client_id(), writer))); + let (write_awake_tx, write_awake_rx) = async_std::channel::bounded(1); + let (read_awake_tx, read_awake_rx) = async_std::channel::bounded(1); + let kcp_c = kcp.clone(); + let task = async_std::task::spawn(async move { + let mut timer = async_std::stream::interval(std::time::Duration::from_millis(10)); + let started_at = std::time::Instant::now(); + loop { + select! { + _ = timer.next().fuse() => { + if let Err(e) = kcp_c.write().update(started_at.elapsed().as_millis() as u32) { + log::error!("[VirtualStream] kcp update error: {:?}", e); + break; + } + } + e = reader.read().fuse() => { + if let Some(buf) = e { + if buf.len() == 0 { + log::info!("[VirtualStream] reader closed"); + read_awake_tx.try_send(ReadEvent::Close).ok(); + break; + } + + let mut kcp = kcp_c.write(); + if let Err(e) = kcp.input(&buf) { + log::error!("[VirtualStream] kcp input error: {:?}", e); + break; + } else { + if let Ok(len) = kcp.peeksize() { + if len > 0 { + read_awake_tx.try_send(ReadEvent::Continue).ok(); + } + } + if kcp.wait_snd() < MAX_KCP_SEND_QUEUE { + write_awake_tx.try_send(()).ok(); + } + } + } else { + log::info!("[VirtualStream] reader closed"); + read_awake_tx.try_send(ReadEvent::Close).ok(); + break; + } + } + } + } + }); + + Self { + meta, + kcp, + task: Some(task), + write_awake_rx, + read_awake_rx, + } + } + + pub fn meta(&self) -> &HashMap { + &self.meta + } + + pub async fn write(&mut self, buf: &[u8]) -> std::io::Result { + loop { + let kcp_wait_snd = self.kcp.read().wait_snd(); + if kcp_wait_snd < MAX_KCP_SEND_QUEUE { + return self.kcp.write().send(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); + } else { + self.write_awake_rx.recv().await.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "ConnectionInterrupted"))?; + } + } + } + + pub async fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + loop { + let size = match self.kcp.write().recv(buf) { + Ok(size) => size, + Err(e) => match e { + Error::RecvQueueEmpty => 0, + _ => { + return Err(std::io::Error::new(std::io::ErrorKind::Other, e)); + } + }, + }; + if size > 0 { + return Ok(size); + } else { + match self.read_awake_rx.recv().await { + Ok(ReadEvent::Continue) => {} + Ok(ReadEvent::Close) => return Ok(0), + Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::Other, "ConnectionInterrupted")), + } + } + } + } +} + +impl Drop for VirtualStream { + fn drop(&mut self) { + if let Some(task) = self.task.take() { + async_std::task::spawn(async { + task.cancel().await; + }); + } + } +} From cb319205587ffcb0c7d838f044f13f55fd26981c Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Fri, 5 Jan 2024 15:50:57 +0700 Subject: [PATCH 3/8] add tcp_tunnel --- examples/examples/tcp_tunnel.rs | 236 ++++++++++++++++++ .../integration_tests/src/virtual_socket.rs | 4 +- .../virtual_socket/src/state/stream.rs | 10 +- 3 files changed, 245 insertions(+), 5 deletions(-) create mode 100644 examples/examples/tcp_tunnel.rs diff --git a/examples/examples/tcp_tunnel.rs b/examples/examples/tcp_tunnel.rs new file mode 100644 index 00000000..e4b217e8 --- /dev/null +++ b/examples/examples/tcp_tunnel.rs @@ -0,0 +1,236 @@ +use async_std::io::ReadExt; +use async_std::io::WriteExt; +use async_std::net::TcpListener; +use async_std::net::TcpStream; +use atm0s_sdn::NodeId; +use atm0s_sdn::VirtualSocketBehavior; +use atm0s_sdn::VirtualSocketSdk; +use atm0s_sdn::VirtualStream; +use atm0s_sdn::compose_transport_desp::FutureExt; +use atm0s_sdn::compose_transport_desp::select; +use atm0s_sdn::convert_enum; +use atm0s_sdn::SharedRouter; +use atm0s_sdn::SystemTimer; +use atm0s_sdn::{KeyValueBehavior, NodeAddr, NodeAddrBuilder, UdpTransport}; +use atm0s_sdn::{KeyValueBehaviorEvent, KeyValueHandlerEvent, KeyValueSdkEvent}; +use atm0s_sdn::{LayersSpreadRouterSyncBehavior, LayersSpreadRouterSyncBehaviorEvent, LayersSpreadRouterSyncHandlerEvent}; +use atm0s_sdn::{ManualBehavior, ManualBehaviorConf, ManualBehaviorEvent, ManualHandlerEvent}; +use atm0s_sdn::{NetworkPlane, NetworkPlaneConfig}; +use clap::Parser; +use clap::Subcommand; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; + +#[derive(convert_enum::From, convert_enum::TryInto)] +enum NodeBehaviorEvent { + Manual(ManualBehaviorEvent), + LayersSpreadRouterSync(LayersSpreadRouterSyncBehaviorEvent), + KeyValue(KeyValueBehaviorEvent), +} + +#[derive(convert_enum::From, convert_enum::TryInto)] +enum NodeHandleEvent { + Manual(ManualHandlerEvent), + LayersSpreadRouterSync(LayersSpreadRouterSyncHandlerEvent), + KeyValue(KeyValueHandlerEvent), +} + +#[derive(convert_enum::From, convert_enum::TryInto)] +enum NodeSdkEvent { + KeyValue(KeyValueSdkEvent), +} + +/// Node with manual network builder +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Current Node ID + #[arg(env, long)] + node_id: u32, + + /// Neighbours + #[arg(env, long)] + seeds: Vec, + + /// Local tags + #[arg(env, long)] + tags: Vec, + + /// Tags of nodes to connect + #[arg(env, long)] + connect_tags: Vec, + + /// + #[command(subcommand)] + mode: Mode +} + +#[derive(Debug, Subcommand)] +enum Mode { + Server(ServerOpts), + Agent(AgentOpts) +} + +#[derive(Parser, Debug)] +struct ServerOpts { + /// Tunnel Server listen TCP port + #[arg(env, long)] + listen: SocketAddr, + + /// Tunnel dest node_id + #[arg(env, long)] + dest: NodeId, +} + +#[derive(Parser, Debug)] +struct AgentOpts { + /// Tunnel Agent target addr, which will be forwarded to + #[arg(env, long)] + target: SocketAddr, +} + +async fn run_server(sdk: VirtualSocketSdk, opts: ServerOpts) { + let listener = TcpListener::bind(opts.listen).await.expect("Should bind"); + while let Ok((mut stream, remote_addr)) = listener.accept().await { + log::info!("[TcpTunnel][Server] incomming conn from {}", remote_addr); + let connector = sdk.connector(); + async_std::task::spawn(async move { + log::info!("[TcpTunnel][Server] connecting to dest node {}", opts.dest); + match connector.connect_to(opts.dest, "TUNNEL_APP", HashMap::new()).await { + Ok(socket_relay) => { + log::info!("[TcpTunnel][Server] connected to dest node {} remote {:?}", opts.dest, socket_relay.remote()); + let mut target = VirtualStream::new(socket_relay); + let mut buf1 = [0; 4096]; + let mut buf2 = [0; 4096]; + loop { + select! { + e = stream.read(&mut buf1).fuse() => { + if let Ok(len) = e { + if len == 0 { + break; + } + target.write_all(&buf1[..len]).await.expect("Should write"); + } else { + break; + } + }, + e = target.read(&mut buf2).fuse() => { + if let Ok(len) = e { + if len == 0 { + break; + } + stream.write_all(&buf2[..len]).await.expect("Should write"); + } else { + break; + } + } + } + } + } + Err(e) => { + log::info!("[TcpTunnel][Server] connect to dest node {} errpr {:?}", opts.dest, e); + } + } + }); + } +} + +async fn run_agent(sdk: VirtualSocketSdk, opts: AgentOpts) { + let mut listener = sdk.listen("TUNNEL_APP"); + while let Some(socket) = listener.recv().await { + log::info!("[TcpTunnel][Server] incomming conn from {:?}", socket.remote()); + let mut stream = VirtualStream::new(socket); + async_std::task::spawn(async move { + log::info!("[TcpTunnel][Server] connecting to target {}", opts.target); + match TcpStream::connect(&opts.target).await { + Ok(mut target) => { + log::info!("[TcpTunnel][Server] connected to target {}", opts.target); + let mut buf1 = [0; 4096]; + let mut buf2 = [0; 4096]; + loop { + select! { + e = stream.read(&mut buf1).fuse() => { + if let Ok(len) = e { + if len == 0 { + break; + } + target.write_all(&buf1[..len]).await.expect("Should write"); + } else { + break; + } + }, + e = target.read(&mut buf2).fuse() => { + if let Ok(len) = e { + if len == 0 { + break; + } + stream.write_all(&buf2[..len]).await.expect("Should write"); + } else { + break; + } + } + } + } + }, + Err(e) => { + log::info!("[TcpTunnel][Server] connect to target {} error {:?}", opts.target, e); + } + } + }); + } +} + +#[async_std::main] +async fn main() { + env_logger::builder().format_timestamp_millis().init(); + let args: Args = Args::parse(); + let secure = Arc::new(atm0s_sdn::StaticKeySecure::new("secure-token")); + let mut node_addr_builder = NodeAddrBuilder::new(args.node_id); + + let udp_socket = UdpTransport::prepare(50000 + args.node_id as u16, &mut node_addr_builder).await; + let transport = UdpTransport::new(node_addr_builder.addr(), udp_socket, secure.clone()); + + let node_addr = node_addr_builder.addr(); + log::info!("Listen on addr {}", node_addr); + + let timer = Arc::new(SystemTimer()); + let router = SharedRouter::new(args.node_id); + + let manual = ManualBehavior::new(ManualBehaviorConf { + node_id: args.node_id, + node_addr, + seeds: args.seeds.clone(), + local_tags: args.tags, + connect_tags: args.connect_tags, + }); + + let spreads_layer_router: LayersSpreadRouterSyncBehavior = LayersSpreadRouterSyncBehavior::new(router.clone()); + let key_value = KeyValueBehavior::new(args.node_id, 10000, None); + + let (virtual_socket, virtual_socket_sdk) = VirtualSocketBehavior::new(args.node_id); + + let plan_cfg = NetworkPlaneConfig { + router: Arc::new(router), + node_id: args.node_id, + tick_ms: 1000, + behaviors: vec![Box::new(manual), Box::new(spreads_layer_router), Box::new(key_value), Box::new(virtual_socket)], + transport: Box::new(transport), + timer, + }; + + let mut plane = NetworkPlane::::new(plan_cfg); + + plane.started(); + + async_std::task::spawn(async move { + match args.mode { + Mode::Server(opts) => run_server(virtual_socket_sdk, opts).await, + Mode::Agent(opts) => run_agent(virtual_socket_sdk, opts).await, + } + }); + + while let Ok(_) = plane.recv().await {} + + plane.stopped(); +} diff --git a/packages/integration_tests/src/virtual_socket.rs b/packages/integration_tests/src/virtual_socket.rs index 88ed83c6..31450492 100644 --- a/packages/integration_tests/src/virtual_socket.rs +++ b/packages/integration_tests/src/virtual_socket.rs @@ -106,7 +106,7 @@ mod test { .await .expect("Should connect"); let mut stream = VirtualStream::new(socket); - assert_eq!(stream.write(&vec![1, 2, 3]).await.expect("Should send"), 3); + assert_eq!(stream.write_all(&vec![1, 2, 3]).await.expect("Should send"), ()); async_std::task::sleep(Duration::from_secs(1)).await; }); @@ -170,7 +170,7 @@ mod test { .await .expect("Should connect"); let mut stream = VirtualStream::new(socket); - assert_eq!(stream.write(&vec![1, 2, 3]).await.expect("Should send"), 3); + assert_eq!(stream.write_all(&vec![1, 2, 3]).await.expect("Should send"), ()); async_std::task::sleep(Duration::from_secs(1)).await; }); diff --git a/packages/services/virtual_socket/src/state/stream.rs b/packages/services/virtual_socket/src/state/stream.rs index c77720d3..92fd73b1 100644 --- a/packages/services/virtual_socket/src/state/stream.rs +++ b/packages/services/virtual_socket/src/state/stream.rs @@ -26,7 +26,10 @@ impl VirtualStream { pub fn new(socket: VirtualSocket) -> Self { let (mut reader, writer) = socket.split(); let meta = reader.meta().clone(); - let kcp = Arc::new(RwLock::new(Kcp::new_stream(writer.remote().client_id(), writer))); + let mut kcp = Kcp::new_stream(writer.remote().client_id(), writer); + kcp.set_nodelay(true, 20, 2, true); + + let kcp = Arc::new(RwLock::new(kcp)); let (write_awake_tx, write_awake_rx) = async_std::channel::bounded(1); let (read_awake_tx, read_awake_rx) = async_std::channel::bounded(1); let kcp_c = kcp.clone(); @@ -86,11 +89,12 @@ impl VirtualStream { &self.meta } - pub async fn write(&mut self, buf: &[u8]) -> std::io::Result { + pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { loop { let kcp_wait_snd = self.kcp.read().wait_snd(); if kcp_wait_snd < MAX_KCP_SEND_QUEUE { - return self.kcp.write().send(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); + self.kcp.write().send(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + return Ok(()); } else { self.write_awake_rx.recv().await.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "ConnectionInterrupted"))?; } From 9a4fddb40a354d55be48f976269e33d9f675100b Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Sat, 6 Jan 2024 15:04:29 +0700 Subject: [PATCH 4/8] added secure option --- examples/examples/tcp_tunnel.rs | 151 +++++++++++++++--- .../integration_tests/src/virtual_socket.rs | 8 +- packages/services/virtual_socket/src/msg.rs | 2 +- packages/services/virtual_socket/src/sdk.rs | 1 + packages/services/virtual_socket/src/state.rs | 27 ++-- .../virtual_socket/src/state/connector.rs | 4 +- .../virtual_socket/src/state/socket.rs | 27 ++-- packages/transports/udp/src/transport.rs | 4 +- 8 files changed, 173 insertions(+), 51 deletions(-) diff --git a/examples/examples/tcp_tunnel.rs b/examples/examples/tcp_tunnel.rs index e4b217e8..7f8bb0c4 100644 --- a/examples/examples/tcp_tunnel.rs +++ b/examples/examples/tcp_tunnel.rs @@ -2,15 +2,16 @@ use async_std::io::ReadExt; use async_std::io::WriteExt; use async_std::net::TcpListener; use async_std::net::TcpStream; -use atm0s_sdn::NodeId; -use atm0s_sdn::VirtualSocketBehavior; -use atm0s_sdn::VirtualSocketSdk; -use atm0s_sdn::VirtualStream; -use atm0s_sdn::compose_transport_desp::FutureExt; +use async_std::net::UdpSocket; use atm0s_sdn::compose_transport_desp::select; +use atm0s_sdn::compose_transport_desp::FutureExt; use atm0s_sdn::convert_enum; +use atm0s_sdn::NodeId; use atm0s_sdn::SharedRouter; use atm0s_sdn::SystemTimer; +use atm0s_sdn::VirtualSocketBehavior; +use atm0s_sdn::VirtualSocketSdk; +use atm0s_sdn::VirtualStream; use atm0s_sdn::{KeyValueBehavior, NodeAddr, NodeAddrBuilder, UdpTransport}; use atm0s_sdn::{KeyValueBehaviorEvent, KeyValueHandlerEvent, KeyValueSdkEvent}; use atm0s_sdn::{LayersSpreadRouterSyncBehavior, LayersSpreadRouterSyncBehaviorEvent, LayersSpreadRouterSyncHandlerEvent}; @@ -61,18 +62,18 @@ struct Args { #[arg(env, long)] connect_tags: Vec, - /// + /// #[command(subcommand)] - mode: Mode + mode: Mode, } #[derive(Debug, Subcommand)] enum Mode { Server(ServerOpts), - Agent(AgentOpts) + Agent(AgentOpts), } -#[derive(Parser, Debug)] +#[derive(Parser, Debug, Clone)] struct ServerOpts { /// Tunnel Server listen TCP port #[arg(env, long)] @@ -81,9 +82,13 @@ struct ServerOpts { /// Tunnel dest node_id #[arg(env, long)] dest: NodeId, + + /// Tunnel is encrypted or not + #[arg(env, long)] + secure: bool, } -#[derive(Parser, Debug)] +#[derive(Parser, Debug, Clone)] struct AgentOpts { /// Tunnel Agent target addr, which will be forwarded to #[arg(env, long)] @@ -97,7 +102,7 @@ async fn run_server(sdk: VirtualSocketSdk, opts: ServerOpts) { let connector = sdk.connector(); async_std::task::spawn(async move { log::info!("[TcpTunnel][Server] connecting to dest node {}", opts.dest); - match connector.connect_to(opts.dest, "TUNNEL_APP", HashMap::new()).await { + match connector.connect_to(opts.secure, opts.dest, "TUNNEL_APP", HashMap::new()).await { Ok(socket_relay) => { log::info!("[TcpTunnel][Server] connected to dest node {} remote {:?}", opts.dest, socket_relay.remote()); let mut target = VirtualStream::new(socket_relay); @@ -139,13 +144,13 @@ async fn run_server(sdk: VirtualSocketSdk, opts: ServerOpts) { async fn run_agent(sdk: VirtualSocketSdk, opts: AgentOpts) { let mut listener = sdk.listen("TUNNEL_APP"); while let Some(socket) = listener.recv().await { - log::info!("[TcpTunnel][Server] incomming conn from {:?}", socket.remote()); + log::info!("[TcpTunnel][Agent] incomming conn from {:?}", socket.remote()); let mut stream = VirtualStream::new(socket); async_std::task::spawn(async move { - log::info!("[TcpTunnel][Server] connecting to target {}", opts.target); + log::info!("[TcpTunnel][Agent] connecting to target {}", opts.target); match TcpStream::connect(&opts.target).await { Ok(mut target) => { - log::info!("[TcpTunnel][Server] connected to target {}", opts.target); + log::info!("[TcpTunnel][Agent] connected to target {}", opts.target); let mut buf1 = [0; 4096]; let mut buf2 = [0; 4096]; loop { @@ -172,9 +177,89 @@ async fn run_agent(sdk: VirtualSocketSdk, opts: AgentOpts) { } } } - }, + } Err(e) => { - log::info!("[TcpTunnel][Server] connect to target {} error {:?}", opts.target, e); + log::info!("[TcpTunnel][Agent] connect to target {} error {:?}", opts.target, e); + } + } + }); + } +} + +async fn run_udp_server(sdk: VirtualSocketSdk, opts: ServerOpts) { + let udp_server = UdpSocket::bind(opts.listen).await.expect("Should bind"); + log::info!("[UdpTunnel][Server] listen on {}", opts.listen); + let mut buf1 = [0; 1500]; + let (_, remote_addr) = udp_server.peek_from(&mut buf1).await.expect("Should peek"); + udp_server.connect(remote_addr).await.expect("Should connect"); + log::info!("[UdpTunnel][Server] incomming conn from {}", remote_addr); + log::info!("[UdpTunnel][Server] connecting to dest node {}", opts.dest); + match sdk.connector().connect_to(opts.secure, opts.dest, "TUNNEL_APP_UDP", HashMap::new()).await { + Ok(mut socket_relay) => { + log::info!("[UdpTunnel][Server] connected to dest node {} remote {:?}", opts.dest, socket_relay.remote()); + loop { + select! { + e = udp_server.recv(&mut buf1).fuse() => { + if let Ok(len) = e { + if len == 0 { + break; + } + socket_relay.write(&buf1[..len]).expect("Should write"); + } else { + break; + } + }, + e = socket_relay.read().fuse() => { + if let Some(buf) = e { + udp_server.send(&buf).await.expect("Should write"); + } else { + break; + } + } + } + } + } + Err(e) => { + log::info!("[UdpTunnel][Server] connect to dest node {} errpr {:?}", opts.dest, e); + } + } +} + +async fn run_udp_agent(sdk: VirtualSocketSdk, opts: AgentOpts) { + let mut listener = sdk.listen("TUNNEL_APP_UDP"); + while let Some(mut socket) = listener.recv().await { + log::info!("[UdpTunnel][Agent] incomming conn from {:?}", socket.remote()); + async_std::task::spawn(async move { + log::info!("[UdpTunnel][Agent] connecting to target {}", opts.target); + let udp_socket = UdpSocket::bind("0.0.0.0:0").await.expect("Should bind"); + match udp_socket.connect(&opts.target).await { + Ok(_) => { + log::info!("[UdpTunnel][Agent] connected to target {}", opts.target); + let mut buf2 = [0; 1500]; + loop { + select! { + e = socket.read().fuse() => { + if let Some(buf) = e { + udp_socket.send(&buf).await; + } else { + break; + } + }, + e = udp_socket.recv(&mut buf2).fuse() => { + if let Ok(len) = e { + if len == 0 { + break; + } + socket.write(&buf2[..len]).expect("Should write"); + } else { + break; + } + } + } + } + } + Err(e) => { + log::info!("[UdpTunnel][Agent] connect to target {} error {:?}", opts.target, e); } } }); @@ -188,7 +273,7 @@ async fn main() { let secure = Arc::new(atm0s_sdn::StaticKeySecure::new("secure-token")); let mut node_addr_builder = NodeAddrBuilder::new(args.node_id); - let udp_socket = UdpTransport::prepare(50000 + args.node_id as u16, &mut node_addr_builder).await; + let udp_socket = UdpTransport::prepare(50000 + args.node_id as u16, &mut node_addr_builder).await; let transport = UdpTransport::new(node_addr_builder.addr(), udp_socket, secure.clone()); let node_addr = node_addr_builder.addr(); @@ -207,7 +292,7 @@ async fn main() { let spreads_layer_router: LayersSpreadRouterSyncBehavior = LayersSpreadRouterSyncBehavior::new(router.clone()); let key_value = KeyValueBehavior::new(args.node_id, 10000, None); - + let (virtual_socket, virtual_socket_sdk) = VirtualSocketBehavior::new(args.node_id); let plan_cfg = NetworkPlaneConfig { @@ -223,12 +308,32 @@ async fn main() { plane.started(); - async_std::task::spawn(async move { - match args.mode { - Mode::Server(opts) => run_server(virtual_socket_sdk, opts).await, - Mode::Agent(opts) => run_agent(virtual_socket_sdk, opts).await, + match args.mode { + Mode::Server(opts) => { + let sdk_c = virtual_socket_sdk.clone(); + let opts_c = opts.clone(); + async_std::task::spawn(async move { + run_server(sdk_c, opts_c).await; + }); + let sdk_c = virtual_socket_sdk.clone(); + let opts_c = opts.clone(); + async_std::task::spawn(async move { + run_udp_server(sdk_c, opts_c).await; + }); } - }); + Mode::Agent(opts) => { + let sdk_c = virtual_socket_sdk.clone(); + let opts_c = opts.clone(); + async_std::task::spawn(async move { + run_agent(sdk_c, opts_c).await; + }); + let sdk_c = virtual_socket_sdk.clone(); + let opts_c = opts.clone(); + async_std::task::spawn(async move { + run_udp_agent(sdk_c, opts_c).await; + }); + } + } while let Ok(_) = plane.recv().await {} diff --git a/packages/integration_tests/src/virtual_socket.rs b/packages/integration_tests/src/virtual_socket.rs index 31450492..6abbda9b 100644 --- a/packages/integration_tests/src/virtual_socket.rs +++ b/packages/integration_tests/src/virtual_socket.rs @@ -76,7 +76,7 @@ mod test { let connector = sdk.connector(); async_std::task::spawn(async move { let mut socket = connector - .connect_to(node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) + .connect_to(true, node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) .await .expect("Should connect"); socket.write(&vec![1, 2, 3]).expect("Should write"); @@ -102,7 +102,7 @@ mod test { let connector = sdk.connector(); async_std::task::spawn(async move { let socket = connector - .connect_to(node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) + .connect_to(true, node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) .await .expect("Should connect"); let mut stream = VirtualStream::new(socket); @@ -136,7 +136,7 @@ mod test { let connector2 = sdk2.connector(); async_std::task::spawn(async move { let mut socket = connector2 - .connect_to(node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) + .connect_to(true, node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) .await .expect("Should connect"); socket.write(&vec![1, 2, 3]).expect("Should write"); @@ -166,7 +166,7 @@ mod test { let connector2 = sdk2.connector(); async_std::task::spawn(async move { let socket = connector2 - .connect_to(node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) + .connect_to(true, node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) .await .expect("Should connect"); let mut stream = VirtualStream::new(socket); diff --git a/packages/services/virtual_socket/src/msg.rs b/packages/services/virtual_socket/src/msg.rs index 6243e8d8..3579dfad 100644 --- a/packages/services/virtual_socket/src/msg.rs +++ b/packages/services/virtual_socket/src/msg.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum VirtualSocketControlMsg { - ConnectRequest(String, HashMap), + ConnectRequest(String, bool, HashMap), ConnectReponse(bool), ConnectingPing, ConnectingPong, diff --git a/packages/services/virtual_socket/src/sdk.rs b/packages/services/virtual_socket/src/sdk.rs index 4934dcad..37504501 100644 --- a/packages/services/virtual_socket/src/sdk.rs +++ b/packages/services/virtual_socket/src/sdk.rs @@ -4,6 +4,7 @@ use parking_lot::RwLock; use crate::state::{connector::VirtualSocketConnector, listener::VirtualSocketListener, State}; +#[derive(Clone)] pub struct VirtualSocketSdk { state: Arc>, } diff --git a/packages/services/virtual_socket/src/state.rs b/packages/services/virtual_socket/src/state.rs index cb8e84f4..987a52e7 100644 --- a/packages/services/virtual_socket/src/state.rs +++ b/packages/services/virtual_socket/src/state.rs @@ -20,8 +20,16 @@ const PING_INTERVAL_MS: u64 = 1000; const PING_TIMEOUT_MS: u64 = 10000; enum OutgoingState { - Connecting { started_at: u64, res_tx: Sender }, - Connected { last_ping: u64, pong_time: u64, tx: Sender> }, + Connecting { + started_at: u64, + secure: bool, + res_tx: Sender, + }, + Connected { + last_ping: u64, + pong_time: u64, + tx: Sender>, + }, } struct IncommingState { @@ -71,7 +79,7 @@ impl State { rx } - pub fn new_outgoing(&mut self, dest_node_id: NodeId, dest_listener_id: &str, meta: HashMap) -> Option> { + pub fn new_outgoing(&mut self, secure: bool, dest_node_id: NodeId, dest_listener_id: &str, meta: HashMap) -> Option> { let client_idx = self.client_idx; self.client_idx += 1; log::info!("[VirtualSocketState] new outgoing: {}/{} with meta {:?} => idx {}", dest_node_id, dest_listener_id, meta, client_idx); @@ -81,13 +89,14 @@ impl State { self.outgoings.insert( socket_id.clone(), OutgoingState::Connecting { + secure, started_at: self.last_tick_ms, res_tx: tx, }, ); self.outgoing_queue.push_back(( socket_id, - VirtualSocketEvent::ClientControl(VirtualSocketControlMsg::ConnectRequest(dest_listener_id.to_string(), meta)), + VirtualSocketEvent::ClientControl(VirtualSocketControlMsg::ConnectRequest(dest_listener_id.to_string(), secure, meta)), )); Some(rx) } @@ -98,7 +107,7 @@ impl State { // Remove timed out outgoing connections let mut to_remove = Vec::new(); for (socket_id, state) in self.outgoings.iter() { - if let OutgoingState::Connecting { started_at, res_tx: tx } = state { + if let OutgoingState::Connecting { started_at, res_tx: tx, .. } = state { if now_ms - started_at > CONNECT_TIMEOUT_MS { log::info!("[VirtualSocketState] outgoing timeout: node {} idx: {}", socket_id.node_id(), socket_id.client_id()); to_remove.push(socket_id.clone()); @@ -177,11 +186,11 @@ impl State { match control { VirtualSocketControlMsg::ConnectReponse(success) => { if let Some(state) = self.outgoings.get_mut(&socket_id) { - if let OutgoingState::Connecting { res_tx, .. } = state { + if let OutgoingState::Connecting { res_tx, secure, .. } = state { if success { let (socket_tx, socket_rx) = async_std::channel::bounded(10); res_tx - .try_send(VirtualSocketConnectResult::Success(VirtualSocketBuilder::new(true, socket_id, HashMap::new(), socket_rx))) + .try_send(VirtualSocketConnectResult::Success(VirtualSocketBuilder::new(true, *secure, socket_id, HashMap::new(), socket_rx))) .print_error("Should send connect response to waiting connector"); *state = OutgoingState::Connected { last_ping: now_ms, @@ -223,7 +232,7 @@ impl State { pub fn on_recv_client_control(&mut self, now_ms: u64, socket_id: SocketId, control: VirtualSocketControlMsg) { log::debug!("[VirtualSocketState] on_recv_client_control from : {:?} {:?}", socket_id, control); match control { - VirtualSocketControlMsg::ConnectRequest(listener_id, meta) => { + VirtualSocketControlMsg::ConnectRequest(listener_id, secure, meta) => { if self.incomings.contains_key(&socket_id) { log::warn!("[VirtualSocketState] on_recv_client_control socket already connected: {:?}", socket_id); return; @@ -231,7 +240,7 @@ impl State { if let Some(tx) = self.listeners.get(&listener_id) { let (socket_tx, socket_rx) = async_std::channel::bounded(10); self.incomings.insert(socket_id.clone(), IncommingState { ping_time: now_ms, tx: socket_tx }); - tx.try_send(VirtualSocketBuilder::new(false, socket_id.clone(), meta, socket_rx)) + tx.try_send(VirtualSocketBuilder::new(false, secure, socket_id.clone(), meta, socket_rx)) .print_error("Should send new virtual socket to listener"); self.outgoing_queue .push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectReponse(true)))); diff --git a/packages/services/virtual_socket/src/state/connector.rs b/packages/services/virtual_socket/src/state/connector.rs index e4611f86..8bbb772b 100644 --- a/packages/services/virtual_socket/src/state/connector.rs +++ b/packages/services/virtual_socket/src/state/connector.rs @@ -16,8 +16,8 @@ pub struct VirtualSocketConnector { } impl VirtualSocketConnector { - pub async fn connect_to(&self, dest: NodeId, listener: &str, meta: HashMap) -> Result { - let rx = self.state.write().new_outgoing(dest, listener, meta).ok_or(VirtualSocketConnectorError::Unreachable)?; + pub async fn connect_to(&self, secure: bool, dest: NodeId, listener: &str, meta: HashMap) -> Result { + let rx = self.state.write().new_outgoing(secure, dest, listener, meta).ok_or(VirtualSocketConnectorError::Unreachable)?; match rx.recv().await { Ok(VirtualSocketConnectResult::Success(builder)) => Ok(builder.build(self.state.clone())), Ok(VirtualSocketConnectResult::Timeout) => Err(VirtualSocketConnectorError::Timeout), diff --git a/packages/services/virtual_socket/src/state/socket.rs b/packages/services/virtual_socket/src/state/socket.rs index 34c9dd4e..21933ab5 100644 --- a/packages/services/virtual_socket/src/state/socket.rs +++ b/packages/services/virtual_socket/src/state/socket.rs @@ -21,18 +21,19 @@ pub const DATA_SERVER_META: u8 = 3; pub struct VirtualSocketBuilder { is_client: bool, + secure: bool, remote: SocketId, rx: Receiver>, meta: HashMap, } impl VirtualSocketBuilder { - pub(crate) fn new(is_client: bool, remote: SocketId, meta: HashMap, rx: Receiver>) -> Self { - Self { is_client, remote, meta, rx } + pub(crate) fn new(is_client: bool, secure: bool, remote: SocketId, meta: HashMap, rx: Receiver>) -> Self { + Self { is_client, secure, remote, meta, rx } } pub fn build(self, state: Arc>) -> VirtualSocket { - VirtualSocket::new(self.is_client, self.remote, self.meta, self.rx, state) + VirtualSocket::new(self.is_client, self.secure, self.remote, self.meta, self.rx, state) } } @@ -40,24 +41,26 @@ impl VirtualSocketBuilder { pub enum VirtualSocketEvent { ServerControl(VirtualSocketControlMsg), ClientControl(VirtualSocketControlMsg), - ServerData(Vec), - ClientData(Vec), + ServerData(Vec, bool), + ClientData(Vec, bool), } impl VirtualSocketEvent { pub fn into_transport_msg(self, local_node: NodeId, remote_node: NodeId, client_id: u32) -> TransportMsg { match self { - VirtualSocketEvent::ServerData(data) => { + VirtualSocketEvent::ServerData(data, secure) => { let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) .set_from_node(Some(local_node)) .set_stream_id(client_id) + .set_secure(secure) .set_meta(DATA_SERVER_META); TransportMsg::build_raw(header, &data) } - VirtualSocketEvent::ClientData(data) => { + VirtualSocketEvent::ClientData(data, secure) => { let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) .set_from_node(Some(local_node)) .set_stream_id(client_id) + .set_secure(secure) .set_meta(DATA_CLIENT_META); TransportMsg::build_raw(header, &data) } @@ -65,6 +68,7 @@ impl VirtualSocketEvent { let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) .set_from_node(Some(local_node)) .set_stream_id(client_id) + .set_secure(true) .set_meta(CONTROL_SERVER_META); TransportMsg::from_payload_bincode(header, &control) } @@ -72,6 +76,7 @@ impl VirtualSocketEvent { let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) .set_from_node(Some(local_node)) .set_stream_id(client_id) + .set_secure(true) .set_meta(CONTROL_CLIENT_META); TransportMsg::from_payload_bincode(header, &control) } @@ -85,10 +90,11 @@ pub struct VirtualSocket { } impl VirtualSocket { - pub(crate) fn new(is_client: bool, remote: SocketId, meta: HashMap, rx: Receiver>, state: Arc>) -> Self { + pub(crate) fn new(is_client: bool, secure: bool, remote: SocketId, meta: HashMap, rx: Receiver>, state: Arc>) -> Self { Self { writer: VirtualSocketWriter { is_client, + secure, remote: remote.clone(), state: state.clone(), }, @@ -147,6 +153,7 @@ impl Drop for VirtualSocketReader { pub struct VirtualSocketWriter { is_client: bool, + secure: bool, remote: SocketId, state: Arc>, } @@ -160,9 +167,9 @@ impl VirtualSocketWriter { impl Write for VirtualSocketWriter { fn write(&mut self, buf: &[u8]) -> std::io::Result { if self.is_client { - self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ClientData(buf.to_vec())); + self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ClientData(buf.to_vec(), self.secure)); } else { - self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ServerData(buf.to_vec())); + self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ServerData(buf.to_vec(), self.secure)); } Ok(buf.len()) } diff --git a/packages/transports/udp/src/transport.rs b/packages/transports/udp/src/transport.rs index ccd4c96d..0dc565c0 100644 --- a/packages/transports/udp/src/transport.rs +++ b/packages/transports/udp/src/transport.rs @@ -71,14 +71,14 @@ impl UdpTransport { if let Ok((size, addr)) = async_socket.recv_from(&mut buf).await { let current_ms = timer.now_ms(); if let Some(msg_tx) = connection.get_mut(&addr) { - msg_tx.0.try_send((buf, size)).expect("should forward to receiver"); + msg_tx.0.try_send((buf, size)); msg_tx.1 = current_ms; } else { log::info!("[UdpTransport] on new connection from {}", addr); conn_id_seed += 1; let conn_id = ConnId::from_in(UDP_PROTOCOL_ID, conn_id_seed); let (msg_tx, msg_rx) = async_std::channel::bounded(1024); - msg_tx.try_send((buf, size)).expect("should forward to receiver"); + msg_tx.try_send((buf, size)); connection.insert(addr, (msg_tx, current_ms)); let socket = socket.clone(); let async_socket = async_socket.clone(); From ef19c7ffa1250021d2c44d08e23ba343cc5f3c78 Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Mon, 8 Jan 2024 11:12:32 +0700 Subject: [PATCH 5/8] optimize virtual socket performance by switch to quinn instead of kcp --- examples/examples/tcp_tunnel.rs | 240 ++++--------- .../integration_tests/src/virtual_socket.rs | 173 ++++----- packages/runner/src/lib.rs | 2 +- packages/services/virtual_socket/Cargo.toml | 8 +- .../services/virtual_socket/src/behavior.rs | 56 ++- .../services/virtual_socket/src/handler.rs | 16 +- packages/services/virtual_socket/src/lib.rs | 37 +- packages/services/virtual_socket/src/msg.rs | 25 -- .../virtual_socket/src/quinn_utils.rs | 17 + packages/services/virtual_socket/src/sdk.rs | 26 -- packages/services/virtual_socket/src/state.rs | 328 ------------------ .../virtual_socket/src/state/connector.rs | 28 -- .../virtual_socket/src/state/listener.rs | 20 -- .../virtual_socket/src/state/socket.rs | 180 ---------- .../virtual_socket/src/state/stream.rs | 136 -------- packages/services/virtual_socket/src/vnet.rs | 44 +++ .../virtual_socket/src/vnet/async_queue.rs | 70 ++++ .../virtual_socket/src/vnet/internal.rs | 137 ++++++++ .../virtual_socket/src/vnet/udp_socket.rs | 95 +++++ 19 files changed, 597 insertions(+), 1041 deletions(-) delete mode 100644 packages/services/virtual_socket/src/msg.rs create mode 100644 packages/services/virtual_socket/src/quinn_utils.rs delete mode 100644 packages/services/virtual_socket/src/sdk.rs delete mode 100644 packages/services/virtual_socket/src/state.rs delete mode 100644 packages/services/virtual_socket/src/state/connector.rs delete mode 100644 packages/services/virtual_socket/src/state/listener.rs delete mode 100644 packages/services/virtual_socket/src/state/socket.rs delete mode 100644 packages/services/virtual_socket/src/state/stream.rs create mode 100644 packages/services/virtual_socket/src/vnet.rs create mode 100644 packages/services/virtual_socket/src/vnet/async_queue.rs create mode 100644 packages/services/virtual_socket/src/vnet/internal.rs create mode 100644 packages/services/virtual_socket/src/vnet/udp_socket.rs diff --git a/examples/examples/tcp_tunnel.rs b/examples/examples/tcp_tunnel.rs index 7f8bb0c4..04b66a37 100644 --- a/examples/examples/tcp_tunnel.rs +++ b/examples/examples/tcp_tunnel.rs @@ -1,17 +1,18 @@ -use async_std::io::ReadExt; -use async_std::io::WriteExt; use async_std::net::TcpListener; use async_std::net::TcpStream; -use async_std::net::UdpSocket; -use atm0s_sdn::compose_transport_desp::select; -use atm0s_sdn::compose_transport_desp::FutureExt; use atm0s_sdn::convert_enum; +use atm0s_sdn::virtual_socket::create_vnet; +use atm0s_sdn::virtual_socket::make_insecure_quinn_client; +use atm0s_sdn::virtual_socket::make_insecure_quinn_server; +use atm0s_sdn::virtual_socket::quinn::Connection; +use atm0s_sdn::virtual_socket::quinn::Endpoint; +use atm0s_sdn::virtual_socket::quinn::RecvStream; +use atm0s_sdn::virtual_socket::quinn::SendStream; +use atm0s_sdn::virtual_socket::vnet_addr; +use atm0s_sdn::virtual_socket::VirtualNet; use atm0s_sdn::NodeId; use atm0s_sdn::SharedRouter; use atm0s_sdn::SystemTimer; -use atm0s_sdn::VirtualSocketBehavior; -use atm0s_sdn::VirtualSocketSdk; -use atm0s_sdn::VirtualStream; use atm0s_sdn::{KeyValueBehavior, NodeAddr, NodeAddrBuilder, UdpTransport}; use atm0s_sdn::{KeyValueBehaviorEvent, KeyValueHandlerEvent, KeyValueSdkEvent}; use atm0s_sdn::{LayersSpreadRouterSyncBehavior, LayersSpreadRouterSyncBehaviorEvent, LayersSpreadRouterSyncHandlerEvent}; @@ -19,7 +20,7 @@ use atm0s_sdn::{ManualBehavior, ManualBehaviorConf, ManualBehaviorEvent, ManualH use atm0s_sdn::{NetworkPlane, NetworkPlaneConfig}; use clap::Parser; use clap::Subcommand; -use std::collections::HashMap; +use std::error::Error; use std::net::SocketAddr; use std::sync::Arc; @@ -82,10 +83,6 @@ struct ServerOpts { /// Tunnel dest node_id #[arg(env, long)] dest: NodeId, - - /// Tunnel is encrypted or not - #[arg(env, long)] - secure: bool, } #[derive(Parser, Debug, Clone)] @@ -95,43 +92,44 @@ struct AgentOpts { target: SocketAddr, } -async fn run_server(sdk: VirtualSocketSdk, opts: ServerOpts) { +async fn open_tunnel_to(client: &Endpoint, addr: SocketAddr) -> Result<(Connection, SendStream, RecvStream), Box> { + let connection = client.connect(addr, "localhost").unwrap().await?; + let (send, recv) = connection.open_bi().await?; + Ok((connection, send, recv)) +} + +async fn run_server(sdk: VirtualNet, opts: ServerOpts) { let listener = TcpListener::bind(opts.listen).await.expect("Should bind"); - while let Ok((mut stream, remote_addr)) = listener.accept().await { + while let Ok((stream, remote_addr)) = listener.accept().await { log::info!("[TcpTunnel][Server] incomming conn from {}", remote_addr); - let connector = sdk.connector(); + let client = make_insecure_quinn_client(sdk.create_udp_socket(0, 100).unwrap()).unwrap(); + let client_local = client.local_addr().expect(""); async_std::task::spawn(async move { log::info!("[TcpTunnel][Server] connecting to dest node {}", opts.dest); - match connector.connect_to(opts.secure, opts.dest, "TUNNEL_APP", HashMap::new()).await { - Ok(socket_relay) => { - log::info!("[TcpTunnel][Server] connected to dest node {} remote {:?}", opts.dest, socket_relay.remote()); - let mut target = VirtualStream::new(socket_relay); - let mut buf1 = [0; 4096]; - let mut buf2 = [0; 4096]; - loop { - select! { - e = stream.read(&mut buf1).fuse() => { - if let Ok(len) = e { - if len == 0 { - break; - } - target.write_all(&buf1[..len]).await.expect("Should write"); - } else { - break; - } - }, - e = target.read(&mut buf2).fuse() => { - if let Ok(len) = e { - if len == 0 { - break; - } - stream.write_all(&buf2[..len]).await.expect("Should write"); - } else { - break; - } + match open_tunnel_to(&client, vnet_addr(opts.dest, 80)).await { + Ok((connection, send, mut recv)) => { + let vnet_dest = connection.remote_address(); + log::info!("[TcpTunnel][Server] connected to dest node, pipe {} <==> {} <==> {}", remote_addr, client_local, vnet_dest); + let stream_c = stream.clone(); + let task1 = async_std::task::spawn(async move { + match async_std::io::copy(stream, send).await { + Ok(_) => {} + Err(e) => { + log::info!("[TcpTunnel][Server] copy to dest {} ==> {} ==> {} error {:?}", remote_addr, client_local, vnet_dest, e); } } - } + }); + let task2 = async_std::task::spawn(async move { + match async_std::io::copy(&mut recv, stream_c).await { + Ok(_) => {} + Err(e) => { + log::info!("[TcpTunnel][Server] copy from dest {} ==> {} ==> {} error {:?}", vnet_dest, client_local, remote_addr, e); + } + } + }); + task1.await; + task2.await; + log::info!("[TcpTunnel][Server] disconnected pipe {} <==> {} <==> {}", remote_addr, client_local, vnet_dest); } Err(e) => { log::info!("[TcpTunnel][Server] connect to dest node {} errpr {:?}", opts.dest, e); @@ -141,125 +139,42 @@ async fn run_server(sdk: VirtualSocketSdk, opts: ServerOpts) { } } -async fn run_agent(sdk: VirtualSocketSdk, opts: AgentOpts) { - let mut listener = sdk.listen("TUNNEL_APP"); - while let Some(socket) = listener.recv().await { - log::info!("[TcpTunnel][Agent] incomming conn from {:?}", socket.remote()); - let mut stream = VirtualStream::new(socket); +async fn run_agent(sdk: VirtualNet, opts: AgentOpts) { + let endpoint = make_insecure_quinn_server(sdk.create_udp_socket(80, 200).expect("")).expect(""); + while let Some(connecting) = endpoint.accept().await { + log::info!("[TcpTunnel][Agent] incomming conn from {:?}", connecting.remote_address()); async_std::task::spawn(async move { - log::info!("[TcpTunnel][Agent] connecting to target {}", opts.target); + let connection = connecting.await.expect("Should accept"); + let (send, recv) = connection.accept_bi().await.expect("Should open bi"); + let vnet_remote = connection.remote_address(); + log::info!("[TcpTunnel][Agent] vnet incomming {} connecting to target {}", vnet_remote, opts.target); match TcpStream::connect(&opts.target).await { - Ok(mut target) => { - log::info!("[TcpTunnel][Agent] connected to target {}", opts.target); - let mut buf1 = [0; 4096]; - let mut buf2 = [0; 4096]; - loop { - select! { - e = stream.read(&mut buf1).fuse() => { - if let Ok(len) = e { - if len == 0 { - break; - } - target.write_all(&buf1[..len]).await.expect("Should write"); - } else { - break; - } - }, - e = target.read(&mut buf2).fuse() => { - if let Ok(len) = e { - if len == 0 { - break; - } - stream.write_all(&buf2[..len]).await.expect("Should write"); - } else { - break; - } + Ok(target) => { + let local_addr = target.local_addr().expect(""); + log::info!("[TcpTunnel][Agent] connected to target, pipe {} <==> {} <==> {}", opts.target, local_addr, vnet_remote); + let target_c = target.clone(); + let task1 = async_std::task::spawn(async move { + match async_std::io::copy(target, send).await { + Ok(_) => {} + Err(e) => { + log::info!("[TcpTunnel][Agent] copy from target {} ==> {} ==> {} error {:?}", opts.target, local_addr, vnet_remote, e); } } - } - } - Err(e) => { - log::info!("[TcpTunnel][Agent] connect to target {} error {:?}", opts.target, e); - } - } - }); - } -} - -async fn run_udp_server(sdk: VirtualSocketSdk, opts: ServerOpts) { - let udp_server = UdpSocket::bind(opts.listen).await.expect("Should bind"); - log::info!("[UdpTunnel][Server] listen on {}", opts.listen); - let mut buf1 = [0; 1500]; - let (_, remote_addr) = udp_server.peek_from(&mut buf1).await.expect("Should peek"); - udp_server.connect(remote_addr).await.expect("Should connect"); - log::info!("[UdpTunnel][Server] incomming conn from {}", remote_addr); - log::info!("[UdpTunnel][Server] connecting to dest node {}", opts.dest); - match sdk.connector().connect_to(opts.secure, opts.dest, "TUNNEL_APP_UDP", HashMap::new()).await { - Ok(mut socket_relay) => { - log::info!("[UdpTunnel][Server] connected to dest node {} remote {:?}", opts.dest, socket_relay.remote()); - loop { - select! { - e = udp_server.recv(&mut buf1).fuse() => { - if let Ok(len) = e { - if len == 0 { - break; + }); + let task2 = async_std::task::spawn(async move { + match async_std::io::copy(recv, target_c).await { + Ok(_) => {} + Err(e) => { + log::info!("[TcpTunnel][Agent] copy to target {} ==> {} ==> {} error {:?}", vnet_remote, local_addr, opts.target, e); } - socket_relay.write(&buf1[..len]).expect("Should write"); - } else { - break; - } - }, - e = socket_relay.read().fuse() => { - if let Some(buf) = e { - udp_server.send(&buf).await.expect("Should write"); - } else { - break; } - } - } - } - } - Err(e) => { - log::info!("[UdpTunnel][Server] connect to dest node {} errpr {:?}", opts.dest, e); - } - } -} - -async fn run_udp_agent(sdk: VirtualSocketSdk, opts: AgentOpts) { - let mut listener = sdk.listen("TUNNEL_APP_UDP"); - while let Some(mut socket) = listener.recv().await { - log::info!("[UdpTunnel][Agent] incomming conn from {:?}", socket.remote()); - async_std::task::spawn(async move { - log::info!("[UdpTunnel][Agent] connecting to target {}", opts.target); - let udp_socket = UdpSocket::bind("0.0.0.0:0").await.expect("Should bind"); - match udp_socket.connect(&opts.target).await { - Ok(_) => { - log::info!("[UdpTunnel][Agent] connected to target {}", opts.target); - let mut buf2 = [0; 1500]; - loop { - select! { - e = socket.read().fuse() => { - if let Some(buf) = e { - udp_socket.send(&buf).await; - } else { - break; - } - }, - e = udp_socket.recv(&mut buf2).fuse() => { - if let Ok(len) = e { - if len == 0 { - break; - } - socket.write(&buf2[..len]).expect("Should write"); - } else { - break; - } - } - } - } + }); + task1.await; + task2.await; + log::info!("[TcpTunnel][Agent] disconnected {} <==> {} <==> {}", opts.target, local_addr, vnet_remote); } Err(e) => { - log::info!("[UdpTunnel][Agent] connect to target {} error {:?}", opts.target, e); + log::info!("[TcpTunnel][Agent] connect to target {} error {:?}", opts.target, e); } } }); @@ -293,10 +208,11 @@ async fn main() { let spreads_layer_router: LayersSpreadRouterSyncBehavior = LayersSpreadRouterSyncBehavior::new(router.clone()); let key_value = KeyValueBehavior::new(args.node_id, 10000, None); - let (virtual_socket, virtual_socket_sdk) = VirtualSocketBehavior::new(args.node_id); + let router = Arc::new(router); + let (virtual_socket, virtual_socket_sdk) = create_vnet(args.node_id, router.clone()); let plan_cfg = NetworkPlaneConfig { - router: Arc::new(router), + router, node_id: args.node_id, tick_ms: 1000, behaviors: vec![Box::new(manual), Box::new(spreads_layer_router), Box::new(key_value), Box::new(virtual_socket)], @@ -315,11 +231,6 @@ async fn main() { async_std::task::spawn(async move { run_server(sdk_c, opts_c).await; }); - let sdk_c = virtual_socket_sdk.clone(); - let opts_c = opts.clone(); - async_std::task::spawn(async move { - run_udp_server(sdk_c, opts_c).await; - }); } Mode::Agent(opts) => { let sdk_c = virtual_socket_sdk.clone(); @@ -327,11 +238,6 @@ async fn main() { async_std::task::spawn(async move { run_agent(sdk_c, opts_c).await; }); - let sdk_c = virtual_socket_sdk.clone(); - let opts_c = opts.clone(); - async_std::task::spawn(async move { - run_udp_agent(sdk_c, opts_c).await; - }); } } diff --git a/packages/integration_tests/src/virtual_socket.rs b/packages/integration_tests/src/virtual_socket.rs index 6abbda9b..d19479db 100644 --- a/packages/integration_tests/src/virtual_socket.rs +++ b/packages/integration_tests/src/virtual_socket.rs @@ -1,12 +1,13 @@ #[cfg(test)] mod test { - use std::{collections::HashMap, sync::Arc, time::Duration}; + use std::{net::SocketAddrV4, sync::Arc, time::Duration}; use async_std::task::JoinHandle; use atm0s_sdn::{ - convert_enum, KeyValueBehaviorEvent, KeyValueHandlerEvent, KeyValueSdkEvent, LayersSpreadRouterSyncBehavior, LayersSpreadRouterSyncBehaviorEvent, LayersSpreadRouterSyncHandlerEvent, - ManualBehavior, ManualBehaviorConf, ManualBehaviorEvent, ManualHandlerEvent, NetworkPlane, NetworkPlaneConfig, NodeAddr, NodeAddrBuilder, NodeId, SharedRouter, SystemTimer, - VirtualSocketBehavior, VirtualSocketSdk, VirtualStream, + convert_enum, + virtual_socket::{create_vnet, make_insecure_quinn_client, make_insecure_quinn_server, vnet_addr, vnet_addr_v4, VirtualNet, VirtualSocketPkt}, + KeyValueBehaviorEvent, KeyValueHandlerEvent, KeyValueSdkEvent, LayersSpreadRouterSyncBehavior, LayersSpreadRouterSyncBehaviorEvent, LayersSpreadRouterSyncHandlerEvent, ManualBehavior, + ManualBehaviorConf, ManualBehaviorEvent, ManualHandlerEvent, NetworkPlane, NetworkPlaneConfig, NodeAddr, NodeAddrBuilder, NodeId, SharedRouter, SystemTimer, }; use atm0s_sdn_transport_vnet::VnetEarth; @@ -29,7 +30,7 @@ mod test { KeyValue(KeyValueSdkEvent), } - async fn run_node(vnet: Arc, node_id: NodeId, seeds: Vec) -> (VirtualSocketSdk, NodeAddr, JoinHandle<()>) { + async fn run_node(vnet: Arc, node_id: NodeId, seeds: Vec) -> (VirtualNet, NodeAddr, JoinHandle<()>) { log::info!("Run node {} connect to {:?}", node_id, seeds); let node_addr = Arc::new(NodeAddrBuilder::new(node_id)); let transport = Box::new(atm0s_sdn_transport_vnet::VnetTransport::new(vnet, node_addr.addr())); @@ -44,8 +45,9 @@ mod test { connect_tags: vec![], }); - let (virtual_socket_behaviour, virtual_socket_sdk) = VirtualSocketBehavior::new(node_id); let router_sync_behaviour = LayersSpreadRouterSyncBehavior::new(router.clone()); + let router = Arc::new(router); + let (virtual_socket_behaviour, virtual_socket_sdk) = create_vnet(node_id, router.clone()); let mut plane = NetworkPlane::::new(NetworkPlaneConfig { node_id, @@ -53,7 +55,7 @@ mod test { behaviors: vec![Box::new(virtual_socket_behaviour), Box::new(router_sync_behaviour), Box::new(manual)], transport, timer, - router: Arc::new(router.clone()), + router, }); let join = async_std::task::spawn(async move { @@ -72,88 +74,86 @@ mod test { let (sdk, _addr, join) = run_node(vnet.clone(), node_id, vec![]).await; async_std::task::sleep(Duration::from_millis(300)).await; - let mut listener = sdk.listen("DEMO"); - let connector = sdk.connector(); - async_std::task::spawn(async move { - let mut socket = connector - .connect_to(true, node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) - .await - .expect("Should connect"); - socket.write(&vec![1, 2, 3]).expect("Should write"); - }); - - if let Some(mut socket) = listener.recv().await { - assert_eq!(socket.meta(), &HashMap::from([("k1".to_string(), "k2".to_string())])); - assert_eq!(socket.read().await.expect("Should read"), vec![1, 2, 3]); - assert_eq!(socket.read().await, None); - } + let server = sdk.create_udp_socket(1000, 10).expect(""); + let client = sdk.create_udp_socket(0, 10).expect(""); + client.send_to(vnet_addr_v4(node_id, 1000), &vec![1, 2, 3], None).expect("Should write"); + assert_eq!( + server.try_recv_from(), + Some(VirtualSocketPkt { + src: SocketAddrV4::new(node_id.into(), client.local_port()), + payload: vec![1, 2, 3], + ecn: None, + }) + ); + assert_eq!(server.try_recv_from(), None); join.cancel().await; } #[async_std::test] - async fn local_stream() { - let node_id = 1; + async fn remote_socket() { let vnet = Arc::new(VnetEarth::default()); - let (sdk, _addr, join) = run_node(vnet.clone(), node_id, vec![]).await; + + let node_id1 = 1; + let node_id2 = 2; + let (sdk1, addr1, join1) = run_node(vnet.clone(), node_id1, vec![]).await; + let (sdk2, _addr2, join2) = run_node(vnet.clone(), node_id2, vec![addr1]).await; async_std::task::sleep(Duration::from_millis(300)).await; - let mut listener = sdk.listen("DEMO"); - let connector = sdk.connector(); - async_std::task::spawn(async move { - let socket = connector - .connect_to(true, node_id, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) - .await - .expect("Should connect"); - let mut stream = VirtualStream::new(socket); - assert_eq!(stream.write_all(&vec![1, 2, 3]).await.expect("Should send"), ()); - async_std::task::sleep(Duration::from_secs(1)).await; - }); + let server1 = sdk1.create_udp_socket(1000, 10).expect(""); + let client2 = sdk2.create_udp_socket(0, 10).expect(""); + client2.send_to(vnet_addr_v4(node_id1, 1000), &vec![1, 2, 3], None).expect("Should write"); - if let Some(socket) = listener.recv().await { - let mut stream = VirtualStream::new(socket); - assert_eq!(stream.meta(), &HashMap::from([("k1".to_string(), "k2".to_string())])); - let mut buf = vec![0; 1500]; - assert_eq!(stream.read(&mut buf).await.expect("Should read"), 3); - assert_eq!(buf[..3], [1, 2, 3]); - assert_eq!(stream.read(&mut buf).await.expect("Should read"), 0); - } + assert_eq!( + server1.recv_from().await, + Some(VirtualSocketPkt { + src: SocketAddrV4::new(node_id2.into(), client2.local_port()), + payload: vec![1, 2, 3], + ecn: None, + }) + ); - join.cancel().await; + join1.cancel().await; + join2.cancel().await; } #[async_std::test] - async fn remote_socket() { + async fn local_quinn() { + let node_id = 1; let vnet = Arc::new(VnetEarth::default()); - - let node_id1 = 1; - let node_id2 = 2; - let (sdk1, addr1, join1) = run_node(vnet.clone(), node_id1, vec![]).await; - let (sdk2, _addr2, join2) = run_node(vnet.clone(), node_id2, vec![addr1]).await; + let (sdk, _addr, join) = run_node(vnet.clone(), node_id, vec![]).await; async_std::task::sleep(Duration::from_millis(300)).await; - let mut listener1 = sdk1.listen("DEMO"); - let connector2 = sdk2.connector(); + let server = make_insecure_quinn_server(sdk.create_udp_socket(1000, 10).expect("")).expect(""); + let client = make_insecure_quinn_client(sdk.create_udp_socket(0, 10).expect("")).expect(""); + async_std::task::spawn(async move { - let mut socket = connector2 - .connect_to(true, node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) - .await - .expect("Should connect"); - socket.write(&vec![1, 2, 3]).expect("Should write"); + let connection = client.connect(vnet_addr(node_id, 1000), "localhost").unwrap().await.unwrap(); + let (mut send, mut recv) = connection.open_bi().await.unwrap(); + send.write(&vec![4, 5, 6]).await.unwrap(); + let mut buf = vec![0; 3]; + let len = recv.read(&mut buf).await.unwrap(); + assert_eq!(buf, vec![1, 2, 3]); + assert_eq!(len, Some(3)); + send.write(&vec![7, 8, 9]).await.unwrap(); + send.finish().await.unwrap(); }); - if let Some(mut socket) = listener1.recv().await { - assert_eq!(socket.meta(), &HashMap::from([("k1".to_string(), "k2".to_string())])); - assert_eq!(socket.read().await.expect("Should read"), vec![1, 2, 3]); - assert_eq!(socket.read().await, None); - } - - join1.cancel().await; - join2.cancel().await; + let connection = server.accept().await.unwrap().await.unwrap(); + let (mut send, mut recv) = connection.accept_bi().await.unwrap(); + let mut buf = vec![0; 3]; + let len = recv.read(&mut buf).await.unwrap(); + assert_eq!(buf, vec![4, 5, 6]); + assert_eq!(len, Some(3)); + send.write(&vec![1, 2, 3]).await.unwrap(); + let len = recv.read(&mut buf).await.unwrap(); + assert_eq!(buf, vec![7, 8, 9]); + assert_eq!(len, Some(3)); + join.cancel().await; } #[async_std::test] - async fn remote_stream() { + async fn remote_quinn() { let vnet = Arc::new(VnetEarth::default()); let node_id1 = 1; @@ -162,26 +162,31 @@ mod test { let (sdk2, _addr2, join2) = run_node(vnet.clone(), node_id2, vec![addr1]).await; async_std::task::sleep(Duration::from_millis(300)).await; - let mut listener1 = sdk1.listen("DEMO"); - let connector2 = sdk2.connector(); + let server1 = make_insecure_quinn_server(sdk1.create_udp_socket(1000, 10).expect("")).expect(""); + let client2 = make_insecure_quinn_client(sdk2.create_udp_socket(0, 10).expect("")).expect(""); + async_std::task::spawn(async move { - let socket = connector2 - .connect_to(true, node_id1, "DEMO", HashMap::from([("k1".to_string(), "k2".to_string())])) - .await - .expect("Should connect"); - let mut stream = VirtualStream::new(socket); - assert_eq!(stream.write_all(&vec![1, 2, 3]).await.expect("Should send"), ()); - async_std::task::sleep(Duration::from_secs(1)).await; + let connection = client2.connect(vnet_addr(node_id1, 1000), "localhost").unwrap().await.unwrap(); + let (mut send, mut recv) = connection.open_bi().await.unwrap(); + send.write(&vec![4, 5, 6]).await.unwrap(); + let mut buf = vec![0; 3]; + let len = recv.read(&mut buf).await.unwrap(); + assert_eq!(buf, vec![1, 2, 3]); + assert_eq!(len, Some(3)); + send.write(&vec![7, 8, 9]).await.unwrap(); + send.finish().await.unwrap(); }); - if let Some(socket) = listener1.recv().await { - let mut stream = VirtualStream::new(socket); - assert_eq!(stream.meta(), &HashMap::from([("k1".to_string(), "k2".to_string())])); - let mut buf = vec![0; 1500]; - assert_eq!(stream.read(&mut buf).await.expect("Should read"), 3); - assert_eq!(buf[..3], [1, 2, 3]); - assert_eq!(stream.read(&mut buf).await.expect("Should read"), 0); - } + let connection = server1.accept().await.unwrap().await.unwrap(); + let (mut send, mut recv) = connection.accept_bi().await.unwrap(); + let mut buf = vec![0; 3]; + let len = recv.read(&mut buf).await.unwrap(); + assert_eq!(buf, vec![4, 5, 6]); + assert_eq!(len, Some(3)); + send.write(&vec![1, 2, 3]).await.unwrap(); + let len = recv.read(&mut buf).await.unwrap(); + assert_eq!(buf, vec![7, 8, 9]); + assert_eq!(len, Some(3)); join1.cancel().await; join2.cancel().await; diff --git a/packages/runner/src/lib.rs b/packages/runner/src/lib.rs index eaba4a46..530f6304 100644 --- a/packages/runner/src/lib.rs +++ b/packages/runner/src/lib.rs @@ -34,7 +34,7 @@ pub use atm0s_sdn_pub_sub::{ pub use atm0s_sdn_rpc::{RpcBehavior, RpcBox, RpcEmitter, RpcError, RpcHandler, RpcIdGenerate, RpcMsg, RpcMsgParam, RpcQueue, RpcRequest}; #[cfg(feature = "virtual-socket")] -pub use atm0s_sdn_virtual_socket::{VirtualSocket, VirtualSocketBehavior, VirtualSocketSdk, VirtualStream}; +pub use atm0s_sdn_virtual_socket as virtual_socket; #[cfg(feature = "transport-tcp")] pub use atm0s_sdn_transport_tcp::TcpTransport; diff --git a/packages/services/virtual_socket/Cargo.toml b/packages/services/virtual_socket/Cargo.toml index 99f09484..da182ce8 100644 --- a/packages/services/virtual_socket/Cargo.toml +++ b/packages/services/virtual_socket/Cargo.toml @@ -17,5 +17,9 @@ futures = "0.3" async-trait = { workspace = true } async-std = { workspace = true } parking_lot = { workspace = true } -serde = { workspace = true } -kcp = "0.5.3" +quinn = { version = "0.10.2", default-features = false, features = ["runtime-async-std", "log", "futures-io"], optional = true } +quinn-plaintext = "0.2.0" + +[features] +default = ["quic"] +quic = ["quinn"] diff --git a/packages/services/virtual_socket/src/behavior.rs b/packages/services/virtual_socket/src/behavior.rs index eeb80f05..30781139 100644 --- a/packages/services/virtual_socket/src/behavior.rs +++ b/packages/services/virtual_socket/src/behavior.rs @@ -4,26 +4,18 @@ use atm0s_sdn_identity::{ConnId, NodeId}; use atm0s_sdn_network::{ behaviour::{BehaviorContext, ConnectionHandler, NetworkBehavior, NetworkBehaviorAction}, msg::TransportMsg, - transport::{ConnectionEvent, ConnectionRejectReason, ConnectionSender, OutgoingConnectionError}, + transport::{ConnectionRejectReason, ConnectionSender, OutgoingConnectionError}, }; -use parking_lot::RwLock; -use crate::{ - handler::VirtualSocketHandler, - state::{process_incoming_data, State}, - VirtualSocketSdk, VIRTUAL_SOCKET_SERVICE_ID, -}; +use crate::{handler::VirtualSocketHandler, vnet::internal::VirtualNetInternal, VIRTUAL_SOCKET_SERVICE_ID}; pub struct VirtualSocketBehavior { - node_id: NodeId, - state: Arc>, + internal: VirtualNetInternal, } impl VirtualSocketBehavior { - pub fn new(node_id: NodeId) -> (Self, VirtualSocketSdk) { - log::info!("[VirtualSocketBehavior] create new on node: {}", node_id); - let state = Arc::new(RwLock::new(State::default())); - (Self { node_id, state: state.clone() }, VirtualSocketSdk::new(state)) + pub fn new(internal: VirtualNetInternal) -> Self { + Self { internal } } } @@ -32,20 +24,16 @@ impl NetworkBehavior for VirtualSocketBehavior { VIRTUAL_SOCKET_SERVICE_ID } - fn on_started(&mut self, ctx: &BehaviorContext, _now_ms: u64) { - self.state.write().set_awaker(ctx.awaker.clone()); - } + fn on_started(&mut self, _ctx: &BehaviorContext, _now_ms: u64) {} - fn on_tick(&mut self, _ctx: &BehaviorContext, now_ms: u64, _interval_ms: u64) { - self.state.write().on_tick(now_ms); - } + fn on_tick(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _interval_ms: u64) {} fn on_awake(&mut self, _ctx: &BehaviorContext, _now_ms: u64) {} fn on_sdk_msg(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _from_service: u8, _event: SE) {} - fn on_local_msg(&mut self, _ctx: &BehaviorContext, now_ms: u64, msg: TransportMsg) { - process_incoming_data(now_ms, &self.state, ConnectionEvent::Msg(msg)); + fn on_local_msg(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _msg: TransportMsg) { + panic!("Should not happend"); } fn check_incoming_connection(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node: NodeId, _conn_id: ConnId) -> Result<(), ConnectionRejectReason> { @@ -56,27 +44,33 @@ impl NetworkBehavior for VirtualSocketBehavior { Ok(()) } - fn on_incoming_connection_connected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _conn: Arc) -> Option>> { - Some(Box::new(VirtualSocketHandler { state: self.state.clone() })) + fn on_incoming_connection_connected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, conn: Arc) -> Option>> { + self.internal.add_conn(conn); + Some(Box::new(VirtualSocketHandler { internal: self.internal.clone() })) } - fn on_outgoing_connection_connected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _conn: Arc) -> Option>> { - Some(Box::new(VirtualSocketHandler { state: self.state.clone() })) + fn on_outgoing_connection_connected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, conn: Arc) -> Option>> { + self.internal.add_conn(conn); + Some(Box::new(VirtualSocketHandler { internal: self.internal.clone() })) } - fn on_incoming_connection_disconnected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, _conn_id: ConnId) {} + fn on_incoming_connection_disconnected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, conn_id: ConnId) { + self.internal.remove_conn(conn_id); + } - fn on_outgoing_connection_disconnected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, _conn_id: ConnId) {} + fn on_outgoing_connection_disconnected(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, conn_id: ConnId) { + self.internal.remove_conn(conn_id); + } - fn on_outgoing_connection_error(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, _conn_id: ConnId, _err: &OutgoingConnectionError) {} + fn on_outgoing_connection_error(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, conn_id: ConnId, _err: &OutgoingConnectionError) { + self.internal.remove_conn(conn_id); + } fn on_handler_event(&mut self, _ctx: &BehaviorContext, _now_ms: u64, _node_id: NodeId, _conn_id: ConnId, _event: BE) {} fn on_stopped(&mut self, _ctx: &BehaviorContext, _now_ms: u64) {} fn pop_action(&mut self) -> Option> { - let (socket_id, out) = self.state.write().pop_outgoing()?; - let msg = out.into_transport_msg(self.node_id, socket_id.node_id(), socket_id.client_id()); - Some(NetworkBehaviorAction::ToNet(msg)) + None } } diff --git a/packages/services/virtual_socket/src/handler.rs b/packages/services/virtual_socket/src/handler.rs index ef11e4b4..4cd59fe3 100644 --- a/packages/services/virtual_socket/src/handler.rs +++ b/packages/services/virtual_socket/src/handler.rs @@ -1,16 +1,13 @@ -use std::sync::Arc; - use atm0s_sdn_identity::{ConnId, NodeId}; use atm0s_sdn_network::{ behaviour::{ConnectionContext, ConnectionHandler, ConnectionHandlerAction}, transport::ConnectionEvent, }; -use parking_lot::RwLock; -use crate::state::{process_incoming_data, State}; +use crate::vnet::internal::VirtualNetInternal; pub struct VirtualSocketHandler { - pub(crate) state: Arc>, + pub(crate) internal: VirtualNetInternal, } impl ConnectionHandler for VirtualSocketHandler { @@ -24,8 +21,13 @@ impl ConnectionHandler for VirtualSocketHandler { fn on_awake(&mut self, _ctx: &ConnectionContext, _now_ms: u64) {} /// Called when an event occurs on the connection. - fn on_event(&mut self, _ctx: &ConnectionContext, now_ms: u64, event: ConnectionEvent) { - process_incoming_data(now_ms, &self.state, event); + fn on_event(&mut self, _ctx: &ConnectionContext, _now_ms: u64, event: ConnectionEvent) { + match event { + ConnectionEvent::Msg(msg) => { + self.internal.on_incomming(msg); + } + _ => {} + } } /// Called when an event occurs on another handler. diff --git a/packages/services/virtual_socket/src/lib.rs b/packages/services/virtual_socket/src/lib.rs index a70aaf99..e151bb34 100644 --- a/packages/services/virtual_socket/src/lib.rs +++ b/packages/services/virtual_socket/src/lib.rs @@ -1,11 +1,36 @@ +use std::{ + net::{SocketAddr, SocketAddrV4}, + sync::Arc, +}; + +use atm0s_sdn_identity::NodeId; +use atm0s_sdn_router::RouterTable; +use behavior::VirtualSocketBehavior; + pub(crate) const VIRTUAL_SOCKET_SERVICE_ID: u8 = 6; mod behavior; mod handler; -mod msg; -mod sdk; -pub(crate) mod state; +#[cfg(feature = "quinn")] +mod quinn_utils; +mod vnet; + +#[cfg(feature = "quinn")] +pub use quinn; +#[cfg(feature = "quinn")] +pub use quinn_utils::{make_insecure_quinn_client, make_insecure_quinn_server}; +pub use vnet::{udp_socket::VirtualUdpSocket, VirtualNet, VirtualNetError, VirtualSocketPkt}; + +pub fn create_vnet(node_id: NodeId, router: Arc) -> (VirtualSocketBehavior, vnet::VirtualNet) { + let (net, interal) = vnet::VirtualNet::new(node_id, router); + let behavior = VirtualSocketBehavior::new(interal); + (behavior, net) +} + +pub fn vnet_addr_v4(node_id: NodeId, port: u16) -> SocketAddrV4 { + SocketAddrV4::new(node_id.into(), port) +} -pub use behavior::VirtualSocketBehavior; -pub use sdk::VirtualSocketSdk; -pub use state::{socket::VirtualSocket, stream::VirtualStream, VirtualSocketConnectResult}; +pub fn vnet_addr(node_id: NodeId, port: u16) -> SocketAddr { + SocketAddr::V4(vnet_addr_v4(node_id, port)) +} diff --git a/packages/services/virtual_socket/src/msg.rs b/packages/services/virtual_socket/src/msg.rs deleted file mode 100644 index 3579dfad..00000000 --- a/packages/services/virtual_socket/src/msg.rs +++ /dev/null @@ -1,25 +0,0 @@ -use atm0s_sdn_identity::NodeId; -use serde::{Deserialize, Serialize}; - -use std::collections::HashMap; - -#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum VirtualSocketControlMsg { - ConnectRequest(String, bool, HashMap), - ConnectReponse(bool), - ConnectingPing, - ConnectingPong, - ConnectionClose(), -} - -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct SocketId(pub NodeId, pub u32); -impl SocketId { - pub fn node_id(&self) -> NodeId { - self.0 - } - - pub fn client_id(&self) -> u32 { - self.1 - } -} diff --git a/packages/services/virtual_socket/src/quinn_utils.rs b/packages/services/virtual_socket/src/quinn_utils.rs new file mode 100644 index 00000000..19787c81 --- /dev/null +++ b/packages/services/virtual_socket/src/quinn_utils.rs @@ -0,0 +1,17 @@ +use std::sync::Arc; + +use quinn::{AsyncStdRuntime, Endpoint, EndpointConfig}; + +use crate::VirtualUdpSocket; + +pub fn make_insecure_quinn_server(socket: VirtualUdpSocket) -> Result { + let runtime = Arc::new(AsyncStdRuntime); + Endpoint::new_with_abstract_socket(EndpointConfig::default(), Some(quinn_plaintext::server_config()), socket, runtime) +} + +pub fn make_insecure_quinn_client(socket: VirtualUdpSocket) -> Result { + let runtime = Arc::new(AsyncStdRuntime); + let mut endpoint = Endpoint::new_with_abstract_socket(EndpointConfig::default(), None, socket, runtime)?; + endpoint.set_default_client_config(quinn_plaintext::client_config()); + Ok(endpoint) +} diff --git a/packages/services/virtual_socket/src/sdk.rs b/packages/services/virtual_socket/src/sdk.rs deleted file mode 100644 index 37504501..00000000 --- a/packages/services/virtual_socket/src/sdk.rs +++ /dev/null @@ -1,26 +0,0 @@ -use std::sync::Arc; - -use parking_lot::RwLock; - -use crate::state::{connector::VirtualSocketConnector, listener::VirtualSocketListener, State}; - -#[derive(Clone)] -pub struct VirtualSocketSdk { - state: Arc>, -} - -impl VirtualSocketSdk { - pub fn new(state: Arc>) -> Self { - Self { state } - } - - pub fn connector(&self) -> VirtualSocketConnector { - VirtualSocketConnector { state: self.state.clone() } - } - - pub fn listen(&self, id: &str) -> VirtualSocketListener { - log::info!("[VirtualSocketSdk] listen on: {}", id); - let rx = self.state.write().new_listener(id); - VirtualSocketListener { rx, state: self.state.clone() } - } -} diff --git a/packages/services/virtual_socket/src/state.rs b/packages/services/virtual_socket/src/state.rs deleted file mode 100644 index 987a52e7..00000000 --- a/packages/services/virtual_socket/src/state.rs +++ /dev/null @@ -1,328 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use async_std::channel::{Receiver, Sender}; -use atm0s_sdn_identity::NodeId; -use atm0s_sdn_network::transport::ConnectionEvent; -use atm0s_sdn_utils::{awaker::Awaker, error_handle::ErrorUtils, option_handle::OptionUtils, vec_dequeue::VecDeque}; -use parking_lot::RwLock; - -use crate::msg::{SocketId, VirtualSocketControlMsg}; - -use self::socket::{VirtualSocketBuilder, VirtualSocketEvent, CONTROL_CLIENT_META, CONTROL_SERVER_META, DATA_CLIENT_META, DATA_SERVER_META}; - -pub mod connector; -pub mod listener; -pub mod socket; -pub mod stream; - -const CONNECT_TIMEOUT_MS: u64 = 10000; -const PING_INTERVAL_MS: u64 = 1000; -const PING_TIMEOUT_MS: u64 = 10000; - -enum OutgoingState { - Connecting { - started_at: u64, - secure: bool, - res_tx: Sender, - }, - Connected { - last_ping: u64, - pong_time: u64, - tx: Sender>, - }, -} - -struct IncommingState { - ping_time: u64, - tx: Sender>, -} - -pub enum VirtualSocketConnectResult { - Success(VirtualSocketBuilder), - Timeout, - Unreachable, -} - -pub struct State { - client_idx: u32, - last_tick_ms: u64, - listeners: HashMap>, - outgoings: HashMap, - incomings: HashMap, - outgoing_queue: VecDeque<(SocketId, VirtualSocketEvent)>, - awaker: Option>, -} - -impl Default for State { - fn default() -> Self { - Self { - client_idx: 0, - last_tick_ms: 0, - listeners: HashMap::new(), - outgoings: HashMap::new(), - incomings: HashMap::new(), - outgoing_queue: VecDeque::new(), - awaker: None, - } - } -} - -impl State { - pub fn set_awaker(&mut self, awaker: Arc) { - self.awaker = Some(awaker); - } - - pub fn new_listener(&mut self, id: &str) -> Receiver { - log::info!("[VirtualSocketState] new listener: {}", id); - let (tx, rx) = async_std::channel::bounded(10); - self.listeners.insert(id.to_string(), tx); - rx - } - - pub fn new_outgoing(&mut self, secure: bool, dest_node_id: NodeId, dest_listener_id: &str, meta: HashMap) -> Option> { - let client_idx = self.client_idx; - self.client_idx += 1; - log::info!("[VirtualSocketState] new outgoing: {}/{} with meta {:?} => idx {}", dest_node_id, dest_listener_id, meta, client_idx); - let socket_id = SocketId(dest_node_id, client_idx); - - let (tx, rx) = async_std::channel::bounded(1); - self.outgoings.insert( - socket_id.clone(), - OutgoingState::Connecting { - secure, - started_at: self.last_tick_ms, - res_tx: tx, - }, - ); - self.outgoing_queue.push_back(( - socket_id, - VirtualSocketEvent::ClientControl(VirtualSocketControlMsg::ConnectRequest(dest_listener_id.to_string(), secure, meta)), - )); - Some(rx) - } - - pub fn on_tick(&mut self, now_ms: u64) { - self.last_tick_ms = now_ms; - - // Remove timed out outgoing connections - let mut to_remove = Vec::new(); - for (socket_id, state) in self.outgoings.iter() { - if let OutgoingState::Connecting { started_at, res_tx: tx, .. } = state { - if now_ms - started_at > CONNECT_TIMEOUT_MS { - log::info!("[VirtualSocketState] outgoing timeout: node {} idx: {}", socket_id.node_id(), socket_id.client_id()); - to_remove.push(socket_id.clone()); - tx.try_send(VirtualSocketConnectResult::Timeout).print_error("Should send timeout to waiting connector"); - } - } - } - for socket_id in to_remove { - self.outgoings.remove(&socket_id).print_none("Should remove timed out outgoing connection"); - } - - // send ping from outgoing sockets - for (socket_id, state) in self.outgoings.iter_mut() { - if let OutgoingState::Connected { last_ping, .. } = state { - if now_ms - *last_ping > PING_INTERVAL_MS { - *last_ping = now_ms; - self.outgoing_queue - .push_back((socket_id.clone(), VirtualSocketEvent::ClientControl(VirtualSocketControlMsg::ConnectingPing))); - } - } - } - - // Remote ping timeout outgoing sockets - let mut to_remove = Vec::new(); - for (socket_id, state) in self.outgoings.iter() { - if let OutgoingState::Connected { pong_time, .. } = state { - if now_ms - *pong_time > PING_TIMEOUT_MS { - log::info!("[VirtualSocketState] outgoing ping timeout: node {} idx: {}", socket_id.node_id(), socket_id.client_id()); - to_remove.push(socket_id.clone()); - } - } - } - for socket_id in to_remove { - self.outgoings.remove(&socket_id).print_none("Should remove timed out outgoing connection"); - } - - // Remove timed out incoming sockets - let mut to_remove = Vec::new(); - for (socket_id, state) in self.incomings.iter() { - if now_ms - state.ping_time > PING_TIMEOUT_MS { - log::info!("[VirtualSocketState] incoming ping timeout: node {} idx: {}", socket_id.node_id(), socket_id.client_id()); - to_remove.push(socket_id.clone()); - } - } - for socket_id in to_remove { - self.incomings.remove(&socket_id).print_none("Should remove timed out incoming connection"); - } - } - - pub fn on_recv_server_data(&self, _now_ms: u64, socket_id: SocketId, data: &[u8]) { - log::debug!("[VirtualSocketState] on_recv_server_data: {:?} {:?}", socket_id, data); - if let Some(OutgoingState::Connected { tx, .. }) = self.outgoings.get(&socket_id) { - tx.try_send(data.to_vec()).ok(); - } - } - - pub fn on_recv_client_data(&self, _now_ms: u64, socket_id: SocketId, data: &[u8]) { - log::debug!("[VirtualSocketState] on_recv_client_data: {:?} {:?}", socket_id, data); - if let Some(state) = self.incomings.get(&socket_id) { - state.tx.try_send(data.to_vec()).ok(); - } - } - - pub fn send_out(&mut self, socket_id: SocketId, event: VirtualSocketEvent) { - log::debug!("[VirtualSocketState] send_out to : {:?} {:?}", socket_id, event); - self.outgoing_queue.push_back((socket_id, event)); - if self.outgoing_queue.len() == 1 { - if let Some(awaker) = self.awaker.as_ref() { - awaker.notify(); - } - } - } - - pub fn on_recv_server_control(&mut self, now_ms: u64, socket_id: SocketId, control: VirtualSocketControlMsg) { - log::debug!("[VirtualSocketState] on_recv_server_control from : {:?} {:?}", socket_id, control); - match control { - VirtualSocketControlMsg::ConnectReponse(success) => { - if let Some(state) = self.outgoings.get_mut(&socket_id) { - if let OutgoingState::Connecting { res_tx, secure, .. } = state { - if success { - let (socket_tx, socket_rx) = async_std::channel::bounded(10); - res_tx - .try_send(VirtualSocketConnectResult::Success(VirtualSocketBuilder::new(true, *secure, socket_id, HashMap::new(), socket_rx))) - .print_error("Should send connect response to waiting connector"); - *state = OutgoingState::Connected { - last_ping: now_ms, - pong_time: now_ms, - tx: socket_tx, - }; - } else { - res_tx - .try_send(VirtualSocketConnectResult::Unreachable) - .print_error("Should send connect response to waiting connector"); - self.outgoings.remove(&socket_id).print_none("Should remove failed outgoing connection"); - }; - } else { - log::warn!("[VirtualSocketState] on_recv_server_control socket already connected: {:?}", socket_id); - } - } else { - log::warn!("[VirtualSocketState] on_recv_server_control socket not found: {:?}", socket_id); - } - } - VirtualSocketControlMsg::ConnectionClose() => { - if self.incomings.remove(&socket_id).is_some() { - log::info!("[VirtualSocketState] closed outgoing socket: {:?}", socket_id); - } - } - VirtualSocketControlMsg::ConnectingPong => { - //update pong time to outgoing sockets - if let Some(state) = self.outgoings.get_mut(&socket_id) { - if let OutgoingState::Connected { pong_time, .. } = state { - *pong_time = now_ms; - } - } - } - _ => { - log::warn!("[VirtualSocketState] on_recv_server_control Unknown control message: {:?}", control); - } - } - } - - pub fn on_recv_client_control(&mut self, now_ms: u64, socket_id: SocketId, control: VirtualSocketControlMsg) { - log::debug!("[VirtualSocketState] on_recv_client_control from : {:?} {:?}", socket_id, control); - match control { - VirtualSocketControlMsg::ConnectRequest(listener_id, secure, meta) => { - if self.incomings.contains_key(&socket_id) { - log::warn!("[VirtualSocketState] on_recv_client_control socket already connected: {:?}", socket_id); - return; - } - if let Some(tx) = self.listeners.get(&listener_id) { - let (socket_tx, socket_rx) = async_std::channel::bounded(10); - self.incomings.insert(socket_id.clone(), IncommingState { ping_time: now_ms, tx: socket_tx }); - tx.try_send(VirtualSocketBuilder::new(false, secure, socket_id.clone(), meta, socket_rx)) - .print_error("Should send new virtual socket to listener"); - self.outgoing_queue - .push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectReponse(true)))); - } else { - self.outgoing_queue - .push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectReponse(false)))); - } - } - VirtualSocketControlMsg::ConnectionClose() => { - if self.incomings.remove(&socket_id).is_some() { - log::info!("[VirtualSocketState] closed incoming socket: {:?}", socket_id); - } - } - VirtualSocketControlMsg::ConnectingPing => { - //update ping time to incoming sockets - if let Some(state) = self.incomings.get_mut(&socket_id) { - state.ping_time = now_ms; - } - self.outgoing_queue.push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectingPong))); - } - _ => { - log::warn!("[VirtualSocketState] on_recv_client_control Unknown control message: {:?}", control); - } - } - } - - pub fn close_socket(&mut self, is_client: bool, socket_id: &SocketId) { - if is_client { - if self.outgoings.remove(socket_id).is_some() { - log::debug!("[VirtualSocketState] will close outgoing socket: {:?}", socket_id); - let socket_id: SocketId = socket_id.clone(); - self.outgoing_queue - .push_back((socket_id, VirtualSocketEvent::ClientControl(VirtualSocketControlMsg::ConnectionClose()))); - } - } else { - if self.incomings.remove(socket_id).is_some() { - log::debug!("[VirtualSocketState] will close incomming socket: {:?}", socket_id); - let socket_id: SocketId = socket_id.clone(); - self.outgoing_queue - .push_back((socket_id, VirtualSocketEvent::ServerControl(VirtualSocketControlMsg::ConnectionClose()))); - } - } - } - - pub fn pop_outgoing(&mut self) -> Option<(SocketId, VirtualSocketEvent)> { - self.outgoing_queue.pop_front() - } -} - -pub fn process_incoming_data(now_ms: u64, state: &RwLock, event: ConnectionEvent) { - if let ConnectionEvent::Msg(data) = event { - if let Some(from) = data.header.from_node { - match data.header.meta { - DATA_CLIENT_META => { - let socket_id = SocketId(from, data.header.stream_id); - state.read().on_recv_client_data(now_ms, socket_id, data.payload()); - } - DATA_SERVER_META => { - let socket_id = SocketId(from, data.header.stream_id); - state.read().on_recv_server_data(now_ms, socket_id, data.payload()); - } - CONTROL_CLIENT_META => { - if let Ok(control) = data.get_payload_bincode::() { - //is control - if let Some(from) = data.header.from_node { - state.write().on_recv_client_control(now_ms, SocketId(from, data.header.stream_id), control); - } - } - } - CONTROL_SERVER_META => { - if let Ok(control) = data.get_payload_bincode::() { - //is control - if let Some(from) = data.header.from_node { - state.write().on_recv_server_control(now_ms, SocketId(from, data.header.stream_id), control); - } - } - } - _ => {} - } - } - } -} - -#[cfg(test)] -mod tests {} diff --git a/packages/services/virtual_socket/src/state/connector.rs b/packages/services/virtual_socket/src/state/connector.rs deleted file mode 100644 index 8bbb772b..00000000 --- a/packages/services/virtual_socket/src/state/connector.rs +++ /dev/null @@ -1,28 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use atm0s_sdn_identity::NodeId; -use parking_lot::RwLock; - -use super::{socket::VirtualSocket, State, VirtualSocketConnectResult}; - -#[derive(Debug)] -pub enum VirtualSocketConnectorError { - Timeout, - Unreachable, -} - -pub struct VirtualSocketConnector { - pub(crate) state: Arc>, -} - -impl VirtualSocketConnector { - pub async fn connect_to(&self, secure: bool, dest: NodeId, listener: &str, meta: HashMap) -> Result { - let rx = self.state.write().new_outgoing(secure, dest, listener, meta).ok_or(VirtualSocketConnectorError::Unreachable)?; - match rx.recv().await { - Ok(VirtualSocketConnectResult::Success(builder)) => Ok(builder.build(self.state.clone())), - Ok(VirtualSocketConnectResult::Timeout) => Err(VirtualSocketConnectorError::Timeout), - Ok(VirtualSocketConnectResult::Unreachable) => Err(VirtualSocketConnectorError::Unreachable), - Err(_) => Err(VirtualSocketConnectorError::Unreachable), - } - } -} diff --git a/packages/services/virtual_socket/src/state/listener.rs b/packages/services/virtual_socket/src/state/listener.rs deleted file mode 100644 index c4379786..00000000 --- a/packages/services/virtual_socket/src/state/listener.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::sync::Arc; - -use parking_lot::RwLock; - -use super::{ - socket::{VirtualSocket, VirtualSocketBuilder}, - State, -}; - -pub struct VirtualSocketListener { - pub(crate) rx: async_std::channel::Receiver, - pub(crate) state: Arc>, -} - -impl VirtualSocketListener { - pub async fn recv(&mut self) -> Option { - let builder = self.rx.recv().await.ok()?; - Some(builder.build(self.state.clone())) - } -} diff --git a/packages/services/virtual_socket/src/state/socket.rs b/packages/services/virtual_socket/src/state/socket.rs deleted file mode 100644 index 21933ab5..00000000 --- a/packages/services/virtual_socket/src/state/socket.rs +++ /dev/null @@ -1,180 +0,0 @@ -use std::io::Write; -use std::{collections::HashMap, sync::Arc}; - -use async_std::channel::Receiver; -use atm0s_sdn_identity::NodeId; -use atm0s_sdn_network::msg::{MsgHeader, TransportMsg}; -use atm0s_sdn_router::RouteRule; -use parking_lot::RwLock; - -use crate::{ - msg::{SocketId, VirtualSocketControlMsg}, - VIRTUAL_SOCKET_SERVICE_ID, -}; - -use super::State; - -pub const CONTROL_CLIENT_META: u8 = 0; -pub const CONTROL_SERVER_META: u8 = 1; -pub const DATA_CLIENT_META: u8 = 2; -pub const DATA_SERVER_META: u8 = 3; - -pub struct VirtualSocketBuilder { - is_client: bool, - secure: bool, - remote: SocketId, - rx: Receiver>, - meta: HashMap, -} - -impl VirtualSocketBuilder { - pub(crate) fn new(is_client: bool, secure: bool, remote: SocketId, meta: HashMap, rx: Receiver>) -> Self { - Self { is_client, secure, remote, meta, rx } - } - - pub fn build(self, state: Arc>) -> VirtualSocket { - VirtualSocket::new(self.is_client, self.secure, self.remote, self.meta, self.rx, state) - } -} - -#[derive(Debug)] -pub enum VirtualSocketEvent { - ServerControl(VirtualSocketControlMsg), - ClientControl(VirtualSocketControlMsg), - ServerData(Vec, bool), - ClientData(Vec, bool), -} - -impl VirtualSocketEvent { - pub fn into_transport_msg(self, local_node: NodeId, remote_node: NodeId, client_id: u32) -> TransportMsg { - match self { - VirtualSocketEvent::ServerData(data, secure) => { - let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) - .set_from_node(Some(local_node)) - .set_stream_id(client_id) - .set_secure(secure) - .set_meta(DATA_SERVER_META); - TransportMsg::build_raw(header, &data) - } - VirtualSocketEvent::ClientData(data, secure) => { - let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) - .set_from_node(Some(local_node)) - .set_stream_id(client_id) - .set_secure(secure) - .set_meta(DATA_CLIENT_META); - TransportMsg::build_raw(header, &data) - } - VirtualSocketEvent::ServerControl(control) => { - let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) - .set_from_node(Some(local_node)) - .set_stream_id(client_id) - .set_secure(true) - .set_meta(CONTROL_SERVER_META); - TransportMsg::from_payload_bincode(header, &control) - } - VirtualSocketEvent::ClientControl(control) => { - let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, RouteRule::ToNode(remote_node)) - .set_from_node(Some(local_node)) - .set_stream_id(client_id) - .set_secure(true) - .set_meta(CONTROL_CLIENT_META); - TransportMsg::from_payload_bincode(header, &control) - } - } - } -} - -pub struct VirtualSocket { - writer: VirtualSocketWriter, - reader: VirtualSocketReader, -} - -impl VirtualSocket { - pub(crate) fn new(is_client: bool, secure: bool, remote: SocketId, meta: HashMap, rx: Receiver>, state: Arc>) -> Self { - Self { - writer: VirtualSocketWriter { - is_client, - secure, - remote: remote.clone(), - state: state.clone(), - }, - reader: VirtualSocketReader { is_client, rx, remote, meta, state }, - } - } - - pub fn remote(&self) -> &SocketId { - self.writer.remote() - } - - pub fn meta(&self) -> &HashMap { - self.reader.meta() - } - - pub async fn read(&mut self) -> Option> { - self.reader.read().await - } - - pub fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.writer.write(buf) - } - - pub fn split(self) -> (VirtualSocketReader, VirtualSocketWriter) { - (self.reader, self.writer) - } -} - -pub struct VirtualSocketReader { - is_client: bool, - rx: Receiver>, - remote: SocketId, - meta: HashMap, - state: Arc>, -} - -impl VirtualSocketReader { - pub fn remote(&self) -> &SocketId { - &self.remote - } - - pub fn meta(&self) -> &HashMap { - &self.meta - } - - pub async fn read(&mut self) -> Option> { - self.rx.recv().await.ok() - } -} - -impl Drop for VirtualSocketReader { - fn drop(&mut self) { - self.state.write().close_socket(self.is_client, &self.remote); - } -} - -pub struct VirtualSocketWriter { - is_client: bool, - secure: bool, - remote: SocketId, - state: Arc>, -} - -impl VirtualSocketWriter { - pub fn remote(&self) -> &SocketId { - &self.remote - } -} - -impl Write for VirtualSocketWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - if self.is_client { - self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ClientData(buf.to_vec(), self.secure)); - } else { - self.state.write().send_out(self.remote.clone(), VirtualSocketEvent::ServerData(buf.to_vec(), self.secure)); - } - Ok(buf.len()) - } - - fn flush(&mut self) -> std::io::Result<()> { - Ok(()) - } -} diff --git a/packages/services/virtual_socket/src/state/stream.rs b/packages/services/virtual_socket/src/state/stream.rs deleted file mode 100644 index 92fd73b1..00000000 --- a/packages/services/virtual_socket/src/state/stream.rs +++ /dev/null @@ -1,136 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use async_std::{channel::Receiver, stream::StreamExt, task::JoinHandle}; -use futures::{select, FutureExt as _}; -use kcp::{Error, Kcp}; -use parking_lot::RwLock; - -use super::socket::{VirtualSocket, VirtualSocketWriter}; - -const MAX_KCP_SEND_QUEUE: usize = 10; - -enum ReadEvent { - Continue, - Close, -} - -pub struct VirtualStream { - meta: HashMap, - kcp: Arc>>, - task: Option>, - write_awake_rx: Receiver<()>, - read_awake_rx: Receiver, -} - -impl VirtualStream { - pub fn new(socket: VirtualSocket) -> Self { - let (mut reader, writer) = socket.split(); - let meta = reader.meta().clone(); - let mut kcp = Kcp::new_stream(writer.remote().client_id(), writer); - kcp.set_nodelay(true, 20, 2, true); - - let kcp = Arc::new(RwLock::new(kcp)); - let (write_awake_tx, write_awake_rx) = async_std::channel::bounded(1); - let (read_awake_tx, read_awake_rx) = async_std::channel::bounded(1); - let kcp_c = kcp.clone(); - let task = async_std::task::spawn(async move { - let mut timer = async_std::stream::interval(std::time::Duration::from_millis(10)); - let started_at = std::time::Instant::now(); - loop { - select! { - _ = timer.next().fuse() => { - if let Err(e) = kcp_c.write().update(started_at.elapsed().as_millis() as u32) { - log::error!("[VirtualStream] kcp update error: {:?}", e); - break; - } - } - e = reader.read().fuse() => { - if let Some(buf) = e { - if buf.len() == 0 { - log::info!("[VirtualStream] reader closed"); - read_awake_tx.try_send(ReadEvent::Close).ok(); - break; - } - - let mut kcp = kcp_c.write(); - if let Err(e) = kcp.input(&buf) { - log::error!("[VirtualStream] kcp input error: {:?}", e); - break; - } else { - if let Ok(len) = kcp.peeksize() { - if len > 0 { - read_awake_tx.try_send(ReadEvent::Continue).ok(); - } - } - if kcp.wait_snd() < MAX_KCP_SEND_QUEUE { - write_awake_tx.try_send(()).ok(); - } - } - } else { - log::info!("[VirtualStream] reader closed"); - read_awake_tx.try_send(ReadEvent::Close).ok(); - break; - } - } - } - } - }); - - Self { - meta, - kcp, - task: Some(task), - write_awake_rx, - read_awake_rx, - } - } - - pub fn meta(&self) -> &HashMap { - &self.meta - } - - pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { - loop { - let kcp_wait_snd = self.kcp.read().wait_snd(); - if kcp_wait_snd < MAX_KCP_SEND_QUEUE { - self.kcp.write().send(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - return Ok(()); - } else { - self.write_awake_rx.recv().await.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "ConnectionInterrupted"))?; - } - } - } - - pub async fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - loop { - let size = match self.kcp.write().recv(buf) { - Ok(size) => size, - Err(e) => match e { - Error::RecvQueueEmpty => 0, - _ => { - return Err(std::io::Error::new(std::io::ErrorKind::Other, e)); - } - }, - }; - if size > 0 { - return Ok(size); - } else { - match self.read_awake_rx.recv().await { - Ok(ReadEvent::Continue) => {} - Ok(ReadEvent::Close) => return Ok(0), - Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::Other, "ConnectionInterrupted")), - } - } - } - } -} - -impl Drop for VirtualStream { - fn drop(&mut self) { - if let Some(task) = self.task.take() { - async_std::task::spawn(async { - task.cancel().await; - }); - } - } -} diff --git a/packages/services/virtual_socket/src/vnet.rs b/packages/services/virtual_socket/src/vnet.rs new file mode 100644 index 00000000..19c39431 --- /dev/null +++ b/packages/services/virtual_socket/src/vnet.rs @@ -0,0 +1,44 @@ +use std::{net::SocketAddrV4, sync::Arc}; + +use atm0s_sdn_identity::NodeId; +use atm0s_sdn_router::RouterTable; + +use self::{internal::VirtualNetInternal, udp_socket::VirtualUdpSocket}; + +mod async_queue; +pub(crate) mod internal; +pub(crate) mod udp_socket; + +#[derive(Debug, PartialEq, Clone)] +pub struct VirtualSocketPkt { + pub src: SocketAddrV4, + pub payload: Vec, + /// ecn, only 2 bits + pub ecn: Option, +} + +#[derive(Debug, PartialEq)] +pub enum VirtualNetError { + QueueFull, + Unreachable, + AllreadyExists, + NoAvailablePort, +} + +#[derive(Clone)] +pub struct VirtualNet { + pub(crate) internal: VirtualNetInternal, +} + +impl VirtualNet { + pub(crate) fn new(node_id: NodeId, router: Arc) -> (Self, VirtualNetInternal) { + log::info!("[VirtualNet] Create new virtual socket service"); + let internal = VirtualNetInternal::new(node_id, router); + let net = Self { internal: internal.clone() }; + (net, internal) + } + + pub fn create_udp_socket(&self, port: u16, buffer_size: usize) -> Result { + Ok(VirtualUdpSocket::new(self.internal.clone(), port, buffer_size)?) + } +} diff --git a/packages/services/virtual_socket/src/vnet/async_queue.rs b/packages/services/virtual_socket/src/vnet/async_queue.rs new file mode 100644 index 00000000..cdab8db1 --- /dev/null +++ b/packages/services/virtual_socket/src/vnet/async_queue.rs @@ -0,0 +1,70 @@ +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, +}; + +use atm0s_sdn_utils::vec_dequeue::VecDeque; +use futures::Future; +use parking_lot::{Mutex, RwLock}; + +#[derive(Clone)] +pub struct AsyncQueue { + data: Arc>>, + awake: Arc>>, + max_size: usize, +} + +impl AsyncQueue { + pub fn new(max_size: usize) -> Self { + Self { + data: Default::default(), + awake: Default::default(), + max_size, + } + } + + pub fn try_push(&self, item: T) -> Result<(), T> { + let mut data = self.data.write(); + if data.len() >= self.max_size { + return Err(item); + } + data.push_back(item); + if data.len() == 1 { + if let Some(waker) = self.awake.lock().take() { + waker.wake(); + } + } + Ok(()) + } + + pub fn try_pop(&self) -> Option { + let mut data = self.data.write(); + data.pop_front() + } + + pub fn poll_pop(&self, cx: &mut std::task::Context) -> std::task::Poll> { + let mut data = self.data.write(); + if let Some(item) = data.pop_front() { + return std::task::Poll::Ready(Some(item)); + } + *self.awake.lock() = Some(cx.waker().clone()); + std::task::Poll::Pending + } + + pub fn recv(&self) -> Recv<'_, T> { + Recv { queue: self } + } +} + +pub struct Recv<'a, T> { + queue: &'a AsyncQueue, +} + +impl<'a, T> Future for Recv<'a, T> { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.queue.poll_pop(cx) + } +} diff --git a/packages/services/virtual_socket/src/vnet/internal.rs b/packages/services/virtual_socket/src/vnet/internal.rs new file mode 100644 index 00000000..5a6994ca --- /dev/null +++ b/packages/services/virtual_socket/src/vnet/internal.rs @@ -0,0 +1,137 @@ +use std::{collections::HashMap, net::SocketAddrV4, sync::Arc}; + +use atm0s_sdn_identity::{ConnId, NodeId}; +use atm0s_sdn_network::{ + msg::{MsgHeader, TransportMsg}, + transport::ConnectionSender, +}; +use atm0s_sdn_router::{RouteAction, RouteRule, RouterTable}; +use parking_lot::RwLock; + +use crate::{VirtualSocketPkt, VIRTUAL_SOCKET_SERVICE_ID}; + +use super::{async_queue::AsyncQueue, VirtualNetError}; + +#[derive(Clone)] +pub struct VirtualNetInternal { + node_id: NodeId, + router: Arc, + conns: Arc>>>, + sockets: Arc>>>, + ports: Arc>>, +} + +impl VirtualNetInternal { + pub fn new(node_id: NodeId, router: Arc) -> Self { + Self { + node_id, + router, + conns: Default::default(), + sockets: Default::default(), + ports: Arc::new(RwLock::new((1..=65535).collect())), + } + } + + pub fn local_node(&self) -> NodeId { + self.node_id + } + + pub fn register_socket(&self, mut port: u16, buffer_size: usize) -> Result<(AsyncQueue, u16), VirtualNetError> { + let queue = AsyncQueue::new(buffer_size); + let mut sockets = self.sockets.write(); + let mut ports = self.ports.write(); + if port == 0 { + port = *ports.last().ok_or(VirtualNetError::NoAvailablePort)?; + log::info!("[VirtualNetInternal] No port specified, using {}", port) + } + if sockets.contains_key(&port) { + return Err(VirtualNetError::AllreadyExists); + } + log::info!("[VirtualNetInternal] Register socket on port {}", port); + sockets.insert(port, queue.clone()); + ports.pop(); + Ok((queue, port)) + } + + pub fn unregister_socket(&self, port: u16) { + let mut sockets = self.sockets.write(); + let mut ports = self.ports.write(); + if let Some(_) = sockets.remove(&port) { + log::info!("[VirtualNetInternal] Unregister socket on port {}", port); + ports.push(port); + } + } + + pub fn add_conn(&self, conn: Arc) { + self.conns.write().insert(conn.conn_id(), conn); + } + + pub fn remove_conn(&self, conn_id: ConnId) { + self.conns.write().remove(&conn_id); + } + + pub fn on_incomming(&self, msg: TransportMsg) { + let from_port = (msg.header.stream_id >> 16) as u16; + let dest_port = (msg.header.stream_id & 0xFFFF) as u16; + if let Some(from_node) = msg.header.from_node { + if let Some(sender) = self.sockets.read().get(&dest_port) { + if let Err(_e) = sender.try_push(VirtualSocketPkt { + src: SocketAddrV4::new(from_node.into(), from_port), + payload: msg.payload().to_vec(), + ecn: if msg.header.meta == 0b11 { + None + } else { + Some(msg.header.meta) + }, + }) { + log::warn!("Error sending to queue socket {} full", dest_port); + } + } else { + log::trace!("No socket for port {}", dest_port); + } + } + } + + pub fn send_to(&self, from: u16, dest: SocketAddrV4, payload: &[u8], ecn: Option) -> Result<(), VirtualNetError> { + let dest_node: NodeId = (*dest.ip()).into(); + let dest_port = dest.port(); + self.send_to_node(from, dest_node, dest_port, payload, ecn) + } + + pub fn send_to_node(&self, from: u16, dest_node: NodeId, dest_port: u16, payload: &[u8], ecn: Option) -> Result<(), VirtualNetError> { + let rule = RouteRule::ToNode(dest_node); + match self.router.derive_action(&rule, VIRTUAL_SOCKET_SERVICE_ID) { + RouteAction::Local => { + if let Some(sender) = self.sockets.read().get(&dest_port) { + sender + .try_push(VirtualSocketPkt { + src: SocketAddrV4::new(self.node_id.into(), from), + payload: payload.to_vec(), + ecn, + }) + .map_err(|_| VirtualNetError::QueueFull)?; + log::trace!("[VirtualNetInternal] Send {} bytes from {} to {}:{} via local socket", payload.len(), from, dest_node, dest_port); + Ok(()) + } else { + Err(VirtualNetError::Unreachable) + } + } + RouteAction::Next(conn_id, _) => { + if let Some(sender) = self.conns.read().get(&conn_id) { + let stream_id = (from as u32) << 16 | (dest_port as u32); + let header = MsgHeader::build(VIRTUAL_SOCKET_SERVICE_ID, VIRTUAL_SOCKET_SERVICE_ID, rule) + .set_from_node(Some(self.node_id)) + .set_secure(false) + .set_meta(ecn.unwrap_or(0b11)) + .set_stream_id(stream_id); + let msg = TransportMsg::build_raw(header, payload); + sender.send(msg); + Ok(()) + } else { + Err(VirtualNetError::Unreachable) + } + } + RouteAction::Reject => Err(VirtualNetError::Unreachable), + } + } +} diff --git a/packages/services/virtual_socket/src/vnet/udp_socket.rs b/packages/services/virtual_socket/src/vnet/udp_socket.rs new file mode 100644 index 00000000..60c03ad8 --- /dev/null +++ b/packages/services/virtual_socket/src/vnet/udp_socket.rs @@ -0,0 +1,95 @@ +use std::{ + fmt::Debug, + net::{SocketAddr, SocketAddrV4}, + ops::DerefMut, +}; + +use atm0s_sdn_identity::NodeId; +use quinn::{udp::EcnCodepoint, AsyncUdpSocket}; + +use crate::VirtualSocketPkt; + +use super::{async_queue::AsyncQueue, internal::VirtualNetInternal, VirtualNetError}; + +pub struct VirtualUdpSocket { + local_port: u16, + internal: VirtualNetInternal, + queue: AsyncQueue, +} + +impl VirtualUdpSocket { + pub(crate) fn new(internal: VirtualNetInternal, port: u16, buffer_size: usize) -> Result { + let (queue, local_port) = internal.register_socket(port, buffer_size)?; + Ok(Self { internal, queue, local_port }) + } + + pub fn local_port(&self) -> u16 { + self.local_port + } + + pub fn send_to_node(&self, node: NodeId, port: u16, payload: &[u8], ecn: Option) -> Result<(), VirtualNetError> { + self.internal.send_to_node(self.local_port, node, port, payload, ecn) + } + + pub fn send_to(&self, dest: SocketAddrV4, payload: &[u8], ecn: Option) -> Result<(), VirtualNetError> { + self.internal.send_to(self.local_port, dest, payload, ecn) + } + + pub fn try_recv_from(&self) -> Option { + self.queue.try_pop() + } + + pub async fn recv_from(&self) -> Option { + self.queue.recv().await + } +} + +impl Debug for VirtualUdpSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VirtualUdpSocket").field("local_port", &self.local_port).finish() + } +} + +impl AsyncUdpSocket for VirtualUdpSocket { + fn poll_send(&self, _state: &quinn::udp::UdpState, _cx: &mut std::task::Context, transmits: &[quinn::udp::Transmit]) -> std::task::Poll> { + for transmit in transmits { + let res = match transmit.destination { + SocketAddr::V4(addr) => self.internal.send_to(self.local_port, addr, &transmit.contents, transmit.ecn.map(|x| x as u8)), + _ => return std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Only IPv4 supported"))), + }; + if res.is_err() { + break; + } + } + std::task::Poll::Ready(Ok(transmits.len())) + } + + fn poll_recv(&self, cx: &mut std::task::Context, bufs: &mut [std::io::IoSliceMut<'_>], meta: &mut [quinn::udp::RecvMeta]) -> std::task::Poll> { + match self.queue.poll_pop(cx) { + std::task::Poll::Pending => std::task::Poll::Pending, + std::task::Poll::Ready(Some(pkt)) => { + let len = pkt.payload.len(); + bufs[0].deref_mut()[0..len].copy_from_slice(&pkt.payload); + meta[0] = quinn::udp::RecvMeta { + addr: SocketAddr::V4(pkt.src), + len, + stride: len, + ecn: pkt.ecn.map(|x| EcnCodepoint::from_bits(x).expect("Invalid ECN codepoint")), + dst_ip: None, + }; + std::task::Poll::Ready(Ok(1)) + } + std::task::Poll::Ready(None) => std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "Socket closed"))), + } + } + + fn local_addr(&self) -> std::io::Result { + Ok(SocketAddr::V4(SocketAddrV4::new(self.internal.local_node().into(), self.local_port))) + } +} + +impl Drop for VirtualUdpSocket { + fn drop(&mut self) { + self.internal.unregister_socket(self.local_port); + } +} From c5a0e55fd7ec5f62e2325edfb2b3e3c321cbc336 Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Mon, 8 Jan 2024 12:09:39 +0700 Subject: [PATCH 6/8] added async-queue test --- .../virtual_socket/src/vnet/async_queue.rs | 51 +++++++++++++++++-- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/packages/services/virtual_socket/src/vnet/async_queue.rs b/packages/services/virtual_socket/src/vnet/async_queue.rs index cdab8db1..e9bc19f0 100644 --- a/packages/services/virtual_socket/src/vnet/async_queue.rs +++ b/packages/services/virtual_socket/src/vnet/async_queue.rs @@ -6,11 +6,11 @@ use std::{ use atm0s_sdn_utils::vec_dequeue::VecDeque; use futures::Future; -use parking_lot::{Mutex, RwLock}; +use parking_lot::Mutex; #[derive(Clone)] pub struct AsyncQueue { - data: Arc>>, + data: Arc>>, awake: Arc>>, max_size: usize, } @@ -25,7 +25,7 @@ impl AsyncQueue { } pub fn try_push(&self, item: T) -> Result<(), T> { - let mut data = self.data.write(); + let mut data = self.data.lock(); if data.len() >= self.max_size { return Err(item); } @@ -39,12 +39,12 @@ impl AsyncQueue { } pub fn try_pop(&self) -> Option { - let mut data = self.data.write(); + let mut data = self.data.lock(); data.pop_front() } pub fn poll_pop(&self, cx: &mut std::task::Context) -> std::task::Poll> { - let mut data = self.data.write(); + let mut data = self.data.lock(); if let Some(item) = data.pop_front() { return std::task::Poll::Ready(Some(item)); } @@ -68,3 +68,44 @@ impl<'a, T> Future for Recv<'a, T> { self.queue.poll_pop(cx) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_try_push_success() { + let queue = AsyncQueue::new(5); + assert_eq!(queue.try_push(1), Ok(())); + assert_eq!(queue.try_push(2), Ok(())); + assert_eq!(queue.try_push(3), Ok(())); + } + + #[test] + fn test_try_push_failure() { + let queue = AsyncQueue::new(2); + assert_eq!(queue.try_push(1), Ok(())); + assert_eq!(queue.try_push(2), Ok(())); + assert_eq!(queue.try_push(3), Err(3)); + } + + #[test] + fn test_try_pop() { + let queue = AsyncQueue::new(5); + queue.try_push(1).unwrap(); + queue.try_push(2).unwrap(); + assert_eq!(queue.try_pop(), Some(1)); + assert_eq!(queue.try_pop(), Some(2)); + assert_eq!(queue.try_pop(), None); + } + + #[test] + fn test_recv() { + let queue = AsyncQueue::new(5); + queue.try_push(1).unwrap(); + queue.try_push(2).unwrap(); + assert_eq!(futures::executor::block_on(queue.recv()), Some(1)); + assert_eq!(futures::executor::block_on(queue.recv()), Some(2)); + } +} + From 95ab880ced16f8361c434686abe9a07efe8c3281 Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Mon, 8 Jan 2024 12:14:27 +0700 Subject: [PATCH 7/8] added tests --- packages/integration_tests/src/virtual_socket.rs | 11 +++++++++++ .../services/virtual_socket/src/vnet/async_queue.rs | 1 - 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/packages/integration_tests/src/virtual_socket.rs b/packages/integration_tests/src/virtual_socket.rs index d19479db..ae0021f1 100644 --- a/packages/integration_tests/src/virtual_socket.rs +++ b/packages/integration_tests/src/virtual_socket.rs @@ -75,8 +75,11 @@ mod test { async_std::task::sleep(Duration::from_millis(300)).await; let server = sdk.create_udp_socket(1000, 10).expect(""); + assert_eq!(format!("{:?}", server), format!("VirtualUdpSocket {{ local_port: {} }}", server.local_port())); + let client = sdk.create_udp_socket(0, 10).expect(""); client.send_to(vnet_addr_v4(node_id, 1000), &vec![1, 2, 3], None).expect("Should write"); + client.send_to_node(node_id, 1000, &vec![4, 5, 6], None).expect("Should write"); assert_eq!( server.try_recv_from(), Some(VirtualSocketPkt { @@ -85,6 +88,14 @@ mod test { ecn: None, }) ); + assert_eq!( + server.try_recv_from(), + Some(VirtualSocketPkt { + src: SocketAddrV4::new(node_id.into(), client.local_port()), + payload: vec![4, 5, 6], + ecn: None, + }) + ); assert_eq!(server.try_recv_from(), None); join.cancel().await; diff --git a/packages/services/virtual_socket/src/vnet/async_queue.rs b/packages/services/virtual_socket/src/vnet/async_queue.rs index e9bc19f0..cd18cd23 100644 --- a/packages/services/virtual_socket/src/vnet/async_queue.rs +++ b/packages/services/virtual_socket/src/vnet/async_queue.rs @@ -108,4 +108,3 @@ mod tests { assert_eq!(futures::executor::block_on(queue.recv()), Some(2)); } } - From 2a7a3245dfc37f7a878e48ea8c90cd8c739301ed Mon Sep 17 00:00:00 2001 From: Giang Minh Date: Mon, 8 Jan 2024 12:18:28 +0700 Subject: [PATCH 8/8] fix warn --- packages/services/virtual_socket/src/handler.rs | 7 ++----- packages/services/virtual_socket/src/vnet.rs | 2 +- packages/services/virtual_socket/src/vnet/internal.rs | 2 +- packages/transports/udp/src/transport.rs | 4 ++-- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/packages/services/virtual_socket/src/handler.rs b/packages/services/virtual_socket/src/handler.rs index 4cd59fe3..e7ad37ef 100644 --- a/packages/services/virtual_socket/src/handler.rs +++ b/packages/services/virtual_socket/src/handler.rs @@ -22,11 +22,8 @@ impl ConnectionHandler for VirtualSocketHandler { /// Called when an event occurs on the connection. fn on_event(&mut self, _ctx: &ConnectionContext, _now_ms: u64, event: ConnectionEvent) { - match event { - ConnectionEvent::Msg(msg) => { - self.internal.on_incomming(msg); - } - _ => {} + if let ConnectionEvent::Msg(msg) = event { + self.internal.on_incomming(msg); } } diff --git a/packages/services/virtual_socket/src/vnet.rs b/packages/services/virtual_socket/src/vnet.rs index 19c39431..5a9150e5 100644 --- a/packages/services/virtual_socket/src/vnet.rs +++ b/packages/services/virtual_socket/src/vnet.rs @@ -39,6 +39,6 @@ impl VirtualNet { } pub fn create_udp_socket(&self, port: u16, buffer_size: usize) -> Result { - Ok(VirtualUdpSocket::new(self.internal.clone(), port, buffer_size)?) + VirtualUdpSocket::new(self.internal.clone(), port, buffer_size) } } diff --git a/packages/services/virtual_socket/src/vnet/internal.rs b/packages/services/virtual_socket/src/vnet/internal.rs index 5a6994ca..b80b083e 100644 --- a/packages/services/virtual_socket/src/vnet/internal.rs +++ b/packages/services/virtual_socket/src/vnet/internal.rs @@ -56,7 +56,7 @@ impl VirtualNetInternal { pub fn unregister_socket(&self, port: u16) { let mut sockets = self.sockets.write(); let mut ports = self.ports.write(); - if let Some(_) = sockets.remove(&port) { + if sockets.remove(&port).is_some() { log::info!("[VirtualNetInternal] Unregister socket on port {}", port); ports.push(port); } diff --git a/packages/transports/udp/src/transport.rs b/packages/transports/udp/src/transport.rs index 0dc565c0..303521bc 100644 --- a/packages/transports/udp/src/transport.rs +++ b/packages/transports/udp/src/transport.rs @@ -71,14 +71,14 @@ impl UdpTransport { if let Ok((size, addr)) = async_socket.recv_from(&mut buf).await { let current_ms = timer.now_ms(); if let Some(msg_tx) = connection.get_mut(&addr) { - msg_tx.0.try_send((buf, size)); + let _ = msg_tx.0.try_send((buf, size)); msg_tx.1 = current_ms; } else { log::info!("[UdpTransport] on new connection from {}", addr); conn_id_seed += 1; let conn_id = ConnId::from_in(UDP_PROTOCOL_ID, conn_id_seed); let (msg_tx, msg_rx) = async_std::channel::bounded(1024); - msg_tx.try_send((buf, size)); + let _ = msg_tx.try_send((buf, size)); connection.insert(addr, (msg_tx, current_ms)); let socket = socket.clone(); let async_socket = async_socket.clone();