diff --git a/Cargo.toml b/Cargo.toml index aecf2a8806..b4ddd1e394 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ uuid = { version = "0.8.2", default-features = false, features = ["v4"] } thiserror = "1.0.30" eyre = "0.6.5" stable-eyre = "0.2.2" +ipnetwork = "0.18.0" [target.'cfg(target_os = "linux")'.dependencies] sys-info = "0.9.0" diff --git a/build.rs b/build.rs index 52beaf217c..72a86e4cfe 100644 --- a/build.rs +++ b/build.rs @@ -35,6 +35,7 @@ fn main() -> Result<(), Box> { "proto/quilkin/extensions/filters/load_balancer/v1alpha1/load_balancer.proto", "proto/quilkin/extensions/filters/local_rate_limit/v1alpha1/local_rate_limit.proto", "proto/quilkin/extensions/filters/token_router/v1alpha1/token_router.proto", + "proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto", ] .iter() .map(|name| std::env::current_dir().unwrap().join(name)) diff --git a/proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto b/proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto new file mode 100644 index 0000000000..9df512e75b --- /dev/null +++ b/proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package quilkin.extensions.filters.firewall.v1alpha1; + +message Firewall { + enum Action { + Allow = 0; + Deny = 1; + } + + message PortRange { + uint32 min = 1; + uint32 max = 2; + } + + message Rule { + Action action = 1; + string source = 2; + repeated PortRange ports = 3; + } + + repeated Rule on_read = 1; + repeated Rule on_write = 2; +} + diff --git a/src/filters.rs b/src/filters.rs index fe4054ecdd..01d6f5415e 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -30,6 +30,7 @@ pub mod capture_bytes; pub mod compress; pub mod concatenate_bytes; pub mod debug; +pub mod firewall; pub mod load_balancer; pub mod local_rate_limit; pub mod metadata; diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs new file mode 100644 index 0000000000..3afe7a5fee --- /dev/null +++ b/src/filters/firewall.rs @@ -0,0 +1,222 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Filter for allowing/blocking traffic by IP and port. + +use slog::{debug, o, Logger}; + +use crate::filters::firewall::metrics::Metrics; +use crate::filters::prelude::*; + +use self::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; + +crate::include_proto!("quilkin.extensions.filters.firewall.v1alpha1"); + +mod config; +mod metrics; + +pub use config::{Action, Config, PortRange, PortRangeError, Rule}; + +pub const NAME: &str = "quilkin.extensions.filters.compress.v1alpha1.Firewall"; + +pub fn factory(base: &Logger) -> DynFilterFactory { + Box::from(FirewallFactory::new(base)) +} + +struct FirewallFactory { + log: Logger, +} + +impl FirewallFactory { + pub fn new(base: &Logger) -> Self { + Self { log: base.clone() } + } +} + +impl FilterFactory for FirewallFactory { + fn name(&self) -> &'static str { + NAME + } + + fn create_filter(&self, args: CreateFilterArgs) -> Result { + let (config_json, config) = self + .require_config(args.config)? + .deserialize::(self.name())?; + + let filter = Firewall::new(&self.log, config, Metrics::new(&args.metrics_registry)?); + Ok(FilterInstance::new( + config_json, + Box::new(filter) as Box, + )) + } +} + +struct Firewall { + log: Logger, + metrics: Metrics, + on_read: Vec, + on_write: Vec, +} + +impl Firewall { + fn new(base: &Logger, config: Config, metrics: Metrics) -> Self { + Self { + log: base.new(o!("source" => "extensions::Firewall")), + metrics, + on_read: config.on_read, + on_write: config.on_write, + } + } +} + +impl Filter for Firewall { + fn read(&self, ctx: ReadContext) -> Option { + for rule in &self.on_read { + if rule.contains(ctx.from) { + return match rule.action { + Action::Allow => { + debug!(self.log, "Allow"; "event" => "read", "from" => ctx.from.to_string()); + self.metrics.packets_allowed_read.inc(); + Some(ctx.into()) + } + Action::Deny => { + debug!(self.log, "Deny"; "event" => "read", "from" => ctx.from ); + self.metrics.packets_denied_read.inc(); + None + } + }; + } + } + + debug!(self.log, "default: Deny"; "event" => "read", "from" => ctx.from.to_string()); + self.metrics.packets_denied_read.inc(); + None + } + + fn write(&self, ctx: WriteContext) -> Option { + for rule in &self.on_write { + if rule.contains(ctx.from) { + return match rule.action { + Action::Allow => { + debug!(self.log, "Allow"; "event" => "write", "from" => ctx.from.to_string()); + self.metrics.packets_allowed_write.inc(); + Some(ctx.into()) + } + Action::Deny => { + debug!(self.log, "Deny"; "event" => "write", "from" => ctx.from ); + self.metrics.packets_denied_write.inc(); + None + } + }; + } + } + + debug!(self.log, "default: Deny"; "event" => "write", "from" => ctx.from.to_string()); + self.metrics.packets_denied_write.inc(); + None + } +} +#[cfg(test)] +mod tests { + use prometheus::Registry; + use std::net::Ipv4Addr; + + use crate::endpoint::{Endpoint, Endpoints, UpstreamEndpoints}; + use crate::filters::firewall::config::PortRange; + use crate::test_utils::logger; + + use super::*; + + #[test] + fn read() { + let firewall = Firewall { + log: logger(), + metrics: Metrics::new(&Registry::default()).unwrap(), + on_read: vec![Rule { + action: Action::Allow, + source: "192.168.75.0/24".parse().unwrap(), + ports: vec![PortRange::new(10, 100).unwrap()], + }], + on_write: vec![], + }; + + let local_ip = [192, 168, 75, 20]; + let ctx = ReadContext::new( + UpstreamEndpoints::from( + Endpoints::new(vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())]).unwrap(), + ), + (local_ip, 80).into(), + vec![], + ); + assert!(firewall.read(ctx).is_some()); + assert_eq!(1, firewall.metrics.packets_allowed_read.get()); + assert_eq!(0, firewall.metrics.packets_denied_read.get()); + + let ctx = ReadContext::new( + UpstreamEndpoints::from( + Endpoints::new(vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())]).unwrap(), + ), + (local_ip, 2000).into(), + vec![], + ); + assert!(firewall.read(ctx).is_none()); + assert_eq!(1, firewall.metrics.packets_allowed_read.get()); + assert_eq!(1, firewall.metrics.packets_denied_read.get()); + + assert_eq!(0, firewall.metrics.packets_allowed_write.get()); + assert_eq!(0, firewall.metrics.packets_denied_write.get()); + } + + #[test] + fn write() { + let firewall = Firewall { + log: logger(), + metrics: Metrics::new(&Registry::default()).unwrap(), + on_read: vec![], + on_write: vec![Rule { + action: Action::Allow, + source: "192.168.75.0/24".parse().unwrap(), + ports: vec![PortRange::new(10, 100).unwrap()], + }], + }; + + let endpoint = Endpoint::new((Ipv4Addr::LOCALHOST, 80).into()); + let local_addr = (Ipv4Addr::LOCALHOST, 8081).into(); + + let ctx = WriteContext::new( + &endpoint, + ([192, 168, 75, 20], 80).into(), + local_addr, + vec![], + ); + assert!(firewall.write(ctx).is_some()); + assert_eq!(1, firewall.metrics.packets_allowed_write.get()); + assert_eq!(0, firewall.metrics.packets_denied_write.get()); + + let ctx = WriteContext::new( + &endpoint, + ([192, 168, 77, 20], 80).into(), + local_addr, + vec![], + ); + assert!(!firewall.write(ctx).is_some()); + assert_eq!(1, firewall.metrics.packets_allowed_write.get()); + assert_eq!(1, firewall.metrics.packets_denied_write.get()); + + assert_eq!(0, firewall.metrics.packets_allowed_read.get()); + assert_eq!(0, firewall.metrics.packets_denied_read.get()); + } +} diff --git a/src/filters/firewall/config.rs b/src/filters/firewall/config.rs new file mode 100644 index 0000000000..b40f3d03d8 --- /dev/null +++ b/src/filters/firewall/config.rs @@ -0,0 +1,350 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::{convert::TryFrom, fmt, fmt::Formatter, net::SocketAddr, ops::Range}; + +use ipnetwork::IpNetwork; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{filters::ConvertProtoConfigError, map_proto_enum}; + +use super::quilkin::extensions::filters::firewall::v1alpha1::{ + firewall::{Action as ProtoAction, PortRange as ProtoPortRange, Rule as ProtoRule}, + Firewall as ProtoConfig, +}; + +/// Represents how a Firewall filter is configured for read and write +/// operations. +#[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] +#[non_exhaustive] +pub struct Config { + pub on_read: Vec, + pub on_write: Vec, +} + +/// Whether or not a matching [Rule] should Allow or Deny access +#[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] +pub enum Action { + /// Matching rules will allow packets through. + #[serde(rename = "ALLOW")] + Allow, + /// Matching rules will block packets. + #[serde(rename = "DENY")] + Deny, +} + +/// Combination of CIDR range, port range and action to take. +#[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] +pub struct Rule { + pub action: Action, + /// ipv4 or ipv6 CIDR address. + pub source: IpNetwork, + pub ports: Vec, +} + +impl Rule { + /// Returns `true` if `address` matches the provided CIDR address as well + /// as at least one of the port ranges in the [Rule]. + /// + /// # Examples + /// ``` + /// use quilkin::filters::firewall::{Action, PortRange}; + /// + /// let rule = quilkin::filters::firewall::Rule { + /// action: Action::Allow, + /// source: "192.168.75.0/24".parse().unwrap(), + /// ports: vec![PortRange::new(10, 100).unwrap()], + /// }; + /// + /// let ip = [192, 168, 75, 10]; + /// assert!(rule.contains((ip, 50).into())); + /// assert!(rule.contains((ip, 99).into())); + /// assert!(rule.contains((ip, 10).into())); + /// + /// assert!(!rule.contains((ip, 5).into())); + /// assert!(!rule.contains((ip, 1000).into())); + /// assert!(!rule.contains(([192, 168, 76, 10], 40).into())); + /// ``` + pub fn contains(&self, address: SocketAddr) -> bool { + if !self.source.contains(address.ip()) { + return false; + } + + self.ports + .iter() + .any(|range| range.contains(&address.port())) + } +} + +/// Invalid min and max values for a [PortRange]. +#[derive(Debug, thiserror::Error)] +pub enum PortRangeError { + #[error("invalid port range: min {min:?} is greater than or equal to max {max:?}")] + InvalidRange { min: u16, max: u16 }, +} + +/// Range of matching ports that are configured against a [Rule]. +#[derive(Clone, Debug, PartialEq)] +pub struct PortRange(Range); + +impl PortRange { + /// Creates a new [PortRange], where min is inclusive, max is exclusive. + /// [Result] will be a [PortRangeError] if `min >= max`. + pub fn new(min: u16, max: u16) -> Result { + if min >= max { + return Err(PortRangeError::InvalidRange { min, max }); + } + + Ok(Self(Range { + start: min, + end: max, + })) + } + + /// Returns true if the range contain the given `port`. + pub fn contains(&self, port: &u16) -> bool { + self.0.contains(port) + } +} + +impl Serialize for PortRange { + /// Serialise the [PortRange] into a single digit if min and max are the same + /// otherwise, serialise it to "min-max". + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if self.0.start == (self.0.end - 1) { + return serializer.serialize_str(self.0.start.to_string().as_str()); + } + + let range = format!("{}-{}", self.0.start, self.0.end); + serializer.serialize_str(range.as_str()) + } +} + +impl<'de> Deserialize<'de> for PortRange { + /// Port ranges can be specified in yaml as either "10" as as single value + /// or as "10-20" as a range, between a minimum and a maximum. + /// This deserializes either format into a [PortRange]. + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct PortRangeVisitor; + + impl<'de> Visitor<'de> for PortRangeVisitor { + type Value = PortRange; + + fn expecting(&self, f: &mut Formatter) -> Result<(), fmt::Error> { + f.write_str("A port range in the format of '10' or '10-20'") + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + match v.split_once('-') { + None => { + let value = v.parse::().map_err(de::Error::custom)?; + PortRange::new(value, value + 1).map_err(de::Error::custom) + } + Some(split) => { + let start = split.0.parse::().map_err(de::Error::custom)?; + let end = split.1.parse::().map_err(de::Error::custom)?; + PortRange::new(start, end).map_err(de::Error::custom) + } + } + } + } + + deserializer.deserialize_str(PortRangeVisitor) + } +} + +impl TryFrom for Config { + type Error = ConvertProtoConfigError; + + fn try_from(p: ProtoConfig) -> Result { + fn convert_port(range: &ProtoPortRange) -> Result { + let min = u16::try_from(range.min).map_err(|err| { + ConvertProtoConfigError::new( + format!("min too large: {}", err), + Some("port.min".into()), + ) + })?; + + let max = u16::try_from(range.max).map_err(|err| { + ConvertProtoConfigError::new( + format!("max too large: {}", err), + Some("port.max".into()), + ) + })?; + + PortRange::new(min, max).map_err(|err| { + ConvertProtoConfigError::new(format!("{}", err), Some("ports".into())) + }) + } + + fn convert_rule(rule: &ProtoRule) -> Result { + let action = map_proto_enum!( + value = rule.action, + field = "policy", + proto_enum_type = ProtoAction, + target_enum_type = Action, + variants = [Allow, Deny] + )?; + + let source = IpNetwork::try_from(rule.source.as_str()).map_err(|err| { + ConvertProtoConfigError::new( + format!("invalid source: {:?}", err), + Some("source".into()), + ) + })?; + + let ports = rule + .ports + .iter() + .map(convert_port) + .collect::, ConvertProtoConfigError>>()?; + + Ok(Rule { + action, + source, + ports, + }) + } + + Ok(Config { + on_read: p + .on_read + .iter() + .map(convert_rule) + .collect::, ConvertProtoConfigError>>()?, + on_write: p + .on_write + .iter() + .map(convert_rule) + .collect::, ConvertProtoConfigError>>()?, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserialize_yaml() { + let yaml = " +on_read: + - action: ALLOW + source: 192.168.51.0/24 + ports: + - 10 + - 1000-7000 +on_write: + - action: DENY + source: 192.168.75.0/24 + ports: + - 7000 + "; + + let config: Config = serde_yaml::from_str(yaml).unwrap(); + + let rule1 = config.on_read[0].clone(); + assert_eq!(rule1.action, Action::Allow); + assert_eq!(rule1.source, "192.168.51.0/24".parse().unwrap()); + assert_eq!(2, rule1.ports.len()); + assert_eq!(10, rule1.ports[0].0.start); + assert_eq!(11, rule1.ports[0].0.end); + assert_eq!(1000, rule1.ports[1].0.start); + assert_eq!(7000, rule1.ports[1].0.end); + + let rule2 = config.on_write[0].clone(); + assert_eq!(rule2.action, Action::Deny); + assert_eq!(rule2.source, "192.168.75.0/24".parse().unwrap()); + assert_eq!(1, rule2.ports.len()); + assert_eq!(7000, rule2.ports[0].0.start); + assert_eq!(7001, rule2.ports[0].0.end); + } + + #[test] + fn portrange_contains() { + let range = PortRange::new(10, 100).unwrap(); + assert!(range.contains(&10)); + assert!(!range.contains(&100)); + assert!(range.contains(&50)); + assert!(!range.contains(&200)); + assert!(!range.contains(&5)); + + // single value + let single = PortRange::new(10, 11).unwrap(); + assert!(single.contains(&10)); + assert!(!single.contains(&11)); + } + + #[test] + fn convert() { + let proto_config = ProtoConfig { + on_read: vec![ProtoRule { + action: ProtoAction::Allow as i32, + source: "192.168.75.0/24".into(), + ports: vec![ProtoPortRange { min: 10, max: 100 }], + }], + on_write: vec![ProtoRule { + action: ProtoAction::Deny as i32, + source: "192.168.124.0/24".into(), + ports: vec![ProtoPortRange { min: 50, max: 51 }], + }], + }; + + let config = Config::try_from(proto_config).unwrap(); + + let rule1 = config.on_read[0].clone(); + assert_eq!(rule1.action, Action::Allow); + assert_eq!(rule1.source, "192.168.75.0/24".parse().unwrap()); + assert_eq!(1, rule1.ports.len()); + assert_eq!(10, rule1.ports[0].0.start); + assert_eq!(100, rule1.ports[0].0.end); + + let rule2 = config.on_write[0].clone(); + assert_eq!(rule2.action, Action::Deny); + assert_eq!(rule2.source, "192.168.124.0/24".parse().unwrap()); + assert_eq!(1, rule2.ports.len()); + assert_eq!(50, rule2.ports[0].0.start); + assert_eq!(51, rule2.ports[0].0.end); + } + + #[test] + fn rule_contains() { + let rule = Rule { + action: Action::Allow, + source: "192.168.75.0/24".parse().unwrap(), + ports: vec![PortRange::new(10, 100).unwrap()], + }; + + let ip = [192, 168, 75, 10]; + assert!(rule.contains((ip, 50).into())); + assert!(rule.contains((ip, 99).into())); + assert!(rule.contains((ip, 10).into())); + + assert!(!rule.contains((ip, 5).into())); + assert!(!rule.contains((ip, 1000).into())); + assert!(!rule.contains(([192, 168, 76, 10], 40).into())); + } +} diff --git a/src/filters/firewall/metrics.rs b/src/filters/firewall/metrics.rs new file mode 100644 index 0000000000..8f620c564a --- /dev/null +++ b/src/filters/firewall/metrics.rs @@ -0,0 +1,66 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use prometheus::{ + core::{AtomicU64, GenericCounter}, + IntCounterVec, Registry, Result as MetricsResult, +}; + +use crate::metrics::{filter_opts, CollectorExt}; + +const READ: &str = "read"; +const WRITE: &str = "write"; + +/// Register and manage metrics for this filter +pub(super) struct Metrics { + pub(super) packets_denied_read: GenericCounter, + pub(super) packets_denied_write: GenericCounter, + pub(super) packets_allowed_read: GenericCounter, + pub(super) packets_allowed_write: GenericCounter, +} + +impl Metrics { + pub(super) fn new(registry: &Registry) -> MetricsResult { + let event_labels = &["event"]; + + let deny_metric = IntCounterVec::new( + filter_opts( + "packets_denied_total", + "Firewall", + "Total number of packets denied. Labels: event.", + ), + event_labels, + )? + .register_if_not_exists(registry)?; + + let allow_metric = IntCounterVec::new( + filter_opts( + "packets_allowed_total", + "Firewall", + "Total number of packets allowed. Labels: event.", + ), + event_labels, + )? + .register_if_not_exists(registry)?; + + Ok(Metrics { + packets_denied_read: deny_metric.get_metric_with_label_values(&[READ])?, + packets_denied_write: deny_metric.get_metric_with_label_values(&[WRITE])?, + packets_allowed_read: allow_metric.get_metric_with_label_values(&[READ])?, + packets_allowed_write: allow_metric.get_metric_with_label_values(&[WRITE])?, + }) + } +} diff --git a/src/filters/set.rs b/src/filters/set.rs index 42fe5e61c2..50fa1e903a 100644 --- a/src/filters/set.rs +++ b/src/filters/set.rs @@ -64,6 +64,7 @@ impl FilterSet { filters::capture_bytes::factory(base), filters::token_router::factory(base), filters::compress::factory(base), + filters::firewall::factory(base), ]) .chain(filters), ) diff --git a/tests/firewall.rs b/tests/firewall.rs new file mode 100644 index 0000000000..b66996e69b --- /dev/null +++ b/tests/firewall.rs @@ -0,0 +1,121 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use quilkin::config::{Builder, Filter}; +use quilkin::endpoint::Endpoint; +use quilkin::filters::firewall; +use quilkin::test_utils::TestHelper; +use slog::info; +use std::net::SocketAddr; +use tokio::sync::oneshot::Receiver; +use tokio::time::{timeout, Duration}; + +#[tokio::test] +async fn firewall_allow() { + let mut t = TestHelper::default(); + let yaml = " +on_read: + - action: ALLOW + source: 127.0.0.1/32 + ports: + - %1 +on_write: + - action: ALLOW + source: 127.0.0.0/24 + ports: + - %2 +"; + let recv = test(&mut t, 12354, yaml).await; + + assert_eq!( + "hello", + timeout(Duration::from_secs(5), recv) + .await + .expect("should have received a packet") + .unwrap() + ); +} + +#[tokio::test] +async fn firewall_read_deny() { + let mut t = TestHelper::default(); + let yaml = " +on_read: + - action: DENY + source: 127.0.0.1/32 + ports: + - %1 +on_write: + - action: ALLOW + source: 127.0.0.0/24 + ports: + - %2 +"; + let recv = test(&mut t, 12355, yaml).await; + + let result = timeout(Duration::from_secs(3), recv).await; + assert!(result.is_err(), "should not have received a packet"); +} + +#[tokio::test] +async fn firewall_write_deny() { + let mut t = TestHelper::default(); + let yaml = " +on_read: + - action: ALLOW + source: 127.0.0.1/32 + ports: + - %1 +on_write: + - action: DENY + source: 127.0.0.0/24 + ports: + - %2 +"; + let recv = test(&mut t, 12356, yaml).await; + + let result = timeout(Duration::from_secs(3), recv).await; + assert!(result.is_err(), "should not have received a packet"); +} + +async fn test(t: &mut TestHelper, server_port: u16, yaml: &str) -> Receiver { + let echo = t.run_echo_server().await; + + let recv = t.open_socket_and_recv_single_packet().await; + let client_addr = recv.socket.local_addr().unwrap(); + let yaml = yaml + .replace("%1", client_addr.port().to_string().as_str()) + .replace("%2", echo.port().to_string().as_str()); + info!(t.log, "Config"; "config" => yaml.as_str()); + + let server_config = Builder::empty() + .with_port(server_port) + .with_static( + vec![Filter { + name: firewall::factory(&t.log).name().into(), + config: serde_yaml::from_str(yaml.as_str()).unwrap(), + }], + vec![Endpoint::new(echo)], + ) + .build(); + t.run_server_with_config(server_config); + + let local_addr: SocketAddr = (std::net::Ipv4Addr::LOCALHOST, server_port).into(); + info!(t.log, "Sending hello"; "from" => client_addr, "address" => local_addr); + recv.socket.send_to(b"hello", &local_addr).await.unwrap(); + + recv.packet_rx +}