diff --git a/src/server/server.rs b/src/server/server.rs index cde932959c..dd7d2e6ff4 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -28,7 +28,7 @@ use tokio::sync::{mpsc, oneshot}; use tokio::sync::{Mutex, RwLock}; use tokio::time::{delay_for, Duration, Instant}; -use crate::config::{Config, ConnectionConfig}; +use crate::config::{Config, ConnectionConfig, EndPoint}; use crate::extensions::{Filter, FilterChain, FilterRegistry}; use crate::load_balancer_policy::LoadBalancerPolicy; use crate::server::sessions::{Packet, Session, SESSION_TIMEOUT_SECONDS}; @@ -161,7 +161,7 @@ impl Server { &log, sessions.clone(), recv_addr, - endpoint.address, + &endpoint, send_packets.clone(), ) .await @@ -251,16 +251,16 @@ impl Server { log: &Logger, sessions: SessionMap, from: SocketAddr, - dest: SocketAddr, + dest: &EndPoint, sender: mpsc::Sender, ) -> Result<()> { { let map = sessions.read().await; - if map.contains_key(&(from, dest)) { + if map.contains_key(&(from, dest.address)) { return Ok(()); } } - let s = Session::new(log, from, dest, sender).await?; + let s = Session::new(log, from, dest.clone(), sender).await?; { let mut map = sessions.write().await; map.insert(s.key(), Mutex::new(s)); @@ -565,12 +565,17 @@ mod tests { let from: SocketAddr = "127.0.0.1:27890".parse().unwrap(); let dest: SocketAddr = "127.0.0.1:27891".parse().unwrap(); let (sender, mut recv) = mpsc::channel::(1); + let endpoint = EndPoint { + name: "endpoint".to_string(), + address: dest, + connection_ids: vec![], + }; // gate { assert!(map.read().await.is_empty()); } - Server::ensure_session(&log, map.clone(), from, dest, sender) + Server::ensure_session(&log, map.clone(), from, &endpoint, sender) .await .unwrap(); @@ -642,8 +647,13 @@ mod tests { let from: SocketAddr = "127.0.0.1:7000".parse().unwrap(); let to: SocketAddr = "127.0.0.1:7001".parse().unwrap(); let (send, _recv) = mpsc::channel::(1); + let endpoint = EndPoint { + name: "endpoint".to_string(), + address: to, + connection_ids: vec![], + }; - Server::ensure_session(&log, sessions.clone(), from, to, send) + Server::ensure_session(&log, sessions.clone(), from, &endpoint, send) .await .unwrap(); @@ -689,9 +699,14 @@ mod tests { let to: SocketAddr = "127.0.0.1:7001".parse().unwrap(); let (send, _recv) = mpsc::channel::(1); let key = (from, to); + let endpoint = EndPoint { + name: "endpoint".to_string(), + address: to, + connection_ids: vec![], + }; server.run_prune_sessions(&sessions); - Server::ensure_session(&log, sessions.clone(), from, to, send) + Server::ensure_session(&log, sessions.clone(), from, &endpoint, send) .await .unwrap(); diff --git a/src/server/sessions.rs b/src/server/sessions.rs index 06a77cddce..47f446dc45 100644 --- a/src/server/sessions.rs +++ b/src/server/sessions.rs @@ -29,21 +29,17 @@ use tokio::select; use tokio::sync::{mpsc, watch, RwLock}; use tokio::time::{Duration, Instant}; +use crate::config::EndPoint; + /// SESSION_TIMEOUT_SECONDS is the default session timeout - which is one minute. pub const SESSION_TIMEOUT_SECONDS: u64 = 60; -/// Packet represents a packet that needs to go somewhere -pub struct Packet { - dest: SocketAddr, - contents: Vec, -} - /// Session encapsulates a UDP stream session pub struct Session { log: Logger, send: SendHalf, /// dest is where to send data to - dest: SocketAddr, + dest: EndPoint, /// from is the original sender from: SocketAddr, /// session expiration timestamp @@ -54,6 +50,12 @@ pub struct Session { is_closed: Arc, } +/// Packet represents a packet that needs to go somewhere +pub struct Packet { + dest: SocketAddr, + contents: Vec, +} + impl Packet { pub fn new(dest: SocketAddr, contents: Vec) -> Packet { Packet { dest, contents } @@ -74,14 +76,14 @@ impl Session { pub async fn new( base: &Logger, from: SocketAddr, - dest: SocketAddr, + dest: EndPoint, sender: mpsc::Sender, ) -> Result { let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0); let (recv, send) = UdpSocket::bind(addr).await?.split(); let (closer, closed) = watch::channel::(false); let mut s = Session { - log: base.new(o!("source" => "server::Session", "from" => from, "dest" => dest)), + log: base.new(o!("source" => "server::Session", "from" => from, "dest_name" => dest.name.clone(), "dest_address" => dest.address.clone())), send, from, dest, @@ -150,7 +152,7 @@ impl Session { /// key returns the key to be used for this session in a SessionMap pub fn key(&self) -> (SocketAddr, SocketAddr) { - (self.from, self.dest) + (self.from, self.dest.address) } /// process_recv_packet processes a packet that is received by this session. @@ -183,8 +185,8 @@ impl Session { /// Sends a packet to the Session's dest. pub async fn send_to(&mut self, buf: &[u8]) -> Result { - debug!(self.log, "Sending packet"; "dest" => self.dest, "contents" => from_utf8(buf).unwrap()); - return self.send.send_to(buf, &self.dest).await; + debug!(self.log, "Sending packet"; "dest_name" => &self.dest.name, "dest_address" => &self.dest.address, "contents" => from_utf8(buf).unwrap()); + return self.send.send_to(buf, &self.dest.address).await; } /// is_closed returns if the Session is closed or not. @@ -195,7 +197,7 @@ impl Session { /// close closes this Session. pub fn close(&self) -> result::Result<(), watch::error::SendError> { - debug!(self.log, "Session closed"; "from" => %self.from, "dest" => %self.dest); + debug!(self.log, "Session closed"; "from" => self.from, "dest_name" => &self.dest.name, "dest_address" => &self.dest.address); self.closer.broadcast(true) } } @@ -217,9 +219,14 @@ mod tests { let log = logger(); let mut socket = ephemeral_socket().await; let local_addr = socket.local_addr().unwrap(); + let endpoint = EndPoint { + name: "endpoint".to_string(), + address: local_addr, + connection_ids: vec![], + }; let (send_packet, mut recv_packet) = mpsc::channel::(5); - let mut sess = Session::new(&log, local_addr, local_addr, send_packet) + let mut sess = Session::new(&log, local_addr, endpoint, send_packet) .await .unwrap(); @@ -267,8 +274,13 @@ mod tests { let msg = "hello"; let (sender, _) = mpsc::channel::(1); let (local_addr, wait) = recv_udp().await; + let endpoint = EndPoint { + name: "endpoint".to_string(), + address: local_addr, + connection_ids: vec![], + }; - let mut session = Session::new(&log, local_addr, local_addr, sender) + let mut session = Session::new(&log, local_addr, endpoint.clone(), sender) .await .unwrap(); session.send_to(msg.as_bytes()).await.unwrap(); @@ -281,9 +293,14 @@ mod tests { let socket = ephemeral_socket().await; let local_addr = socket.local_addr().unwrap(); let (send_packet, _) = mpsc::channel::(5); + let endpoint = EndPoint { + name: "endpoint".to_string(), + address: local_addr, + connection_ids: vec![], + }; info!(log, ">> creating sessions"); - let sess = Session::new(&log, local_addr, local_addr, send_packet) + let sess = Session::new(&log, local_addr, endpoint, send_packet) .await .unwrap(); info!(log, ">> session created and running");