diff --git a/docs/src/filters/load_balancer.md b/docs/src/filters/load_balancer.md index 48e70d9398..f2ab383ccf 100644 --- a/docs/src/filters/load_balancer.md +++ b/docs/src/filters/load_balancer.md @@ -28,7 +28,7 @@ static: ``` The load balancing policy (the strategy to use to select what endpoint to send traffic to) is configurable. -In the example above, packets will be distributed by selecting endpoints in turn, in round robin fashion +In the example above, packets will be distributed by selecting endpoints in turn, in round robin fashion. ### Configuration Options @@ -41,6 +41,7 @@ properties: enum: - ROUND_ROBIN # Send packets by selecting endpoints in turn. - RANDOM # Send packets by randomly selecting endpoints. + - HASH # Send packets by hashing the source IP and port. default: ROUND_ROBIN ``` diff --git a/proto/quilkin/extensions/filters/load_balancer/v1alpha1/load_balancer.proto b/proto/quilkin/extensions/filters/load_balancer/v1alpha1/load_balancer.proto index 8a66178389..efe1a2f043 100644 --- a/proto/quilkin/extensions/filters/load_balancer/v1alpha1/load_balancer.proto +++ b/proto/quilkin/extensions/filters/load_balancer/v1alpha1/load_balancer.proto @@ -22,6 +22,7 @@ message LoadBalancer { enum Policy { RoundRobin = 0; Random = 1; + Hash = 2; } message PolicyValue { diff --git a/src/filters/load_balancer.rs b/src/filters/load_balancer.rs index e2bb67ede2..93c199a539 100644 --- a/src/filters/load_balancer.rs +++ b/src/filters/load_balancer.rs @@ -38,7 +38,8 @@ struct LoadBalancer { impl Filter for LoadBalancer { fn read(&self, mut ctx: ReadContext) -> Option { - self.endpoint_chooser.choose_endpoints(&mut ctx.endpoints); + self.endpoint_chooser + .choose_endpoints(&mut ctx.endpoints, ctx.from); Some(ctx.into()) } } @@ -88,6 +89,7 @@ mod tests { fn get_response_addresses( filter: &dyn Filter, input_addresses: &[SocketAddr], + source: SocketAddr, ) -> Vec { filter .read(ReadContext::new( @@ -99,7 +101,7 @@ mod tests { ) .unwrap() .into(), - "127.0.0.1:8080".parse().unwrap(), + source, vec![], )) .unwrap() @@ -129,7 +131,11 @@ policy: ROUND_ROBIN assert_eq!( expected_sequence, (0..addresses.len()) - .map(|_| get_response_addresses(filter.as_ref(), &addresses)) + .map(|_| get_response_addresses( + filter.as_ref(), + &addresses, + "127.0.0.1:8080".parse().unwrap() + )) .collect::>() ); } @@ -152,7 +158,13 @@ policy: RANDOM let mut result_sequences = vec![]; for _ in 0..10 { let sequence = (0..addresses.len()) - .map(|_| get_response_addresses(filter.as_ref(), &addresses)) + .map(|_| { + get_response_addresses( + filter.as_ref(), + &addresses, + "127.0.0.1:8080".parse().unwrap(), + ) + }) .collect::>(); result_sequences.push(sequence); } @@ -176,4 +188,112 @@ policy: RANDOM "the same sequence of addresses were chosen for random load balancer" ); } + + #[test] + fn hash_load_balancer_policy() { + let addresses = vec![ + "127.0.0.1:8080".parse().unwrap(), + "127.0.0.2:8080".parse().unwrap(), + "127.0.0.3:8080".parse().unwrap(), + ]; + let source_ips = vec!["127.1.1.1", "127.2.2.2", "127.3.3.3"]; + let source_ports = vec!["11111", "22222", "33333", "44444", "55555"]; + + let yaml = " +policy: HASH +"; + let filter = create_filter(yaml); + + // Run a few selection rounds through the addresses. + let mut result_sequences = vec![]; + for _ in 0..10 { + let sequence = (0..addresses.len()) + .map(|_| { + get_response_addresses( + filter.as_ref(), + &addresses, + "127.0.0.1:8080".parse().unwrap(), + ) + }) + .collect::>(); + result_sequences.push(sequence); + } + + // Verify that all packets went the same way + assert_eq!( + 1, + result_sequences + .into_iter() + .flatten() + .flatten() + .collect::>() + .len(), + ); + + // Run a few selection rounds through the address + // this time vary the port for a single IP + let mut result_sequences = vec![]; + for port in &source_ports { + let sequence = (0..addresses.len()) + .map(|_| { + get_response_addresses( + filter.as_ref(), + &addresses, + format!("127.0.0.1:{}", port).parse().unwrap(), + ) + }) + .collect::>(); + result_sequences.push(sequence); + } + + // Verify that more than 1 path was picked + assert_ne!( + 1, + result_sequences + .into_iter() + .flatten() + .flatten() + .collect::>() + .len(), + ); + + // Run a few selection rounds through the addresses + // This time vary the source IP and port + let mut result_sequences = vec![]; + for ip in source_ips { + for port in &source_ports { + let sequence = (0..addresses.len()) + .map(|_| { + get_response_addresses( + filter.as_ref(), + &addresses, + format!("{}:{}", ip, port).parse().unwrap(), + ) + }) + .collect::>(); + result_sequences.push(sequence); + } + } + + // Check that every address was chosen at least once. + assert_eq!( + addresses.into_iter().collect::>(), + result_sequences + .clone() + .into_iter() + .flatten() + .flatten() + .collect::>(), + ); + + // Check that there is at least one different sequence of addresses. + assert!( + &result_sequences[1..] + .iter() + .any(|seq| seq != &result_sequences[0]), + "the same sequence of addresses were chosen for hash load balancer" + ); + + // + } } diff --git a/src/filters/load_balancer/config.rs b/src/filters/load_balancer/config.rs index 5ff4e12b08..970afad005 100644 --- a/src/filters/load_balancer/config.rs +++ b/src/filters/load_balancer/config.rs @@ -21,7 +21,9 @@ use std::convert::TryFrom; use serde::{Deserialize, Serialize}; use self::quilkin::extensions::filters::load_balancer::v1alpha1::load_balancer::Policy as ProtoPolicy; -use super::endpoint_chooser::{EndpointChooser, RandomEndpointChooser, RoundRobinEndpointChooser}; +use super::endpoint_chooser::{ + EndpointChooser, HashEndpointChooser, RandomEndpointChooser, RoundRobinEndpointChooser, +}; use crate::{filters::ConvertProtoConfigError, map_proto_enum}; pub use self::quilkin::extensions::filters::load_balancer::v1alpha1::LoadBalancer as ProtoConfig; @@ -46,7 +48,7 @@ impl TryFrom for Config { field = "policy", proto_enum_type = ProtoPolicy, target_enum_type = Policy, - variants = [RoundRobin, Random] + variants = [RoundRobin, Random, Hash] ) }) .transpose()? @@ -65,6 +67,9 @@ pub enum Policy { /// Send packets to endpoints chosen at random. #[serde(rename = "RANDOM")] Random, + /// Send packets to endpoints based on hash of source IP and port. + #[serde(rename = "HASH")] + Hash, } impl Policy { @@ -72,6 +77,7 @@ impl Policy { match self { Policy::RoundRobin => Box::new(RoundRobinEndpointChooser::new()), Policy::Random => Box::new(RandomEndpointChooser), + Policy::Hash => Box::new(HashEndpointChooser), } } } @@ -118,6 +124,17 @@ mod tests { policy: Policy::RoundRobin, }), ), + ( + "HashPolicy", + ProtoConfig { + policy: Some(PolicyValue { + value: ProtoPolicy::Hash as i32, + }), + }, + Some(Config { + policy: Policy::Hash, + }), + ), ( "should fail when invalid policy is provided", ProtoConfig { diff --git a/src/filters/load_balancer/endpoint_chooser.rs b/src/filters/load_balancer/endpoint_chooser.rs index 177e421b60..5f3e2643f6 100644 --- a/src/filters/load_balancer/endpoint_chooser.rs +++ b/src/filters/load_balancer/endpoint_chooser.rs @@ -18,12 +18,16 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use rand::{thread_rng, Rng}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::net::SocketAddr; + use crate::endpoint::UpstreamEndpoints; /// EndpointChooser chooses from a set of endpoints that a proxy is connected to. pub trait EndpointChooser: Send + Sync { /// choose_endpoints asks for the next endpoint(s) to use. - fn choose_endpoints(&self, endpoints: &mut UpstreamEndpoints); + fn choose_endpoints(&self, endpoints: &mut UpstreamEndpoints, from: SocketAddr); } /// RoundRobinEndpointChooser chooses endpoints in round-robin order. @@ -40,7 +44,7 @@ impl RoundRobinEndpointChooser { } impl EndpointChooser for RoundRobinEndpointChooser { - fn choose_endpoints(&self, endpoints: &mut UpstreamEndpoints) { + fn choose_endpoints(&self, endpoints: &mut UpstreamEndpoints, _from: SocketAddr) { let count = self.next_endpoint.fetch_add(1, Ordering::Relaxed); // Note: Unwrap is safe here because the index is guaranteed to be in range. let num_endpoints = endpoints.size(); @@ -53,10 +57,23 @@ impl EndpointChooser for RoundRobinEndpointChooser { pub struct RandomEndpointChooser; impl EndpointChooser for RandomEndpointChooser { - fn choose_endpoints(&self, endpoints: &mut UpstreamEndpoints) { + fn choose_endpoints(&self, endpoints: &mut UpstreamEndpoints, _from: SocketAddr) { // Note: Unwrap is safe here because the index is guaranteed to be in range. let idx = (&mut thread_rng()).gen_range(0..endpoints.size()); endpoints.keep(idx) .expect("BUG: unwrap should have been safe because index into endpoints list should be in range"); } } + +/// HashEndpointChooser chooses endpoints based on a hash of source IP and port. +pub struct HashEndpointChooser; + +impl EndpointChooser for HashEndpointChooser { + fn choose_endpoints(&self, endpoints: &mut UpstreamEndpoints, from: SocketAddr) { + let num_endpoints = endpoints.size(); + let mut hasher = DefaultHasher::new(); + from.hash(&mut hasher); + endpoints.keep(hasher.finish() as usize % num_endpoints) + .expect("BUG: unwrap should have been safe because index into endpoints list should be in range"); + } +}