From 353877010b0f535996455af518557405de847760 Mon Sep 17 00:00:00 2001 From: Erin Power Date: Fri, 26 Aug 2022 10:05:24 +0200 Subject: [PATCH] Move proxy::Server to Proxy and sessions::session to sessions --- benches/throughput.rs | 4 +- docs/src/filters.md | 2 +- docs/src/filters/capture.md | 2 +- docs/src/filters/compress.md | 2 +- docs/src/filters/concatenate_bytes.md | 2 +- docs/src/filters/debug.md | 2 +- docs/src/filters/firewall.md | 2 +- docs/src/filters/load_balancer.md | 2 +- docs/src/filters/local_rate_limit.md | 2 +- docs/src/filters/match.md | 2 +- docs/src/filters/token_router.md | 4 +- docs/src/filters/writing_custom_filters.md | 4 +- examples/quilkin-filter-example/src/main.rs | 2 +- src/cli/run.rs | 2 +- src/lib.rs | 2 +- src/proxy.rs | 549 ++++++++++++++++- src/proxy/server.rs | 558 ------------------ src/proxy/sessions.rs | 504 +++++++++++++++- .../{session_manager.rs => manager.rs} | 4 +- src/proxy/sessions/session.rs | 511 ---------------- src/test_utils.rs | 2 +- src/xds.rs | 2 +- tests/capture.rs | 2 +- tests/compress.rs | 4 +- tests/concatenate_bytes.rs | 2 +- tests/filter_order.rs | 2 +- tests/filters.rs | 8 +- tests/firewall.rs | 2 +- tests/health.rs | 2 +- tests/load_balancer.rs | 2 +- tests/local_rate_limit.rs | 2 +- tests/match.rs | 2 +- tests/metrics.rs | 6 +- tests/no_filter.rs | 2 +- tests/token_router.rs | 2 +- 35 files changed, 1084 insertions(+), 1120 deletions(-) delete mode 100644 src/proxy/server.rs rename src/proxy/sessions/{session_manager.rs => manager.rs} (98%) delete mode 100644 src/proxy/sessions/session.rs diff --git a/benches/throughput.rs b/benches/throughput.rs index 2879811321..7eb64e65e1 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -29,7 +29,7 @@ const PACKETS: &[&[u8]] = &[ fn run_quilkin(port: u16, endpoint: SocketAddr) { std::thread::spawn(move || { let runtime = tokio::runtime::Runtime::new().unwrap(); - let config = quilkin::Server::builder() + let config = quilkin::Config::builder() .port(port) .admin(Admin { address: "[::]:0".parse().unwrap(), @@ -38,7 +38,7 @@ fn run_quilkin(port: u16, endpoint: SocketAddr) { .build() .unwrap(); - let server = quilkin::Server::try_from(config).unwrap(); + let server = quilkin::Proxy::try_from(config).unwrap(); runtime.block_on(async move { let (_shutdown_tx, shutdown_rx) = tokio::sync::watch::channel::<()>(()); server.run(shutdown_rx).await.unwrap(); diff --git a/docs/src/filters.md b/docs/src/filters.md index 1e9159e721..222c759f19 100644 --- a/docs/src/filters.md +++ b/docs/src/filters.md @@ -69,7 +69,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 2); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); # } ``` diff --git a/docs/src/filters/capture.md b/docs/src/filters/capture.md index e2613e0b14..2afd0f06fa 100644 --- a/docs/src/filters/capture.md +++ b/docs/src/filters/capture.md @@ -48,7 +48,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 1); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); ``` ### Configuration Options ([Rust Doc](../../api/quilkin/filters/capture/struct.Config.html)) diff --git a/docs/src/filters/compress.md b/docs/src/filters/compress.md index 53f867f1c8..046e8682f7 100644 --- a/docs/src/filters/compress.md +++ b/docs/src/filters/compress.md @@ -26,7 +26,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 1); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); ``` The above example shows a proxy that could be used with a typical game client, where the original client data is diff --git a/docs/src/filters/concatenate_bytes.md b/docs/src/filters/concatenate_bytes.md index 7bad5c7973..522386c278 100644 --- a/docs/src/filters/concatenate_bytes.md +++ b/docs/src/filters/concatenate_bytes.md @@ -26,7 +26,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 1); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); ``` ### Configuration Options ([Rust Doc](../../api/quilkin/filters/concatenate_bytes/struct.Config.html)) diff --git a/docs/src/filters/debug.md b/docs/src/filters/debug.md index 7933040761..e4a4ad12b3 100644 --- a/docs/src/filters/debug.md +++ b/docs/src/filters/debug.md @@ -25,7 +25,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 1); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); ``` ### Configuration Options ([Rust Doc](../../api/quilkin/filters/debug/struct.Config.html)) diff --git a/docs/src/filters/firewall.md b/docs/src/filters/firewall.md index 7b5e198abd..9ae17e914c 100644 --- a/docs/src/filters/firewall.md +++ b/docs/src/filters/firewall.md @@ -34,7 +34,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 1); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); ``` ### Configuration Options ([Rust Doc](../../api/quilkin/filters/firewall/struct.Config.html)) diff --git a/docs/src/filters/load_balancer.md b/docs/src/filters/load_balancer.md index 0efe2667c1..5f3e7d71ad 100644 --- a/docs/src/filters/load_balancer.md +++ b/docs/src/filters/load_balancer.md @@ -25,7 +25,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 1); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); # } ``` diff --git a/docs/src/filters/local_rate_limit.md b/docs/src/filters/local_rate_limit.md index 7c38f345d7..563b9c970e 100644 --- a/docs/src/filters/local_rate_limit.md +++ b/docs/src/filters/local_rate_limit.md @@ -29,7 +29,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 1); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); # } ``` To configure a rate limiter, we specify the maximum rate at which the proxy is allowed to forward packets. In the example above, we configured the proxy to forward a maximum of 1000 packets per second). diff --git a/docs/src/filters/match.md b/docs/src/filters/match.md index 0944fe0758..efe62b678e 100644 --- a/docs/src/filters/match.md +++ b/docs/src/filters/match.md @@ -39,7 +39,7 @@ filters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 2); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); ``` diff --git a/docs/src/filters/token_router.md b/docs/src/filters/token_router.md index c5f3abb65a..0bf97e1796 100644 --- a/docs/src/filters/token_router.md +++ b/docs/src/filters/token_router.md @@ -37,7 +37,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 1); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); ``` View the [CaptureBytes](./capture.md) filter documentation for more details. @@ -102,7 +102,7 @@ clusters: # "; # let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); # assert_eq!(config.filters.load().len(), 2); -# quilkin::Server::try_from(config).unwrap(); +# quilkin::Proxy::try_from(config).unwrap(); ``` On the game client side the [ConcatenateBytes](./concatenate_bytes.md) filter could also be used to add authentication diff --git a/docs/src/filters/writing_custom_filters.md b/docs/src/filters/writing_custom_filters.md index acf6ff887c..cf0ee44efb 100644 --- a/docs/src/filters/writing_custom_filters.md +++ b/docs/src/filters/writing_custom_filters.md @@ -73,7 +73,7 @@ impl StaticFilter for Greet { ## Running -We can run the proxy using [`Server::TryFrom`][Server::TryFrom] function. Let's +We can run the proxy using [`Proxy::TryFrom`][Proxy::TryFrom] function. Let's add a main function that does that. Quilkin relies on the [Tokio] async runtime, so we need to import that crate and wrap our main function with it. @@ -243,7 +243,7 @@ filter. Try it out with the following configuration: [filter-factory-name]: ../../api/quilkin/filters/trait.FilterFactory.html#tymethod.name [FilterRegistry]: ../../api/quilkin/filters/struct.FilterRegistry.html [FilterRegistry::register]: ../../api/quilkin/filters/struct.FilterRegistry.html#method.register -[Server::try_from]: ../../api/struct.Server.html#impl-TryFrom%3CConfig%3E +[Proxy::try_from]: ../../api/struct.Proxy.html#impl-TryFrom%3CConfig%3E [CreateFilterArgs::config]: ../../api/quilkin/filters/prelude/struct.CreateFilterArgs.html#structfield.config [ConfigType::dynamic]: ../../api/quilkin/config/enum.ConfigType.html#variant.Dynamic [ConfigType::static]: ../../api/quilkin/config/enum.ConfigType.html#variant.Static diff --git a/examples/quilkin-filter-example/src/main.rs b/examples/quilkin-filter-example/src/main.rs index 99b443bf3c..6b18ada422 100644 --- a/examples/quilkin-filter-example/src/main.rs +++ b/examples/quilkin-filter-example/src/main.rs @@ -95,7 +95,7 @@ async fn main() -> quilkin::Result<()> { quilkin::filters::FilterRegistry::register(vec![Greet::factory()].into_iter()); let (_shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(()); - let server: quilkin::Server = quilkin::Config::builder() + let server: quilkin::Proxy = quilkin::Config::builder() .port(7001) .filters(vec![quilkin::config::Filter { name: Greet::NAME.into(), diff --git a/src/cli/run.rs b/src/cli/run.rs index d3d8522f94..6bdefbe459 100644 --- a/src/cli/run.rs +++ b/src/cli/run.rs @@ -31,7 +31,7 @@ impl Run { let span = span!(Level::INFO, "source::run"); let _enter = span.enter(); - let server = crate::Server::try_from(config)?; + let server = crate::Proxy::try_from(config)?; #[cfg(target_os = "linux")] let mut sig_term_fut = signal::unix::signal(signal::unix::SignalKind::terminate())?; diff --git a/src/lib.rs b/src/lib.rs index a2bf7e5557..c024211435 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,7 +37,7 @@ pub mod xds; pub type Result = std::result::Result; #[doc(inline)] -pub use self::{cli::Cli, config::Config, proxy::Server}; +pub use self::{cli::Cli, config::Config, proxy::Proxy}; pub use quilkin_macros::include_proto; diff --git a/src/proxy.rs b/src/proxy.rs index 4f8924b460..f390c2623f 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -14,10 +14,551 @@ * limitations under the License. */ +mod health; +mod sessions; + pub(crate) use health::Health; -pub use server::Server; pub use sessions::SessionKey; -mod health; -mod server; -mod sessions; +use std::{ + net::{Ipv4Addr, SocketAddrV4}, + sync::Arc, +}; + +use prometheus::HistogramTimer; +use tokio::{net::UdpSocket, sync::watch, time::Duration}; + +use crate::{ + endpoint::{Endpoint, EndpointAddress}, + filters::{Filter, ReadContext}, + proxy::sessions::{ + manager::SessionManager, metrics::Metrics as SessionMetrics, Session, SessionArgs, + SESSION_TIMEOUT_SECONDS, + }, + utils::{debug, net}, + xds::ResourceType, + Config, Result, +}; + +/// The UDP proxy service. +pub struct Proxy { + config: Arc, + session_metrics: SessionMetrics, +} + +impl TryFrom for Proxy { + type Error = eyre::Error; + fn try_from(config: Config) -> Result { + Ok(Self { + config: Arc::from(config), + session_metrics: SessionMetrics::new()?, + }) + } +} + +/// Represents arguments to the `Proxy::run_recv_from` method. +struct RunRecvFromArgs { + session_manager: SessionManager, + session_ttl: Duration, + shutdown_rx: watch::Receiver<()>, +} + +/// Packet received from local port +#[derive(Debug)] +struct DownstreamPacket { + source: EndpointAddress, + contents: Vec, + timer: HistogramTimer, +} + +/// Represents the required arguments to run a worker task that +/// processes packets received downstream. +struct DownstreamReceiveWorkerConfig { + /// ID of the worker. + worker_id: usize, + /// Socket with reused port from which the worker receives packets. + socket: Arc, + /// Configuration required to process a received downstream packet. + receive_config: ProcessDownstreamReceiveConfig, + /// The worker task exits when a value is received from this shutdown channel. + shutdown_rx: watch::Receiver<()>, +} + +/// Contains arguments to process a received downstream packet, through the +/// filter chain and session pipeline. +struct ProcessDownstreamReceiveConfig { + config: Arc, + session_metrics: SessionMetrics, + session_manager: SessionManager, + session_ttl: Duration, + socket: Arc, +} + +impl Proxy { + /// Returns a builder for configuring a Quilkin proxy. + pub fn builder() -> crate::config::Builder { + <_>::default() + } + + /// start the async processing of incoming UDP packets. Will block until an + /// event is sent through the stop Receiver. + pub async fn run(self, mut shutdown_rx: watch::Receiver<()>) -> Result<()> { + let proxy = self.config.proxy.load(); + + tracing::info!(port = proxy.port, proxy_id = &*proxy.id, "Starting"); + + let session_manager = SessionManager::new(shutdown_rx.clone()); + let session_ttl = Duration::from_secs(SESSION_TIMEOUT_SECONDS); + + let management_servers = self.config.management_servers.load(); + let _xds_stream = if !management_servers.is_empty() { + let client = crate::xds::Client::connect(self.config.clone()).await?; + let mut stream = client.stream().await?; + + stream.send(ResourceType::Endpoint, &[]).await?; + stream.send(ResourceType::Listener, &[]).await?; + Some(stream) + } else { + None + }; + + if self.config.admin.is_some() { + tokio::spawn(crate::admin::server(self.config.clone())); + } + + self.run_recv_from(RunRecvFromArgs { + session_manager, + session_ttl, + shutdown_rx: shutdown_rx.clone(), + }) + .await?; + tracing::info!("Quilkin is ready"); + + shutdown_rx + .changed() + .await + .map_err(|error| eyre::eyre!(error)) + } + + /// Spawns a background task that sits in a loop, receiving packets from the passed in socket. + /// Each received packet is placed on a queue to be processed by a worker task. + /// This function also spawns the set of worker tasks responsible for consuming packets + /// off the aforementioned queue and processing them through the filter chain and session + /// pipeline. + async fn run_recv_from(&self, args: RunRecvFromArgs) -> Result<()> { + let session_manager = args.session_manager; + let session_metrics = self.session_metrics.clone(); + + // The number of worker tasks to spawn. Each task gets a dedicated queue to + // consume packets off. + let num_workers = num_cpus::get(); + + // Contains config for each worker task. + let mut worker_configs = vec![]; + for worker_id in 0..num_workers { + let socket = Arc::new(self.bind(self.config.proxy.load().port)?); + worker_configs.push(DownstreamReceiveWorkerConfig { + worker_id, + socket: socket.clone(), + shutdown_rx: args.shutdown_rx.clone(), + receive_config: ProcessDownstreamReceiveConfig { + config: self.config.clone(), + session_metrics: session_metrics.clone(), + session_manager: session_manager.clone(), + session_ttl: args.session_ttl, + socket, + }, + }) + } + + // Start the worker tasks that pick up received packets from their queue + // and processes them. + Self::spawn_downstream_receive_workers(worker_configs); + Ok(()) + } + + /// For each worker config provided, spawn a background task that sits in a + /// loop, receiving packets from a socket and processing them through + /// the filter chain. + fn spawn_downstream_receive_workers(worker_configs: Vec) { + for DownstreamReceiveWorkerConfig { + worker_id, + socket, + mut shutdown_rx, + receive_config, + } in worker_configs + { + tokio::spawn(async move { + // Initialize a buffer for the UDP packet. We use the maximum size of a UDP + // packet, which is the maximum value of 16 a bit integer. + let mut buf = vec![0; 1 << 16]; + loop { + tracing::debug!( + id = worker_id, + addr = ?socket.local_addr(), + "Awaiting packet" + ); + tokio::select! { + recv = socket.recv_from(&mut buf) => { + let timer = crate::metrics::PROCESSING_TIME.with_label_values(&[crate::metrics::READ_DIRECTION_LABEL]).start_timer(); + match recv { + Ok((size, source)) => { + let contents = (&buf[..size]).to_vec(); + tracing::trace!(id = worker_id, size = size, source = %source, contents=&*debug::bytes_to_string(&contents), "received packet from downstream"); + let packet = DownstreamPacket { + source: source.into(), + contents, + timer, + }; + Self::process_downstream_received_packet(packet, &receive_config).await + }, + Err(error) => { + tracing::error!(%error); + return; + } + } + } + _ = shutdown_rx.changed() => { + tracing::debug!(id = worker_id, "Received shutdown signal"); + return; + } + } + } + }); + } + } + + /// Processes a packet by running it through the filter chain. + async fn process_downstream_received_packet( + packet: DownstreamPacket, + args: &ProcessDownstreamReceiveConfig, + ) { + let clusters = args.config.clusters.load(); + + tracing::trace!(?clusters, "Clusters available"); + + let endpoints: Vec<_> = clusters.endpoints().collect(); + if endpoints.is_empty() { + tracing::trace!("dropping packet, no upstream endpoints available"); + crate::metrics::PACKETS_DROPPED + .with_label_values(&[crate::metrics::READ_DIRECTION_LABEL, "NoEndpointsAvailable"]) + .inc(); + return; + } + + let result = args.config.filters.load().read(ReadContext::new( + endpoints, + packet.source.clone(), + packet.contents, + )); + + if let Some(response) = result { + for endpoint in response.endpoints.iter() { + Self::session_send_packet( + &response.contents, + packet.source.clone(), + endpoint, + args, + ) + .await; + } + } + + packet.timer.stop_and_record(); + } + + /// Send a packet received from `recv_addr` to an endpoint. + #[tracing::instrument(skip_all, fields(source = %recv_addr, dest = %endpoint.address))] + async fn session_send_packet( + packet: &[u8], + recv_addr: EndpointAddress, + endpoint: &Endpoint, + args: &ProcessDownstreamReceiveConfig, + ) { + let session_key = SessionKey { + source: recv_addr, + dest: endpoint.address.clone(), + }; + + // Grab a read lock and find the session. + let guard = args.session_manager.get_sessions().await; + if let Some(session) = guard.get(&session_key) { + // If it exists then send the packet, we're done. + Self::session_send_packet_helper(session, packet, args.session_ttl).await + } else { + // If it does not exist, grab a write lock so that we can create it. + // + // NOTE: We must drop the lock guard to release the lock before + // trying to acquire a write lock since these lock aren't reentrant, + // otherwise we will deadlock with our self. + drop(guard); + + // Grab a write lock. + let mut guard = args.session_manager.get_sessions_mut().await; + + // Although we have the write lock now, check whether some other thread + // managed to create the session in-between our dropping the read + // lock and grabbing the write lock. + if let Some(session) = guard.get(&session_key) { + tracing::trace!("Reusing previous session"); + // If the session now exists then we have less work to do, + // simply send the packet. + Self::session_send_packet_helper(session, packet, args.session_ttl).await; + } else { + tracing::trace!("Creating new session"); + // Otherwise, create the session and insert into the map. + let session_args = SessionArgs { + config: args.config.clone(), + metrics: args.session_metrics.clone(), + source: session_key.source.clone(), + downstream_socket: args.socket.clone(), + dest: endpoint.clone(), + ttl: args.session_ttl, + }; + match session_args.into_session().await { + Ok(session) => { + // Insert the session into the map and release the write lock + // immediately since we don't want to block other threads while we send + // the packet. Instead, re-acquire a read lock and send the packet. + guard.insert(session.key(), session); + + // Release the write lock. + drop(guard); + + // Grab a read lock to send the packet. + let guard = args.session_manager.get_sessions().await; + if let Some(session) = guard.get(&session_key) { + Self::session_send_packet_helper(session, packet, args.session_ttl) + .await; + } else { + tracing::warn!( + key = %format!("({}:{})", session_key.source, session_key.dest), + "Could not find session" + ) + } + } + Err(error) => { + tracing::error!(%error, "Failed to ensure session exists"); + } + } + } + } + } + + // A helper function to push a session's packet on its socket. + async fn session_send_packet_helper(session: &Session, packet: &[u8], ttl: Duration) { + match session.send(packet).await { + Ok(_) => { + if let Err(error) = session.update_expiration(ttl) { + tracing::warn!(%error, "Error updating session expiration") + } + } + Err(error) => tracing::error!(%error, "Error sending packet from session"), + }; + } + + /// binds the local configured port with port and address reuse applied. + fn bind(&self, port: u16) -> Result { + let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port); + net::socket_with_reuse(addr.into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use tokio::time::{timeout, Duration}; + + use crate::{ + config, + endpoint::Endpoint, + test_utils::{available_addr, create_socket, load_test_filters, TestHelper}, + }; + + #[tokio::test] + async fn run_server() { + let mut t = TestHelper::default(); + + let endpoint1 = t.open_socket_and_recv_single_packet().await; + let endpoint2 = t.open_socket_and_recv_single_packet().await; + + let local_addr = available_addr().await; + let config = Config::builder() + .port(local_addr.port()) + .endpoints(vec![ + Endpoint::new(endpoint1.socket.local_addr().unwrap().into()), + Endpoint::new(endpoint2.socket.local_addr().unwrap().into()), + ]) + .build() + .unwrap(); + t.run_server_with_config(config); + + let msg = "hello"; + endpoint1 + .socket + .send_to(msg.as_bytes(), &local_addr) + .await + .unwrap(); + assert_eq!( + msg, + timeout(Duration::from_secs(1), endpoint1.packet_rx) + .await + .expect("should get a packet") + .unwrap() + ); + assert_eq!( + msg, + timeout(Duration::from_secs(1), endpoint2.packet_rx) + .await + .expect("should get a packet") + .unwrap() + ); + } + + #[tokio::test] + async fn run_client() { + let mut t = TestHelper::default(); + + let endpoint = t.open_socket_and_recv_single_packet().await; + + let local_addr = available_addr().await; + let config = Config::builder() + .port(local_addr.port()) + .endpoints(vec![Endpoint::new( + endpoint.socket.local_addr().unwrap().into(), + )]) + .build() + .unwrap(); + t.run_server_with_config(config); + + let msg = "hello"; + endpoint + .socket + .send_to(msg.as_bytes(), &local_addr) + .await + .unwrap(); + assert_eq!( + msg, + timeout(Duration::from_millis(100), endpoint.packet_rx) + .await + .unwrap() + .unwrap() + ); + } + + #[tokio::test] + async fn run_with_filter() { + let mut t = TestHelper::default(); + + load_test_filters(); + let endpoint = t.open_socket_and_recv_single_packet().await; + let local_addr = available_addr().await; + let config = Config::builder() + .port(local_addr.port()) + .filters(vec![config::Filter { + name: "TestFilter".to_string(), + config: None, + }]) + .endpoints(vec![Endpoint::new( + endpoint.socket.local_addr().unwrap().into(), + )]) + .build() + .unwrap(); + t.run_server_with_config(config); + + let msg = "hello"; + endpoint + .socket + .send_to(msg.as_bytes(), &local_addr) + .await + .unwrap(); + + // search for the filter strings. + let result = timeout(Duration::from_millis(100), endpoint.packet_rx) + .await + .unwrap() + .unwrap(); + assert!(result.contains(msg), "'{}' not found in '{}'", msg, result); + assert!(result.contains(":odr:"), ":odr: not found in '{}'", result); + } + + #[tokio::test] + async fn spawn_downstream_receive_workers() { + let t = TestHelper::default(); + + let socket = Arc::new(create_socket().await); + let addr = socket.local_addr().unwrap(); + let (_shutdown_tx, shutdown_rx) = watch::channel(()); + let endpoint = t.open_socket_and_recv_single_packet().await; + let msg = "hello"; + + // we'll test a single DownstreamReceiveWorkerConfig + let config = DownstreamReceiveWorkerConfig { + worker_id: 1, + socket: socket.clone(), + receive_config: ProcessDownstreamReceiveConfig { + config: Arc::new( + Config::builder() + .endpoints(&[endpoint.socket.local_addr().unwrap().into()][..]) + .build() + .unwrap(), + ), + session_metrics: SessionMetrics::new().unwrap(), + session_manager: SessionManager::new(shutdown_rx.clone()), + session_ttl: Duration::from_secs(10), + socket, + }, + shutdown_rx, + }; + + Proxy::spawn_downstream_receive_workers(vec![config]); + + let socket = create_socket().await; + socket.send_to(msg.as_bytes(), &addr).await.unwrap(); + + assert_eq!( + msg, + timeout(Duration::from_secs(1), endpoint.packet_rx) + .await + .expect("should receive a packet") + .unwrap() + ); + } + + #[tokio::test] + async fn run_recv_from() { + let t = TestHelper::default(); + let (_shutdown_tx, shutdown_rx) = watch::channel(()); + + let msg = "hello"; + let endpoint = t.open_socket_and_recv_single_packet().await; + let session_manager = SessionManager::new(shutdown_rx.clone()); + let local_addr = available_addr().await; + let config = Config::builder() + .port(local_addr.port()) + .endpoints(&[Endpoint::from(endpoint.socket.local_addr().unwrap())][..]) + .build() + .unwrap(); + let server = Proxy::try_from(config).unwrap(); + + server + .run_recv_from(RunRecvFromArgs { + session_manager: session_manager.clone(), + session_ttl: Duration::from_secs(10), + shutdown_rx, + }) + .await + .unwrap(); + + let socket = create_socket().await; + socket.send_to(msg.as_bytes(), &local_addr).await.unwrap(); + assert_eq!( + msg, + timeout(Duration::from_secs(1), endpoint.packet_rx) + .await + .expect("should receive a packet") + .unwrap() + ); + } +} diff --git a/src/proxy/server.rs b/src/proxy/server.rs deleted file mode 100644 index 703425abad..0000000000 --- a/src/proxy/server.rs +++ /dev/null @@ -1,558 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::{ - net::{Ipv4Addr, SocketAddrV4}, - sync::Arc, -}; - -use prometheus::HistogramTimer; -use tokio::{net::UdpSocket, sync::watch, time::Duration}; - -use crate::{ - endpoint::{Endpoint, EndpointAddress}, - filters::{Filter, ReadContext}, - proxy::sessions::{ - metrics::Metrics as SessionMetrics, session_manager::SessionManager, Session, SessionArgs, - SessionKey, SESSION_TIMEOUT_SECONDS, - }, - utils::{debug, net}, - xds::ResourceType, - Config, Result, -}; - -/// Server is the UDP server main implementation -pub struct Server { - config: Arc, - session_metrics: SessionMetrics, -} - -impl TryFrom for Server { - type Error = eyre::Error; - fn try_from(config: Config) -> Result { - Ok(Self { - config: Arc::from(config), - session_metrics: SessionMetrics::new()?, - }) - } -} - -/// Represents arguments to the `Server::run_recv_from` method. -struct RunRecvFromArgs { - session_manager: SessionManager, - session_ttl: Duration, - shutdown_rx: watch::Receiver<()>, -} - -/// Packet received from local port -#[derive(Debug)] -struct DownstreamPacket { - source: EndpointAddress, - contents: Vec, - timer: HistogramTimer, -} - -/// Represents the required arguments to run a worker task that -/// processes packets received downstream. -struct DownstreamReceiveWorkerConfig { - /// ID of the worker. - worker_id: usize, - /// Socket with reused port from which the worker receives packets. - socket: Arc, - /// Configuration required to process a received downstream packet. - receive_config: ProcessDownstreamReceiveConfig, - /// The worker task exits when a value is received from this shutdown channel. - shutdown_rx: watch::Receiver<()>, -} - -/// Contains arguments to process a received downstream packet, through the -/// filter chain and session pipeline. -struct ProcessDownstreamReceiveConfig { - config: Arc, - session_metrics: SessionMetrics, - session_manager: SessionManager, - session_ttl: Duration, - socket: Arc, -} - -impl Server { - /// Returns a builder for configuring a Quilkin Server. - pub fn builder() -> crate::config::Builder { - <_>::default() - } - - /// start the async processing of incoming UDP packets. Will block until an - /// event is sent through the stop Receiver. - pub async fn run(self, mut shutdown_rx: watch::Receiver<()>) -> Result<()> { - let proxy = self.config.proxy.load(); - - tracing::info!(port = proxy.port, proxy_id = &*proxy.id, "Starting"); - - let session_manager = SessionManager::new(shutdown_rx.clone()); - let session_ttl = Duration::from_secs(SESSION_TIMEOUT_SECONDS); - - let management_servers = self.config.management_servers.load(); - let _xds_stream = if !management_servers.is_empty() { - let client = crate::xds::Client::connect(self.config.clone()).await?; - let mut stream = client.stream().await?; - - stream.send(ResourceType::Endpoint, &[]).await?; - stream.send(ResourceType::Listener, &[]).await?; - Some(stream) - } else { - None - }; - - if self.config.admin.is_some() { - tokio::spawn(crate::admin::server(self.config.clone())); - } - - self.run_recv_from(RunRecvFromArgs { - session_manager, - session_ttl, - shutdown_rx: shutdown_rx.clone(), - }) - .await?; - tracing::info!("Quilkin is ready"); - - shutdown_rx - .changed() - .await - .map_err(|error| eyre::eyre!(error)) - } - - /// Spawns a background task that sits in a loop, receiving packets from the passed in socket. - /// Each received packet is placed on a queue to be processed by a worker task. - /// This function also spawns the set of worker tasks responsible for consuming packets - /// off the aforementioned queue and processing them through the filter chain and session - /// pipeline. - async fn run_recv_from(&self, args: RunRecvFromArgs) -> Result<()> { - let session_manager = args.session_manager; - let session_metrics = self.session_metrics.clone(); - - // The number of worker tasks to spawn. Each task gets a dedicated queue to - // consume packets off. - let num_workers = num_cpus::get(); - - // Contains config for each worker task. - let mut worker_configs = vec![]; - for worker_id in 0..num_workers { - let socket = Arc::new(self.bind(self.config.proxy.load().port)?); - worker_configs.push(DownstreamReceiveWorkerConfig { - worker_id, - socket: socket.clone(), - shutdown_rx: args.shutdown_rx.clone(), - receive_config: ProcessDownstreamReceiveConfig { - config: self.config.clone(), - session_metrics: session_metrics.clone(), - session_manager: session_manager.clone(), - session_ttl: args.session_ttl, - socket, - }, - }) - } - - // Start the worker tasks that pick up received packets from their queue - // and processes them. - Self::spawn_downstream_receive_workers(worker_configs); - Ok(()) - } - - /// For each worker config provided, spawn a background task that sits in a - /// loop, receiving packets from a socket and processing them through - /// the filter chain. - fn spawn_downstream_receive_workers(worker_configs: Vec) { - for DownstreamReceiveWorkerConfig { - worker_id, - socket, - mut shutdown_rx, - receive_config, - } in worker_configs - { - tokio::spawn(async move { - // Initialize a buffer for the UDP packet. We use the maximum size of a UDP - // packet, which is the maximum value of 16 a bit integer. - let mut buf = vec![0; 1 << 16]; - loop { - tracing::debug!( - id = worker_id, - addr = ?socket.local_addr(), - "Awaiting packet" - ); - tokio::select! { - recv = socket.recv_from(&mut buf) => { - let timer = crate::metrics::PROCESSING_TIME.with_label_values(&[crate::metrics::READ_DIRECTION_LABEL]).start_timer(); - match recv { - Ok((size, source)) => { - let contents = (&buf[..size]).to_vec(); - tracing::trace!(id = worker_id, size = size, source = %source, contents=&*debug::bytes_to_string(&contents), "received packet from downstream"); - let packet = DownstreamPacket { - source: source.into(), - contents, - timer, - }; - Self::process_downstream_received_packet(packet, &receive_config).await - }, - Err(error) => { - tracing::error!(%error); - return; - } - } - } - _ = shutdown_rx.changed() => { - tracing::debug!(id = worker_id, "Received shutdown signal"); - return; - } - } - } - }); - } - } - - /// Processes a packet by running it through the filter chain. - async fn process_downstream_received_packet( - packet: DownstreamPacket, - args: &ProcessDownstreamReceiveConfig, - ) { - let clusters = args.config.clusters.load(); - - tracing::trace!(?clusters, "Clusters available"); - - let endpoints: Vec<_> = clusters.endpoints().collect(); - if endpoints.is_empty() { - tracing::trace!("dropping packet, no upstream endpoints available"); - crate::metrics::PACKETS_DROPPED - .with_label_values(&[crate::metrics::READ_DIRECTION_LABEL, "NoEndpointsAvailable"]) - .inc(); - return; - } - - let result = args.config.filters.load().read(ReadContext::new( - endpoints, - packet.source.clone(), - packet.contents, - )); - - if let Some(response) = result { - for endpoint in response.endpoints.iter() { - Self::session_send_packet( - &response.contents, - packet.source.clone(), - endpoint, - args, - ) - .await; - } - } - - packet.timer.stop_and_record(); - } - - /// Send a packet received from `recv_addr` to an endpoint. - #[tracing::instrument(skip_all, fields(source = %recv_addr, dest = %endpoint.address))] - async fn session_send_packet( - packet: &[u8], - recv_addr: EndpointAddress, - endpoint: &Endpoint, - args: &ProcessDownstreamReceiveConfig, - ) { - let session_key = SessionKey { - source: recv_addr, - dest: endpoint.address.clone(), - }; - - // Grab a read lock and find the session. - let guard = args.session_manager.get_sessions().await; - if let Some(session) = guard.get(&session_key) { - // If it exists then send the packet, we're done. - Self::session_send_packet_helper(session, packet, args.session_ttl).await - } else { - // If it does not exist, grab a write lock so that we can create it. - // - // NOTE: We must drop the lock guard to release the lock before - // trying to acquire a write lock since these lock aren't reentrant, - // otherwise we will deadlock with our self. - drop(guard); - - // Grab a write lock. - let mut guard = args.session_manager.get_sessions_mut().await; - - // Although we have the write lock now, check whether some other thread - // managed to create the session in-between our dropping the read - // lock and grabbing the write lock. - if let Some(session) = guard.get(&session_key) { - tracing::trace!("Reusing previous session"); - // If the session now exists then we have less work to do, - // simply send the packet. - Self::session_send_packet_helper(session, packet, args.session_ttl).await; - } else { - tracing::trace!("Creating new session"); - // Otherwise, create the session and insert into the map. - let session_args = SessionArgs { - config: args.config.clone(), - metrics: args.session_metrics.clone(), - source: session_key.source.clone(), - downstream_socket: args.socket.clone(), - dest: endpoint.clone(), - ttl: args.session_ttl, - }; - match session_args.into_session().await { - Ok(session) => { - // Insert the session into the map and release the write lock - // immediately since we don't want to block other threads while we send - // the packet. Instead, re-acquire a read lock and send the packet. - guard.insert(session.key(), session); - - // Release the write lock. - drop(guard); - - // Grab a read lock to send the packet. - let guard = args.session_manager.get_sessions().await; - if let Some(session) = guard.get(&session_key) { - Self::session_send_packet_helper(session, packet, args.session_ttl) - .await; - } else { - tracing::warn!( - key = %format!("({}:{})", session_key.source, session_key.dest), - "Could not find session" - ) - } - } - Err(error) => { - tracing::error!(%error, "Failed to ensure session exists"); - } - } - } - } - } - - // A helper function to push a session's packet on its socket. - async fn session_send_packet_helper(session: &Session, packet: &[u8], ttl: Duration) { - match session.send(packet).await { - Ok(_) => { - if let Err(error) = session.update_expiration(ttl) { - tracing::warn!(%error, "Error updating session expiration") - } - } - Err(error) => tracing::error!(%error, "Error sending packet from session"), - }; - } - - /// binds the local configured port with port and address reuse applied. - fn bind(&self, port: u16) -> Result { - let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port); - net::socket_with_reuse(addr.into()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use tokio::time::{timeout, Duration}; - - use crate::{ - config, - endpoint::Endpoint, - test_utils::{available_addr, create_socket, load_test_filters, TestHelper}, - }; - - #[tokio::test] - async fn run_server() { - let mut t = TestHelper::default(); - - let endpoint1 = t.open_socket_and_recv_single_packet().await; - let endpoint2 = t.open_socket_and_recv_single_packet().await; - - let local_addr = available_addr().await; - let config = Server::builder() - .port(local_addr.port()) - .endpoints(vec![ - Endpoint::new(endpoint1.socket.local_addr().unwrap().into()), - Endpoint::new(endpoint2.socket.local_addr().unwrap().into()), - ]) - .build() - .unwrap(); - t.run_server_with_config(config); - - let msg = "hello"; - endpoint1 - .socket - .send_to(msg.as_bytes(), &local_addr) - .await - .unwrap(); - assert_eq!( - msg, - timeout(Duration::from_secs(1), endpoint1.packet_rx) - .await - .expect("should get a packet") - .unwrap() - ); - assert_eq!( - msg, - timeout(Duration::from_secs(1), endpoint2.packet_rx) - .await - .expect("should get a packet") - .unwrap() - ); - } - - #[tokio::test] - async fn run_client() { - let mut t = TestHelper::default(); - - let endpoint = t.open_socket_and_recv_single_packet().await; - - let local_addr = available_addr().await; - let config = Server::builder() - .port(local_addr.port()) - .endpoints(vec![Endpoint::new( - endpoint.socket.local_addr().unwrap().into(), - )]) - .build() - .unwrap(); - t.run_server_with_config(config); - - let msg = "hello"; - endpoint - .socket - .send_to(msg.as_bytes(), &local_addr) - .await - .unwrap(); - assert_eq!( - msg, - timeout(Duration::from_millis(100), endpoint.packet_rx) - .await - .unwrap() - .unwrap() - ); - } - - #[tokio::test] - async fn run_with_filter() { - let mut t = TestHelper::default(); - - load_test_filters(); - let endpoint = t.open_socket_and_recv_single_packet().await; - let local_addr = available_addr().await; - let config = Server::builder() - .port(local_addr.port()) - .filters(vec![config::Filter { - name: "TestFilter".to_string(), - config: None, - }]) - .endpoints(vec![Endpoint::new( - endpoint.socket.local_addr().unwrap().into(), - )]) - .build() - .unwrap(); - t.run_server_with_config(config); - - let msg = "hello"; - endpoint - .socket - .send_to(msg.as_bytes(), &local_addr) - .await - .unwrap(); - - // search for the filter strings. - let result = timeout(Duration::from_millis(100), endpoint.packet_rx) - .await - .unwrap() - .unwrap(); - assert!(result.contains(msg), "'{}' not found in '{}'", msg, result); - assert!(result.contains(":odr:"), ":odr: not found in '{}'", result); - } - - #[tokio::test] - async fn spawn_downstream_receive_workers() { - let t = TestHelper::default(); - - let socket = Arc::new(create_socket().await); - let addr = socket.local_addr().unwrap(); - let (_shutdown_tx, shutdown_rx) = watch::channel(()); - let endpoint = t.open_socket_and_recv_single_packet().await; - let msg = "hello"; - - // we'll test a single DownstreamReceiveWorkerConfig - let config = DownstreamReceiveWorkerConfig { - worker_id: 1, - socket: socket.clone(), - receive_config: ProcessDownstreamReceiveConfig { - config: Arc::new( - Config::builder() - .endpoints(&[endpoint.socket.local_addr().unwrap().into()][..]) - .build() - .unwrap(), - ), - session_metrics: SessionMetrics::new().unwrap(), - session_manager: SessionManager::new(shutdown_rx.clone()), - session_ttl: Duration::from_secs(10), - socket, - }, - shutdown_rx, - }; - - Server::spawn_downstream_receive_workers(vec![config]); - - let socket = create_socket().await; - socket.send_to(msg.as_bytes(), &addr).await.unwrap(); - - assert_eq!( - msg, - timeout(Duration::from_secs(1), endpoint.packet_rx) - .await - .expect("should receive a packet") - .unwrap() - ); - } - - #[tokio::test] - async fn run_recv_from() { - let t = TestHelper::default(); - let (_shutdown_tx, shutdown_rx) = watch::channel(()); - - let msg = "hello"; - let endpoint = t.open_socket_and_recv_single_packet().await; - let session_manager = SessionManager::new(shutdown_rx.clone()); - let local_addr = available_addr().await; - let config = Config::builder() - .port(local_addr.port()) - .endpoints(&[Endpoint::from(endpoint.socket.local_addr().unwrap())][..]) - .build() - .unwrap(); - let server = Server::try_from(config).unwrap(); - - server - .run_recv_from(RunRecvFromArgs { - session_manager: session_manager.clone(), - session_ttl: Duration::from_secs(10), - shutdown_rx, - }) - .await - .unwrap(); - - let socket = create_socket().await; - socket.send_to(msg.as_bytes(), &local_addr).await.unwrap(); - assert_eq!( - msg, - timeout(Duration::from_secs(1), endpoint.packet_rx) - .await - .expect("should receive a packet") - .unwrap() - ); - } -} diff --git a/src/proxy/sessions.rs b/src/proxy/sessions.rs index f9ae6418c7..7d6e8c7877 100644 --- a/src/proxy/sessions.rs +++ b/src/proxy/sessions.rs @@ -14,10 +14,504 @@ * limitations under the License. */ -pub use session::{Session, SessionArgs, SessionKey}; -pub use session_manager::SESSION_TIMEOUT_SECONDS; - pub(crate) mod error; +pub(crate) mod manager; pub(crate) mod metrics; -mod session; -pub(crate) mod session_manager; + +pub use manager::SESSION_TIMEOUT_SECONDS; + +use std::{ + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::{SystemTime, UNIX_EPOCH}, +}; + +use prometheus::HistogramTimer; +use tokio::{ + net::UdpSocket, + select, + sync::watch, + time::{Duration, Instant}, +}; + +use crate::{ + endpoint::{Endpoint, EndpointAddress}, + filters::{Filter, WriteContext}, + proxy::sessions::{error::Error, metrics::Metrics}, + utils::debug, +}; + +type Result = std::result::Result; + +/// Session encapsulates a UDP stream session +pub struct Session { + config: Arc, + metrics: Metrics, + /// created_at is time at which the session was created + created_at: Instant, + /// socket that sends and receives from and to the endpoint address + upstream_socket: Arc, + /// dest is where to send data to + dest: Endpoint, + /// address of original sender + source: EndpointAddress, + /// The time at which the session is considered expired and can be removed. + expiration: Arc, + /// a channel to broadcast on if we are shutting down this Session + shutdown_tx: watch::Sender<()>, +} + +// A (source, destination) address pair that uniquely identifies a session. +#[derive(Clone, Eq, Hash, PartialEq, Debug, PartialOrd, Ord)] +pub struct SessionKey { + pub source: EndpointAddress, + pub dest: EndpointAddress, +} + +impl From<(EndpointAddress, EndpointAddress)> for SessionKey { + fn from((source, dest): (EndpointAddress, EndpointAddress)) -> Self { + SessionKey { source, dest } + } +} + +/// ReceivedPacketContext contains state needed to process a received packet. +struct ReceivedPacketContext<'a> { + packet: &'a [u8], + config: Arc, + endpoint: &'a Endpoint, + source: EndpointAddress, + dest: EndpointAddress, + timer: HistogramTimer, +} + +pub struct SessionArgs { + pub config: Arc, + pub metrics: Metrics, + pub source: EndpointAddress, + pub downstream_socket: Arc, + pub dest: Endpoint, + pub ttl: Duration, +} + +impl SessionArgs { + /// Creates a new Session, and starts the process of receiving udp sockets + /// from its ephemeral port from endpoint(s) + pub async fn into_session(self) -> Result { + Session::new(self).await + } +} + +impl Session { + /// internal constructor for a Session from SessionArgs + #[tracing::instrument(skip_all)] + async fn new(args: SessionArgs) -> Result { + let addr = (std::net::Ipv4Addr::UNSPECIFIED, 0); + let upstream_socket = Arc::new(UdpSocket::bind(addr).await.map_err(Error::BindUdpSocket)?); + upstream_socket + .connect( + args.dest + .address + .to_socket_addr() + .map_err(Error::ToSocketAddr)?, + ) + .await + .map_err(Error::BindUdpSocket)?; + let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); + + let expiration = Arc::new(AtomicU64::new(0)); + Self::do_update_expiration(&expiration, args.ttl)?; + + let s = Session { + metrics: args.metrics, + config: args.config.clone(), + upstream_socket, + source: args.source, + dest: args.dest, + created_at: Instant::now(), + expiration, + shutdown_tx, + }; + tracing::debug!(source = %s.source, dest = ?s.dest, "Session created"); + + s.metrics.sessions_total.inc(); + s.metrics.active_sessions.inc(); + s.run(args.ttl, args.downstream_socket, shutdown_rx); + Ok(s) + } + + /// run starts processing receiving upstream udp packets + /// and sending them back downstream + fn run( + &self, + ttl: Duration, + downstream_socket: Arc, + mut shutdown_rx: watch::Receiver<()>, + ) { + let source = self.source.clone(); + let expiration = self.expiration.clone(); + let config = self.config.clone(); + let endpoint = self.dest.clone(); + let metrics = self.metrics.clone(); + let upstream_socket = self.upstream_socket.clone(); + + tokio::spawn(async move { + let mut buf: Vec = vec![0; 65535]; + loop { + tracing::debug!(source = %source, dest = ?endpoint, "Awaiting incoming packet"); + + select! { + received = upstream_socket.recv_from(&mut buf) => { + match received { + Err(error) => { + metrics.rx_errors_total.inc(); + tracing::error!(%error, %source, dest = ?endpoint, "Error receiving packet"); + }, + Ok((size, recv_addr)) => { + crate::metrics::PACKETS_SIZE.with_label_values(&[crate::metrics::WRITE_DIRECTION_LABEL]).inc_by(size as f64); + crate::metrics::PACKETS_TOTAL.with_label_values(&[crate::metrics::WRITE_DIRECTION_LABEL]).inc(); + Session::process_recv_packet( + &metrics, + &downstream_socket, + &expiration, + ttl, + ReceivedPacketContext { + config: config.clone(), + packet: &buf[..size], + endpoint: &endpoint, + source: recv_addr.into(), + dest: source.clone(), + timer: crate::metrics::PROCESSING_TIME.with_label_values(&[crate::metrics::WRITE_DIRECTION_LABEL]).start_timer(), + }).await + } + }; + } + _ = shutdown_rx.changed() => { + tracing::debug!(%source, dest = ?endpoint, "Closing Session"); + return; + } + }; + } + }); + } + + /// expiration returns the current expiration Instant value + pub fn expiration(&self) -> u64 { + self.expiration.load(Ordering::Relaxed) + } + + /// key returns the key to be used for this session in a SessionMap + pub fn key(&self) -> SessionKey { + SessionKey { + source: self.source.clone(), + dest: self.dest.address.clone(), + } + } + + /// process_recv_packet processes a packet that is received by this session. + async fn process_recv_packet( + metrics: &Metrics, + downstream_socket: &Arc, + expiration: &Arc, + ttl: Duration, + packet_ctx: ReceivedPacketContext<'_>, + ) { + let ReceivedPacketContext { + packet, + config, + endpoint, + source: from, + dest, + timer, + } = packet_ctx; + + tracing::trace!(%from, dest = %endpoint.address, contents = %debug::bytes_to_string(packet), "received packet from upstream"); + + if let Err(error) = Session::do_update_expiration(expiration, ttl) { + tracing::warn!(%error, "Error updating session expiration") + } + + match config.filters.load().write(WriteContext::new( + endpoint, + from.clone(), + dest.clone(), + packet.to_vec(), + )) { + None => metrics.packets_dropped_total.inc(), + Some(response) => { + let addr = match dest.to_socket_addr() { + Ok(addr) => addr, + Err(error) => { + tracing::error!(%dest, %error, "Error resolving address"); + metrics.packets_dropped_total.inc(); + return; + } + }; + + let packet = response.contents.as_slice(); + tracing::trace!(%from, dest = %addr, contents = %debug::bytes_to_string(packet), "sending packet downstream"); + if let Err(error) = downstream_socket.send_to(packet, addr).await { + metrics.rx_errors_total.inc(); + tracing::error!(%error, "Error sending packet"); + } + } + } + + timer.stop_and_record(); + } + + /// update_expiration set the increments the expiration value by the session timeout + pub fn update_expiration(&self, ttl: Duration) -> Result<()> { + Self::do_update_expiration(&self.expiration, ttl) + } + + /// do_update_expiration increments the expiration value by the session timeout (internal) + fn do_update_expiration(expiration: &Arc, ttl: Duration) -> Result<()> { + let new_expiration_time = SystemTime::now() + .checked_add(ttl) + .ok_or_else(|| { + Error::UpdateSessionExpiration(format!( + "checked_add error: expiration ttl {:?} is out of bounds", + ttl + )) + })? + .duration_since(UNIX_EPOCH) + .map_err(|_| { + Error::UpdateSessionExpiration( + "duration_since was called with time later than the current time".into(), + ) + })? + .as_secs(); + + expiration.store(new_expiration_time, Ordering::Relaxed); + + Ok(()) + } + + /// Sends a packet to the Session's dest. + pub async fn send(&self, buf: &[u8]) -> crate::Result> { + tracing::trace!( + dest_address = %self.dest.address, + contents = %debug::bytes_to_string(buf), + "sending packet upstream"); + + self.do_send(buf) + .await + .map(|size| { + self.metrics.tx_packets_total.inc(); + self.metrics.tx_bytes_total.inc_by(size as u64); + Some(size) + }) + .map_err(|err| { + self.metrics.tx_errors_total.inc(); + eyre::eyre!(err).wrap_err("Error sending to destination.") + }) + } + + /// Sends `buf` to the session's destination address. On success, returns + /// the number of bytes written. + pub async fn do_send(&self, buf: &[u8]) -> crate::Result { + self.upstream_socket + .send(buf) + .await + .map_err(|error| eyre::eyre!(error)) + } +} + +impl Drop for Session { + fn drop(&mut self) { + self.metrics.active_sessions.dec(); + self.metrics + .duration_secs + .observe(self.created_at.elapsed().as_secs() as f64); + + if let Err(error) = self.shutdown_tx.send(()) { + tracing::warn!(%error, "Error sending session shutdown signal"); + } + + tracing::debug!(source = %self.source, dest_address = %self.dest.address, "Session closed"); + } +} + +#[cfg(test)] +mod tests { + use std::{ + str::from_utf8, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::{Duration, SystemTime, UNIX_EPOCH}, + }; + + use super::{Metrics, Session}; + + use prometheus::{Histogram, HistogramOpts}; + use tokio::time::timeout; + + use crate::{ + endpoint::{Endpoint, EndpointAddress}, + proxy::sessions::{ReceivedPacketContext, SessionArgs}, + test_utils::{create_socket, new_test_config, TestHelper}, + }; + + #[tokio::test] + async fn session_send_and_receive() { + let mut t = TestHelper::default(); + let addr = t.run_echo_server().await; + let endpoint = Endpoint::new(addr.clone()); + let socket = Arc::new(create_socket().await); + let msg = "hello"; + + let sess = Session::new(SessionArgs { + config: <_>::default(), + metrics: Metrics::new().unwrap(), + source: addr.clone(), + downstream_socket: socket.clone(), + dest: endpoint, + ttl: Duration::from_secs(20), + }) + .await + .unwrap(); + + let initial_expiration_secs = sess.expiration.load(Ordering::Relaxed); + let now_secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let diff = initial_expiration_secs - now_secs; + assert!((15..21).contains(&diff)); + + sess.send(msg.as_bytes()).await.unwrap(); + + let mut buf = vec![0; 1024]; + let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) + .await + .unwrap() + .unwrap(); + let packet = &buf[..size]; + assert_eq!(msg, from_utf8(packet).unwrap()); + assert_eq!(addr.port(), recv_addr.port()); + } + + #[tokio::test] + async fn process_recv_packet() { + crate::test_utils::load_test_filters(); + let histogram = Histogram::with_opts(HistogramOpts::new("test", "test")).unwrap(); + + let socket = Arc::new(create_socket().await); + let endpoint = Endpoint::new("127.0.1.1:80".parse().unwrap()); + let dest: EndpointAddress = socket.local_addr().unwrap().into(); + let expiration = Arc::new(AtomicU64::new( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + )); + let initial_expiration = expiration.load(Ordering::Relaxed); + + // first test with no filtering + let msg = "hello"; + Session::process_recv_packet( + &Metrics::new().unwrap(), + &socket, + &expiration, + Duration::from_secs(10), + ReceivedPacketContext { + config: <_>::default(), + packet: msg.as_bytes(), + endpoint: &endpoint, + source: endpoint.address.clone(), + dest: dest.clone(), + timer: histogram.start_timer(), + }, + ) + .await; + + assert!(initial_expiration < expiration.load(Ordering::Relaxed)); + + let mut buf = vec![0; 1024]; + let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) + .await + .expect("Should receive a packet") + .unwrap(); + assert_eq!(msg, from_utf8(&buf[..size]).unwrap()); + assert_eq!(dest.port(), recv_addr.port()); + + let expiration = Arc::new(AtomicU64::new( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + )); + let initial_expiration = expiration.load(Ordering::Relaxed); + // add filter + let config = Arc::new(new_test_config()); + Session::process_recv_packet( + &Metrics::new().unwrap(), + &socket, + &expiration, + Duration::from_secs(10), + ReceivedPacketContext { + config, + packet: msg.as_bytes(), + endpoint: &endpoint, + source: endpoint.address.clone(), + dest: dest.clone(), + timer: histogram.start_timer(), + }, + ) + .await; + + assert!(initial_expiration < expiration.load(Ordering::Relaxed)); + let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) + .await + .expect("Should receive a packet") + .unwrap(); + assert_eq!( + format!("{}:our:{}:{}", msg, endpoint.address, dest), + from_utf8(&buf[..size]).unwrap() + ); + assert_eq!(dest.port(), recv_addr.port()); + } + + #[tokio::test] + async fn metrics() { + let t = TestHelper::default(); + let ep = t.open_socket_and_recv_single_packet().await; + let addr: EndpointAddress = ep.socket.local_addr().unwrap().into(); + let endpoint = Endpoint::new(addr.clone()); + let socket = Arc::new(create_socket().await); + + let session = Session::new(SessionArgs { + config: <_>::default(), + metrics: Metrics::new().unwrap(), + source: addr, + downstream_socket: socket, + dest: endpoint, + ttl: Duration::from_secs(10), + }) + .await + .unwrap(); + + assert_eq!(session.metrics.sessions_total.get(), 1); + assert_eq!(session.metrics.active_sessions.get(), 1); + + // send a packet + session.send(b"hello").await.unwrap(); + timeout(Duration::from_secs(1), ep.packet_rx) + .await + .expect("should receive a packet") + .unwrap(); + + assert_eq!(session.metrics.tx_bytes_total.get(), 5); + assert_eq!(session.metrics.tx_packets_total.get(), 1); + + // drop metrics + let metrics = session.metrics.clone(); + drop(session); + assert_eq!(metrics.sessions_total.get(), 1); + assert_eq!(metrics.active_sessions.get(), 0); + } +} diff --git a/src/proxy/sessions/session_manager.rs b/src/proxy/sessions/manager.rs similarity index 98% rename from src/proxy/sessions/session_manager.rs rename to src/proxy/sessions/manager.rs index 0d5033833f..bb1f5386a7 100644 --- a/src/proxy/sessions/session_manager.rs +++ b/src/proxy/sessions/manager.rs @@ -116,9 +116,7 @@ mod tests { use crate::{ endpoint::{Endpoint, EndpointAddress}, - proxy::sessions::{ - metrics::Metrics, session::SessionArgs, session_manager::Sessions, SessionKey, - }, + proxy::sessions::{manager::Sessions, metrics::Metrics, SessionArgs, SessionKey}, test_utils::create_socket, }; diff --git a/src/proxy/sessions/session.rs b/src/proxy/sessions/session.rs deleted file mode 100644 index 45e81a2daf..0000000000 --- a/src/proxy/sessions/session.rs +++ /dev/null @@ -1,511 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::{ - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }, - time::{SystemTime, UNIX_EPOCH}, -}; - -use prometheus::HistogramTimer; -use tokio::{ - net::UdpSocket, - select, - sync::watch, - time::{Duration, Instant}, -}; - -use crate::{ - endpoint::{Endpoint, EndpointAddress}, - filters::{Filter, WriteContext}, - proxy::sessions::{error::Error, metrics::Metrics}, - utils::debug, -}; - -type Result = std::result::Result; - -/// Session encapsulates a UDP stream session -pub struct Session { - config: Arc, - metrics: Metrics, - /// created_at is time at which the session was created - created_at: Instant, - /// socket that sends and receives from and to the endpoint address - upstream_socket: Arc, - /// dest is where to send data to - dest: Endpoint, - /// address of original sender - source: EndpointAddress, - /// The time at which the session is considered expired and can be removed. - expiration: Arc, - /// a channel to broadcast on if we are shutting down this Session - shutdown_tx: watch::Sender<()>, -} - -// A (source, destination) address pair that uniquely identifies a session. -#[derive(Clone, Eq, Hash, PartialEq, Debug, PartialOrd, Ord)] -pub struct SessionKey { - pub source: EndpointAddress, - pub dest: EndpointAddress, -} - -impl From<(EndpointAddress, EndpointAddress)> for SessionKey { - fn from((source, dest): (EndpointAddress, EndpointAddress)) -> Self { - SessionKey { source, dest } - } -} - -/// ReceivedPacketContext contains state needed to process a received packet. -struct ReceivedPacketContext<'a> { - packet: &'a [u8], - config: Arc, - endpoint: &'a Endpoint, - source: EndpointAddress, - dest: EndpointAddress, - timer: HistogramTimer, -} - -pub struct SessionArgs { - pub config: Arc, - pub metrics: Metrics, - pub source: EndpointAddress, - pub downstream_socket: Arc, - pub dest: Endpoint, - pub ttl: Duration, -} - -impl SessionArgs { - /// Creates a new Session, and starts the process of receiving udp sockets - /// from its ephemeral port from endpoint(s) - pub async fn into_session(self) -> Result { - Session::new(self).await - } -} - -impl Session { - /// internal constructor for a Session from SessionArgs - #[tracing::instrument(skip_all)] - async fn new(args: SessionArgs) -> Result { - let addr = (std::net::Ipv4Addr::UNSPECIFIED, 0); - let upstream_socket = Arc::new(UdpSocket::bind(addr).await.map_err(Error::BindUdpSocket)?); - upstream_socket - .connect( - args.dest - .address - .to_socket_addr() - .map_err(Error::ToSocketAddr)?, - ) - .await - .map_err(Error::BindUdpSocket)?; - let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); - - let expiration = Arc::new(AtomicU64::new(0)); - Self::do_update_expiration(&expiration, args.ttl)?; - - let s = Session { - metrics: args.metrics, - config: args.config.clone(), - upstream_socket, - source: args.source, - dest: args.dest, - created_at: Instant::now(), - expiration, - shutdown_tx, - }; - tracing::debug!(source = %s.source, dest = ?s.dest, "Session created"); - - s.metrics.sessions_total.inc(); - s.metrics.active_sessions.inc(); - s.run(args.ttl, args.downstream_socket, shutdown_rx); - Ok(s) - } - - /// run starts processing receiving upstream udp packets - /// and sending them back downstream - fn run( - &self, - ttl: Duration, - downstream_socket: Arc, - mut shutdown_rx: watch::Receiver<()>, - ) { - let source = self.source.clone(); - let expiration = self.expiration.clone(); - let config = self.config.clone(); - let endpoint = self.dest.clone(); - let metrics = self.metrics.clone(); - let upstream_socket = self.upstream_socket.clone(); - - tokio::spawn(async move { - let mut buf: Vec = vec![0; 65535]; - loop { - tracing::debug!(source = %source, dest = ?endpoint, "Awaiting incoming packet"); - - select! { - received = upstream_socket.recv_from(&mut buf) => { - match received { - Err(error) => { - metrics.rx_errors_total.inc(); - tracing::error!(%error, %source, dest = ?endpoint, "Error receiving packet"); - }, - Ok((size, recv_addr)) => { - crate::metrics::PACKETS_SIZE.with_label_values(&[crate::metrics::WRITE_DIRECTION_LABEL]).inc_by(size as f64); - crate::metrics::PACKETS_TOTAL.with_label_values(&[crate::metrics::WRITE_DIRECTION_LABEL]).inc(); - Session::process_recv_packet( - &metrics, - &downstream_socket, - &expiration, - ttl, - ReceivedPacketContext { - config: config.clone(), - packet: &buf[..size], - endpoint: &endpoint, - source: recv_addr.into(), - dest: source.clone(), - timer: crate::metrics::PROCESSING_TIME.with_label_values(&[crate::metrics::WRITE_DIRECTION_LABEL]).start_timer(), - }).await - } - }; - } - _ = shutdown_rx.changed() => { - tracing::debug!(%source, dest = ?endpoint, "Closing Session"); - return; - } - }; - } - }); - } - - /// expiration returns the current expiration Instant value - pub fn expiration(&self) -> u64 { - self.expiration.load(Ordering::Relaxed) - } - - /// key returns the key to be used for this session in a SessionMap - pub fn key(&self) -> SessionKey { - SessionKey { - source: self.source.clone(), - dest: self.dest.address.clone(), - } - } - - /// process_recv_packet processes a packet that is received by this session. - async fn process_recv_packet( - metrics: &Metrics, - downstream_socket: &Arc, - expiration: &Arc, - ttl: Duration, - packet_ctx: ReceivedPacketContext<'_>, - ) { - let ReceivedPacketContext { - packet, - config, - endpoint, - source: from, - dest, - timer, - } = packet_ctx; - - tracing::trace!(%from, dest = %endpoint.address, contents = %debug::bytes_to_string(packet), "received packet from upstream"); - - if let Err(error) = Session::do_update_expiration(expiration, ttl) { - tracing::warn!(%error, "Error updating session expiration") - } - - match config.filters.load().write(WriteContext::new( - endpoint, - from.clone(), - dest.clone(), - packet.to_vec(), - )) { - None => metrics.packets_dropped_total.inc(), - Some(response) => { - let addr = match dest.to_socket_addr() { - Ok(addr) => addr, - Err(error) => { - tracing::error!(%dest, %error, "Error resolving address"); - metrics.packets_dropped_total.inc(); - return; - } - }; - - let packet = response.contents.as_slice(); - tracing::trace!(%from, dest = %addr, contents = %debug::bytes_to_string(packet), "sending packet downstream"); - if let Err(error) = downstream_socket.send_to(packet, addr).await { - metrics.rx_errors_total.inc(); - tracing::error!(%error, "Error sending packet"); - } - } - } - - timer.stop_and_record(); - } - - /// update_expiration set the increments the expiration value by the session timeout - pub fn update_expiration(&self, ttl: Duration) -> Result<()> { - Self::do_update_expiration(&self.expiration, ttl) - } - - /// do_update_expiration increments the expiration value by the session timeout (internal) - fn do_update_expiration(expiration: &Arc, ttl: Duration) -> Result<()> { - let new_expiration_time = SystemTime::now() - .checked_add(ttl) - .ok_or_else(|| { - Error::UpdateSessionExpiration(format!( - "checked_add error: expiration ttl {:?} is out of bounds", - ttl - )) - })? - .duration_since(UNIX_EPOCH) - .map_err(|_| { - Error::UpdateSessionExpiration( - "duration_since was called with time later than the current time".into(), - ) - })? - .as_secs(); - - expiration.store(new_expiration_time, Ordering::Relaxed); - - Ok(()) - } - - /// Sends a packet to the Session's dest. - pub async fn send(&self, buf: &[u8]) -> crate::Result> { - tracing::trace!( - dest_address = %self.dest.address, - contents = %debug::bytes_to_string(buf), - "sending packet upstream"); - - self.do_send(buf) - .await - .map(|size| { - self.metrics.tx_packets_total.inc(); - self.metrics.tx_bytes_total.inc_by(size as u64); - Some(size) - }) - .map_err(|err| { - self.metrics.tx_errors_total.inc(); - eyre::eyre!(err).wrap_err("Error sending to destination.") - }) - } - - /// Sends `buf` to the session's destination address. On success, returns - /// the number of bytes written. - pub async fn do_send(&self, buf: &[u8]) -> crate::Result { - self.upstream_socket - .send(buf) - .await - .map_err(|error| eyre::eyre!(error)) - } -} - -impl Drop for Session { - fn drop(&mut self) { - self.metrics.active_sessions.dec(); - self.metrics - .duration_secs - .observe(self.created_at.elapsed().as_secs() as f64); - - if let Err(error) = self.shutdown_tx.send(()) { - tracing::warn!(%error, "Error sending session shutdown signal"); - } - - tracing::debug!(source = %self.source, dest_address = %self.dest.address, "Session closed"); - } -} - -#[cfg(test)] -mod tests { - use std::{ - str::from_utf8, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }, - time::{Duration, SystemTime, UNIX_EPOCH}, - }; - - use super::{Metrics, Session}; - - use prometheus::{Histogram, HistogramOpts}; - use tokio::time::timeout; - - use crate::{ - endpoint::{Endpoint, EndpointAddress}, - proxy::sessions::session::{ReceivedPacketContext, SessionArgs}, - test_utils::{create_socket, new_test_config, TestHelper}, - }; - - #[tokio::test] - async fn session_send_and_receive() { - let mut t = TestHelper::default(); - let addr = t.run_echo_server().await; - let endpoint = Endpoint::new(addr.clone()); - let socket = Arc::new(create_socket().await); - let msg = "hello"; - - let sess = Session::new(SessionArgs { - config: <_>::default(), - metrics: Metrics::new().unwrap(), - source: addr.clone(), - downstream_socket: socket.clone(), - dest: endpoint, - ttl: Duration::from_secs(20), - }) - .await - .unwrap(); - - let initial_expiration_secs = sess.expiration.load(Ordering::Relaxed); - let now_secs = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - let diff = initial_expiration_secs - now_secs; - assert!((15..21).contains(&diff)); - - sess.send(msg.as_bytes()).await.unwrap(); - - let mut buf = vec![0; 1024]; - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .unwrap() - .unwrap(); - let packet = &buf[..size]; - assert_eq!(msg, from_utf8(packet).unwrap()); - assert_eq!(addr.port(), recv_addr.port()); - } - - #[tokio::test] - async fn process_recv_packet() { - crate::test_utils::load_test_filters(); - let histogram = Histogram::with_opts(HistogramOpts::new("test", "test")).unwrap(); - - let socket = Arc::new(create_socket().await); - let endpoint = Endpoint::new("127.0.1.1:80".parse().unwrap()); - let dest: EndpointAddress = socket.local_addr().unwrap().into(); - let expiration = Arc::new(AtomicU64::new( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(), - )); - let initial_expiration = expiration.load(Ordering::Relaxed); - - // first test with no filtering - let msg = "hello"; - Session::process_recv_packet( - &Metrics::new().unwrap(), - &socket, - &expiration, - Duration::from_secs(10), - ReceivedPacketContext { - config: <_>::default(), - packet: msg.as_bytes(), - endpoint: &endpoint, - source: endpoint.address.clone(), - dest: dest.clone(), - timer: histogram.start_timer(), - }, - ) - .await; - - assert!(initial_expiration < expiration.load(Ordering::Relaxed)); - - let mut buf = vec![0; 1024]; - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .expect("Should receive a packet") - .unwrap(); - assert_eq!(msg, from_utf8(&buf[..size]).unwrap()); - assert_eq!(dest.port(), recv_addr.port()); - - let expiration = Arc::new(AtomicU64::new( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(), - )); - let initial_expiration = expiration.load(Ordering::Relaxed); - // add filter - let config = Arc::new(new_test_config()); - Session::process_recv_packet( - &Metrics::new().unwrap(), - &socket, - &expiration, - Duration::from_secs(10), - ReceivedPacketContext { - config, - packet: msg.as_bytes(), - endpoint: &endpoint, - source: endpoint.address.clone(), - dest: dest.clone(), - timer: histogram.start_timer(), - }, - ) - .await; - - assert!(initial_expiration < expiration.load(Ordering::Relaxed)); - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .expect("Should receive a packet") - .unwrap(); - assert_eq!( - format!("{}:our:{}:{}", msg, endpoint.address, dest), - from_utf8(&buf[..size]).unwrap() - ); - assert_eq!(dest.port(), recv_addr.port()); - } - - #[tokio::test] - async fn metrics() { - let t = TestHelper::default(); - let ep = t.open_socket_and_recv_single_packet().await; - let addr: EndpointAddress = ep.socket.local_addr().unwrap().into(); - let endpoint = Endpoint::new(addr.clone()); - let socket = Arc::new(create_socket().await); - - let session = Session::new(SessionArgs { - config: <_>::default(), - metrics: Metrics::new().unwrap(), - source: addr, - downstream_socket: socket, - dest: endpoint, - ttl: Duration::from_secs(10), - }) - .await - .unwrap(); - - assert_eq!(session.metrics.sessions_total.get(), 1); - assert_eq!(session.metrics.active_sessions.get(), 1); - - // send a packet - session.send(b"hello").await.unwrap(); - timeout(Duration::from_secs(1), ep.packet_rx) - .await - .expect("should receive a packet") - .unwrap(); - - assert_eq!(session.metrics.tx_bytes_total.get(), 5); - assert_eq!(session.metrics.tx_packets_total.get(), 1); - - // drop metrics - let metrics = session.metrics.clone(); - drop(session); - assert_eq!(metrics.sessions_total.get(), 1); - assert_eq!(metrics.active_sessions.get(), 0); - } -} diff --git a/src/test_utils.rs b/src/test_utils.rs index 18b127e96e..6b18dd8123 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -225,7 +225,7 @@ impl TestHelper { self.run_server(<_>::try_from(config).unwrap()); } - pub fn run_server(&mut self, server: crate::Server) { + pub fn run_server(&mut self, server: crate::Proxy) { let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); self.server_shutdown_tx.push(Some(shutdown_tx)); tokio::spawn(async move { diff --git a/src/xds.rs b/src/xds.rs index 9e65a5452e..19adc3a22a 100644 --- a/src/xds.rs +++ b/src/xds.rs @@ -193,7 +193,7 @@ mod tests { let (_shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(()); tokio::spawn(server::spawn(xds_config.clone())); tokio::spawn( - crate::Server::try_from(client_config) + crate::Proxy::try_from(client_config) .unwrap() .run(shutdown_rx), ); diff --git a/tests/capture.rs b/tests/capture.rs index 543d0dd57d..9989abbb56 100644 --- a/tests/capture.rs +++ b/tests/capture.rs @@ -33,7 +33,7 @@ async fn token_router() { let mut t = TestHelper::default(); let echo = t.run_echo_server().await; let server_port = 12348; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![ Filter { diff --git a/tests/compress.rs b/tests/compress.rs index 48667dbd04..1c1976f8b0 100644 --- a/tests/compress.rs +++ b/tests/compress.rs @@ -35,7 +35,7 @@ async fn client_and_server() { on_read: DECOMPRESS on_write: COMPRESS "; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_addr.port()) .filters(vec![Filter { name: Compress::factory().name().into(), @@ -53,7 +53,7 @@ on_write: COMPRESS on_read: COMPRESS on_write: DECOMPRESS "; - let client_config = quilkin::Server::builder() + let client_config = quilkin::Config::builder() .port(client_addr.port()) .filters(vec![Filter { name: Compress::factory().name().into(), diff --git a/tests/concatenate_bytes.rs b/tests/concatenate_bytes.rs index 9756d17a48..84e5801aa5 100644 --- a/tests/concatenate_bytes.rs +++ b/tests/concatenate_bytes.rs @@ -35,7 +35,7 @@ bytes: YWJj #abc let echo = t.run_echo_server().await; let server_port = 12346; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![Filter { name: ConcatenateBytes::factory().name().into(), diff --git a/tests/filter_order.rs b/tests/filter_order.rs index 49baf25bbd..f552256296 100644 --- a/tests/filter_order.rs +++ b/tests/filter_order.rs @@ -55,7 +55,7 @@ on_write: DECOMPRESS .await; let server_port = 12346; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![ Filter { diff --git a/tests/filters.rs b/tests/filters.rs index 985227cc46..9adb9ef406 100644 --- a/tests/filters.rs +++ b/tests/filters.rs @@ -37,7 +37,7 @@ async fn test_filter() { // create server configuration let server_port = 12346; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![Filter { name: "TestFilter".to_string(), @@ -52,7 +52,7 @@ async fn test_filter() { // create a local client let client_port = 12347; - let client_config = quilkin::Server::builder() + let client_config = quilkin::Config::builder() .port(client_port) .filters(vec![Filter { name: "TestFilter".to_string(), @@ -112,7 +112,7 @@ async fn debug_filter() { }); // create server configuration let server_port = 12247; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![Filter { name: factory.name().into(), @@ -129,7 +129,7 @@ async fn debug_filter() { // create a local client let client_port = 12248; - let client_config = quilkin::Server::builder() + let client_config = quilkin::Config::builder() .port(client_port) .filters(vec![Filter { name: factory.name().into(), diff --git a/tests/firewall.rs b/tests/firewall.rs index cef15d6f59..b40c72830b 100644 --- a/tests/firewall.rs +++ b/tests/firewall.rs @@ -106,7 +106,7 @@ async fn test(t: &mut TestHelper, server_port: u16, yaml: &str) -> oneshot::Rece .replace("%2", echo.port().to_string().as_str()); tracing::info!(config = yaml.as_str(), "Config"); - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![Filter { name: Firewall::factory().name().into(), diff --git a/tests/health.rs b/tests/health.rs index c675c1da35..05e50f9179 100644 --- a/tests/health.rs +++ b/tests/health.rs @@ -27,7 +27,7 @@ async fn health_server() { // create server configuration let server_port = 12349; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .endpoints(vec!["127.0.0.1:0".parse::().unwrap()]) .admin(Admin { diff --git a/tests/load_balancer.rs b/tests/load_balancer.rs index fe54335621..04843389ec 100644 --- a/tests/load_balancer.rs +++ b/tests/load_balancer.rs @@ -47,7 +47,7 @@ policy: ROUND_ROBIN } let server_port = 12346; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![Filter { name: LoadBalancer::factory().name().into(), diff --git a/tests/local_rate_limit.rs b/tests/local_rate_limit.rs index 1d6d76021c..628ca4a465 100644 --- a/tests/local_rate_limit.rs +++ b/tests/local_rate_limit.rs @@ -36,7 +36,7 @@ period: 1 let echo = t.run_echo_server().await; let server_addr = available_addr().await; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_addr.port()) .filters(vec![Filter { name: LocalRateLimit::factory().name().into(), diff --git a/tests/match.rs b/tests/match.rs index 6a062317af..61d4b01c55 100644 --- a/tests/match.rs +++ b/tests/match.rs @@ -57,7 +57,7 @@ on_read: bytes: YWJj # abc "; let server_port = 12348; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![ Filter { diff --git a/tests/metrics.rs b/tests/metrics.rs index cfe42445ef..c6d85f95a3 100644 --- a/tests/metrics.rs +++ b/tests/metrics.rs @@ -27,7 +27,7 @@ async fn metrics_server() { // create server configuration let server_port = 12346; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .endpoints(vec![Endpoint::new(echo)]) .admin(Admin { @@ -35,11 +35,11 @@ async fn metrics_server() { }) .build() .unwrap(); - t.run_server(quilkin::Server::try_from(server_config).unwrap()); + t.run_server(quilkin::Proxy::try_from(server_config).unwrap()); // create a local client let client_port = 12347; - let client_config = quilkin::Server::builder() + let client_config = quilkin::Config::builder() .port(client_port) .endpoints(vec![Endpoint::new( (IpAddr::V4(Ipv4Addr::LOCALHOST), server_port).into(), diff --git a/tests/no_filter.rs b/tests/no_filter.rs index 6130886835..aefef3736a 100644 --- a/tests/no_filter.rs +++ b/tests/no_filter.rs @@ -30,7 +30,7 @@ async fn echo() { // create server configuration let local_addr = available_addr().await; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(local_addr.port()) .endpoints(vec![Endpoint::new(server1), Endpoint::new(server2)]) .build() diff --git a/tests/token_router.rs b/tests/token_router.rs index f5203ce930..952fcf55fa 100644 --- a/tests/token_router.rs +++ b/tests/token_router.rs @@ -44,7 +44,7 @@ quilkin.dev: - YWJj # abc "; let server_port = 12348; - let server_config = quilkin::Server::builder() + let server_config = quilkin::Config::builder() .port(server_port) .filters(vec![ Filter {