diff --git a/src/proxy/server.rs b/src/proxy/server.rs index 080ad2a505..d6e1fd616a 100644 --- a/src/proxy/server.rs +++ b/src/proxy/server.rs @@ -329,16 +329,14 @@ mod tests { use std::sync::Arc; use slog::info; - use tokio::sync::{mpsc, oneshot, RwLock}; + use tokio::sync::{mpsc, RwLock}; use tokio::time; use tokio::time::{Duration, Instant}; use crate::config; use crate::config::{Config, ConnectionConfig, EndPoint, Local}; use crate::proxy::sessions::{Packet, SESSION_TIMEOUT_SECONDS}; - use crate::test_utils::{ - ephemeral_socket, logger, recv_udp, recv_udp_done, TestFilter, TestFilterFactory, - }; + use crate::test_utils::{SplitSocket, TestFilter, TestFilterFactory, TestHelper}; use super::*; use crate::extensions::FilterRegistry; @@ -346,18 +344,13 @@ mod tests { #[tokio::test] async fn run_server() { - let socket1 = ephemeral_socket().await; - let endpoint1 = socket1.local_addr().unwrap(); - let socket2 = ephemeral_socket().await; - let endpoint2 = socket2.local_addr().unwrap(); - let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 12358); + let mut t = TestHelper::default(); - let (recv1, mut send) = socket1.split(); - let (recv2, _) = socket2.split(); - let (done1, wait1) = oneshot::channel::(); - let (done2, wait2) = oneshot::channel::(); + let mut endpoint1 = t.open_socket_and_recv_single_packet().await; + let endpoint2 = t.open_socket_and_recv_single_packet().await; - let config = Arc::new(Config { + let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 12358); + let config = Config { local: Local { port: local_addr.port(), }, @@ -366,77 +359,68 @@ mod tests { endpoints: vec![ EndPoint { name: String::from("e1"), - address: endpoint1.clone(), + address: endpoint1.addr.clone(), connection_ids: vec![], }, EndPoint { name: String::from("e2"), - address: endpoint2.clone(), + address: endpoint2.addr.clone(), connection_ids: vec![], }, ], }, - }); - - let server = Builder::from(config).validate().unwrap().build(); - let (close, stop) = oneshot::channel::<()>(); - tokio::spawn(async move { - server.run(stop).await.unwrap(); - }); + }; + t.run_server(config); let msg = "hello"; - recv_udp_done(recv1, done1); - recv_udp_done(recv2, done2); - send.send_to(msg.as_bytes(), &local_addr).await.unwrap(); - assert_eq!(msg, wait1.await.unwrap()); - assert_eq!(msg, wait2.await.unwrap()); - close.send(()).unwrap(); + endpoint1 + .send + .send_to(msg.as_bytes(), &local_addr) + .await + .unwrap(); + assert_eq!(msg, endpoint1.packet_rx.await.unwrap()); + assert_eq!(msg, endpoint2.packet_rx.await.unwrap()); } #[tokio::test] async fn run_client() { - let socket = ephemeral_socket().await; - let endpoint_addr = socket.local_addr().unwrap(); - let (recv, mut send) = socket.split(); + let mut t = TestHelper::default(); + + let mut endpoint = t.open_socket_and_recv_single_packet().await; + let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 12357); - let (done, wait) = oneshot::channel::(); - let config = Arc::new(Config { + let config = Config { local: Local { port: local_addr.port(), }, filters: vec![], connections: ConnectionConfig::Client { - addresses: vec![endpoint_addr], + addresses: vec![endpoint.addr], connection_id: "".into(), lb_policy: None, }, - }); - - let (close, stop) = oneshot::channel::<()>(); - let server = Builder::from(config).validate().unwrap().build(); - tokio::spawn(async move { - server.run(stop).await.unwrap(); - }); + }; + t.run_server(config); let msg = "hello"; - recv_udp_done(recv, done); - send.send_to(msg.as_bytes(), &local_addr).await.unwrap(); - assert_eq!(msg, wait.await.unwrap()); - - close.send(()).unwrap(); + endpoint + .send + .send_to(msg.as_bytes(), &local_addr) + .await + .unwrap(); + assert_eq!(msg, endpoint.packet_rx.await.unwrap()); } #[tokio::test] async fn run_with_filter() { + let mut t = TestHelper::default(); + let mut registry = FilterRegistry::default(); registry.insert(TestFilterFactory {}); - let socket = ephemeral_socket().await; - let endpoint_addr = socket.local_addr().unwrap(); - let (recv, mut send) = socket.split(); + let mut endpoint = t.open_socket_and_recv_single_packet().await; let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 12367); - let (done, wait) = oneshot::channel::(); - let config = Arc::new(Config { + let config = Config { local: Local { port: local_addr.port(), }, @@ -445,29 +429,23 @@ mod tests { config: None, }], connections: ConnectionConfig::Client { - addresses: vec![endpoint_addr], + addresses: vec![endpoint.addr], connection_id: "".into(), lb_policy: None, }, - }); - - let (close, stop) = oneshot::channel::<()>(); - let server = Builder::from(config) - .with_filter_registry(registry) - .validate() - .unwrap() - .build(); - tokio::spawn(async move { - server.run(stop).await.unwrap(); - }); + }; + t.run_server_with_filter_registry(config, registry); let msg = "hello"; - recv_udp_done(recv, done); - send.send_to(msg.as_bytes(), &local_addr).await.unwrap(); + endpoint + .send + .send_to(msg.as_bytes(), &local_addr) + .await + .unwrap(); // since we don't know what the session ephemeral port is, we'll just // search for the filter strings. - let result = wait.await.unwrap(); + let result = endpoint.packet_rx.await.unwrap(); assert!( result.contains(msg), format!("'{}' not found in '{}'", msg, result) @@ -476,8 +454,6 @@ mod tests { result.contains(":odr:"), format!(":odr: not found in '{}'", result) ); - - close.send(()).unwrap(); } #[tokio::test] @@ -508,36 +484,36 @@ mod tests { session_len: usize, } - async fn test( - name: String, - log: &Logger, - chain: Arc, - expected: Expected, - ) -> Result { - info!(log, "Test"; "name" => name); + async fn test(name: String, chain: Arc, expected: Expected) -> Result { + let t = TestHelper::default(); + + info!(t.log, "Test"; "name" => name); let msg = "hello".to_string(); - let (local_addr, wait) = recv_udp().await; + let endpoint = t.open_socket_and_recv_single_packet().await; let lb_policy = Arc::new(LoadBalancerPolicy::new(&ConnectionConfig::Client { - addresses: vec![local_addr], + addresses: vec![endpoint.addr], connection_id: "".into(), lb_policy: None, })); - let receive_socket = ephemeral_socket().await; - let receive_addr = receive_socket.local_addr().unwrap(); - let (mut recv, mut send) = receive_socket.split(); + + let SplitSocket { + addr: receive_addr, + mut recv, + mut send, + } = t.create_and_split_socket().await; + let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new())); let (send_packets, mut recv_packets) = mpsc::channel::(1); let sessions_clone = sessions.clone(); - let log_clone = log.clone(); let time_increment = 10; time::advance(Duration::from_secs(time_increment)).await; tokio::spawn(async move { Server::recv_from( - &log_clone, + &t.log, &Metrics::default(), lb_policy, chain, @@ -550,7 +526,7 @@ mod tests { send.send_to(msg.as_bytes(), &receive_addr).await.unwrap(); - let result = wait.await.unwrap(); + let result = endpoint.packet_rx.await.unwrap(); recv_packets.close(); let map = sessions.read().await; @@ -559,7 +535,7 @@ mod tests { // need to switch to 127.0.0.1, as the request comes locally let mut receive_addr_local = receive_addr.clone(); receive_addr_local.set_ip("127.0.0.1".parse().unwrap()); - let build_key = (receive_addr_local, local_addr); + let build_key = (receive_addr_local, endpoint.addr); assert!(map.contains_key(&build_key)); let session = map.get(&build_key).unwrap().lock().await; assert_eq!( @@ -577,22 +553,13 @@ mod tests { } } - let log = logger(); - let chain = Arc::new(FilterChain::new(vec![])); - let result = test( - "no filter".to_string(), - &log, - chain, - Expected { session_len: 1 }, - ) - .await; + let result = test("no filter".to_string(), chain, Expected { session_len: 1 }).await; assert_eq!("hello", result.msg); let chain = Arc::new(FilterChain::new(vec![Box::new(TestFilter {})])); let result = test( "test filter".to_string(), - &log, chain, Expected { session_len: 2 }, ) @@ -608,26 +575,28 @@ mod tests { #[tokio::test] async fn run_recv_from() { + let t = TestHelper::default(); + let msg = "hello"; - let (local_addr, wait) = recv_udp().await; + let endpoint = t.open_socket_and_recv_single_packet().await; let lb_policy = Arc::new(LoadBalancerPolicy::new(&ConnectionConfig::Client { - addresses: vec![local_addr], + addresses: vec![endpoint.addr], connection_id: "".into(), lb_policy: None, })); - let socket = ephemeral_socket().await; - let addr = socket.local_addr().unwrap(); - let (recv, mut send) = socket.split(); + let SplitSocket { + addr, + recv, + mut send, + } = t.create_and_split_socket().await; let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new())); let (send_packets, mut recv_packets) = mpsc::channel::(1); let config = Arc::new(Config { - local: Local { - port: local_addr.port(), - }, + local: Local { port: 0 }, filters: vec![], connections: ConnectionConfig::Client { - addresses: vec![local_addr], + addresses: vec![], connection_id: "".into(), lb_policy: None, }, @@ -643,17 +612,17 @@ mod tests { ); send.send_to(msg.as_bytes(), &addr).await.unwrap(); - assert_eq!(msg, wait.await.unwrap()); + assert_eq!(msg, endpoint.packet_rx.await.unwrap()); recv_packets.close(); } #[tokio::test] async fn ensure_session() { - let log = logger(); + let t = TestHelper::default(); let map: SessionMap = Arc::new(RwLock::new(HashMap::new())); let from: SocketAddr = "127.0.0.1:27890".parse().unwrap(); let dest: SocketAddr = "127.0.0.1:27891".parse().unwrap(); - let (sender, mut recv) = mpsc::channel::(1); + let (sender, _) = mpsc::channel::(1); let endpoint = EndPoint { name: "endpoint".to_string(), address: dest, @@ -665,7 +634,7 @@ mod tests { assert!(map.read().await.is_empty()); } Server::ensure_session( - &log, + &t.log, &Metrics::default(), Arc::new(FilterChain::new(vec![])), map.clone(), @@ -683,51 +652,42 @@ mod tests { let sess = rmap.get(&key).unwrap().lock().await; assert_eq!(key, sess.key()); assert_eq!(1, rmap.keys().len()); - - recv.close(); } #[tokio::test] async fn run_receive_packet() { + let t = TestHelper::default(); + let msg = "hello"; // without a filter - let socket = ephemeral_socket().await; - let local_addr = socket.local_addr().unwrap(); - - let (recv_socket, send_socket) = socket.split(); let (mut send_packet, recv_packet) = mpsc::channel::(5); - let (done, wait) = oneshot::channel::(); - - recv_udp_done(recv_socket, done); - - if let Err(err) = send_packet - .send(Packet::new(local_addr, msg.as_bytes().to_vec())) + let endpoint = t.open_socket_and_recv_single_packet().await; + if let Err(_) = send_packet + .send(Packet::new(endpoint.addr, msg.as_bytes().to_vec())) .await { - assert!(false, err) + unreachable!("failed to send packet over channel"); } let config = Arc::new(Config { - local: Local { - port: local_addr.port(), - }, + local: Local { port: 0 }, filters: vec![], connections: ConnectionConfig::Client { - addresses: vec![local_addr], + addresses: vec![], connection_id: "".into(), lb_policy: None, }, }); let server = Builder::from(config).validate().unwrap().build(); - server.run_receive_packet(send_socket, recv_packet); - assert_eq!(msg, wait.await.unwrap()); + server.run_receive_packet(endpoint.send, recv_packet); + assert_eq!(msg, endpoint.packet_rx.await.unwrap()); } #[tokio::test] async fn prune_sessions() { time::pause(); - let log = logger(); + let t = TestHelper::default(); let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new())); let from: SocketAddr = "127.0.0.1:7000".parse().unwrap(); let to: SocketAddr = "127.0.0.1:7001".parse().unwrap(); @@ -739,7 +699,7 @@ mod tests { }; Server::ensure_session( - &log, + &t.log.clone(), &Metrics::default(), Arc::new(FilterChain::new(vec![])), sessions.clone(), @@ -761,7 +721,7 @@ mod tests { // session map should be the same since, we haven't passed expiry time::advance(Duration::new(SESSION_TIMEOUT_SECONDS / 2, 0)).await; - Server::prune_sessions(&log, sessions.clone()).await; + Server::prune_sessions(&t.log, sessions.clone()).await; { let map = sessions.read().await; assert!(map.contains_key(&key)); @@ -769,7 +729,7 @@ mod tests { } time::advance(Duration::new(2 * SESSION_TIMEOUT_SECONDS, 0)).await; - Server::prune_sessions(&log, sessions.clone()).await; + Server::prune_sessions(&t.log, sessions.clone()).await; { let map = sessions.read().await; assert!( @@ -778,14 +738,14 @@ mod tests { ); assert_eq!(0, map.len(), "len should be 0, bit is {}", map.len()); } - info!(log, "test complete"); + info!(t.log, "test complete"); time::resume(); } #[tokio::test] async fn run_prune_sessions() { time::pause(); - let log = logger(); + let t = TestHelper::default(); let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new())); let from: SocketAddr = "127.0.0.1:7000".parse().unwrap(); let to: SocketAddr = "127.0.0.1:7001".parse().unwrap(); @@ -809,7 +769,7 @@ mod tests { let server = Builder::from(config).validate().unwrap().build(); server.run_prune_sessions(&sessions); Server::ensure_session( - &log, + &t.log, &Metrics::default(), Arc::new(FilterChain::new(vec![])), sessions.clone(), diff --git a/src/proxy/sessions/session.rs b/src/proxy/sessions/session.rs index 13c90959a7..5809716669 100644 --- a/src/proxy/sessions/session.rs +++ b/src/proxy/sessions/session.rs @@ -281,7 +281,7 @@ mod tests { use tokio::time; use tokio::time::delay_for; - use crate::test_utils::{ephemeral_socket, logger, recv_udp, TestFilter}; + use crate::test_utils::{SplitSocket, TestFilter, TestHelper}; use super::*; @@ -289,26 +289,24 @@ mod tests { async fn session_new() { time::pause(); - let log = logger(); - let mut socket = ephemeral_socket().await; - let local_addr = socket.local_addr().unwrap(); + let t = TestHelper::default(); + let SplitSocket { + addr, + mut recv, + mut send, + } = t.create_and_split_socket().await; let endpoint = EndPoint { name: "endpoint".to_string(), - address: local_addr, + address: addr, connection_ids: vec![], }; let (send_packet, mut recv_packet) = mpsc::channel::(5); let mut sess = Session::new( - &log, - Metrics::new( - &Registry::default(), - local_addr.to_string(), - local_addr.to_string(), - ) - .unwrap(), + &t.log, + Metrics::new(&Registry::default(), addr.to_string(), addr.to_string()).unwrap(), Arc::new(FilterChain::new(vec![])), - local_addr, + addr, endpoint, send_packet, ) @@ -328,9 +326,9 @@ mod tests { // echo the packet back again tokio::spawn(async move { let mut buf = vec![0; 1024]; - let (size, recv_addr) = socket.recv_from(&mut buf).await.unwrap(); + let (size, recv_addr) = recv.recv_from(&mut buf).await.unwrap(); assert_eq!("hello", from_utf8(&buf[..size]).unwrap()); - socket.send_to(&buf[..size], recv_addr).await.unwrap(); + send.send_to(&buf[..size], &recv_addr).await.unwrap(); }); sess.send_to("hello".as_bytes()).await.unwrap(); @@ -340,13 +338,13 @@ mod tests { .await .expect("Should receive a packet 'hello'"); assert_eq!(String::from("hello").into_bytes(), packet.contents); - assert_eq!(local_addr, packet.dest); + assert_eq!(addr, packet.dest); let current_expiration = sess.expiration.read().await.clone(); assert!(Instant::now() < current_expiration); let diff = current_expiration.duration_since(initial_expiration); - info!(log, "difference during test"; "duration" => format!("{:?}", diff)); + info!(t.log, "difference during test"; "duration" => format!("{:?}", diff)); assert!(diff.as_secs() >= time_increment); sess.close().unwrap(); @@ -355,66 +353,66 @@ mod tests { #[tokio::test] async fn session_send_to() { - let log = logger(); + let t = TestHelper::default(); let msg = "hello"; // without a filter let (sender, _) = mpsc::channel::(1); - let (local_addr, wait) = recv_udp().await; + let ep = t.open_socket_and_recv_single_packet().await; let endpoint = EndPoint { name: "endpoint".to_string(), - address: local_addr, + address: ep.addr, connection_ids: vec![], }; let mut session = Session::new( - &log, + &t.log, Metrics::new( &Registry::default(), - local_addr.to_string(), - local_addr.to_string(), + ep.addr.to_string(), + ep.addr.to_string(), ) .unwrap(), Arc::new(FilterChain::new(vec![])), - local_addr, + ep.addr, endpoint.clone(), sender, ) .await .unwrap(); session.send_to(msg.as_bytes()).await.unwrap(); - assert_eq!(msg, wait.await.unwrap()); + assert_eq!(msg, ep.packet_rx.await.unwrap()); } #[tokio::test] async fn session_close() { - let log = logger(); - let socket = ephemeral_socket().await; - let local_addr = socket.local_addr().unwrap(); + let t = TestHelper::default(); + + let ep = t.open_socket_and_recv_single_packet().await; let (send_packet, _) = mpsc::channel::(5); let endpoint = EndPoint { name: "endpoint".to_string(), - address: local_addr, + address: ep.addr, connection_ids: vec![], }; - info!(log, ">> creating sessions"); + info!(t.log, ">> creating sessions"); let sess = Session::new( - &log, + &t.log, Metrics::new( &Registry::default(), - local_addr.to_string(), - local_addr.to_string(), + ep.addr.to_string(), + ep.addr.to_string(), ) .unwrap(), Arc::new(FilterChain::new(vec![])), - local_addr, + ep.addr, endpoint, send_packet, ) .await .unwrap(); - info!(log, ">> session created and running"); + info!(t.log, ">> session created and running"); assert!(!sess.is_closed(), "session should not be closed"); sess.close().unwrap(); @@ -422,7 +420,7 @@ mod tests { // Poll the state to wait for the change, because everything is async for _ in 1..1000 { let is_closed = sess.is_closed(); - info!(log, "session closed?"; "closed" => is_closed); + info!(t.log, "session closed?"; "closed" => is_closed); if is_closed { break; } @@ -435,7 +433,8 @@ mod tests { #[tokio::test] async fn process_recv_packet() { - let log = logger(); + let t = TestHelper::default(); + let chain = Arc::new(FilterChain::new(vec![])); let endpoint = EndPoint { name: "endpoint".to_string(), @@ -453,7 +452,7 @@ mod tests { // first test with no filtering let msg = "hello"; Session::process_recv_packet( - &log, + &t.log, &Metrics::new( &Registry::default(), "127.0.1.1:80".parse().unwrap(), @@ -483,7 +482,7 @@ mod tests { // add filter let chain = Arc::new(FilterChain::new(vec![Box::new(TestFilter {})])); Session::process_recv_packet( - &log, + &t.log, &Metrics::new( &Registry::default(), "127.0.1.1:80".parse().unwrap(), @@ -516,26 +515,25 @@ mod tests { #[tokio::test] async fn session_new_metrics() { - let log = logger(); - let socket = ephemeral_socket().await; - let local_addr = socket.local_addr().unwrap(); + let t = TestHelper::default(); + let ep = t.open_socket_and_recv_single_packet().await; let endpoint = EndPoint { name: "endpoint".to_string(), - address: local_addr, + address: ep.addr, connection_ids: vec![], }; let (send_packet, _) = mpsc::channel::(5); let session = Session::new( - &log, + &t.log, Metrics::new( &Registry::default(), - local_addr.to_string(), - local_addr.to_string(), + ep.addr.to_string(), + ep.addr.to_string(), ) .unwrap(), Arc::new(FilterChain::new(vec![])), - local_addr, + ep.addr, endpoint, send_packet, ) @@ -549,22 +547,24 @@ mod tests { #[tokio::test] async fn send_to_metrics() { + let t = TestHelper::default(); + let (sender, _) = mpsc::channel::(1); - let (local_addr, wait) = recv_udp().await; + let endpoint = t.open_socket_and_recv_single_packet().await; let mut session = Session::new( - &logger(), + &t.log, Metrics::new( &Registry::default(), - local_addr.to_string(), - local_addr.to_string(), + endpoint.addr.to_string(), + endpoint.addr.to_string(), ) .unwrap(), Arc::new(FilterChain::new(vec![])), - local_addr, + endpoint.addr, EndPoint { name: "endpoint".to_string(), - address: local_addr, + address: endpoint.addr, connection_ids: vec![], }, sender, @@ -572,7 +572,7 @@ mod tests { .await .unwrap(); session.send_to(b"hello").await.unwrap(); - wait.await.unwrap(); + endpoint.packet_rx.await.unwrap(); assert_eq!(session.metrics.tx_bytes_total.get(), 5); assert_eq!(session.metrics.tx_packets_total.get(), 1); @@ -581,27 +581,25 @@ mod tests { #[tokio::test] async fn session_drop_metrics() { - let log = logger(); - let socket = ephemeral_socket().await; - let local_addr = socket.local_addr().unwrap(); - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: local_addr, - connection_ids: vec![], - }; + let t = TestHelper::default(); let (send_packet, _) = mpsc::channel::(5); + let endpoint = t.open_socket_and_recv_single_packet().await; let session = Session::new( - &log, + &t.log, Metrics::new( &Registry::default(), - local_addr.to_string(), - local_addr.to_string(), + endpoint.addr.to_string(), + endpoint.addr.to_string(), ) .unwrap(), Arc::new(FilterChain::new(vec![])), - local_addr, - endpoint, + endpoint.addr, + EndPoint { + name: "endpoint".to_string(), + address: endpoint.addr, + connection_ids: vec![], + }, send_packet, ) .await diff --git a/src/test_utils.rs b/src/test_utils.rs index 31e160d0b5..52f299031f 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -23,12 +23,12 @@ use slog::{o, warn, Drain, Logger}; use slog_term::{FullFormat, PlainSyncDecorator}; use tokio::net::udp::{RecvHalf, SendHalf}; use tokio::net::UdpSocket; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, watch}; use crate::config::{Config, EndPoint}; use crate::extensions::{ - CreateFilterArgs, DownstreamContext, DownstreamResponse, Error, Filter, FilterFactory, - FilterRegistry, UpstreamContext, UpstreamResponse, + default_registry, CreateFilterArgs, DownstreamContext, DownstreamResponse, Error, Filter, + FilterFactory, FilterRegistry, UpstreamContext, UpstreamResponse, }; use crate::proxy::{Builder, Metrics}; @@ -83,97 +83,217 @@ pub fn logger() -> Logger { Logger::root(drain, o!()) } -/// recv_udp waits for a UDP packet to be received on SocketAddr, and sends -/// that value to the oneshot channel so it can be tested. -pub async fn recv_udp() -> (SocketAddr, oneshot::Receiver) { - let socket = ephemeral_socket().await; - let local_addr = socket.local_addr().unwrap(); - let (recv, _) = socket.split(); - let (done, wait) = oneshot::channel::(); - recv_udp_done(recv, done); - (local_addr, wait) +pub struct TestHelper { + pub log: Logger, + /// Channel to subscribe to, and trigger the shutdown of created resources. + shutdown_ch: Option<(watch::Sender<()>, watch::Receiver<()>)>, + server_shutdown_tx: Vec>>, } -/// ephemeral_socket provides a socket bound to an ephemeral port -pub async fn ephemeral_socket() -> UdpSocket { - let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0); - UdpSocket::bind(addr).await.unwrap() +/// Returned from [creating a socket](TestHelper::open_socket_and_recv_single_packet) +pub struct OpenSocketRecvPacket { + /// The local address that the opened socket is bound to. + pub addr: SocketAddr, + /// The sender side, after splitting the opened socket. + pub send: SendHalf, + /// A channel on which the received packet will be forwarded. + pub packet_rx: oneshot::Receiver, } -/// recv_udp_done will send the String value of the receiving UDP packet to the passed in oneshot channel. -pub fn recv_udp_done(mut recv: RecvHalf, done: oneshot::Sender) { - tokio::spawn(async move { - let mut buf = vec![0; 1024]; - let size = recv.recv(&mut buf).await.unwrap(); - done.send(from_utf8(&buf[..size]).unwrap().to_string()) - .unwrap(); - }); +/// Returned from [creating a socket](TestHelper::create_and_split_socket) +pub struct SplitSocket { + /// The local address that the opened socket is bound to. + pub addr: SocketAddr, + /// The receiver side, after splitting the opened socket. + pub recv: RecvHalf, + /// The sender side, after splitting the opened socket. + pub send: SendHalf, } -// recv_multiple_packets enables you to send multiple packets through SendHalf -// and will return any received packets back to the Receiver. -pub async fn recv_multiple_packets(logger: &Logger) -> (mpsc::Receiver, SendHalf) { - let (mut send_chan, recv_chan) = mpsc::channel::(10); - let (mut recv, send) = ephemeral_socket().await.split(); - // a channel, so we can wait for packets coming back. - let logger = logger.clone(); - tokio::spawn(async move { - let mut buf = vec![0; 1024]; - loop { - let (size, _) = recv.recv_from(&mut buf).await.unwrap(); - let str = from_utf8(&buf[..size]).unwrap().to_string(); - match send_chan.send(str).await { - Ok(_) => {} - Err(err) => { - warn!(logger, "recv_multiple_packets: recv_chan dropped"; "error" => %err); - break; - } - }; +impl Drop for TestHelper { + fn drop(&mut self) { + let log = self.log.clone(); + for shutdown_tx in self + .server_shutdown_tx + .iter_mut() + .map(|tx| tx.take()) + .flatten() + { + shutdown_tx + .send(()) + .map_err(|err| { + warn!( + log, + "failed to send server shutdown over channel: {:?}", err + ) + }) + .ok(); } - }); - (recv_chan, send) -} -// echo_server runs a udp echo server, and returns the ephemeral addr -// that it is running on. -pub async fn echo_server() -> SocketAddr { - let mut socket = ephemeral_socket().await; - let addr = socket.local_addr().unwrap(); - tokio::spawn(async move { - loop { - let mut buf = vec![0; 1024]; - let (size, sender) = socket.recv_from(&mut buf).await.unwrap(); - socket.send_to(&buf[..size], sender).await.unwrap(); + if let Some((shutdown_tx, _)) = self.shutdown_ch.take() { + shutdown_tx.broadcast(()).unwrap(); } - }); - addr + } } -// run_proxy creates a instance of the Server proxy and runs it, returning a cancel function -pub fn run_proxy(registry: FilterRegistry, config: Config) -> Box { - run_proxy_with_metrics(registry, config, Metrics::default()) +impl Default for TestHelper { + fn default() -> Self { + TestHelper { + log: logger(), + shutdown_ch: None, + server_shutdown_tx: vec![], + } + } } -// run_proxy_with_metrics creates a instance of the Server proxy and -// runs it, returning a cancel function -pub fn run_proxy_with_metrics( - registry: FilterRegistry, - config: Config, - metrics: Metrics, -) -> Box { - let (close, stop) = oneshot::channel::<()>(); - let proxy = Builder::from(Arc::new(config)) - .with_filter_registry(registry) - .with_metrics(metrics) - .validate() - .unwrap() - .build(); - // run the proxy - tokio::spawn(async move { - proxy.run(stop).await.unwrap(); - }); - - Box::new(|| close.send(()).unwrap()) +impl TestHelper { + /// Creates a [`Server`] and runs it. The server is shutdown once `self` + /// goes out of scope. + pub fn run_server(&mut self, config: Config) { + self.run_server_with_filter_registry(config, default_registry(&self.log)) + } + + pub fn run_server_with_filter_registry( + &mut self, + config: Config, + filter_registry: FilterRegistry, + ) { + self.run_server_with_arguments(config, filter_registry, Metrics::default()) + } + + pub fn run_server_with_metrics(&mut self, config: Config, metrics: Metrics) { + self.run_server_with_arguments(config, default_registry(&self.log), metrics) + } + + /// Opens a new socket bound to an ephemeral port + pub async fn create_socket(&self) -> UdpSocket { + let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0); + UdpSocket::bind(addr).await.unwrap() + } + + /// Helper function to opens a new socket and split it immediately. + pub async fn create_and_split_socket(&self) -> SplitSocket { + let socket = self.create_socket().await; + let addr = socket.local_addr().unwrap(); + let (recv, send) = socket.split(); + SplitSocket { addr, recv, send } + } + + /// Opens a socket, listening for a packet. Once a packet is received, it + /// is forwarded over the returned channel. + pub async fn open_socket_and_recv_single_packet(&self) -> OpenSocketRecvPacket { + let socket = self.create_socket().await; + let addr = socket.local_addr().unwrap(); + let (mut recv, send) = socket.split(); + let (packet_tx, packet_rx) = oneshot::channel::(); + tokio::spawn(async move { + let mut buf = vec![0; 1024]; + let size = recv.recv(&mut buf).await.unwrap(); + packet_tx + .send(from_utf8(&buf[..size]).unwrap().to_string()) + .unwrap(); + }); + OpenSocketRecvPacket { + addr, + send, + packet_rx, + } + } + + /// Opens a socket, listening for packets. Received packets are forwarded over the + /// returned channel. + pub async fn open_socket_and_recv_multiple_packets( + &mut self, + ) -> (mpsc::Receiver, SendHalf) { + let (mut packet_tx, packet_rx) = mpsc::channel::(10); + let (mut socket_recv, socket_send) = self.create_socket().await.split(); + let log = self.log.clone(); + let mut shutdown_rx = self.get_shutdown_subscriber().await; + tokio::spawn(async move { + let mut buf = vec![0; 1024]; + loop { + tokio::select! { + received = socket_recv.recv_from(&mut buf) => { + let (size, _) = received.unwrap(); + let str = from_utf8(&buf[..size]).unwrap().to_string(); + match packet_tx.send(str).await { + Ok(_) => {} + Err(err) => { + warn!(log, "recv_multiple_packets: recv_chan dropped"; "error" => %err); + return; + } + }; + }, + _ = shutdown_rx.recv() => { + return; + } + } + } + }); + (packet_rx, socket_send) + } + + /// Runs a simple UDP server that echos back payloads. + /// Returns the server's address. + pub async fn run_echo_server(&mut self) -> SocketAddr { + let mut socket = self.create_socket().await; + let addr = socket.local_addr().unwrap(); + let mut shutdown = self.get_shutdown_subscriber().await; + tokio::spawn(async move { + loop { + let mut buf = vec![0; 1024]; + tokio::select! { + recvd = socket.recv_from(&mut buf) => { + let (size, sender) = recvd.unwrap(); + socket.send_to(&buf[..size], sender).await.unwrap(); + }, + _ = shutdown.recv() => { + return; + } + } + } + }); + addr + } + + /// Create and run a server. + fn run_server_with_arguments( + &mut self, + config: Config, + filter_registry: FilterRegistry, + metrics: Metrics, + ) { + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + self.server_shutdown_tx.push(Some(shutdown_tx)); + tokio::spawn(async move { + Builder::from(Arc::new(config)) + .with_filter_registry(filter_registry) + .with_metrics(metrics) + .validate() + .unwrap() + .build() + .run(shutdown_rx) + .await + .unwrap(); + }); + } + + /// Returns a receiver subscribed to the helper's shutdown event. + async fn get_shutdown_subscriber(&mut self) -> watch::Receiver<()> { + // If this is the first call, then we set up the channel first. + match self.shutdown_ch { + Some((_, ref rx)) => rx.clone(), + None => { + let mut ch = watch::channel(()); + // Remove the init value from the channel so that we can later + // shutdown as soon as we receive any value from the channel. + let _ = ch.1.recv().await; + let recv = ch.1.clone(); + self.shutdown_ch = Some(ch); + recv + } + } + } } /// assert that on_downstream_receive makes no changes @@ -227,16 +347,19 @@ where #[cfg(test)] mod tests { - use super::*; + use crate::test_utils::TestHelper; #[tokio::test] async fn test_echo_server() { - let echo_addr = echo_server().await; - let (recv, mut send) = ephemeral_socket().await.split(); - let (done, wait) = oneshot::channel::(); + let mut t = TestHelper::default(); + let echo_addr = t.run_echo_server().await; + let mut endpoint = t.open_socket_and_recv_single_packet().await; let msg = "hello"; - recv_udp_done(recv, done); - send.send_to(msg.as_bytes(), &echo_addr).await.unwrap(); - assert_eq!(msg, wait.await.unwrap()); + endpoint + .send + .send_to(msg.as_bytes(), &echo_addr) + .await + .unwrap(); + assert_eq!(msg, endpoint.packet_rx.await.unwrap()); } } diff --git a/tests/filters.rs b/tests/filters.rs index b19462d60e..a8193b574e 100644 --- a/tests/filters.rs +++ b/tests/filters.rs @@ -26,16 +26,14 @@ mod tests { use quilkin::config::{Config, ConnectionConfig, EndPoint, Filter, Local}; use quilkin::extensions::filters::DebugFilterFactory; use quilkin::extensions::{default_registry, FilterFactory}; - use quilkin::test_utils::{ - echo_server, logger, recv_multiple_packets, run_proxy, TestFilterFactory, - }; + use quilkin::test_utils::{TestFilterFactory, TestHelper}; #[tokio::test] async fn test_filter() { - let base_logger = logger(); + let mut t = TestHelper::default(); - // create two echo servers as endpoints - let echo = echo_server().await; + // create an echo server as an endpoint. + let echo = t.run_echo_server().await; // create server configuration let server_port = 12346; @@ -55,9 +53,10 @@ mod tests { }; assert_eq!(Ok(()), server_config.validate()); - let mut registry = default_registry(&base_logger); + // Run server proxy. + let mut registry = default_registry(&t.log); registry.insert(TestFilterFactory {}); - let close_server = run_proxy(registry, server_config); + t.run_server_with_filter_registry(server_config, registry); // create a local client let client_port = 12347; @@ -78,16 +77,17 @@ mod tests { }; assert_eq!(Ok(()), client_config.validate()); - let mut registry = default_registry(&base_logger); + // Run client proxy. + let mut registry = default_registry(&t.log); registry.insert(TestFilterFactory {}); - let close_client = run_proxy(registry, client_config); + t.run_server_with_filter_registry(client_config, registry); // let's send the packet - let (mut recv_chan, mut send) = recv_multiple_packets(&base_logger).await; + let (mut recv_chan, mut send) = t.open_socket_and_recv_multiple_packets().await; // game_client let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), client_port); - info!(base_logger, "Sending hello"; "addr" => local_addr); + info!(t.log, "Sending hello"; "addr" => local_addr); send.send_to("hello".as_bytes(), &local_addr).await.unwrap(); let result = recv_chan.recv().await.unwrap(); @@ -106,19 +106,17 @@ mod tests { "Should be 2 on_upstream_receive calls in {}", result ); - - close_server(); - close_client(); } #[tokio::test] async fn debug_filter() { - let base_logger = logger(); + let mut t = TestHelper::default(); + // handy for grabbing the configuration name - let factory = DebugFilterFactory::new(&base_logger); + let factory = DebugFilterFactory::new(&t.log); - // create two echo servers as endpoints - let echo = echo_server().await; + // create an echo server as an endpoint. + let echo = t.run_echo_server().await; // filter config let mut map = Mapping::new(); @@ -139,7 +137,7 @@ mod tests { }], }, }; - let close_server = run_proxy(default_registry(&base_logger), server_config); + t.run_server(server_config); let mut map = Mapping::new(); map.insert(Value::from("id"), Value::from("client")); @@ -160,20 +158,17 @@ mod tests { lb_policy: None, }, }; - let close_client = run_proxy(default_registry(&base_logger), client_config); + t.run_server(client_config); // let's send the packet - let (mut recv_chan, mut send) = recv_multiple_packets(&base_logger).await; + let (mut recv_chan, mut send) = t.open_socket_and_recv_multiple_packets().await; // game client let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), client_port); - info!(base_logger, "Sending hello"; "addr" => local_addr); + info!(t.log, "Sending hello"; "addr" => local_addr); send.send_to("hello".as_bytes(), &local_addr).await.unwrap(); // since the debug filter doesn't change the data, it should be exactly the same assert_eq!("hello", recv_chan.recv().await.unwrap()); - - close_server(); - close_client(); } } diff --git a/tests/local_rate_limit.rs b/tests/local_rate_limit.rs index 300b80f4f5..2252f2ae8b 100644 --- a/tests/local_rate_limit.rs +++ b/tests/local_rate_limit.rs @@ -22,18 +22,18 @@ mod tests { use quilkin::config::{Config, ConnectionConfig, EndPoint, Filter, Local}; use quilkin::extensions::filters::RateLimitFilterFactory; - use quilkin::extensions::{default_registry, FilterFactory}; - use quilkin::test_utils::{echo_server, logger, recv_multiple_packets, run_proxy}; + use quilkin::extensions::FilterFactory; + use quilkin::test_utils::TestHelper; #[tokio::test] async fn local_rate_limit_filter() { - let base_logger = logger(); + let mut t = TestHelper::default(); let yaml = " max_packets: 2 period: 1s "; - let echo = echo_server().await; + let echo = t.run_echo_server().await; let server_port = 12346; let server_config = Config { @@ -50,10 +50,9 @@ period: 1s }], }, }; + t.run_server(server_config); - let close_server = run_proxy(default_registry(&base_logger), server_config); - - let (mut recv_chan, mut send) = recv_multiple_packets(&base_logger).await; + let (mut recv_chan, mut send) = t.open_socket_and_recv_multiple_packets().await; let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port); @@ -71,7 +70,5 @@ period: 1s tokio::time::delay_for(std::time::Duration::from_millis(100)).await; // Check that we do not get any response. assert!(recv_chan.try_recv().is_err()); - - close_server(); } } diff --git a/tests/metrics.rs b/tests/metrics.rs index 7b9ed6267e..8c70f18ce1 100644 --- a/tests/metrics.rs +++ b/tests/metrics.rs @@ -25,19 +25,16 @@ mod tests { use slog::info; use quilkin::config::{Config, ConnectionConfig, EndPoint, Local}; - use quilkin::extensions::FilterRegistry; use quilkin::proxy::Metrics; - use quilkin::test_utils::{ - echo_server, logger, recv_multiple_packets, run_proxy, run_proxy_with_metrics, - }; + use quilkin::test_utils::TestHelper; #[tokio::test] async fn metrics_server() { - let base_logger = logger(); + let mut t = TestHelper::default(); let server_metrics = Metrics::new(Some("[::]:9092".parse().unwrap()), Registry::default()); - // create two echo servers as endpoints - let echo = echo_server().await; + // create an echo server as an endpoint. + let echo = t.run_echo_server().await; // create server configuration let server_port = 12346; @@ -53,8 +50,7 @@ mod tests { }, }; - let close_server = - run_proxy_with_metrics(FilterRegistry::default(), server_config, server_metrics); + t.run_server_with_metrics(server_config, server_metrics); // create a local client let client_port = 12347; @@ -70,14 +66,14 @@ mod tests { lb_policy: None, }, }; - let close_client = run_proxy(FilterRegistry::default(), client_config); + t.run_server(client_config); // let's send the packet - let (mut recv_chan, mut send) = recv_multiple_packets(&base_logger).await; + let (mut recv_chan, mut send) = t.open_socket_and_recv_multiple_packets().await; // game_client let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), client_port); - info!(base_logger, "Sending hello"; "addr" => local_addr); + info!(t.log, "Sending hello"; "addr" => local_addr); send.send_to("hello".as_bytes(), &local_addr).await.unwrap(); let _ = recv_chan.recv().await.unwrap(); @@ -99,8 +95,5 @@ mod tests { let upstream = (&c[2]).parse::().unwrap(); assert_ne!(downstream, upstream); } - - close_server(); - close_client(); } } diff --git a/tests/no_filter.rs b/tests/no_filter.rs index 31f1f331e3..f771946a66 100644 --- a/tests/no_filter.rs +++ b/tests/no_filter.rs @@ -24,16 +24,15 @@ mod tests { use tokio::time::{delay_for, Duration}; use quilkin::config::{Config, ConnectionConfig, EndPoint, Local}; - use quilkin::extensions::default_registry; - use quilkin::test_utils::{echo_server, logger, recv_multiple_packets, run_proxy}; + use quilkin::test_utils::TestHelper; #[tokio::test] async fn echo() { - let base_logger = logger(); + let mut t = TestHelper::default(); // create two echo servers as endpoints - let server1 = echo_server().await; - let server2 = echo_server().await; + let server1 = t.run_echo_server().await; + let server2 = t.run_echo_server().await; // create server configuration let server_port = 12345; @@ -57,7 +56,7 @@ mod tests { }; assert_eq!(Ok(()), server_config.validate()); - let close_server = run_proxy(default_registry(&base_logger), server_config); + t.run_server(server_config); // create a local client let client_port = 12344; @@ -75,10 +74,10 @@ mod tests { }; assert_eq!(Ok(()), client_config.validate()); - let close_client = run_proxy(default_registry(&base_logger), client_config); + t.run_server(client_config); // let's send the packet - let (mut recv_chan, mut send) = recv_multiple_packets(&base_logger).await; + let (mut recv_chan, mut send) = t.open_socket_and_recv_multiple_packets().await; // game_client let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), client_port); @@ -90,11 +89,9 @@ mod tests { // should only be two returned items select! { res = recv_chan.recv() => { - assert!(false, format!("Should not receive a third packet: {}", res.unwrap())); + unreachable!("Should not receive a third packet: {}", res.unwrap()); } _ = delay_for(Duration::from_secs(2)) => {} }; - close_server(); - close_client(); } }