diff --git a/src/cluster/cluster_manager.rs b/src/cluster/cluster_manager.rs index eb3a2f0de8..a6f2af0901 100644 --- a/src/cluster/cluster_manager.rs +++ b/src/cluster/cluster_manager.rs @@ -27,6 +27,7 @@ use std::net::SocketAddr; use std::{fmt, sync::Arc}; use tokio::sync::{mpsc, oneshot, watch}; +use crate::cluster::Endpoint; use crate::config::{EmptyListError, EndPoint, Endpoints, ManagementServer, UpstreamEndpoints}; use crate::xds::ads_client::{AdsClient, ClusterUpdate, ExecutionResult}; @@ -71,7 +72,8 @@ impl ClusterManager { } /// Returns a ClusterManager backed by the fixed set of clusters provided in the config. - pub fn fixed(endpoints: Vec) -> SharedClusterManager { + pub fn fixed(endpoints: Vec) -> SharedClusterManager { + // TODO: Return a result rather than unwrap. Arc::new(RwLock::new(Self::new(Some( Endpoints::new(endpoints) .expect("endpoints list in config should be validated non-empty"), @@ -140,7 +142,7 @@ impl ClusterManager { endpoints .endpoints .into_iter() - .map(|ep| EndPoint::new("N/A".into(), ep.address, vec![])) + .map(|ep| Endpoint::from_address(ep.address)) }) .flatten(); endpoints.extend(cluster_endpoints); diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 55bb39f197..46b4ee1034 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -14,8 +14,9 @@ * limitations under the License. */ +use crate::config::{parse_endpoint_metadata_from_yaml, EndPoint}; use serde_json::value::Value; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; #[cfg(not(doctest))] @@ -31,6 +32,7 @@ pub(crate) mod cluster_manager { #[derive(Clone, Debug, Eq, PartialEq)] pub struct Endpoint { pub address: SocketAddr, + pub tokens: HashSet>, pub metadata: Option, } @@ -52,3 +54,29 @@ pub struct Cluster { } pub type ClusterLocalities = HashMap, LocalityEndpoints>; + +impl Endpoint { + pub fn new(address: SocketAddr, tokens: HashSet>, metadata: Option) -> Endpoint { + Endpoint { + address, + tokens, + metadata, + } + } + + pub fn from_address(address: SocketAddr) -> Endpoint { + Endpoint::new(address, Default::default(), None) + } + + /// Converts an endpoint config into an internal endpoint representation. + pub fn from_config(config: &EndPoint) -> Result { + let (metadata, tokens) = if let Some(metadata) = config.metadata.clone() { + let (metadata, tokens) = parse_endpoint_metadata_from_yaml(metadata)?; + (Some(metadata), tokens) + } else { + (None, Default::default()) + }; + + Ok(Endpoint::new(config.address, tokens, metadata)) + } +} diff --git a/src/config/endpoints.rs b/src/config/endpoints.rs index f08e0ce7ba..77a75bb329 100644 --- a/src/config/endpoints.rs +++ b/src/config/endpoints.rs @@ -14,7 +14,11 @@ * limitations under the License. */ -use crate::config::EndPoint; +pub const ENDPOINT_METADATA_KEY_PREFIX: &str = "quilkin.dev"; +pub const ENDPOINT_METADATA_TOKEN_KEY: &str = "endpoint.tokens"; + +// TODO Move endpoint.rs out of config/ into cluster/ +use crate::cluster::Endpoint; use std::sync::Arc; #[derive(Debug)] @@ -28,7 +32,7 @@ pub struct IndexOutOfRangeError; /// Endpoints represents the set of all known upstream endpoints. #[derive(Clone, Debug, PartialEq)] -pub struct Endpoints(Arc>); +pub struct Endpoints(Arc>); /// UpstreamEndpoints represents a set of endpoints. /// This set is guaranteed to be non-empty - any operation that would @@ -45,7 +49,7 @@ pub struct UpstreamEndpoints { impl Endpoints { /// Returns an [`Endpoints`] backed by the provided list of endpoints. - pub fn new(endpoints: Vec) -> Result { + pub fn new(endpoints: Vec) -> Result { if endpoints.is_empty() { Err(EmptyListError) } else { @@ -98,7 +102,7 @@ impl UpstreamEndpoints { /// Returns an error if the predicate returns `false` for all endpoints. pub fn retain(&mut self, predicate: F) -> Result<(), AllEndpointsRemovedError> where - F: Fn(&EndPoint) -> bool, + F: Fn(&Endpoint) -> bool, { match self.subset.as_mut() { Some(subset) => { @@ -151,7 +155,7 @@ pub struct UpstreamEndpointsIter<'a> { } impl<'a> Iterator for UpstreamEndpointsIter<'a> { - type Item = &'a EndPoint; + type Item = &'a Endpoint; fn next(&mut self) -> Option { match &self.collection.subset { @@ -172,14 +176,11 @@ impl<'a> Iterator for UpstreamEndpointsIter<'a> { #[cfg(test)] mod tests { use super::Endpoints; - use crate::config::{EndPoint, UpstreamEndpoints}; - - fn ep(id: usize) -> EndPoint { - EndPoint::new( - format!("ep-{}", id), - format!("127.0.0.{}:8080", id).parse().unwrap(), - vec![], - ) + use crate::cluster::Endpoint; + use crate::config::UpstreamEndpoints; + + fn ep(id: usize) -> Endpoint { + Endpoint::from_address(format!("127.0.0.{}:8080", id).parse().unwrap()) } #[test] @@ -215,23 +216,25 @@ mod tests { let mut up: UpstreamEndpoints = Endpoints::new(initial_endpoints.clone()).unwrap().into(); - up.retain(|ep| ep.name != "ep-2").unwrap(); + up.retain(|ep| ep.address.to_string().as_str() != "127.0.0.2:8080") + .unwrap(); assert_eq!(up.size(), 3); assert_eq!( vec![ep(1), ep(3), ep(4)], up.iter().cloned().collect::>() ); - up.retain(|ep| ep.name != "ep-3").unwrap(); + up.retain(|ep| ep.address.to_string().as_str() != "127.0.0.3:8080") + .unwrap(); assert_eq!(up.size(), 2); assert_eq!(vec![ep(1), ep(4)], up.iter().cloned().collect::>()); // test an empty result on retain - let result = up.retain(|ep| ep.name == "never"); + let result = up.retain(|_| false); assert!(result.is_err()); let mut up: UpstreamEndpoints = Endpoints::new(initial_endpoints).unwrap().into(); - let result = up.retain(|ep| ep.name == "never"); + let result = up.retain(|_| false); assert!(result.is_err()); } diff --git a/src/config/metadata.rs b/src/config/metadata.rs new file mode 100644 index 0000000000..964f6bd06e --- /dev/null +++ b/src/config/metadata.rs @@ -0,0 +1,202 @@ +use crate::config::{ENDPOINT_METADATA_KEY_PREFIX, ENDPOINT_METADATA_TOKEN_KEY}; +use serde_json::map::Map as JsonMap; +use serde_json::value::Value as JSONValue; +use serde_json::Number as JSONNumber; +use serde_yaml::Value as YamlValue; +use std::collections::HashSet; + +// Returns an empty map if no tokens exist. +pub fn extract_endpoint_tokens( + metadata: &mut JsonMap, +) -> Result>, String> { + let tokens = metadata.remove(ENDPOINT_METADATA_KEY_PREFIX) + .map(|raw_value| { + match raw_value { + JSONValue::Object(mut object) => { + match object.remove(ENDPOINT_METADATA_TOKEN_KEY) { + Some(JSONValue::Array(raw_tokens)) => { + raw_tokens.into_iter().fold(Ok(HashSet::new()), |acc, val| { + let mut tokens = acc?; + + let token = match val { + JSONValue::String(token) => + base64::decode(token) + .map_err(|err| format!( + "key {}.{}: failed to decode token as a base64 string:{}", + ENDPOINT_METADATA_KEY_PREFIX, + ENDPOINT_METADATA_TOKEN_KEY, + err + )), + _ => Err(format!( + "invalid value in token list for key `{}`: value must a base64 string", + ENDPOINT_METADATA_TOKEN_KEY + )) + }; + + tokens.insert(token?); + Ok(tokens) + }) + }, + Some(_) => Err(format!( + "invalid data type for key `{}.{}`: value must be a list of base64 strings", + ENDPOINT_METADATA_KEY_PREFIX, + ENDPOINT_METADATA_TOKEN_KEY + )), + None => Ok(Default::default()), + } + } + _ => Err(format!("invalid data type for key `{}`: value must be an object", ENDPOINT_METADATA_KEY_PREFIX)) + } + }) + .transpose()?; + + Ok(tokens.unwrap_or_default()) +} + +pub fn parse_endpoint_metadata_from_yaml( + yaml: YamlValue, +) -> Result<(JSONValue, HashSet>), String> { + let mapping = if let YamlValue::Mapping(mapping) = yaml { + mapping + } else { + return Err("invalid endpoint metadata: value must be a yaml object".into()); + }; + + let mut map = JsonMap::new(); + for (yaml_key, yaml_value) in mapping { + let key = parse_yaml_key(yaml_key)?; + let value = yaml_to_json_value(key.as_str(), yaml_value)?; + map.insert(key, value); + } + + let tokens = extract_endpoint_tokens(&mut map)?; + + Ok((JSONValue::Object(map), tokens)) +} + +fn yaml_to_json_value(key: &str, yaml: YamlValue) -> Result { + let json_value = match yaml { + YamlValue::Null => JSONValue::Null, + YamlValue::Bool(v) => JSONValue::Bool(v), + YamlValue::Number(v) => match v.as_f64() { + Some(v) => JSONValue::Number( + JSONNumber::from_f64(v) + .ok_or_else(|| format!("invalid f64 `{:?}` provided for key `{}`", v, key))?, + ), + None => return Err(format!("failed to parse key `{}` as f64 number", key)), + }, + YamlValue::String(v) => JSONValue::String(v), + YamlValue::Sequence(v) => { + let mut array = vec![]; + for yaml_value in v { + array.push(yaml_to_json_value(key, yaml_value)?); + } + JSONValue::Array(array) + } + YamlValue::Mapping(v) => { + let mut map = JsonMap::new(); + + for (yaml_key, yaml_value) in v { + let nested_key = parse_yaml_key(yaml_key)?; + let value = + yaml_to_json_value(format!("{}.{}", key, nested_key).as_str(), yaml_value)?; + map.insert(nested_key, value); + } + + JSONValue::Object(map) + } + }; + + Ok(json_value) +} + +fn parse_yaml_key(yaml: YamlValue) -> Result { + match yaml { + YamlValue::String(v) => Ok(v), + v => Err(format!( + "invalid key `{:?}`: only string keys are allowed", + v + )), + } +} +#[cfg(test)] +mod tests { + use crate::config::metadata::{parse_endpoint_metadata_from_yaml, yaml_to_json_value}; + use std::collections::HashSet; + + #[test] + fn yaml_data_types() { + let yaml = " +one: two +three: + four: + - five + - 6 + seven: + eight: true +"; + let yaml_value = serde_yaml::from_str(yaml).unwrap(); + let expected_json = serde_json::json!({ + "one": "two", + "three": { + "four": ["five", 6.0], + "seven": { + "eight": true + } + } + }); + + assert_eq!(yaml_to_json_value("k", yaml_value).unwrap(), expected_json); + } + + #[test] + fn yaml_parse_endpoint_metadata() { + let yaml = " +user: + key1: value1 +quilkin.dev: + endpoint.tokens: + - MXg3aWp5Ng== #1x7ijy6 + - OGdqM3YyaQ== #8gj3v2i +"; + let yaml_value = serde_yaml::from_str(yaml).unwrap(); + let expected_user_metadata = serde_json::json!({ + "user": { + "key1": "value1" + } + }); + + let (user_metadata, tokens) = parse_endpoint_metadata_from_yaml(yaml_value).unwrap(); + assert_eq!(user_metadata, expected_user_metadata); + assert_eq!( + tokens, + vec!["1x7ijy6".into(), "8gj3v2i".into()] + .into_iter() + .collect::>() + ); + } + + #[test] + fn yaml_parse_invalid_endpoint_metadata() { + let not_a_list = " +quilkin.dev: + endpoint.tokens: OGdqM3YyaQ== +"; + let not_a_string_value = " +quilkin.dev: + endpoint.tokens: + - OGdqM3YyaQ== #8gj3v2i + - 300 +"; + let not_a_base64_string = " +quilkin.dev: + endpoint.tokens: + - OGdqM3YyaQ== #8gj3v2i + - 1x7ijy6 +"; + for yaml in vec![not_a_list, not_a_string_value, not_a_base64_string] { + let yaml_value = serde_yaml::from_str(yaml).unwrap(); + assert!(parse_endpoint_metadata_from_yaml(yaml_value).is_err()); + } + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index bbdb363c76..ce6119a723 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -27,13 +27,16 @@ use uuid::Uuid; mod builder; mod endpoints; mod error; +mod metadata; pub use crate::config::endpoints::{ EmptyListError, Endpoints, UpstreamEndpoints, UpstreamEndpointsIter, }; use crate::config::error::ValueInvalidArgs; pub use builder::Builder; +pub use endpoints::{ENDPOINT_METADATA_KEY_PREFIX, ENDPOINT_METADATA_TOKEN_KEY}; pub use error::ValidationError; +pub(crate) use metadata::{extract_endpoint_tokens, parse_endpoint_metadata_from_yaml}; use std::convert::TryInto; base64_serde_type!(Base64Standard, base64::STANDARD); @@ -160,38 +163,20 @@ pub struct Filter { pub config: Option, } -/// ConnectionId is the connection auth token value -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConnectionId(#[serde(with = "Base64Standard")] Vec); - -impl From<&str> for ConnectionId { - fn from(s: &str) -> Self { - ConnectionId(s.as_bytes().to_vec()) - } -} - -impl AsRef> for ConnectionId { - fn as_ref(&self) -> &Vec { - &self.0 - } -} - /// A singular endpoint, to pass on UDP packets to. #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct EndPoint { - pub name: String, pub address: SocketAddr, - #[serde(default)] - pub connection_ids: Vec, + pub metadata: Option, } impl EndPoint { - pub fn new(name: String, address: SocketAddr, connection_ids: Vec) -> Self { - EndPoint { - name, - address, - connection_ids, - } + pub fn new(address: SocketAddr) -> Self { + EndPoint::with_metadata(address, None) + } + + pub fn with_metadata(address: SocketAddr, metadata: Option) -> Self { + EndPoint { address, metadata } } } @@ -221,18 +206,6 @@ impl Source { return Err(ValidationError::EmptyList("static.endpoints".to_string())); } - if endpoints - .iter() - .map(|ep| ep.name.clone()) - .collect::>() - .len() - != endpoints.len() - { - return Err(ValidationError::NotUnique( - "static.endpoints.name".to_string(), - )); - } - if endpoints .iter() .map(|ep| ep.address) @@ -245,6 +218,18 @@ impl Source { )); } + for ep in endpoints { + if let Some(ref metadata) = ep.metadata { + if let Err(err) = parse_endpoint_metadata_from_yaml(metadata.clone()) { + return Err(ValidationError::ValueInvalid(ValueInvalidArgs { + field: "static.endpoints.metadata".into(), + clarification: Some(err), + examples: None, + })); + } + } + } + Ok(()) } Source::Dynamic { @@ -297,6 +282,7 @@ mod tests { use crate::config::{ Builder, Config, EndPoint, ManagementServer, ProxyMode, Source, ValidationError, }; + use std::collections::HashMap; fn parse_config(yaml: &str) -> Config { Config::from_reader(yaml.as_bytes()).unwrap() @@ -332,11 +318,7 @@ mod tests { .with_port(7000) .with_static( vec![], - vec![EndPoint { - name: "test".into(), - address: "127.0.0.1:25999".parse().unwrap(), - connection_ids: vec![], - }], + vec![EndPoint::new("127.0.0.1:25999".parse().unwrap())], ) .build(); let _ = serde_yaml::to_string(&config).unwrap(); @@ -349,16 +331,8 @@ mod tests { .with_static( vec![], vec![ - EndPoint { - name: String::from("No.1"), - address: "127.0.0.1:26000".parse().unwrap(), - connection_ids: vec!["1234".into(), "5678".into()], - }, - EndPoint { - name: String::from("No.2"), - address: "127.0.0.1:26001".parse().unwrap(), - connection_ids: vec!["1234".into()], - }, + EndPoint::new("127.0.0.1:26000".parse().unwrap()), + EndPoint::new("127.0.0.1:26001".parse().unwrap()), ], ) .build(); @@ -371,8 +345,7 @@ mod tests { version: v1alpha1 static: endpoints: - - name: ep-1 - address: 127.0.0.1:25999 + - address: 127.0.0.1:25999 "; let config = parse_config(yaml); @@ -400,8 +373,7 @@ static: - 27 - true endpoints: - - name: endpoint-1 - address: 127.0.0.1:7001 + - address: 127.0.0.1:7001 "; let config = parse_config(yaml); @@ -434,8 +406,7 @@ proxy: port: 7000 static: endpoints: - - name: ep-1 - address: 127.0.0.1:25999 + - address: 127.0.0.1:25999 "; let config = parse_config(yaml); @@ -460,11 +431,7 @@ static: assert_eq!(config.proxy.mode, ProxyMode::Client); assert_static_endpoints( &config.source, - vec![EndPoint::new( - "ep-1".into(), - "127.0.0.1:25999".parse().unwrap(), - vec![], - )], + vec![EndPoint::new("127.0.0.1:25999".parse().unwrap())], ); } @@ -477,28 +444,40 @@ proxy: mode: SERVER static: endpoints: - - name: Game Server No. 1 - address: 127.0.0.1:26000 - connection_ids: - - MXg3aWp5Ng== #1x7ijy6 - - OGdqM3YyaQ== #8gj3v2i - - name: Game Server No. 2 - address: 127.0.0.1:26001 - connection_ids: - - bmt1eTcweA== #nkuy70x"; + - address: 127.0.0.1:26000 + metadata: + tokens: + - MXg3aWp5Ng== #1x7ijy6 + - OGdqM3YyaQ== #8gj3v2i + - address: 127.0.0.1:26001 + metadata: + tokens: + - bmt1eTcweA== #nkuy70x"; let config = parse_config(yaml); assert_static_endpoints( &config.source, vec![ - EndPoint::new( - "Game Server No. 1".into(), + EndPoint::with_metadata( "127.0.0.1:26000".parse().unwrap(), - vec!["1x7ijy6".into(), "8gj3v2i".into()], + Some( + serde_yaml::to_value( + vec![("tokens", vec!["MXg3aWp5Ng==", "OGdqM3YyaQ=="])] + .into_iter() + .collect::>(), + ) + .unwrap(), + ), ), - EndPoint::new( - String::from("Game Server No. 2"), + EndPoint::with_metadata( "127.0.0.1:26001".parse().unwrap(), - vec!["nkuy70x".into()], + Some( + serde_yaml::to_value( + vec![("tokens", vec!["bmt1eTcweA=="])] + .into_iter() + .collect::>(), + ) + .unwrap(), + ), ), ], ); @@ -630,18 +609,15 @@ static: ); let yaml = " -# Non unique endpoint names. +# Invalid metadata version: v1alpha1 static: endpoints: - - name: a - address: 127.0.0.1:25998 - - name: a - address: 127.0.0.1:25999 + - address: 127.0.0.1:25999 + metadata: + quilkin.dev: + endpoint.tokens: abc "; - assert_eq!( - ValidationError::NotUnique("static.endpoints.name".to_string()).to_string(), - parse_config(yaml).validate().unwrap_err().to_string() - ); + assert!(parse_config(yaml).validate().is_err()); } } diff --git a/src/extensions/filter_chain.rs b/src/extensions/filter_chain.rs index 6d0ae06a1d..439fd136b9 100644 --- a/src/extensions/filter_chain.rs +++ b/src/extensions/filter_chain.rs @@ -118,12 +118,13 @@ mod tests { use std::str::from_utf8; use crate::config; - use crate::config::{Builder, EndPoint, Endpoints, ProxyMode, UpstreamEndpoints}; + use crate::config::{Builder, Endpoints, ProxyMode, UpstreamEndpoints}; use crate::extensions::filters::DebugFactory; use crate::extensions::{default_registry, FilterFactory}; use crate::test_utils::{ep, logger, TestFilter}; use super::*; + use crate::cluster::Endpoint; #[test] fn from_config() { @@ -162,22 +163,14 @@ mod tests { assert!(result.is_err()); } - fn endpoints() -> Vec { + fn endpoints() -> Vec { vec![ - EndPoint { - name: "one".to_string(), - address: "127.0.0.1:80".parse().unwrap(), - connection_ids: vec![], - }, - EndPoint { - name: "two".to_string(), - address: "127.0.0.1:90".parse().unwrap(), - connection_ids: vec![], - }, + Endpoint::from_address("127.0.0.1:80".parse().unwrap()), + Endpoint::from_address("127.0.0.1:90".parse().unwrap()), ] } - fn upstream_endpoints(endpoints: Vec) -> UpstreamEndpoints { + fn upstream_endpoints(endpoints: Vec) -> UpstreamEndpoints { Endpoints::new(endpoints).unwrap().into() } @@ -227,7 +220,7 @@ mod tests { .unwrap() ); assert_eq!( - "hello:our:one:127.0.0.1:80:127.0.0.1:70", + "hello:our:127.0.0.1:80:127.0.0.1:70", from_utf8(response.contents.as_slice()).unwrap() ); } @@ -271,7 +264,7 @@ mod tests { )) .unwrap(); assert_eq!( - "hello:our:one:127.0.0.1:80:127.0.0.1:70:our:one:127.0.0.1:80:127.0.0.1:70", + "hello:our:127.0.0.1:80:127.0.0.1:70:our:127.0.0.1:80:127.0.0.1:70", from_utf8(response.contents.as_slice()).unwrap() ); assert_eq!( diff --git a/src/extensions/filter_registry.rs b/src/extensions/filter_registry.rs index c7fb291a1f..7a4a397861 100644 --- a/src/extensions/filter_registry.rs +++ b/src/extensions/filter_registry.rs @@ -22,7 +22,8 @@ use std::net::SocketAddr; use prometheus::{Error as MetricsError, Registry}; -use crate::config::{EndPoint, UpstreamEndpoints, ValidationError}; +use crate::cluster::Endpoint; +use crate::config::{UpstreamEndpoints, ValidationError}; /// Contains the input arguments to [on_downstream_receive](crate::extensions::filter_registry::Filter::on_downstream_receive) pub struct DownstreamContext { @@ -62,7 +63,7 @@ pub struct DownstreamResponse { /// Contains the input arguments to [on_upstream_receive](crate::extensions::filter_registry::Filter::on_upstream_receive) pub struct UpstreamContext<'a> { /// The upstream endpoint that we're expecting packets from. - pub endpoint: &'a EndPoint, + pub endpoint: &'a Endpoint, /// The source of the received packet. pub from: SocketAddr, /// The destination of the received packet. @@ -132,7 +133,7 @@ impl From for DownstreamResponse { impl UpstreamContext<'_> { /// Creates a new [`UpstreamContext`] pub fn new( - endpoint: &EndPoint, + endpoint: &Endpoint, from: SocketAddr, to: SocketAddr, contents: Vec, @@ -149,7 +150,7 @@ impl UpstreamContext<'_> { /// Creates a new [`UpstreamContext`] from a [`UpstreamResponse`] pub fn with_response( - endpoint: &EndPoint, + endpoint: &Endpoint, from: SocketAddr, to: SocketAddr, response: UpstreamResponse, @@ -308,6 +309,7 @@ mod tests { use crate::test_utils::TestFilterFactory; use super::*; + use crate::cluster::Endpoint; use crate::config::Endpoints; struct TestFilter {} @@ -341,18 +343,12 @@ mod tests { .unwrap(); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); - let endpoint = EndPoint { - name: "".to_string(), - address: addr, - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address(addr); assert!(filter .on_downstream_receive(DownstreamContext::new( - Endpoints::new(vec![EndPoint::new( - "foo".into(), + Endpoints::new(vec![Endpoint::from_address( "127.0.0.1:8080".parse().unwrap(), - vec![] )]) .unwrap() .into(), diff --git a/src/extensions/filters/capture_bytes/mod.rs b/src/extensions/filters/capture_bytes/mod.rs index ed797a11f2..b610e9c73d 100644 --- a/src/extensions/filters/capture_bytes/mod.rs +++ b/src/extensions/filters/capture_bytes/mod.rs @@ -185,10 +185,11 @@ mod tests { use prometheus::Registry; use serde_yaml::{Mapping, Value}; - use crate::config::{EndPoint, Endpoints}; + use crate::config::Endpoints; use crate::test_utils::{assert_filter_on_upstream_receive_no_change, logger}; use super::*; + use crate::cluster::Endpoint; const TOKEN_KEY: &str = "TOKEN"; @@ -263,11 +264,7 @@ mod tests { remove: true, }; let filter = capture_bytes(config); - let endpoints = vec![EndPoint { - name: "e1".to_string(), - address: "127.0.0.1:81".parse().unwrap(), - connection_ids: vec![], - }]; + let endpoints = vec![Endpoint::from_address("127.0.0.1:81".parse().unwrap())]; let response = filter.on_downstream_receive(DownstreamContext::new( Endpoints::new(endpoints).unwrap().into(), "127.0.0.1:80".parse().unwrap(), @@ -322,11 +319,7 @@ mod tests { where F: Filter + ?Sized, { - let endpoints = vec![EndPoint { - name: "e1".to_string(), - address: "127.0.0.1:81".parse().unwrap(), - connection_ids: vec![], - }]; + let endpoints = vec![Endpoint::from_address("127.0.0.1:81".parse().unwrap())]; let response = filter .on_downstream_receive(DownstreamContext::new( Endpoints::new(endpoints).unwrap().into(), diff --git a/src/extensions/filters/concatenate_bytes.rs b/src/extensions/filters/concatenate_bytes.rs index 3c2ba9192d..5498cd8d6a 100644 --- a/src/extensions/filters/concatenate_bytes.rs +++ b/src/extensions/filters/concatenate_bytes.rs @@ -103,11 +103,12 @@ impl Filter for ConcatenateBytes { #[cfg(test)] mod tests { - use crate::config::{EndPoint, Endpoints}; + use crate::config::Endpoints; use crate::test_utils::assert_filter_on_downstream_receive_no_change; use serde_yaml::{Mapping, Value}; use super::*; + use crate::cluster::Endpoint; #[test] fn factory_valid_config() { @@ -207,11 +208,7 @@ mod tests { where F: Filter + ?Sized, { - let endpoints = vec![EndPoint { - name: "e1".to_string(), - address: "127.0.0.1:81".parse().unwrap(), - connection_ids: vec![], - }]; + let endpoints = vec![Endpoint::from_address("127.0.0.1:81".parse().unwrap())]; let response = filter .on_downstream_receive(DownstreamContext::new( Endpoints::new(endpoints.clone()).unwrap().into(), diff --git a/src/extensions/filters/debug.rs b/src/extensions/filters/debug.rs index 01cb0bdb8c..fec5ac9319 100644 --- a/src/extensions/filters/debug.rs +++ b/src/extensions/filters/debug.rs @@ -103,7 +103,7 @@ impl Filter for Debug { } fn on_upstream_receive(&self, ctx: UpstreamContext) -> Option { - info!(self.log, "received endpoint packet"; "endpoint" => ctx.endpoint.name.clone(), + info!(self.log, "received endpoint packet"; "from" => ctx.from, "to" => ctx.to, "contents" => packet_to_string(ctx.contents.clone())); diff --git a/src/extensions/filters/load_balancer/mod.rs b/src/extensions/filters/load_balancer/mod.rs index d5cf403aec..d8bac1326b 100644 --- a/src/extensions/filters/load_balancer/mod.rs +++ b/src/extensions/filters/load_balancer/mod.rs @@ -130,7 +130,8 @@ mod tests { use std::collections::HashSet; use std::net::SocketAddr; - use crate::config::{EndPoint, Endpoints}; + use crate::cluster::Endpoint; + use crate::config::Endpoints; use crate::extensions::filter_registry::DownstreamContext; use crate::extensions::filters::load_balancer::LoadBalancerFilterFactory; use crate::extensions::{CreateFilterArgs, Filter, FilterFactory}; @@ -153,7 +154,7 @@ mod tests { Endpoints::new( input_addresses .iter() - .map(|addr| EndPoint::new("".into(), *addr, vec![])) + .map(|addr| Endpoint::from_address(*addr)) .collect(), ) .unwrap() diff --git a/src/extensions/filters/local_rate_limit/mod.rs b/src/extensions/filters/local_rate_limit/mod.rs index 10562bd8f6..39f1d01f25 100644 --- a/src/extensions/filters/local_rate_limit/mod.rs +++ b/src/extensions/filters/local_rate_limit/mod.rs @@ -197,7 +197,8 @@ mod tests { use prometheus::Registry; use tokio::time; - use crate::config::{EndPoint, Endpoints}; + use crate::cluster::Endpoint; + use crate::config::Endpoints; use crate::extensions::filter_registry::DownstreamContext; use crate::extensions::filters::local_rate_limit::metrics::Metrics; use crate::extensions::filters::local_rate_limit::{Config, RateLimitFilter}; @@ -278,10 +279,8 @@ mod tests { // Check that we're rate limited. assert!(r .on_downstream_receive(DownstreamContext::new( - Endpoints::new(vec![EndPoint::new( - "ep".into(), + Endpoints::new(vec![Endpoint::from_address( "127.0.0.1:8080".parse().unwrap(), - vec![] )]) .unwrap() .into(), @@ -300,10 +299,8 @@ mod tests { let result = r .on_downstream_receive(DownstreamContext::new( - Endpoints::new(vec![EndPoint::new( - "ep".into(), + Endpoints::new(vec![Endpoint::from_address( "127.0.0.1:8080".parse().unwrap(), - vec![], )]) .unwrap() .into(), diff --git a/src/extensions/filters/token_router/mod.rs b/src/extensions/filters/token_router/mod.rs index ce34473c99..2341a9dd36 100644 --- a/src/extensions/filters/token_router/mod.rs +++ b/src/extensions/filters/token_router/mod.rs @@ -99,18 +99,13 @@ impl Filter for TokenRouter { None } Some(value) => match value.downcast_ref::>() { - Some(token) => { - match ctx - .endpoints - .retain(|e| e.connection_ids.iter().any(|id| id.as_ref() == token)) - { - Ok(_) => Some(ctx.into()), - Err(_) => { - self.metrics.packets_dropped_no_endpoint_match.inc(); - None - } + Some(token) => match ctx.endpoints.retain(|e| e.tokens.contains(token)) { + Ok(_) => Some(ctx.into()), + Err(_) => { + self.metrics.packets_dropped_no_endpoint_match.inc(); + None } - } + }, None => { error!(self.log, "Filter configuration issue: retrieved token is not the correct type (Vec)"; "metadata_key" => self.metadata_key.clone()); @@ -133,10 +128,11 @@ mod tests { use prometheus::Registry; use serde_yaml::{Mapping, Value}; - use crate::config::{ConnectionId, EndPoint, Endpoints}; + use crate::config::Endpoints; use crate::test_utils::{assert_filter_on_upstream_receive_no_change, logger}; use super::*; + use crate::cluster::Endpoint; const TOKEN_KEY: &str = "TOKEN"; @@ -236,15 +232,15 @@ mod tests { } fn new_ctx() -> DownstreamContext { - let endpoint1 = EndPoint::new( - "one".into(), + let endpoint1 = Endpoint::new( "127.0.0.1:80".parse().unwrap(), - vec![ConnectionId::from("123")], + vec!["123".into()].into_iter().collect(), + None, ); - let endpoint2 = EndPoint::new( - "two".into(), + let endpoint2 = Endpoint::new( "127.0.0.1:90".parse().unwrap(), - vec![ConnectionId::from("456")], + vec!["456".into()].into_iter().collect(), + None, ); DownstreamContext::new( @@ -262,13 +258,5 @@ mod tests { assert_eq!(b"hello".to_vec(), result.contents); assert_eq!(1, result.endpoints.size()); - assert_eq!( - vec!["one"], - result - .endpoints - .iter() - .map(|i| i.name.clone()) - .collect::>() - ); } } diff --git a/src/proxy/server/error.rs b/src/proxy/server/error.rs index 261ed24df0..611285cd68 100644 --- a/src/proxy/server/error.rs +++ b/src/proxy/server/error.rs @@ -22,6 +22,7 @@ pub enum Error { Initialize(String), Session(SessionError), Bind(tokio::io::Error), + InvalidEndpointConfig(String), } #[derive(Debug)] @@ -35,6 +36,7 @@ impl Display for Error { Error::Initialize(reason) => write!(f, "failed to startup properly: {}", reason), Error::Session(inner) => write!(f, "session error: {}", inner), Error::Bind(inner) => write!(f, "failed to bind to port: {}", inner), + Error::InvalidEndpointConfig(reason) => write!(f, "{}", reason), } } } diff --git a/src/proxy/server/mod.rs b/src/proxy/server/mod.rs index f7ae9410ce..0b6324948c 100644 --- a/src/proxy/server/mod.rs +++ b/src/proxy/server/mod.rs @@ -28,13 +28,14 @@ use tokio::time::{delay_for, Duration, Instant}; use metrics::Metrics as ProxyMetrics; use crate::cluster::cluster_manager::{ClusterManager, SharedClusterManager}; -use crate::config::{Config, EndPoint, Source}; +use crate::config::{Config, Source}; use crate::extensions::{DownstreamContext, Filter, FilterChain}; use crate::proxy::server::error::{Error, RecvFromError}; use crate::proxy::sessions::{Packet, Session, SESSION_TIMEOUT_SECONDS}; use crate::utils::debug; use super::metrics::{start_metrics_server, Metrics}; +use crate::cluster::Endpoint; pub mod error; pub(super) mod metrics; @@ -104,8 +105,17 @@ impl Server { match &self.config.source { Source::Static { filters: _, - endpoints, - } => Ok(ClusterManager::fixed(endpoints.to_vec())), + endpoints: config_endpoints, + } => { + let mut endpoints = Vec::with_capacity(config_endpoints.len()); + for ep in config_endpoints { + // TODO: We should a validated config type so that we don't need to + // handle errors when using its values later on since we know it's validated. + endpoints + .push(Endpoint::from_config(ep).map_err(Error::InvalidEndpointConfig)?); + } + Ok(ClusterManager::fixed(endpoints)) + } Source::Dynamic { filters: _, management_servers, @@ -315,7 +325,7 @@ impl Server { chain: Arc, sessions: SessionMap, from: SocketAddr, - dest: &EndPoint, + dest: &Endpoint, sender: mpsc::Sender, ) -> std::result::Result<(), Box> { { @@ -407,18 +417,7 @@ mod tests { .with_port(local_addr.port()) .with_static( vec![], - vec![ - EndPoint { - name: String::from("e1"), - address: endpoint1.addr, - connection_ids: vec![], - }, - EndPoint { - name: String::from("e2"), - address: endpoint2.addr, - connection_ids: vec![], - }, - ], + vec![EndPoint::new(endpoint1.addr), EndPoint::new(endpoint2.addr)], ) .build(); t.run_server(config); @@ -443,10 +442,7 @@ mod tests { let config = ConfigBuilder::empty() .with_mode(ProxyMode::Client) .with_port(local_addr.port()) - .with_static( - vec![], - vec![EndPoint::new("test".into(), endpoint.addr, vec![])], - ) + .with_static(vec![], vec![EndPoint::new(endpoint.addr)]) .build(); t.run_server(config); @@ -476,7 +472,7 @@ mod tests { name: "TestFilter".to_string(), config: None, }], - vec![EndPoint::new("test".into(), endpoint.addr, vec![])], + vec![EndPoint::new(endpoint.addr)], ) .build(); t.run_server_with_filter_registry(config, registry); @@ -552,10 +548,8 @@ mod tests { log: t.log.clone(), metrics: Metrics::default(), proxy_metrics: ProxyMetrics::new(&Metrics::default().registry).unwrap(), - cluster_manager: ClusterManager::fixed(vec![EndPoint::new( - "".into(), + cluster_manager: ClusterManager::fixed(vec![Endpoint::from_address( endpoint_address, - vec![], )]), chain, sessions: sessions_clone, @@ -632,7 +626,7 @@ mod tests { let server = Builder::from(config).validate().unwrap().build(); server.run_recv_from( - ClusterManager::fixed(vec![EndPoint::new("".into(), endpoint.addr, vec![])]), + ClusterManager::fixed(vec![Endpoint::from_address(endpoint.addr)]), server.filter_chain.clone(), recv, &sessions, @@ -651,11 +645,7 @@ mod tests { let from: SocketAddr = "127.0.0.1:27890".parse().unwrap(); let dest: SocketAddr = "127.0.0.1:27891".parse().unwrap(); let (sender, _) = mpsc::channel::(1); - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: dest, - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address(dest); // gate { @@ -712,11 +702,7 @@ mod tests { let from: SocketAddr = "127.0.0.1:7000".parse().unwrap(); let to: SocketAddr = "127.0.0.1:7001".parse().unwrap(); let (send, _recv) = mpsc::channel::(1); - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: to, - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address(to); Server::ensure_session( &t.log.clone(), @@ -771,11 +757,7 @@ mod tests { let to: SocketAddr = "127.0.0.1:7001".parse().unwrap(); let (send, _recv) = mpsc::channel::(1); let key = (from, to); - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: to, - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address(to); let config = Arc::new(config_with_dummy_endpoint().build()); let server = Builder::from(config).validate().unwrap().build(); diff --git a/src/proxy/sessions/session.rs b/src/proxy/sessions/session.rs index 827408515d..b44cfd935d 100644 --- a/src/proxy/sessions/session.rs +++ b/src/proxy/sessions/session.rs @@ -27,7 +27,7 @@ use tokio::select; use tokio::sync::{mpsc, watch, RwLock}; use tokio::time::{Duration, Instant}; -use crate::config::EndPoint; +use crate::cluster::Endpoint; use crate::extensions::{Filter, FilterChain, UpstreamContext}; use crate::proxy::sessions::error::Error; use crate::proxy::sessions::metrics::Metrics; @@ -47,7 +47,7 @@ pub struct Session { created_at: Instant, send: SendHalf, /// dest is where to send data to - dest: EndPoint, + dest: Endpoint, /// from is the original sender from: SocketAddr, /// session expiration timestamp @@ -62,7 +62,7 @@ pub struct Session { struct ReceivedPacketContext<'a> { packet: &'a [u8], chain: Arc, - endpoint: &'a EndPoint, + endpoint: &'a Endpoint, from: SocketAddr, to: SocketAddr, } @@ -95,10 +95,11 @@ impl Session { metrics: Metrics, chain: Arc, from: SocketAddr, - dest: EndPoint, + dest: Endpoint, sender: mpsc::Sender, ) -> Result { - let log = base.new(o!("source" => "proxy::Session", "from" => from, "dest_name" => dest.name.clone(), "dest_address" => dest.address)); + let log = base + .new(o!("source" => "proxy::Session", "from" => from, "dest_address" => dest.address)); let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0); let (recv, send) = UdpSocket::bind(addr) .await @@ -213,7 +214,7 @@ impl Session { to, } = packet_ctx; - trace!(log, "Received packet"; "from" => from, "endpoint_name" => &endpoint.name, + trace!(log, "Received packet"; "from" => from, "endpoint_addr" => &endpoint.address, "contents" => debug::bytes_to_string(packet.to_vec())); Session::inc_expiration(expiration).await; @@ -244,7 +245,7 @@ impl Session { /// Sends a packet to the Session's dest. pub async fn send_to(&mut self, buf: &[u8]) -> Result> { - trace!(self.log, "Sending packet"; "dest_name" => &self.dest.name, + trace!(self.log, "Sending packet"; "dest_address" => &self.dest.address, "contents" => debug::bytes_to_string(buf.to_vec())); @@ -270,7 +271,7 @@ impl Session { /// close closes this Session. pub fn close(&self) -> result::Result<(), watch::error::SendError> { - debug!(self.log, "Session closed"; "from" => self.from, "dest_name" => &self.dest.name, "dest_address" => &self.dest.address); + debug!(self.log, "Session closed"; "from" => self.from, "dest_address" => &self.dest.address); self.closer.broadcast(true) } } @@ -307,11 +308,7 @@ mod tests { mut recv, mut send, } = t.create_and_split_socket().await; - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: addr, - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address(addr); let (send_packet, mut recv_packet) = mpsc::channel::(5); let mut sess = Session::new( @@ -371,11 +368,7 @@ mod tests { // without a filter let (sender, _) = mpsc::channel::(1); let ep = t.open_socket_and_recv_single_packet().await; - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: ep.addr, - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address(ep.addr); let mut session = Session::new( &t.log, @@ -402,11 +395,7 @@ mod tests { let ep = t.open_socket_and_recv_single_packet().await; let (send_packet, _) = mpsc::channel::(5); - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: ep.addr, - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address(ep.addr); info!(t.log, ">> creating sessions"); let sess = Session::new( @@ -448,11 +437,7 @@ mod tests { let t = TestHelper::default(); let chain = Arc::new(FilterChain::new(vec![])); - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: "127.0.1.1:80".parse().unwrap(), - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address("127.0.1.1:80".parse().unwrap()); let dest = "127.0.0.1:88".parse().unwrap(); let (sender, mut receiver) = mpsc::channel::(10); let expiration = Arc::new(RwLock::new(Instant::now())); @@ -516,10 +501,7 @@ mod tests { assert!(initial_expiration < *expiration.read().await); let p = receiver.try_recv().unwrap(); assert_eq!( - format!( - "{}:our:{}:{}:{}", - msg, endpoint.name, endpoint.address, dest - ), + format!("{}:our:{}:{}", msg, endpoint.address, dest), from_utf8(p.contents.as_slice()).unwrap() ); assert_eq!(dest, p.dest); @@ -529,11 +511,7 @@ mod tests { async fn session_new_metrics() { let t = TestHelper::default(); let ep = t.open_socket_and_recv_single_packet().await; - let endpoint = EndPoint { - name: "endpoint".to_string(), - address: ep.addr, - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address(ep.addr); let (send_packet, _) = mpsc::channel::(5); let session = Session::new( @@ -574,11 +552,7 @@ mod tests { .unwrap(), Arc::new(FilterChain::new(vec![])), endpoint.addr, - EndPoint { - name: "endpoint".to_string(), - address: endpoint.addr, - connection_ids: vec![], - }, + Endpoint::from_address(endpoint.addr), sender, ) .await @@ -607,11 +581,7 @@ mod tests { .unwrap(), Arc::new(FilterChain::new(vec![])), endpoint.addr, - EndPoint { - name: "endpoint".to_string(), - address: endpoint.addr, - connection_ids: vec![], - }, + Endpoint::from_address(endpoint.addr), send_packet, ) .await diff --git a/src/test_utils.rs b/src/test_utils.rs index 469a00b88b..18af344482 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -25,6 +25,7 @@ use tokio::net::udp::{RecvHalf, SendHalf}; use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, watch}; +use crate::cluster::Endpoint; use crate::config::{Builder as ConfigBuilder, Config, EndPoint, Endpoints, ProxyMode}; use crate::extensions::{ default_registry, CreateFilterArgs, DownstreamContext, DownstreamResponse, Error, Filter, @@ -66,9 +67,8 @@ impl Filter for TestFilter { .and_modify(|e| e.downcast_mut::().unwrap().push_str(":receive")) .or_insert_with(|| Box::new("receive".to_string())); - ctx.contents.append( - &mut format!(":our:{}:{}:{}", ctx.endpoint.name, ctx.from, ctx.to).into_bytes(), - ); + ctx.contents + .append(&mut format!(":our:{}:{}", ctx.from, ctx.to).into_bytes()); Some(ctx.into()) } } @@ -315,11 +315,7 @@ pub fn assert_filter_on_downstream_receive_no_change(filter: &F) where F: Filter, { - let endpoints = vec![EndPoint { - name: "e1".into(), - address: "127.0.0.1:80".parse().unwrap(), - connection_ids: vec![], - }]; + let endpoints = vec![Endpoint::from_address("127.0.0.1:80".parse().unwrap())]; let from = "127.0.0.1:90".parse().unwrap(); let contents = "hello".to_string().into_bytes(); @@ -344,11 +340,7 @@ pub fn assert_filter_on_upstream_receive_no_change(filter: &F) where F: Filter, { - let endpoint = EndPoint { - name: "e1".into(), - address: "127.0.0.1:90".parse().unwrap(), - connection_ids: vec![], - }; + let endpoint = Endpoint::from_address("127.0.0.1:90".parse().unwrap()); let contents = "hello".to_string().into_bytes(); match filter.on_upstream_receive(UpstreamContext::new( @@ -367,20 +359,12 @@ pub fn config_with_dummy_endpoint() -> ConfigBuilder { .with_mode(ProxyMode::Server) .with_static( vec![], - vec![EndPoint::new( - "test".into(), - "127.0.0.1:8080".parse().unwrap(), - vec![], - )], + vec![EndPoint::new("127.0.0.1:8080".parse().unwrap())], ) } /// Creates a dummy endpoint with `id` as a suffix. pub fn ep(id: u8) -> EndPoint { - EndPoint::new( - format!("test-{}", id), - format!("127.0.0.{:?}:8080", id).parse().unwrap(), - vec![], - ) + EndPoint::new(format!("127.0.0.{:?}:8080", id).parse().unwrap()) } #[cfg(test)] diff --git a/src/xds/cluster.rs b/src/xds/cluster.rs index 7ac4f41666..a1f42d2737 100644 --- a/src/xds/cluster.rs +++ b/src/xds/cluster.rs @@ -307,25 +307,27 @@ impl ClusterManager { })?; // Extract any metadata associated with the endpoint. - let metadata = if let Some(metadata) = metadata { - Some(metadata::to_json(metadata).map_err(Error::new)?) + let (metadata, tokens) = if let Some(metadata) = metadata { + let (metadata, tokens) = + metadata::parse_endpoint_metadata(metadata).map_err(Error::new)?; + (Some(metadata), tokens) } else { - None + (None, Default::default()) }; - processed_endpoints.push((address, metadata)); + processed_endpoints.push((address, tokens, metadata)); } let mut endpoints = vec![]; - for ((addr, port), metadata) in processed_endpoints.into_iter() { - endpoints.push(Endpoint { - metadata, + for ((addr, port), tokens, metadata) in processed_endpoints.into_iter() { + endpoints.push(Endpoint::new( // We only support IP addresses so anything else is an error. - address: addr - .parse::() + addr.parse::() .map_err(|err| Error::new(format!("invalid ip address: {}", err))) .map(|ip_addr| SocketAddr::new(ip_addr, port))?, - }); + tokens, + metadata, + )); } existing_endpoints.insert(locality, LocalityEndpoints { endpoints }); @@ -550,10 +552,7 @@ mod tests { .get(&None) .unwrap() .endpoints, - vec![ProxyEndpoint { - address: expected_socket_addr, - metadata: None, - }] + vec![ProxyEndpoint::from_address(expected_socket_addr)] ); assert_eq!( cm.clusters @@ -563,10 +562,9 @@ mod tests { .get(&None) .unwrap() .endpoints, - vec![ProxyEndpoint { - address: "127.0.0.1:2020".parse().unwrap(), - metadata: None, - }] + vec![ProxyEndpoint::from_address( + "127.0.0.1:2020".parse().unwrap(), + )] ); } } diff --git a/src/xds/metadata.rs b/src/xds/metadata.rs index 8045c106fd..bec705f681 100644 --- a/src/xds/metadata.rs +++ b/src/xds/metadata.rs @@ -1,5 +1,6 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashSet}; +use crate::config::extract_endpoint_tokens; use crate::xds::envoy::config::core::v3::Metadata; use prost_types::value::Kind; use prost_types::Value as ProstValue; @@ -7,15 +8,24 @@ use serde_json::map::Map as JsonMap; use serde_json::value::Value as JSONValue; use serde_json::Number as JSONNumber; -/// Converts an XDS Metadata object into an equivalent JSON object. -pub fn to_json(metadata: Metadata) -> Result { +/// Converts an XDS Metadata object into endpoint specific values and JSON values. +pub fn parse_endpoint_metadata( + metadata: Metadata, +) -> Result<(JSONValue, HashSet>), String> { + let mut metadata = to_json_map(metadata)?; + let tokens = extract_endpoint_tokens(&mut metadata)?; + Ok((JSONValue::Object(metadata), tokens)) +} + +/// Converts an XDS Metadata object into an equivalent JSON map. +fn to_json_map(metadata: Metadata) -> Result, String> { let mut map = JsonMap::new(); for (key, prost_struct) in metadata.filter_metadata { map.insert(key, prost_map_to_json_value(prost_struct.fields)?); } - Ok(JSONValue::Object(map)) + Ok(map) } fn prost_kind_to_json_value(key: &str, kind: Kind) -> Result { @@ -60,23 +70,20 @@ fn prost_map_to_json_value(prost_map: BTreeMap) -> Result>() + ); + } + + #[test] + fn prost_invalid_endpoint_metadata() { + let invalid_values = vec![ + // Not a list. + ProstValue { + kind: Some(Kind::StringValue("MXg3aWp5Ng==".into())), + }, + // Not a string value + ProstValue { + kind: Some(Kind::ListValue(ListValue { + values: vec![ + ProstValue { + kind: Some(Kind::StringValue("MXg3aWp5Ng==".into())), + }, + ProstValue { + kind: Some(Kind::NumberValue(12.0)), + }, + ], + })), + }, + // Not a base64 string value + ProstValue { + kind: Some(Kind::ListValue(ListValue { + values: vec![ProstValue { + kind: Some(Kind::StringValue("cat".into())), + }], + })), + }, + ]; + + for invalid in invalid_values { + let metadata = Metadata { + filter_metadata: vec![( + "quilkin.dev".into(), + ProstStruct { + fields: vec![("endpoint.tokens".into(), invalid)] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + assert!(parse_endpoint_metadata(metadata).is_err()); + } } } diff --git a/tests/concatenate_bytes.rs b/tests/concatenate_bytes.rs index ae4a95418d..d4cbd5635a 100644 --- a/tests/concatenate_bytes.rs +++ b/tests/concatenate_bytes.rs @@ -43,11 +43,7 @@ bytes: YWJj #abc name: ConcatBytesFactory::default().name(), config: serde_yaml::from_str(yaml).unwrap(), }], - vec![EndPoint { - name: "server".to_string(), - address: echo, - connection_ids: vec![], - }], + vec![EndPoint::new(echo)], ) .build(); diff --git a/tests/filters.rs b/tests/filters.rs index 60a805f57a..c8d55f6511 100644 --- a/tests/filters.rs +++ b/tests/filters.rs @@ -44,11 +44,7 @@ mod tests { name: "TestFilter".to_string(), config: None, }], - vec![EndPoint { - name: "server".to_string(), - address: echo, - connection_ids: vec![], - }], + vec![EndPoint::new(echo)], ) .build(); assert_eq!(Ok(()), server_config.validate()); @@ -67,11 +63,10 @@ mod tests { name: "TestFilter".to_string(), config: None, }], - vec![EndPoint::new( - "test".into(), - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port), - vec![], - )], + vec![EndPoint::new(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + server_port, + ))], ) .build(); assert_eq!(Ok(()), client_config.validate()); @@ -129,11 +124,7 @@ mod tests { name: factory.name(), config: Some(serde_yaml::Value::Mapping(map)), }], - vec![EndPoint { - name: "server".to_string(), - address: echo, - connection_ids: vec![], - }], + vec![EndPoint::new(echo)], ) .build(); t.run_server(server_config); @@ -149,11 +140,10 @@ mod tests { name: factory.name(), config: Some(serde_yaml::Value::Mapping(map)), }], - vec![EndPoint::new( - "test".into(), - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port), - vec![], - )], + vec![EndPoint::new(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + server_port, + ))], ) .build(); t.run_server(client_config); diff --git a/tests/load_balancer.rs b/tests/load_balancer.rs index 663ad3488e..28265a77b6 100644 --- a/tests/load_balancer.rs +++ b/tests/load_balancer.rs @@ -57,7 +57,7 @@ policy: ROUND_ROBIN echo_addresses .iter() .enumerate() - .map(|(i, addr)| EndPoint::new(format!("server-{}", i), *addr, vec![])) + .map(|(_, addr)| EndPoint::new(*addr)) .collect(), ) .build(); diff --git a/tests/local_rate_limit.rs b/tests/local_rate_limit.rs index 956d455534..3ef5d94ab5 100644 --- a/tests/local_rate_limit.rs +++ b/tests/local_rate_limit.rs @@ -43,11 +43,7 @@ period: 1s name: RateLimitFilterFactory::default().name(), config: serde_yaml::from_str(yaml).unwrap(), }], - vec![EndPoint { - name: "server".to_string(), - address: echo, - connection_ids: vec![], - }], + vec![EndPoint::new(echo)], ) .build(); t.run_server(server_config); diff --git a/tests/metrics.rs b/tests/metrics.rs index e3f2c38a1b..afe1231741 100644 --- a/tests/metrics.rs +++ b/tests/metrics.rs @@ -40,14 +40,7 @@ mod tests { let server_port = 12346; let server_config = ConfigBuilder::empty() .with_port(server_port) - .with_static( - vec![], - vec![EndPoint { - name: "server".to_string(), - address: echo, - connection_ids: vec![], - }], - ) + .with_static(vec![], vec![EndPoint::new(echo)]) .build(); t.run_server_with_metrics(server_config, server_metrics); @@ -57,11 +50,10 @@ mod tests { .with_port(client_port) .with_static( vec![], - vec![EndPoint::new( - "test".into(), - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port), - vec![], - )], + vec![EndPoint::new(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + server_port, + ))], ) .build(); t.run_server(client_config); diff --git a/tests/no_filter.rs b/tests/no_filter.rs index 96c4a1c7ef..b120cabcf9 100644 --- a/tests/no_filter.rs +++ b/tests/no_filter.rs @@ -38,21 +38,7 @@ mod tests { let server_port = 12345; let server_config = ConfigBuilder::empty() .with_port(server_port) - .with_static( - vec![], - vec![ - EndPoint { - name: "server1".to_string(), - address: server1, - connection_ids: vec![], - }, - EndPoint { - name: "server2".to_string(), - address: server2, - connection_ids: vec![], - }, - ], - ) + .with_static(vec![], vec![EndPoint::new(server1), EndPoint::new(server2)]) .build(); assert_eq!(Ok(()), server_config.validate()); @@ -64,11 +50,10 @@ mod tests { .with_port(client_port) .with_static( vec![], - vec![EndPoint::new( - "test".into(), - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port), - vec![], - )], + vec![EndPoint::new(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + server_port, + ))], ) .build(); assert_eq!(Ok(()), client_config.validate()); diff --git a/tests/token_router.rs b/tests/token_router.rs index 7884dea561..88dbe8016e 100644 --- a/tests/token_router.rs +++ b/tests/token_router.rs @@ -18,11 +18,10 @@ mod tests { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - use slog::debug; use tokio::select; use tokio::time::{delay_for, Duration}; - use quilkin::config::{Builder, ConnectionId, EndPoint, Filter}; + use quilkin::config::{Builder, EndPoint, Filter}; use quilkin::extensions::filters::{CaptureBytesFactory, TokenRouterFactory}; use quilkin::extensions::FilterFactory; use quilkin::test_utils::{logger, TestHelper}; @@ -38,6 +37,11 @@ mod tests { let capture_yaml = " size: 3 remove: true +"; + let endpoint_metadata = " +quilkin.dev: + endpoint.tokens: + - YWJj # abc "; let server_port = 12348; let server_config = Builder::empty() @@ -53,11 +57,10 @@ remove: true config: None, }, ], - vec![EndPoint { - name: "server".to_string(), - address: echo, - connection_ids: vec![ConnectionId::from("abc")], - }], + vec![EndPoint::with_metadata( + echo, + Some(serde_yaml::from_str(endpoint_metadata).unwrap()), + )], ) .build(); server_config.validate().unwrap(); @@ -68,7 +71,6 @@ remove: true let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port); let msg = b"helloabc"; - debug!(log, "sending message"; "content" => format!("{:?}", msg)); send.send_to(msg, &local_addr).await.unwrap(); select! { @@ -82,7 +84,6 @@ remove: true // send an invalid packet let msg = b"helloxyz"; - debug!(log, "sending message"; "content" => format!("{:?}", msg)); send.send_to(msg, &local_addr).await.unwrap(); select! {