diff --git a/src/proxy.rs b/src/proxy.rs index 3ad585db05..8d3fc08bc6 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -19,6 +19,7 @@ pub use builder::{logger, Builder, PendingValidation, Validated}; pub(crate) use health::Health; pub(crate) use metrics::Metrics; pub use server::Server; +pub use sessions::SessionKey; mod admin; mod builder; diff --git a/src/proxy/server.rs b/src/proxy/server.rs index 6f3b8a47d8..72d5aecdfd 100644 --- a/src/proxy/server.rs +++ b/src/proxy/server.rs @@ -34,7 +34,7 @@ use crate::proxy::builder::{ValidatedConfig, ValidatedSource}; use crate::proxy::server::error::Error; use crate::proxy::sessions::metrics::Metrics as SessionMetrics; use crate::proxy::sessions::session_manager::SessionManager; -use crate::proxy::sessions::{Packet, Session, SESSION_TIMEOUT_SECONDS}; +use crate::proxy::sessions::{Packet, Session, SessionKey, SESSION_TIMEOUT_SECONDS}; use crate::proxy::Admin; use crate::utils::debug; @@ -349,7 +349,10 @@ impl Server { endpoint: &Endpoint, args: &ProcessDownstreamReceiveConfig, ) { - let session_key = (recv_addr, endpoint.address); + let session_key = SessionKey { + source: recv_addr, + destination: endpoint.address, + }; // Grab a read lock and find the session. let guard = args.session_manager.get_sessions().await; @@ -381,7 +384,7 @@ impl Server { &args.log, args.session_metrics.clone(), args.filter_manager.clone(), - session_key.0, + session_key.source, endpoint.clone(), args.send_packets.clone(), args.session_ttl, @@ -411,7 +414,7 @@ impl Server { warn!( args.log, "Could not find session"; - "key" => format!("({}:{})", session_key.0.to_string(), session_key.1.to_string()) + "key" => format!("({}:{})", session_key.source.to_string(), session_key.destination.to_string()) ) } } @@ -691,7 +694,7 @@ mod tests { let map = session_manager.get_sessions().await; assert_eq!(expected.session_len, map.len()); - let build_key = (receive_addr, endpoint.socket.local_addr().unwrap()); + let build_key = (receive_addr, endpoint.socket.local_addr().unwrap()).into(); assert!(map.contains_key(&build_key)); let session = map.get(&build_key).unwrap(); let now_secs = SystemTime::now() diff --git a/src/proxy/sessions.rs b/src/proxy/sessions.rs index e9f37ee3c6..54afc8a8d8 100644 --- a/src/proxy/sessions.rs +++ b/src/proxy/sessions.rs @@ -14,7 +14,7 @@ * limitations under the License. */ -pub use session::{Packet, Session}; +pub use session::{Packet, Session, SessionKey}; pub use session_manager::SESSION_TIMEOUT_SECONDS; pub(crate) mod error; diff --git a/src/proxy/sessions/session.rs b/src/proxy/sessions/session.rs index 535fbb4a19..42d9f55deb 100644 --- a/src/proxy/sessions/session.rs +++ b/src/proxy/sessions/session.rs @@ -51,6 +51,22 @@ pub struct Session { shutdown_tx: watch::Sender<()>, } +// A (source, destination) address pair that uniquely identifies a session. +#[derive(Clone, Eq, Hash, PartialEq, Debug, PartialOrd, Ord)] +pub struct SessionKey { + pub source: SocketAddr, + pub destination: SocketAddr, +} + +impl From<(SocketAddr, SocketAddr)> for SessionKey { + fn from(pair: (SocketAddr, SocketAddr)) -> Self { + SessionKey { + source: pair.0, + destination: pair.1, + } + } +} + /// ReceivedPacketContext contains state needed to process a received packet. struct ReceivedPacketContext<'a> { packet: &'a [u8], @@ -179,8 +195,11 @@ impl Session { } /// key returns the key to be used for this session in a SessionMap - pub fn key(&self) -> (SocketAddr, SocketAddr) { - (self.from, self.dest.address) + pub fn key(&self) -> SessionKey { + SessionKey { + source: self.from, + destination: self.dest.address, + } } /// process_recv_packet processes a packet that is received by this session. diff --git a/src/proxy/sessions/session_manager.rs b/src/proxy/sessions/session_manager.rs index 68c1195578..ef2be309bf 100644 --- a/src/proxy/sessions/session_manager.rs +++ b/src/proxy/sessions/session_manager.rs @@ -15,17 +15,16 @@ */ use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use slog::{debug, warn, Logger}; use tokio::sync::{watch, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use crate::proxy::sessions::Session; +use crate::proxy::sessions::{Session, SessionKey}; -// Tracks current sessions keyed by key (source_address,destination_address) pair. -type SessionsMap = HashMap<(SocketAddr, SocketAddr), Session>; +// Tracks current sessions by their [`SessionKey`] +type SessionsMap = HashMap; type Sessions = Arc>; /// SESSION_TIMEOUT_SECONDS is the default session timeout. @@ -126,7 +125,7 @@ mod tests { use crate::filters::{manager::FilterManager, FilterChain}; use crate::proxy::sessions::metrics::Metrics; use crate::proxy::sessions::session_manager::Sessions; - use crate::proxy::sessions::{Packet, Session}; + use crate::proxy::sessions::{Packet, Session, SessionKey}; use crate::test_utils::TestHelper; use super::SessionManager; @@ -154,14 +153,14 @@ mod tests { shutdown_rx, ); - let key = (from, to); + let key = SessionKey::from((from, to)); // Insert key. { let registry = Registry::default(); let mut sessions = sessions.write().await; sessions.insert( - key, + key.clone(), Session::new( &t.log, Metrics::new(®istry).unwrap(), @@ -215,14 +214,14 @@ mod tests { let (send, _recv) = mpsc::channel::(1); let endpoint = Endpoint::from_address(to); - let key = (from, to); + let key = SessionKey::from((from, to)); let ttl = Duration::from_secs(1); { let registry = Registry::default(); let mut sessions = sessions.write().await; sessions.insert( - key, + key.clone(), Session::new( &t.log, Metrics::new(®istry).unwrap(),