diff --git a/src/server/server.rs b/src/server/server.rs index 1d38537e2d..1756e3a41c 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -62,18 +62,7 @@ impl Server { let (send_packets, receive_packets) = mpsc::channel::(1024); self.run_receive_packet(send_socket, receive_packets).await; - - let log = self.log.clone(); - let prune_sessions = sessions.clone(); - // Pruning will occur ~ every interval period. So the timeout expiration may sometimes - // exceed the expected, but we don't have to write lock the SessionMap as often to clean up. - tokio::spawn(async move { - loop { - delay_for(Duration::from_secs(SESSION_TIMEOUT_SECONDS)).await; - debug!(log, "Attempting to Prune Sessions"); - Server::prune_sessions(&log, prune_sessions.clone()).await; - } - }); + self.run_prune_sessions(&sessions).await; loop { if let Err(err) = Server::process_receive_socket( @@ -90,6 +79,22 @@ impl Server { } } + /// run_prune_sessions starts the timer for pruning sessions and runs prune_sessions every + /// SESSION_TIMEOUT_SECONDS, via a tokio::spawn, i.e. it's non-blocking. + /// Pruning will occur ~ every interval period. So the timeout expiration may sometimes + /// exceed the expected, but we don't have to write lock the SessionMap as often to clean up. + async fn run_prune_sessions(&self, sessions: &SessionMap) { + let log = self.log.clone(); + let sessions = sessions.clone(); + tokio::spawn(async move { + loop { + delay_for(Duration::from_secs(SESSION_TIMEOUT_SECONDS)).await; + debug!(log, "Attempting to Prune Sessions"); + Server::prune_sessions(&log, sessions.clone()).await; + } + }); + } + /// process_receive_socket takes packets from the local socket and asynchronously /// processes them to send them out to endpoints. async fn process_receive_socket( @@ -266,6 +271,7 @@ mod tests { use tokio::time::{Duration, Instant}; use crate::config::{Config, ConnectionConfig, EndPoint, Local}; + use crate::extensions::default_filters; use crate::server::sessions::{Packet, SESSION_TIMEOUT_SECONDS}; use crate::test_utils::{assert_recv_udp, ephemeral_socket, logger, recv_socket_done}; @@ -526,4 +532,49 @@ mod tests { info!(log, "test complete"); time::resume(); } + + #[tokio::test] + async fn run_prune_sessions() { + time::pause(); + let log = logger(); + let server = Server::new(log.clone(), default_filters(&log)); + let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new())); + let from: SocketAddr = "127.0.0.1:7000".parse().unwrap(); + let to: SocketAddr = "127.0.0.1:7001".parse().unwrap(); + let (send, _recv) = mpsc::channel::(1); + let key = (from, to); + + server.run_prune_sessions(&sessions).await; + Server::ensure_session(&log, sessions.clone(), from, to, send) + .await + .unwrap(); + + // session map should be the same since, we haven't passed expiry + time::advance(Duration::new(SESSION_TIMEOUT_SECONDS / 2, 0)).await; + { + let map = sessions.read().await; + + assert!(map.contains_key(&key)); + assert_eq!(1, map.len()); + } + time::advance(Duration::new(2 * SESSION_TIMEOUT_SECONDS, 0)).await; + + // poll, since cleanup is async, and may not have happened yet + for _ in 1..10 { + time::delay_for(Duration::from_secs(1)).await; + let map = sessions.read().await; + if !map.contains_key(&key) && map.len() == 0 { + break; + } + } + // do final assertion + { + let map = sessions.read().await; + assert!( + !map.contains_key(&key), + "should not contain the key after prune" + ); + assert_eq!(0, map.len(), "len should be 0, bit is {}", map.len()); + } + } }