From 2f6b6bae18d956394a783618cc8d59e2f000944a Mon Sep 17 00:00:00 2001 From: Mark Mandel Date: Thu, 3 Feb 2022 09:20:47 -0800 Subject: [PATCH 1/2] Benchmark comparing read and write throughput (#479) Wanted to be able to highlight if we had bottlenecks in performance on read vs write operations on the proxy. This adds an extra benchmark to throughput.rs called "readwrite" and follows a similar pattern as the overall throughput benchmark, with both direct and proxies traffic utilised as extra comparison values. Work on #410 --- benches/throughput.rs | 172 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 162 insertions(+), 10 deletions(-) diff --git a/benches/throughput.rs b/benches/throughput.rs index 98e56e38e7..466042763e 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -1,8 +1,13 @@ -use std::net::UdpSocket; +use std::net::{Ipv4Addr, SocketAddr, UdpSocket}; +use std::sync::{atomic, mpsc, Arc}; +use std::thread::sleep; +use std::time; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use once_cell::sync::Lazy; +use quilkin::config::Admin; + const MESSAGE_SIZE: usize = 0xffff; const DEFAULT_MESSAGE: [u8; 0xffff] = [0xff; 0xffff]; const BENCH_LOOP_ADDR: &str = "127.0.0.1:8002"; @@ -19,16 +24,19 @@ const PACKETS: &[&[u8]] = &[ &[0xffu8; 1500], ]; -static SERVER_INIT: Lazy<()> = Lazy::new(|| { - std::thread::spawn(|| { +/// Run and instance of quilkin that sends and received data +/// from the given address. +fn run_quilkin(port: u16, endpoint: SocketAddr) { + std::thread::spawn(move || { let runtime = tokio::runtime::Runtime::new().unwrap(); let config = quilkin::config::Builder::empty() - .with_port(8000) + .with_port(port) + .with_admin(Admin { + address: "[::]:0".parse().unwrap(), + }) .with_static( vec![], - vec![quilkin::endpoint::Endpoint::new( - FEEDBACK_LOOP_ADDR.parse().unwrap(), - )], + vec![quilkin::endpoint::Endpoint::new(endpoint.into())], ) .build(); let server = quilkin::Builder::from(std::sync::Arc::new(config)) @@ -41,6 +49,10 @@ static SERVER_INIT: Lazy<()> = Lazy::new(|| { server.run(shutdown_rx).await.unwrap(); }); }); +} + +static THROUGHPUT_SERVER_INIT: Lazy<()> = Lazy::new(|| { + run_quilkin(8000, FEEDBACK_LOOP_ADDR.parse().unwrap()); }); static FEEDBACK_LOOP: Lazy<()> = Lazy::new(|| { @@ -61,9 +73,9 @@ static FEEDBACK_LOOP: Lazy<()> = Lazy::new(|| { }); }); -fn criterion_benchmark(c: &mut Criterion) { +fn throughput_benchmark(c: &mut Criterion) { Lazy::force(&FEEDBACK_LOOP); - Lazy::force(&SERVER_INIT); + Lazy::force(&THROUGHPUT_SERVER_INIT); // Sleep to give the servers some time to warm-up. std::thread::sleep(std::time::Duration::from_millis(500)); let socket = UdpSocket::bind(BENCH_LOOP_ADDR).unwrap(); @@ -98,5 +110,145 @@ fn criterion_benchmark(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, criterion_benchmark); +const WRITE_LOOP_ADDR: &str = "127.0.0.1:8003"; +const READ_LOOP_ADDR: &str = "127.0.0.1:8004"; + +const READ_QUILKIN_PORT: u16 = 9001; +static READ_SERVER_INIT: Lazy<()> = Lazy::new(|| { + run_quilkin(READ_QUILKIN_PORT, READ_LOOP_ADDR.parse().unwrap()); +}); + +const WRITE_QUILKIN_PORT: u16 = 9002; +static WRITE_SERVER_INIT: Lazy<()> = Lazy::new(|| { + run_quilkin(WRITE_QUILKIN_PORT, WRITE_LOOP_ADDR.parse().unwrap()); +}); + +/// Binds a socket to `addr`, and waits for an initial packet to be sent to it to establish +/// a connection. After which any `Vec` sent to the returned channel will result in that +/// data being send via that connection - thereby skipping the proxy `read` operation. +fn write_feedback(addr: SocketAddr) -> mpsc::Sender> { + let (write_tx, write_rx) = mpsc::channel::>(); + std::thread::spawn(move || { + let socket = UdpSocket::bind(addr).unwrap(); + let mut packet = [0; MESSAGE_SIZE]; + let (_, source) = socket.recv_from(&mut packet).unwrap(); + while let Ok(packet) = write_rx.recv() { + socket.send_to(packet.as_slice(), source).unwrap(); + } + }); + write_tx +} + +fn readwrite_benchmark(c: &mut Criterion) { + Lazy::force(&READ_SERVER_INIT); + + // start a feedback server for read operations, that sends a response through a channel, + // thereby skipping a proxy connection on the return. + let (read_tx, read_rx) = mpsc::channel::>(); + std::thread::spawn(move || { + let socket = UdpSocket::bind(READ_LOOP_ADDR).unwrap(); + let mut packet = [0; MESSAGE_SIZE]; + loop { + let (length, _) = socket.recv_from(&mut packet).unwrap(); + let packet = &packet[..length]; + assert_eq!(packet, &DEFAULT_MESSAGE[..length]); + + if read_tx.send(packet.to_vec()).is_err() { + return; + } + } + }); + + // start a feedback server for a direct write benchmark. + let direct_write_addr = (Ipv4Addr::LOCALHOST, 9004).into(); + let direct_write_tx = write_feedback(direct_write_addr); + + // start a feedback server for a quilkin write benchmark. + let quilkin_write_addr = (Ipv4Addr::LOCALHOST, WRITE_QUILKIN_PORT); + let quilkin_write_tx = write_feedback(WRITE_LOOP_ADDR.parse().unwrap()); + Lazy::force(&WRITE_SERVER_INIT); + + // Sleep to give the servers some time to warm-up. + std::thread::sleep(std::time::Duration::from_millis(500)); + + let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); + + // prime the direct write connection + socket.send_to(PACKETS[0], direct_write_addr).unwrap(); + + // we need to send packets at least once a minute, otherwise the endpoint session expires. + // So setting up a ping packet for the write test. + // TODO(markmandel): If we ever make session timeout configurable, we can remove this. + let ping_socket = socket.try_clone().unwrap(); + let stop = Arc::new(atomic::AtomicBool::default()); + let ping_stop = stop.clone(); + std::thread::spawn(move || { + while !ping_stop.load(atomic::Ordering::Relaxed) { + ping_socket.send_to(PACKETS[0], quilkin_write_addr).unwrap(); + sleep(time::Duration::from_secs(30)); + } + }); + + let mut group = c.benchmark_group("readwrite"); + + for message in PACKETS { + group.sample_size(NUMBER_OF_PACKETS); + group.sampling_mode(criterion::SamplingMode::Flat); + group.throughput(criterion::Throughput::Bytes(message.len() as u64)); + + // direct read + group.bench_with_input( + BenchmarkId::new("direct-read", format!("{} bytes", message.len())), + &message, + |b, message| { + b.iter(|| { + socket.send_to(message, READ_LOOP_ADDR).unwrap(); + read_rx.recv().unwrap(); + }) + }, + ); + // quilkin read + let addr = (Ipv4Addr::LOCALHOST, READ_QUILKIN_PORT); + group.bench_with_input( + BenchmarkId::new("quilkin-read", format!("{} bytes", message.len())), + &message, + |b, message| { + b.iter(|| { + socket.send_to(message, addr).unwrap(); + read_rx.recv().unwrap(); + }) + }, + ); + + // direct write + let mut packet = [0; MESSAGE_SIZE]; + group.bench_with_input( + BenchmarkId::new("direct-write", format!("{} bytes", message.len())), + &message, + |b, message| { + b.iter(|| { + direct_write_tx.send(message.to_vec()).unwrap(); + socket.recv(&mut packet).unwrap(); + }) + }, + ); + + // quilkin write + let mut packet = [0; MESSAGE_SIZE]; + group.bench_with_input( + BenchmarkId::new("quilkin-write", format!("{} bytes", message.len())), + &message, + |b, message| { + b.iter(|| { + quilkin_write_tx.send(message.to_vec()).unwrap(); + socket.recv(&mut packet).unwrap(); + }) + }, + ); + } + + stop.store(true, atomic::Ordering::Relaxed); +} + +criterion_group!(benches, readwrite_benchmark, throughput_benchmark); criterion_main!(benches); From 2fdc8ac779056baddc28c6687c423838b8263d9b Mon Sep 17 00:00:00 2001 From: XAMPPRocky <4464295+XAMPPRocky@users.noreply.github.com> Date: Fri, 4 Feb 2022 07:55:14 +0100 Subject: [PATCH 2/2] Add matches filter (#447) --- Cargo.toml | 2 +- build.rs | 5 +- docs/src/filters/matches.md | 43 ++ docs/src/filters/writing_custom_filters.md | 4 +- .../filters/matches/v1alpha1/matches.proto | 49 ++ src/config/config_type.rs | 33 +- src/config/error.rs | 2 +- src/filters.rs | 1 + src/filters/capture_bytes.rs | 12 +- src/filters/chain.rs | 8 +- src/filters/compress.rs | 8 +- src/filters/debug.rs | 10 +- src/filters/error.rs | 4 +- src/filters/factory.rs | 48 +- src/filters/load_balancer.rs | 5 +- src/filters/local_rate_limit.rs | 3 +- src/filters/matches.rs | 452 ++++++++++++++++++ src/filters/registry.rs | 6 +- src/filters/set.rs | 11 +- src/filters/token_router.rs | 15 +- src/metadata.rs | 84 +++- src/proxy/config_dump.rs | 6 +- src/proxy/health.rs | 2 +- src/xds/listener.rs | 7 +- tests/matches.rs | 128 +++++ 25 files changed, 885 insertions(+), 63 deletions(-) create mode 100644 docs/src/filters/matches.md create mode 100644 proto/quilkin/extensions/filters/matches/v1alpha1/matches.proto create mode 100644 src/filters/matches.rs create mode 100644 tests/matches.rs diff --git a/Cargo.toml b/Cargo.toml index fa1a51f0bd..6bbbbd8c92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ quilkin-macros = { version = "0.3.0-dev", path = "./macros" } # Crates.io base64 = "0.13.0" base64-serde = "0.6.1" -bytes = "1.1.0" +bytes = { version = "1.1.0", features = ["serde"] } clap = { version = "3", features = ["cargo"] } dashmap = "4.0.2" either = "1.6.1" diff --git a/build.rs b/build.rs index 72a86e4cfe..63b01e49b7 100644 --- a/build.rs +++ b/build.rs @@ -28,14 +28,15 @@ fn main() -> Result<(), Box> { "proto/data-plane-api/envoy/type/metadata/v3/metadata.proto", "proto/data-plane-api/envoy/type/tracing/v3/custom_tag.proto", "proto/udpa/xds/core/v3/resource_name.proto", - "proto/quilkin/extensions/filters/debug/v1alpha1/debug.proto", "proto/quilkin/extensions/filters/capture_bytes/v1alpha1/capture_bytes.proto", "proto/quilkin/extensions/filters/compress/v1alpha1/compress.proto", "proto/quilkin/extensions/filters/concatenate_bytes/v1alpha1/concatenate_bytes.proto", + "proto/quilkin/extensions/filters/debug/v1alpha1/debug.proto", + "proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto", "proto/quilkin/extensions/filters/load_balancer/v1alpha1/load_balancer.proto", "proto/quilkin/extensions/filters/local_rate_limit/v1alpha1/local_rate_limit.proto", + "proto/quilkin/extensions/filters/matches/v1alpha1/matches.proto", "proto/quilkin/extensions/filters/token_router/v1alpha1/token_router.proto", - "proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto", ] .iter() .map(|name| std::env::current_dir().unwrap().join(name)) diff --git a/docs/src/filters/matches.md b/docs/src/filters/matches.md new file mode 100644 index 0000000000..0456b953b2 --- /dev/null +++ b/docs/src/filters/matches.md @@ -0,0 +1,43 @@ +# Matches + +The `Matches` filter's job is to provide a mechanism to change behaviour based +on dynamic metadata. This filter behaves similarly to the `match` expression +in Rust or `switch` statements in other languages. + +#### Filter name +```text +quilkin.extensions.filters.matches.v1alpha1.Matches +``` + +### Configuration Examples +```rust +# let yaml = " +version: v1alpha1 +static: + endpoints: + - address: 127.0.0.1:26000 + - address: 127.0.0.1:26001 + filters: + - name: quilkin.extensions.filters.capture_bytes.v1alpha1.CaptureBytes + config: + strategy: PREFIX + metadataKey: myapp.com/token + size: 3 + remove: false + - name: quilkin.extensions.filters.matches.v1alpha1.Matches + config: + on_read: + metadataKey: myapp.com/token + branches: + - value: abc + filter: quilkin.extensions.filters.concatenate_bytes.v1alpha1.ConcatenateBytes + config: + on_read: APPEND + bytes: eHl6 # "xyz" +# "; +# let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); +# assert_eq!(config.source.get_static_filters().unwrap().len(), 1); +# quilkin::Builder::from(std::sync::Arc::new(config)).validate().unwrap(); +``` + +View the [Matches](../../api/quilkin/filters/matches/struct.Config.html) filter documentation for more details. diff --git a/docs/src/filters/writing_custom_filters.md b/docs/src/filters/writing_custom_filters.md index 8a7cfe24be..568d3a95a4 100644 --- a/docs/src/filters/writing_custom_filters.md +++ b/docs/src/filters/writing_custom_filters.md @@ -238,7 +238,7 @@ impl FilterFactory for GreetFilterFactory { fn create_filter(&self, args: CreateFilterArgs) -> Result { let config = match args.config.unwrap() { ConfigType::Static(config) => { - serde_yaml::from_str::(serde_yaml::to_string(config).unwrap().as_str()) + serde_yaml::from_str::(serde_yaml::to_string(&config).unwrap().as_str()) .unwrap() } ConfigType::Dynamic(_) => unimplemented!("dynamic config is not yet supported for this filter"), @@ -263,7 +263,7 @@ has a [Dynamic][ConfigType::dynamic] variant. ```rust,ignore let config = match args.config.unwrap() { ConfigType::Static(config) => { - serde_yaml::from_str::(serde_yaml::to_string(config).unwrap().as_str()) + serde_yaml::from_str::(serde_yaml::to_string(&config).unwrap().as_str()) .unwrap() } ConfigType::Dynamic(_) => unimplemented!("dynamic config is not yet supported for this filter"), diff --git a/proto/quilkin/extensions/filters/matches/v1alpha1/matches.proto b/proto/quilkin/extensions/filters/matches/v1alpha1/matches.proto new file mode 100644 index 0000000000..2710dc24f5 --- /dev/null +++ b/proto/quilkin/extensions/filters/matches/v1alpha1/matches.proto @@ -0,0 +1,49 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package quilkin.extensions.filters.matches.v1alpha1; + +import "google/protobuf/wrappers.proto"; +import "google/protobuf/struct.proto"; +import "google/protobuf/any.proto"; + +message Matches { + message Branch { + google.protobuf.Value value = 1; + google.protobuf.StringValue filter = 2; + optional google.protobuf.Any config = 3; + } + + message FallthroughFilter { + google.protobuf.StringValue filter = 2; + optional google.protobuf.Any config = 3; + } + + message DirectionalConfig { + google.protobuf.StringValue metadata_key = 1; + repeated Branch branches = 2; + oneof fallthrough { + google.protobuf.NullValue pass = 3; + google.protobuf.NullValue drop = 4; + FallthroughFilter filter = 5; + } + } + + optional DirectionalConfig on_read = 1; + optional DirectionalConfig on_write = 2; +} diff --git a/src/config/config_type.rs b/src/config/config_type.rs index 09788a9559..db42137884 100644 --- a/src/config/config_type.rs +++ b/src/config/config_type.rs @@ -22,15 +22,15 @@ use crate::filters::{ConvertProtoConfigError, Error}; /// The configuration of a [`Filter`][crate::filters::Filter] from either a /// static or dynamic source. -#[derive(Debug)] -pub enum ConfigType<'a> { +#[derive(Clone, Debug, PartialEq)] +pub enum ConfigType { /// Static configuration from YAML. - Static(&'a serde_yaml::Value), + Static(serde_yaml::Value), /// Dynamic configuration from Protobuf. Dynamic(prost_types::Any), } -impl ConfigType<'_> { +impl ConfigType { /// Deserializes takes two type arguments `Static` and `Dynamic` representing /// the types of a static and dynamic configuration respectively. /// @@ -57,7 +57,7 @@ impl ConfigType<'_> { + TryFrom, { match self { - ConfigType::Static(config) => serde_yaml::to_string(config) + ConfigType::Static(ref config) => serde_yaml::to_string(config) .and_then(|raw_config| serde_yaml::from_str::(raw_config.as_str())) .map_err(|err| { Error::DeserializeFailed(format!( @@ -95,6 +95,29 @@ impl ConfigType<'_> { } } +impl<'de> serde::Deserialize<'de> for ConfigType { + fn deserialize(de: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde_yaml::Value::deserialize(de).map(ConfigType::Static) + } +} + +impl<'de> serde::Serialize for ConfigType { + fn serialize(&self, ser: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Static(value) => value.serialize(ser), + Self::Dynamic(_) => Err(serde::ser::Error::custom( + "Protobuf configs can't be serialized.", + )), + } + } +} + #[cfg(test)] mod tests { use super::ConfigType; diff --git a/src/config/error.rs b/src/config/error.rs index 22e011b1d8..1e717ba5d0 100644 --- a/src/config/error.rs +++ b/src/config/error.rs @@ -24,7 +24,7 @@ pub struct ValueInvalidArgs { } /// Validation failure for a Config -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub enum ValidationError { NotUnique(String), EmptyList(String), diff --git a/src/filters.rs b/src/filters.rs index 01d6f5415e..74f2aa7ca6 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -33,6 +33,7 @@ pub mod debug; pub mod firewall; pub mod load_balancer; pub mod local_rate_limit; +pub mod matches; pub mod metadata; pub mod token_router; diff --git a/src/filters/capture_bytes.rs b/src/filters/capture_bytes.rs index b96ae89fc2..60c5be458d 100644 --- a/src/filters/capture_bytes.rs +++ b/src/filters/capture_bytes.rs @@ -121,7 +121,8 @@ mod tests { use super::capture::{Capture, Prefix, Suffix}; use crate::filters::{ - metadata::CAPTURED_BYTES, CreateFilterArgs, Filter, FilterFactory, ReadContext, + metadata::CAPTURED_BYTES, CreateFilterArgs, Filter, FilterFactory, FilterRegistry, + ReadContext, }; const TOKEN_KEY: &str = "TOKEN"; @@ -147,8 +148,9 @@ mod tests { let filter = factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&Value::Mapping(map)), + Some(Value::Mapping(map)), )) .unwrap() .filter; @@ -162,8 +164,9 @@ mod tests { map.insert(Value::String("size".into()), Value::Number(3.into())); let filter = factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&Value::Mapping(map)), + Some(Value::Mapping(map)), )) .unwrap() .filter; @@ -177,8 +180,9 @@ mod tests { map.insert(Value::String("size".into()), Value::String("WRONG".into())); let result = factory.create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&Value::Mapping(map)), + Some(Value::Mapping(map)), )); assert!(result.is_err(), "Should be an error"); } diff --git a/src/filters/chain.rs b/src/filters/chain.rs index b0babdbb27..c8644e5a77 100644 --- a/src/filters/chain.rs +++ b/src/filters/chain.rs @@ -119,8 +119,12 @@ impl FilterChain { for filter_config in filter_configs { match filter_registry.get( &filter_config.name, - CreateFilterArgs::fixed(metrics_registry.clone(), filter_config.config.as_ref()) - .with_metrics_registry(metrics_registry.clone()), + CreateFilterArgs::fixed( + filter_registry.clone(), + metrics_registry.clone(), + filter_config.config, + ) + .with_metrics_registry(metrics_registry.clone()), ) { Ok(filter) => filters.push((filter_config.name, filter)), Err(err) => { diff --git a/src/filters/compress.rs b/src/filters/compress.rs index 27a172659c..a691f67354 100644 --- a/src/filters/compress.rs +++ b/src/filters/compress.rs @@ -181,7 +181,7 @@ mod tests { use crate::endpoint::{Endpoint, Endpoints, UpstreamEndpoints}; use crate::filters::{ compress::{compressor::Snappy, Compressor}, - CreateFilterArgs, Filter, FilterFactory, ReadContext, WriteContext, + CreateFilterArgs, Filter, FilterFactory, FilterRegistry, ReadContext, WriteContext, }; use super::quilkin::extensions::filters::compress::v1alpha1::{ @@ -293,8 +293,9 @@ mod tests { ); let filter = factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&Value::Mapping(map)), + Some(Value::Mapping(map)), )) .expect("should create a filter") .filter; @@ -315,7 +316,8 @@ mod tests { Value::String("COMPRESS".into()), ); let config = Value::Mapping(map); - let args = CreateFilterArgs::fixed(Registry::default(), Some(&config)); + let args = + CreateFilterArgs::fixed(FilterRegistry::default(), Registry::default(), Some(config)); let filter = factory .create_filter(args) diff --git a/src/filters/debug.rs b/src/filters/debug.rs index 2ad86a40c9..08d453034d 100644 --- a/src/filters/debug.rs +++ b/src/filters/debug.rs @@ -115,6 +115,7 @@ impl TryFrom for Config { #[cfg(test)] mod tests { + use crate::filters::FilterRegistry; use crate::test_utils::{assert_filter_read_no_change, assert_write_no_change}; use serde_yaml::Mapping; use serde_yaml::Value; @@ -148,8 +149,9 @@ mod tests { map.insert(Value::from("id"), Value::from("name")); assert!(factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&Value::Mapping(map)), + Some(Value::Mapping(map)), )) .is_ok()); } @@ -162,8 +164,9 @@ mod tests { map.insert(Value::from("id"), Value::from("name")); assert!(factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&Value::Mapping(map)), + Some(Value::Mapping(map)), )) .is_ok()); } @@ -176,8 +179,9 @@ mod tests { map.insert(Value::from("id"), Value::Sequence(vec![])); assert!(factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&Value::Mapping(map)) + Some(Value::Mapping(map)) )) .is_err()); } diff --git a/src/filters/error.rs b/src/filters/error.rs index f8d47648da..dc55c6e384 100644 --- a/src/filters/error.rs +++ b/src/filters/error.rs @@ -66,9 +66,9 @@ pub struct ConvertProtoConfigError { } impl ConvertProtoConfigError { - pub fn new(reason: impl Into, field: Option) -> Self { + pub fn new(reason: impl std::fmt::Display, field: Option) -> Self { Self { - reason: reason.into(), + reason: reason.to_string(), field, } } diff --git a/src/filters/factory.rs b/src/filters/factory.rs index d36ecd6aa2..cb929d2a0c 100644 --- a/src/filters/factory.rs +++ b/src/filters/factory.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use crate::{ config::ConfigType, - filters::{Error, Filter}, + filters::{Error, Filter, FilterRegistry}, }; /// An owned pointer to a dynamic [`FilterFactory`] instance. @@ -63,45 +63,61 @@ pub trait FilterFactory: Sync + Send { /// Returns the [`ConfigType`] from the provided Option, otherwise it returns /// Error::MissingConfig if the Option is None. - fn require_config<'a, 'b>( - &'a self, - config: Option>, - ) -> Result, Error> { + fn require_config(&self, config: Option) -> Result { config.ok_or_else(|| Error::MissingConfig(self.name())) } } /// Arguments needed to create a new filter. -pub struct CreateFilterArgs<'a> { +pub struct CreateFilterArgs { /// Configuration for the filter. - pub config: Option>, + pub config: Option, + /// Used if the filter needs to reference or use other filters. + pub filter_registry: FilterRegistry, /// metrics_registry is used to register filter metrics collectors. pub metrics_registry: Registry, } -impl CreateFilterArgs<'_> { +impl CreateFilterArgs { + /// Create a new instance of [`CreateFilterArgs`]. + pub fn new( + filter_registry: FilterRegistry, + metrics_registry: Registry, + config: Option, + ) -> CreateFilterArgs { + Self { + config, + filter_registry, + metrics_registry, + } + } + /// Creates a new instance of [`CreateFilterArgs`] using a /// fixed [`ConfigType`]. pub fn fixed( + filter_registry: FilterRegistry, metrics_registry: Registry, - config: Option<&serde_yaml::Value>, + config: Option, ) -> CreateFilterArgs { - CreateFilterArgs { - config: config.map(ConfigType::Static), + Self::new( + filter_registry, metrics_registry, - } + config.map(ConfigType::Static), + ) } /// Creates a new instance of [`CreateFilterArgs`] using a /// dynamic [`ConfigType`]. pub fn dynamic( + filter_registry: FilterRegistry, metrics_registry: Registry, config: Option, - ) -> CreateFilterArgs<'static> { - CreateFilterArgs { - config: config.map(ConfigType::Dynamic), + ) -> CreateFilterArgs { + CreateFilterArgs::new( + filter_registry, metrics_registry, - } + config.map(ConfigType::Dynamic), + ) } /// Consumes `self` and returns a new instance of [`Self`] using diff --git a/src/filters/load_balancer.rs b/src/filters/load_balancer.rs index dc981c9911..3d36ae95df 100644 --- a/src/filters/load_balancer.rs +++ b/src/filters/load_balancer.rs @@ -72,7 +72,7 @@ mod tests { endpoint::{Endpoint, EndpointAddress, Endpoints}, filters::{ load_balancer::LoadBalancerFilterFactory, CreateFilterArgs, Filter, FilterFactory, - ReadContext, + FilterRegistry, ReadContext, }, }; use prometheus::Registry; @@ -81,8 +81,9 @@ mod tests { let factory = LoadBalancerFilterFactory; factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&serde_yaml::from_str(config).unwrap()), + Some(serde_yaml::from_str(config).unwrap()), )) .unwrap() .filter diff --git a/src/filters/local_rate_limit.rs b/src/filters/local_rate_limit.rs index cbcc3c6a0f..1b59a5bdd1 100644 --- a/src/filters/local_rate_limit.rs +++ b/src/filters/local_rate_limit.rs @@ -276,8 +276,9 @@ period: 0 "; let err = factory .create_filter(CreateFilterArgs { - config: Some(ConfigType::Static(&serde_yaml::from_str(config).unwrap())), + config: Some(ConfigType::Static(serde_yaml::from_str(config).unwrap())), metrics_registry: Default::default(), + filter_registry: crate::filters::FilterRegistry::default(), }) .err() .unwrap(); diff --git a/src/filters/matches.rs b/src/filters/matches.rs new file mode 100644 index 0000000000..f579060f02 --- /dev/null +++ b/src/filters/matches.rs @@ -0,0 +1,452 @@ +crate::include_proto!("quilkin.extensions.filters.matches.v1alpha1"); + +use self::quilkin::extensions::filters::matches::v1alpha1 as proto; +use crate::{ + config::ConfigType, + filters::{prelude::*, registry::FilterRegistry}, + metadata::Value, +}; + +pub const NAME: &str = "quilkin.extensions.filters.matches.v1alpha1.Matches"; + +/// Creates a new factory for generating match filters. +pub fn factory() -> DynFilterFactory { + Box::from(MatchesFactory::new()) +} + +struct FilterConfig { + metadata_key: String, + branches: Vec<(Value, FilterInstance)>, + fallthrough: FallthroughInstance, +} + +impl FilterConfig { + fn new( + config: DirectionalConfig, + filter_registry: FilterRegistry, + metrics_registry: prometheus::Registry, + ) -> Result { + let map_to_instance = |filter: &String, config_type: Option| { + let args = CreateFilterArgs::new( + filter_registry.clone(), + metrics_registry.clone(), + config_type, + ) + .with_metrics_registry(metrics_registry.clone()); + + filter_registry.get(filter, args) + }; + + let branches = config + .branches + .iter() + .map(|branch| { + map_to_instance(&branch.filter, branch.config.clone()) + .map(|instance| (branch.value.clone(), instance)) + }) + .collect::>()?; + + Ok(Self { + metadata_key: config.metadata_key, + branches, + fallthrough: match config.fallthrough { + Fallthrough::Pass => FallthroughInstance::Pass, + Fallthrough::Drop => FallthroughInstance::Drop, + Fallthrough::Filter { filter, config } => { + map_to_instance(&filter, config).map(FallthroughInstance::Filter)? + } + }, + }) + } +} + +pub enum FallthroughInstance { + Pass, + Drop, + Filter(FilterInstance), +} + +struct Matches { + on_read_filters: Option, + on_write_filters: Option, +} + +impl Matches { + fn new( + config: Config, + filter_registry: FilterRegistry, + metrics_registry: prometheus::Registry, + ) -> Result { + let on_read_filters = config + .on_read + .map(|config| { + FilterConfig::new(config, filter_registry.clone(), metrics_registry.clone()) + }) + .transpose()?; + + let on_write_filters = config + .on_write + .map(|config| { + FilterConfig::new(config, filter_registry.clone(), metrics_registry.clone()) + }) + .transpose()?; + + if on_read_filters.is_none() && on_write_filters.is_none() { + return Err(Error::MissingConfig(NAME)); + } + + Ok(Self { + on_read_filters, + on_write_filters, + }) + } +} + +fn match_filter<'config, Ctx, R>( + config: &'config Option, + ctx: Ctx, + get_metadata: impl for<'ctx> Fn(&'ctx Ctx, &'config String) -> Option<&'ctx Value>, + and_then: impl Fn(Ctx, &'config FilterInstance) -> Option, +) -> Option +where + Ctx: Into, +{ + match config { + Some(config) => { + let value = (get_metadata)(&ctx, &config.metadata_key)?; + + match config.branches.iter().find(|(key, _)| key == value) { + Some((_, instance)) => (and_then)(ctx, instance), + None => match &config.fallthrough { + FallthroughInstance::Pass => Some(ctx.into()), + FallthroughInstance::Drop => None, + FallthroughInstance::Filter(instance) => (and_then)(ctx, instance), + }, + } + } + None => Some(ctx.into()), + } +} + +impl Filter for Matches { + #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] + fn read(&self, ctx: ReadContext) -> Option { + match_filter( + &self.on_read_filters, + ctx, + |ctx, metadata_key| ctx.metadata.get(metadata_key), + |ctx, instance| instance.filter.read(ctx), + ) + } + + #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] + fn write(&self, ctx: WriteContext) -> Option { + match_filter( + &self.on_write_filters, + ctx, + |ctx, metadata_key| ctx.metadata.get(metadata_key), + |ctx, instance| instance.filter.write(ctx), + ) + } +} + +struct MatchesFactory; + +impl MatchesFactory { + pub fn new() -> Self { + MatchesFactory + } +} + +impl FilterFactory for MatchesFactory { + fn name(&self) -> &'static str { + NAME + } + + fn create_filter(&self, args: CreateFilterArgs) -> Result { + let (config_json, config) = self + .require_config(args.config)? + .deserialize::(self.name())?; + + let filter = Matches::new(config, args.filter_registry, args.metrics_registry)?; + Ok(FilterInstance::new( + config_json, + Box::new(filter) as Box, + )) + } +} + +/// Configuration for the [`factory`]. +#[derive(Debug, serde::Deserialize, serde::Serialize, PartialEq)] +#[serde(deny_unknown_fields)] +pub struct Config { + /// Configuration for [`Filter::read`]. + pub on_read: Option, + /// Configuration for [`Filter::write`]. + pub on_write: Option, +} + +impl TryFrom for Config { + type Error = ConvertProtoConfigError; + + fn try_from(value: proto::Matches) -> Result { + Ok(Self { + on_read: value + .on_read + .map(proto::matches::DirectionalConfig::try_into) + .transpose() + .map_err(|error: eyre::Report| { + ConvertProtoConfigError::new(error, Some("on_read".into())) + })?, + on_write: value + .on_write + .map(proto::matches::DirectionalConfig::try_into) + .transpose() + .map_err(|error: eyre::Report| { + ConvertProtoConfigError::new(error, Some("on_write".into())) + })?, + }) + } +} + +impl TryFrom for DirectionalConfig { + type Error = eyre::Report; + + fn try_from(value: proto::matches::DirectionalConfig) -> Result { + Ok(Self { + metadata_key: value.metadata_key.ok_or_else(|| { + ConvertProtoConfigError::new("Missing", Some("metadata_key".into())) + })?, + branches: value + .branches + .into_iter() + .map(proto::matches::Branch::try_into) + .collect::>()?, + fallthrough: value + .fallthrough + .ok_or_else(|| ConvertProtoConfigError::new("Missing", Some("fallthrough".into())))? + .try_into()?, + }) + } +} + +/// Configuration for a specific direction. +#[derive(Debug, serde::Deserialize, serde::Serialize, PartialEq)] +pub struct DirectionalConfig { + /// The key for the metadata to compare against. + #[serde(rename = "metadataKey")] + pub metadata_key: String, + /// List of filters to compare and potentially run if any match. + pub branches: Vec, + /// The behaviour for when none of the `branches` match. + #[serde(default)] + pub fallthrough: Fallthrough, +} + +/// A specific match branch. The filter is run when `value` matches the value +/// defined in `metadata_key`. +#[derive(Debug, serde::Deserialize, serde::Serialize, PartialEq)] +pub struct Branch { + /// The value to compare against the dynamic metadata. + pub value: crate::metadata::Value, + /// The identifier of the filter to run on successful matches. + pub filter: String, + /// The configuration for the filter, if any. + pub config: Option, +} + +impl TryFrom for Branch { + type Error = eyre::Report; + + fn try_from(branch: proto::matches::Branch) -> Result { + Ok(Self { + value: branch + .value + .ok_or_else(|| ConvertProtoConfigError::new("Missing", Some("value".into())))? + .try_into()?, + filter: branch + .filter + .ok_or_else(|| ConvertProtoConfigError::new("Missing", Some("filter".into())))?, + config: branch.config.map(ConfigType::Dynamic), + }) + } +} + +/// The behaviour when the none of branches match. +#[derive(Debug, PartialEq)] +pub enum Fallthrough { + /// The packet will be passed onto the next filter. + Pass, + /// The packet will be dropped. **Default behaviour** + Drop, + /// The filter specified in `filter` will be called. + Filter { + /// The identifier for the filter to run. + filter: String, + /// The configuration for the filter to run, if any. + config: Option, + }, +} + +impl Default for Fallthrough { + fn default() -> Self { + Self::Drop + } +} + +impl TryFrom for Fallthrough { + type Error = eyre::Report; + + fn try_from( + branch: proto::matches::directional_config::Fallthrough, + ) -> Result { + use proto::matches::directional_config::Fallthrough as ProtoFallthrough; + + Ok(match branch { + ProtoFallthrough::Pass(_) => Self::Pass, + ProtoFallthrough::Drop(_) => Self::Drop, + ProtoFallthrough::Filter(filter) => Self::Filter { + filter: filter.filter.ok_or_else(|| { + eyre::eyre!("missing `filter` field in Fallthrough configuration") + })?, + config: filter.config.map(ConfigType::Dynamic), + }, + }) + } +} + +impl serde::Serialize for Fallthrough { + fn serialize(&self, ser: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Pass => ser.serialize_str("PASS"), + Self::Drop => ser.serialize_str("DROP"), + Self::Filter { filter, config } => { + use serde::ser::SerializeMap; + + let mut map = ser.serialize_map(Some(2))?; + + map.serialize_entry("filter", filter)?; + map.serialize_entry("config", config)?; + + map.end() + } + } + } +} + +impl<'de> serde::Deserialize<'de> for Fallthrough { + fn deserialize(de: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct FallthroughVisitor; + + impl<'de> serde::de::Visitor<'de> for FallthroughVisitor { + type Value = Fallthrough; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { + f.write_str("`pass`, `drop`, or an object containing a `filter` field and optionally `config` field") + } + + fn visit_borrowed_str(self, string: &'de str) -> Result + where + E: serde::de::Error, + { + self.visit_str(string) + } + + fn visit_string(self, string: String) -> Result + where + E: serde::de::Error, + { + self.visit_str(&string) + } + + fn visit_str(self, string: &str) -> Result + where + E: serde::de::Error, + { + match &*string.to_lowercase() { + "pass" => Ok(Fallthrough::Pass), + "drop" => Ok(Fallthrough::Drop), + _ => Err(serde::de::Error::custom("invalid fallthrough type.")), + } + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + const CONFIG_FIELD: &str = "config"; + const FILTER_FIELD: &str = "filter"; + let mut config = None; + let mut filter = None; + loop { + match map.next_key::()?.as_deref() { + Some(CONFIG_FIELD) => { + if config.replace(map.next_value()?).is_some() { + return Err(serde::de::Error::duplicate_field(CONFIG_FIELD)); + } + } + Some(FILTER_FIELD) => { + if filter.replace(map.next_value()?).is_some() { + return Err(serde::de::Error::duplicate_field(FILTER_FIELD)); + } + } + Some(field) => { + return Err(serde::de::Error::unknown_field( + field, + &[CONFIG_FIELD, FILTER_FIELD], + )) + } + None => break, + } + } + + Ok(Fallthrough::Filter { + filter: filter.ok_or_else(|| serde::de::Error::missing_field("filter"))?, + config, + }) + } + } + + de.deserialize_any(FallthroughVisitor) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serde() { + let matches_yaml = " +on_read: + metadataKey: quilkin.dev/captured_bytes + branches: + - value: abc + filter: quilkin.extensions.filters.debug.v1alpha1.Debug + "; + + let config = serde_yaml::from_str::(matches_yaml).unwrap(); + + assert_eq!( + config, + Config { + on_read: Some(DirectionalConfig { + metadata_key: "quilkin.dev/captured_bytes".into(), + branches: vec![Branch { + value: String::from("abc").into(), + filter: "quilkin.extensions.filters.debug.v1alpha1.Debug".into(), + config: None, + }], + fallthrough: Fallthrough::Drop, + }), + on_write: None, + } + ) + } +} diff --git a/src/filters/registry.rs b/src/filters/registry.rs index 48f449760a..c5596a641b 100644 --- a/src/filters/registry.rs +++ b/src/filters/registry.rs @@ -82,7 +82,7 @@ mod tests { match reg.get( &String::from("not.found"), - CreateFilterArgs::fixed(Registry::default(), None), + CreateFilterArgs::fixed(reg.clone(), Registry::default(), None), ) { Ok(_) => unreachable!("should not be filter"), Err(err) => assert_eq!(Error::NotFound("not.found".to_string()), err), @@ -91,14 +91,14 @@ mod tests { assert!(reg .get( &String::from("TestFilter"), - CreateFilterArgs::fixed(Registry::default(), None) + CreateFilterArgs::fixed(reg.clone(), Registry::default(), None) ) .is_ok()); let filter = reg .get( &String::from("TestFilter"), - CreateFilterArgs::fixed(Registry::default(), None), + CreateFilterArgs::fixed(reg.clone(), Registry::default(), None), ) .unwrap() .filter; diff --git a/src/filters/set.rs b/src/filters/set.rs index 45d33699a6..7752fe7496 100644 --- a/src/filters/set.rs +++ b/src/filters/set.rs @@ -52,14 +52,15 @@ impl FilterSet { pub fn default_with(filters: impl IntoIterator) -> Self { Self::with( std::array::IntoIter::new([ - filters::debug::factory(), - filters::local_rate_limit::factory(), - filters::concatenate_bytes::factory(), - filters::load_balancer::factory(), filters::capture_bytes::factory(), - filters::token_router::factory(), filters::compress::factory(), + filters::concatenate_bytes::factory(), + filters::debug::factory(), filters::firewall::factory(), + filters::load_balancer::factory(), + filters::local_rate_limit::factory(), + filters::matches::factory(), + filters::token_router::factory(), ]) .chain(filters), ) diff --git a/src/filters/token_router.rs b/src/filters/token_router.rs index ad52972bc7..74b2e7aca1 100644 --- a/src/filters/token_router.rs +++ b/src/filters/token_router.rs @@ -190,7 +190,8 @@ mod tests { default_metadata_key, Config, Metrics, ProtoConfig, TokenRouter, TokenRouterFactory, }; use crate::filters::{ - metadata::CAPTURED_BYTES, CreateFilterArgs, Filter, FilterFactory, ReadContext, + metadata::CAPTURED_BYTES, CreateFilterArgs, Filter, FilterFactory, FilterRegistry, + ReadContext, }; const TOKEN_KEY: &str = "TOKEN"; @@ -244,8 +245,9 @@ mod tests { let filter = factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&serde_yaml::Value::Mapping(map)), + Some(serde_yaml::Value::Mapping(map)), )) .unwrap() .filter; @@ -264,8 +266,9 @@ mod tests { let filter = factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), Registry::default(), - Some(&serde_yaml::Value::Mapping(map)), + Some(serde_yaml::Value::Mapping(map)), )) .unwrap() .filter; @@ -282,7 +285,11 @@ mod tests { let factory = TokenRouterFactory::new(); let filter = factory - .create_filter(CreateFilterArgs::fixed(Registry::default(), None)) + .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), + Registry::default(), + None, + )) .unwrap() .filter; let mut ctx = new_ctx(); diff --git a/src/metadata.rs b/src/metadata.rs index ef1fee1772..c261a0a083 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -23,7 +23,8 @@ pub type DynamicMetadata = HashMap, Value>; pub const KEY: &str = "quilkin.dev"; -#[derive(Clone, Debug, PartialEq, PartialOrd)] +#[derive(Clone, Debug, PartialOrd, serde::Serialize, serde::Deserialize, Eq, Ord)] +#[serde(untagged)] pub enum Value { Bool(bool), Number(u64), @@ -61,6 +62,87 @@ impl Value { } } +/// Convenience macro for generating From implementations. +macro_rules! from_value { + (($name:ident) { $($typ:ty => $ex:expr),+ $(,)? }) => { + $( + impl From<$typ> for Value { + fn from($name: $typ) -> Self { + $ex + } + } + )+ + } +} + +from_value! { + (value) { + bool => Self::Bool(value), + u64 => Self::Number(value), + Vec => Self::List(value), + String => Self::String(value), + &str => Self::String(value.into()), + bytes::Bytes => Self::Bytes(value), + } +} + +impl From<[u8; N]> for Value { + fn from(value: [u8; N]) -> Self { + Self::Bytes(bytes::Bytes::copy_from_slice(&value)) + } +} + +impl From<&[u8; N]> for Value { + fn from(value: &[u8; N]) -> Self { + Self::Bytes(bytes::Bytes::copy_from_slice(value)) + } +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Bool(a), Self::Bool(b)) => a == b, + (Self::Bool(_), _) => false, + (Self::Number(a), Self::Number(b)) => a == b, + (Self::Number(_), _) => false, + (Self::List(a), Self::List(b)) => a == b, + (Self::List(_), _) => false, + (Self::String(a), Self::String(b)) => a == b, + (Self::Bytes(a), Self::Bytes(b)) => a == b, + (Self::String(a), Self::Bytes(b)) | (Self::Bytes(b), Self::String(a)) => a == b, + (Self::String(_), _) => false, + (Self::Bytes(_), _) => false, + } + } +} + +impl TryFrom for Value { + type Error = eyre::Report; + + fn try_from(value: prost_types::Value) -> Result { + use prost_types::value::Kind; + + let value = match value.kind { + Some(value) => value, + None => return Err(eyre::eyre!("unexpected missing value")), + }; + + match value { + Kind::NullValue(_) => Err(eyre::eyre!("unexpected missing value")), + Kind::NumberValue(number) => Ok(Self::Number(number as u64)), + Kind::StringValue(string) => Ok(Self::String(string)), + Kind::BoolValue(value) => Ok(Self::Bool(value)), + Kind::ListValue(list) => Ok(Self::List( + list.values + .into_iter() + .map(prost_types::Value::try_into) + .collect::>()?, + )), + Kind::StructValue(_) => Err(eyre::eyre!("unexpected struct value")), + } + } +} + /// Represents a view into the metadata object attached to another object. `T` /// represents metadata known to Quilkin under `quilkin.dev` (available under /// the [`KEY`] constant.) diff --git a/src/proxy/config_dump.rs b/src/proxy/config_dump.rs index fb3fd2fc37..fafc6f2e2c 100644 --- a/src/proxy/config_dump.rs +++ b/src/proxy/config_dump.rs @@ -111,8 +111,7 @@ mod tests { use super::handle_request; use crate::cluster::cluster_manager::ClusterManager; use crate::endpoint::{Endpoint, Endpoints}; - use crate::filters::manager::FilterManager; - use crate::filters::{CreateFilterArgs, FilterChain}; + use crate::filters::{manager::FilterManager, CreateFilterArgs, FilterChain, FilterRegistry}; use prometheus::Registry; use std::sync::Arc; @@ -124,11 +123,12 @@ mod tests { Endpoints::new(vec![Endpoint::new(([127, 0, 0, 1], 8080).into())]).unwrap(), ) .unwrap(); - let debug_config = &serde_yaml::from_str("id: hello").unwrap(); + let debug_config = serde_yaml::from_str("id: hello").unwrap(); let debug_factory = crate::filters::debug::factory(); let debug_filter = debug_factory .create_filter(CreateFilterArgs::fixed( + FilterRegistry::default(), registry.clone(), Some(debug_config), )) diff --git a/src/proxy/health.rs b/src/proxy/health.rs index 3ecdf6e07f..84bef99f2e 100644 --- a/src/proxy/health.rs +++ b/src/proxy/health.rs @@ -34,7 +34,7 @@ impl Health { let healthy = health.healthy.clone(); let default_hook = panic::take_hook(); panic::set_hook(Box::new(move |panic_info| { - tracing::error!("Panic has occurred. Moving to Unhealthy"); + tracing::error!(%panic_info, "Panic has occurred. Moving to Unhealthy"); healthy.swap(false, Relaxed); default_hook(panic_info); })); diff --git a/src/xds/listener.rs b/src/xds/listener.rs index 8819ae694c..0c3ed2015b 100644 --- a/src/xds/listener.rs +++ b/src/xds/listener.rs @@ -147,8 +147,11 @@ impl ListenerManager { }) .transpose()?; - let create_filter_args = - CreateFilterArgs::dynamic(self.metrics_registry.clone(), config); + let create_filter_args = CreateFilterArgs::dynamic( + self.filter_registry.clone(), + self.metrics_registry.clone(), + config, + ); let name = filter.name; let filter = self.filter_registry.get(&name, create_filter_args)?; diff --git a/tests/matches.rs b/tests/matches.rs new file mode 100644 index 0000000000..d54d64bf0d --- /dev/null +++ b/tests/matches.rs @@ -0,0 +1,128 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use tokio::time::{timeout, Duration}; + +use quilkin::{ + config::{Builder, Filter}, + endpoint::Endpoint, + filters::{capture_bytes, matches}, + test_utils::TestHelper, +}; + +#[tokio::test] +async fn matches() { + let mut t = TestHelper::default(); + let echo = t.run_echo_server().await; + + let capture_yaml = " +size: 3 +remove: true +"; + + let matches_yaml = " +on_read: + metadataKey: quilkin.dev/captured_bytes + fallthrough: + filter: quilkin.extensions.filters.concatenate_bytes.v1alpha1.ConcatenateBytes + config: + on_read: APPEND + bytes: ZGVm + branches: + - value: abc + filter: quilkin.extensions.filters.concatenate_bytes.v1alpha1.ConcatenateBytes + config: + on_read: APPEND + bytes: eHl6 # xyz + - value: xyz + filter: quilkin.extensions.filters.concatenate_bytes.v1alpha1.ConcatenateBytes + config: + on_read: APPEND + bytes: YWJj # abc +"; + let server_port = 12348; + let server_config = Builder::empty() + .with_port(server_port) + .with_static( + vec![ + Filter { + name: capture_bytes::factory().name().into(), + config: serde_yaml::from_str(capture_yaml).unwrap(), + }, + Filter { + name: matches::factory().name().into(), + config: serde_yaml::from_str(matches_yaml).unwrap(), + }, + ], + vec![Endpoint::new(echo)], + ) + .build(); + t.run_server_with_config(server_config); + + let (mut recv_chan, socket) = t.open_socket_and_recv_multiple_packets().await; + + let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port); + + // abc packet + let msg = b"helloabc"; + socket.send_to(msg, &local_addr).await.unwrap(); + + assert_eq!( + "helloxyz", + timeout(Duration::from_secs(5), recv_chan.recv()) + .await + .expect("should have received a packet") + .unwrap() + ); + + // send an xyz packet + let msg = b"helloxyz"; + socket.send_to(msg, &local_addr).await.unwrap(); + + assert_eq!( + "helloabc", + timeout(Duration::from_secs(5), recv_chan.recv()) + .await + .expect("should have received a packet") + .unwrap() + ); + + // fallthrough packet + let msg = b"hellodef"; + socket.send_to(msg, &local_addr).await.unwrap(); + + assert_eq!( + "hellodef", + timeout(Duration::from_secs(5), recv_chan.recv()) + .await + .expect("should have received a packet") + .unwrap() + ); + + // second fallthrough packet + let msg = b"hellofgh"; + socket.send_to(msg, &local_addr).await.unwrap(); + + assert_eq!( + "hellodef", + timeout(Duration::from_secs(5), recv_chan.recv()) + .await + .expect("should have received a packet") + .unwrap() + ); +}