From fef299f3dee471b64839947ed2e71a4ecef63e2b Mon Sep 17 00:00:00 2001 From: Mark Mandel Date: Mon, 27 Sep 2021 16:49:34 -0700 Subject: [PATCH 1/2] Code: Firewall filter Code implementation of a Firewall filter that will allow/deny packets based on their from address on both read and write. Documentation to come next to finish off the below two tickets. Work on #158 Work on #343 --- Cargo.toml | 1 + build.rs | 1 + .../filters/firewall/v1alpha1/firewall.proto | 41 +++ src/filters.rs | 1 + src/filters/firewall.rs | 215 +++++++++++ src/filters/firewall/config.rs | 335 ++++++++++++++++++ src/filters/firewall/metrics.rs | 65 ++++ src/filters/set.rs | 1 + tests/firewall.rs | 121 +++++++ 9 files changed, 781 insertions(+) create mode 100644 proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto create mode 100644 src/filters/firewall.rs create mode 100644 src/filters/firewall/config.rs create mode 100644 src/filters/firewall/metrics.rs create mode 100644 tests/firewall.rs diff --git a/Cargo.toml b/Cargo.toml index b5f6241a3d..c450d25bf8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ uuid = { version = "0.8.2", default-features = false, features = ["v4"] } thiserror = "1.0.30" eyre = "0.6.5" stable-eyre = "0.2.2" +ipnetwork = "0.18.0" [target.'cfg(target_os = "linux")'.dependencies] sys-info = "0.9.0" diff --git a/build.rs b/build.rs index 52beaf217c..72a86e4cfe 100644 --- a/build.rs +++ b/build.rs @@ -35,6 +35,7 @@ fn main() -> Result<(), Box> { "proto/quilkin/extensions/filters/load_balancer/v1alpha1/load_balancer.proto", "proto/quilkin/extensions/filters/local_rate_limit/v1alpha1/local_rate_limit.proto", "proto/quilkin/extensions/filters/token_router/v1alpha1/token_router.proto", + "proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto", ] .iter() .map(|name| std::env::current_dir().unwrap().join(name)) diff --git a/proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto b/proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto new file mode 100644 index 0000000000..9df512e75b --- /dev/null +++ b/proto/quilkin/extensions/filters/firewall/v1alpha1/firewall.proto @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package quilkin.extensions.filters.firewall.v1alpha1; + +message Firewall { + enum Action { + Allow = 0; + Deny = 1; + } + + message PortRange { + uint32 min = 1; + uint32 max = 2; + } + + message Rule { + Action action = 1; + string source = 2; + repeated PortRange ports = 3; + } + + repeated Rule on_read = 1; + repeated Rule on_write = 2; +} + diff --git a/src/filters.rs b/src/filters.rs index fe4054ecdd..01d6f5415e 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -30,6 +30,7 @@ pub mod capture_bytes; pub mod compress; pub mod concatenate_bytes; pub mod debug; +pub mod firewall; pub mod load_balancer; pub mod local_rate_limit; pub mod metadata; diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs new file mode 100644 index 0000000000..2c800c4bf6 --- /dev/null +++ b/src/filters/firewall.rs @@ -0,0 +1,215 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +crate::include_proto!("quilkin.extensions.filters.firewall.v1alpha1"); + +use self::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; +use crate::filters::firewall::config::{Action, Config, Rule}; +use crate::filters::firewall::metrics::Metrics; +use crate::filters::{ + CreateFilterArgs, DynFilterFactory, Error, Filter, FilterFactory, FilterInstance, ReadContext, + ReadResponse, WriteContext, WriteResponse, +}; +use slog::{debug, o, Logger}; + +mod config; +mod metrics; + +pub const NAME: &str = "quilkin.extensions.filters.compress.v1alpha1.Firewall"; + +pub fn factory(base: &Logger) -> DynFilterFactory { + Box::from(FirewallFactory::new(base)) +} + +struct FirewallFactory { + log: Logger, +} + +impl FirewallFactory { + pub fn new(base: &Logger) -> Self { + Self { log: base.clone() } + } +} + +impl FilterFactory for FirewallFactory { + fn name(&self) -> &'static str { + NAME + } + + fn create_filter(&self, args: CreateFilterArgs) -> Result { + let (config_json, config) = self + .require_config(args.config)? + .deserialize::(self.name())?; + + let filter = Firewall::new(&self.log, config, Metrics::new(&args.metrics_registry)?); + Ok(FilterInstance::new( + config_json, + Box::new(filter) as Box, + )) + } +} + +struct Firewall { + log: Logger, + metrics: Metrics, + on_read: Vec, + on_write: Vec, +} + +impl Firewall { + fn new(base: &Logger, config: Config, metrics: Metrics) -> Self { + Self { + log: base.new(o!("source" => "extensions::Firewall")), + metrics, + on_read: config.on_read, + on_write: config.on_write, + } + } +} + +impl Filter for Firewall { + fn read(&self, ctx: ReadContext) -> Option { + for rule in &self.on_read { + if rule.contains(ctx.from) { + return match rule.action { + Action::Allow => { + debug!(self.log, "Allow"; "event" => "read", "from" => ctx.from.to_string()); + self.metrics.packets_allowed_on_read.inc(); + Some(ctx.into()) + } + Action::Deny => { + debug!(self.log, "Deny"; "event" => "read", "from" => ctx.from ); + self.metrics.packets_denied_on_read.inc(); + None + } + }; + } + } + + debug!(self.log, "default: Deny"; "event" => "read", "from" => ctx.from.to_string()); + self.metrics.packets_denied_on_read.inc(); + None + } + + fn write(&self, ctx: WriteContext) -> Option { + for rule in &self.on_write { + if rule.contains(ctx.from) { + return match rule.action { + Action::Allow => { + debug!(self.log, "Allow"; "event" => "write", "from" => ctx.from.to_string()); + self.metrics.packets_allowed_on_write.inc(); + Some(ctx.into()) + } + Action::Deny => { + debug!(self.log, "Deny"; "event" => "write", "from" => ctx.from ); + self.metrics.packets_denied_on_write.inc(); + None + } + }; + } + } + + debug!(self.log, "default: Deny"; "event" => "write", "from" => ctx.from.to_string()); + self.metrics.packets_denied_on_write.inc(); + None + } +} +#[cfg(test)] +mod tests { + use crate::endpoint::{Endpoint, Endpoints, UpstreamEndpoints}; + use crate::filters::firewall::config::PortRange; + use crate::test_utils::logger; + use prometheus::Registry; + + use super::*; + + #[test] + fn read() { + let firewall = Firewall { + log: logger(), + metrics: Metrics::new(&Registry::default()).unwrap(), + on_read: vec![Rule { + action: Action::Allow, + source: "192.168.75.0/24".parse().unwrap(), + ports: vec![PortRange { min: 10, max: 100 }], + }], + on_write: vec![], + }; + + let ctx = ReadContext::new( + UpstreamEndpoints::from( + Endpoints::new(vec![Endpoint::new("127.0.0.1:8080".parse().unwrap())]).unwrap(), + ), + "192.168.75.20:80".parse().unwrap(), + vec![], + ); + assert!(firewall.read(ctx).is_some()); + assert_eq!(1, firewall.metrics.packets_allowed_on_read.get()); + assert_eq!(0, firewall.metrics.packets_denied_on_read.get()); + + let ctx = ReadContext::new( + UpstreamEndpoints::from( + Endpoints::new(vec![Endpoint::new("127.0.0.1:8080".parse().unwrap())]).unwrap(), + ), + "192.168.75.20:2000".parse().unwrap(), + vec![], + ); + assert!(firewall.read(ctx).is_none()); + assert_eq!(1, firewall.metrics.packets_allowed_on_read.get()); + assert_eq!(1, firewall.metrics.packets_denied_on_read.get()); + + assert_eq!(0, firewall.metrics.packets_allowed_on_write.get()); + assert_eq!(0, firewall.metrics.packets_denied_on_write.get()); + } + + #[test] + fn write() { + let firewall = Firewall { + log: logger(), + metrics: Metrics::new(&Registry::default()).unwrap(), + on_read: vec![], + on_write: vec![Rule { + action: Action::Allow, + source: "192.168.75.0/24".parse().unwrap(), + ports: vec![PortRange { min: 10, max: 100 }], + }], + }; + + let endpoint = Endpoint::new("127.0.0.1:80".parse().unwrap()); + let ctx = WriteContext::new( + &endpoint, + "192.168.75.20:80".parse().unwrap(), + "127.0.0.1:8081".parse().unwrap(), + vec![], + ); + assert!(firewall.write(ctx).is_some()); + assert_eq!(1, firewall.metrics.packets_allowed_on_write.get()); + assert_eq!(0, firewall.metrics.packets_denied_on_write.get()); + + let ctx = WriteContext::new( + &endpoint, + "192.168.77.20:80".parse().unwrap(), + "127.0.0.1:8081".parse().unwrap(), + vec![], + ); + assert!(!firewall.write(ctx).is_some()); + assert_eq!(1, firewall.metrics.packets_allowed_on_write.get()); + assert_eq!(1, firewall.metrics.packets_denied_on_write.get()); + + assert_eq!(0, firewall.metrics.packets_allowed_on_read.get()); + assert_eq!(0, firewall.metrics.packets_denied_on_read.get()); + } +} diff --git a/src/filters/firewall/config.rs b/src/filters/firewall/config.rs new file mode 100644 index 0000000000..ade82d44a8 --- /dev/null +++ b/src/filters/firewall/config.rs @@ -0,0 +1,335 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::convert::TryFrom; +use std::fmt; +use std::fmt::Formatter; +use std::net::SocketAddr; + +use ipnetwork::IpNetwork; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{filters::ConvertProtoConfigError, map_proto_enum}; + +use super::quilkin::extensions::filters::firewall::v1alpha1::firewall::{ + Action as ProtoAction, PortRange as ProtoPortRange, Rule as ProtoRule, +}; +use super::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; + +/// Represents how a [Firewall] filter is configured for read and write +/// operations. +#[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] +#[non_exhaustive] +pub struct Config { + pub on_read: Vec, + pub on_write: Vec, +} + +#[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] +pub enum Action { + /// Matching details will allow packets through. + #[serde(rename = "ALLOW")] + Allow, + /// Matching details will block packets. + #[serde(rename = "DENY")] + Deny, +} + +#[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] +pub struct Rule { + pub action: Action, + /// ipv4 or ipv6 CIDR address. + pub source: IpNetwork, + pub ports: Vec, +} + +impl Rule { + /// Returns of the SocketAddress matches the provided CIDR address as well + /// as at least one of the port ranges in the Rule. + /// # Arguments + /// + /// * `address`: An ipv4 or ipv6 address and port. + /// + /// returns: bool + /// + pub fn contains(&self, address: SocketAddr) -> bool { + if !self.source.contains(address.ip()) { + return false; + } + + self.ports + .iter() + .any(|range| range.contains(address.port())) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct PortRange { + pub min: u16, + pub max: u16, +} + +impl PortRange { + /// Does this range contain a specific port value? + /// + /// # Arguments + /// + /// * `port`: + /// + /// returns: bool + pub fn contains(&self, port: u16) -> bool { + port >= self.min && port <= self.max + } +} + +impl Serialize for PortRange { + /// Serialise the [PortRange] into a single digit if min and max are the same + /// otherwise, serialise it to "min-max". + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if self.max == self.min { + return serializer.serialize_str(self.min.to_string().as_str()); + } + + let range = format!("{}-{}", self.min, self.max); + serializer.serialize_str(range.as_str()) + } +} + +impl<'de> Deserialize<'de> for PortRange { + /// Port ranges can be specified in yaml as either "10" as as single value + /// or as "10-20" as a range, between a minimum and a maximum. + /// This deserializes either format into a [PortRange]. + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct PortRangeVisitor; + + impl<'de> Visitor<'de> for PortRangeVisitor { + type Value = PortRange; + + fn expecting(&self, f: &mut Formatter) -> Result<(), fmt::Error> { + f.write_str("A port range in the format of '10' or '10-20'") + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + match v.split_once('-') { + None => { + let value = v.parse::().map_err(de::Error::custom)?; + Ok(PortRange { + min: value, + max: value, + }) + } + Some(split) => { + let min = split.0.parse::().map_err(de::Error::custom)?; + let max = split.1.parse::().map_err(de::Error::custom)?; + + if min > max { + return Err(de::Error::custom(format!( + "min ({}) cannot be bigger than max ({})", + min, max + ))); + } + + Ok(PortRange { min, max }) + } + } + } + } + + deserializer.deserialize_str(PortRangeVisitor) + } +} + +impl TryFrom for Config { + type Error = ConvertProtoConfigError; + + fn try_from(p: ProtoConfig) -> Result { + fn convert_port(range: &ProtoPortRange) -> Result { + let min = u16::try_from(range.min).map_err(|err| { + ConvertProtoConfigError::new( + format!("min too large: {}", err), + Some("port.min".into()), + ) + })?; + + let max = u16::try_from(range.max).map_err(|err| { + ConvertProtoConfigError::new( + format!("max too large: {}", err), + Some("port.max".into()), + ) + })?; + + if min > max { + return Err(ConvertProtoConfigError::new( + format!("min port ({}) is greater than the max port ({})", min, max), + Some("ports".into()), + )); + }; + + Ok(PortRange { min, max }) + } + + fn convert_rule(rule: &ProtoRule) -> Result { + let action = map_proto_enum!( + value = rule.action, + field = "policy", + proto_enum_type = ProtoAction, + target_enum_type = Action, + variants = [Allow, Deny] + )?; + + let source = IpNetwork::try_from(rule.source.as_str()).map_err(|err| { + ConvertProtoConfigError::new( + format!("invalid source: {:?}", err), + Some("source".into()), + ) + })?; + + let ports = rule + .ports + .iter() + .map(convert_port) + .collect::, ConvertProtoConfigError>>()?; + + Ok(Rule { + action, + source, + ports, + }) + } + + Ok(Config { + on_read: p + .on_read + .iter() + .map(convert_rule) + .collect::, ConvertProtoConfigError>>()?, + on_write: p + .on_write + .iter() + .map(convert_rule) + .collect::, ConvertProtoConfigError>>()?, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserialize_yaml() { + let yaml = " +on_read: + - action: ALLOW + source: 192.168.51.0/24 + ports: + - 10 + - 1000-7000 +on_write: + - action: DENY + source: 192.168.75.0/24 + ports: + - 7000 + "; + + let config: Config = serde_yaml::from_str(yaml).unwrap(); + + let rule1 = config.on_read[0].clone(); + assert_eq!(rule1.action, Action::Allow); + assert_eq!(rule1.source, "192.168.51.0/24".parse().unwrap()); + assert_eq!(2, rule1.ports.len()); + assert_eq!(10, rule1.ports[0].min); + assert_eq!(10, rule1.ports[0].max); + assert_eq!(1000, rule1.ports[1].min); + assert_eq!(7000, rule1.ports[1].max); + + let rule2 = config.on_write[0].clone(); + assert_eq!(rule2.action, Action::Deny); + assert_eq!(rule2.source, "192.168.75.0/24".parse().unwrap()); + assert_eq!(1, rule2.ports.len()); + assert_eq!(7000, rule2.ports[0].min); + assert_eq!(7000, rule2.ports[0].max); + } + + #[test] + fn portrange_contains() { + let range = PortRange { min: 10, max: 100 }; + assert!(range.contains(10)); + assert!(range.contains(100)); + assert!(range.contains(50)); + assert!(!range.contains(200)); + assert!(!range.contains(5)); + } + + #[test] + fn convert() { + let proto_config = ProtoConfig { + on_read: vec![ProtoRule { + action: ProtoAction::Allow as i32, + source: "192.168.75.0/24".into(), + ports: vec![ProtoPortRange { min: 10, max: 100 }], + }], + on_write: vec![ProtoRule { + action: ProtoAction::Deny as i32, + source: "192.168.124.0/24".into(), + ports: vec![ProtoPortRange { min: 50, max: 50 }], + }], + }; + + let config = Config::try_from(proto_config).unwrap(); + + let rule1 = config.on_read[0].clone(); + assert_eq!(rule1.action, Action::Allow); + assert_eq!(rule1.source, "192.168.75.0/24".parse().unwrap()); + assert_eq!(1, rule1.ports.len()); + assert_eq!(10, rule1.ports[0].min); + assert_eq!(100, rule1.ports[0].max); + + let rule2 = config.on_write[0].clone(); + assert_eq!(rule2.action, Action::Deny); + assert_eq!(rule2.source, "192.168.124.0/24".parse().unwrap()); + assert_eq!(1, rule2.ports.len()); + assert_eq!(50, rule2.ports[0].min); + assert_eq!(50, rule2.ports[0].max); + } + + #[test] + fn rule_contains() { + let rule = Rule { + action: Action::Allow, + source: "192.168.75.0/24".parse().unwrap(), + ports: vec![PortRange { min: 10, max: 100 }], + }; + + assert!(rule.contains("192.168.75.10:50".parse().unwrap())); + assert!(rule.contains("192.168.75.10:100".parse().unwrap())); + assert!(rule.contains("192.168.75.10:10".parse().unwrap())); + + assert!(!rule.contains("192.168.75.10:5".parse().unwrap())); + assert!(!rule.contains("192.168.75.10:1000".parse().unwrap())); + assert!(!rule.contains("192.168.76.10:40".parse().unwrap())); + } +} diff --git a/src/filters/firewall/metrics.rs b/src/filters/firewall/metrics.rs new file mode 100644 index 0000000000..496397a0eb --- /dev/null +++ b/src/filters/firewall/metrics.rs @@ -0,0 +1,65 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use prometheus::core::{AtomicU64, GenericCounter}; +use prometheus::{IntCounterVec, Registry, Result as MetricsResult}; + +use crate::metrics::{filter_opts, CollectorExt}; + +/// Register and manage metrics for this filter +pub(super) struct Metrics { + pub(super) packets_denied_on_read: GenericCounter, + pub(super) packets_denied_on_write: GenericCounter, + pub(super) packets_allowed_on_read: GenericCounter, + pub(super) packets_allowed_on_write: GenericCounter, +} + +impl Metrics { + pub(super) fn new(registry: &Registry) -> MetricsResult { + let event_labels = vec!["events"]; + + let deny_metric = IntCounterVec::new( + filter_opts( + "packets_denied_total", + "Firewall", + "Total number of packets denied. Labels: event.", + ), + &event_labels, + )? + .register_if_not_exists(registry)?; + + let allow_metric = IntCounterVec::new( + filter_opts( + "packets_allowed_total", + "Firewall", + "Total number of packets allowed. Labels: event.", + ), + &event_labels, + )? + .register_if_not_exists(registry)?; + + Ok(Metrics { + packets_denied_on_read: deny_metric + .get_metric_with_label_values(vec!["on_read"].as_slice())?, + packets_denied_on_write: deny_metric + .get_metric_with_label_values(vec!["on_write"].as_slice())?, + packets_allowed_on_read: allow_metric + .get_metric_with_label_values(vec!["on_read"].as_slice())?, + packets_allowed_on_write: allow_metric + .get_metric_with_label_values(vec!["on_write"].as_slice())?, + }) + } +} diff --git a/src/filters/set.rs b/src/filters/set.rs index 42fe5e61c2..50fa1e903a 100644 --- a/src/filters/set.rs +++ b/src/filters/set.rs @@ -64,6 +64,7 @@ impl FilterSet { filters::capture_bytes::factory(base), filters::token_router::factory(base), filters::compress::factory(base), + filters::firewall::factory(base), ]) .chain(filters), ) diff --git a/tests/firewall.rs b/tests/firewall.rs new file mode 100644 index 0000000000..e305d1ba7a --- /dev/null +++ b/tests/firewall.rs @@ -0,0 +1,121 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use quilkin::config::{Builder, Filter}; +use quilkin::endpoint::Endpoint; +use quilkin::filters::firewall; +use quilkin::test_utils::TestHelper; +use slog::info; +use std::net::SocketAddr; +use tokio::sync::oneshot::Receiver; +use tokio::time::{timeout, Duration}; + +#[tokio::test] +async fn firewall_allow() { + let mut t = TestHelper::default(); + let yaml = " +on_read: + - action: ALLOW + source: 127.0.0.1/32 + ports: + - %1 +on_write: + - action: ALLOW + source: 127.0.0.0/24 + ports: + - %2 +"; + let recv = test(&mut t, 12354, yaml).await; + + assert_eq!( + "hello", + timeout(Duration::from_secs(5), recv) + .await + .expect("should have received a packet") + .unwrap() + ); +} + +#[tokio::test] +async fn firewall_read_deny() { + let mut t = TestHelper::default(); + let yaml = " +on_read: + - action: DENY + source: 127.0.0.1/32 + ports: + - %1 +on_write: + - action: ALLOW + source: 127.0.0.0/24 + ports: + - %2 +"; + let recv = test(&mut t, 12355, yaml).await; + + let result = timeout(Duration::from_secs(3), recv).await; + assert!(result.is_err(), "should not have received a packet"); +} + +#[tokio::test] +async fn firewall_write_deny() { + let mut t = TestHelper::default(); + let yaml = " +on_read: + - action: ALLOW + source: 127.0.0.1/32 + ports: + - %1 +on_write: + - action: DENY + source: 127.0.0.0/24 + ports: + - %2 +"; + let recv = test(&mut t, 12356, yaml).await; + + let result = timeout(Duration::from_secs(3), recv).await; + assert!(result.is_err(), "should not have received a packet"); +} + +async fn test(t: &mut TestHelper, server_port: u16, yaml: &str) -> Receiver { + let echo = t.run_echo_server().await; + + let recv = t.open_socket_and_recv_single_packet().await; + let client_addr = recv.socket.local_addr().unwrap(); + let yaml = yaml + .replace("%1", client_addr.port().to_string().as_str()) + .replace("%2", echo.port().to_string().as_str()); + info!(t.log, "Config"; "config" => yaml.as_str()); + + let server_config = Builder::empty() + .with_port(server_port) + .with_static( + vec![Filter { + name: firewall::factory(&t.log).name().into(), + config: serde_yaml::from_str(yaml.as_str()).unwrap(), + }], + vec![Endpoint::new(echo)], + ) + .build(); + t.run_server_with_config(server_config); + + let local_addr: SocketAddr = format!("127.0.0.1:{}", server_port).parse().unwrap(); + info!(t.log, "Sending hello"; "from" => client_addr, "address" => local_addr); + recv.socket.send_to(b"hello", &local_addr).await.unwrap(); + + recv.packet_rx +} From 4301ad5b3d5086934d4857b3375bf20a051d5a54 Mon Sep 17 00:00:00 2001 From: Mark Mandel Date: Mon, 18 Oct 2021 16:42:28 -0700 Subject: [PATCH 2/2] Review updates, including refactoring with PortRange::new() --- src/filters/firewall.rs | 83 ++++++++------- src/filters/firewall/config.rs | 173 +++++++++++++++++--------------- src/filters/firewall/metrics.rs | 35 +++---- tests/firewall.rs | 2 +- 4 files changed, 158 insertions(+), 135 deletions(-) diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs index 2c800c4bf6..3afe7a5fee 100644 --- a/src/filters/firewall.rs +++ b/src/filters/firewall.rs @@ -14,20 +14,22 @@ * limitations under the License. */ -crate::include_proto!("quilkin.extensions.filters.firewall.v1alpha1"); +//! Filter for allowing/blocking traffic by IP and port. -use self::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; -use crate::filters::firewall::config::{Action, Config, Rule}; -use crate::filters::firewall::metrics::Metrics; -use crate::filters::{ - CreateFilterArgs, DynFilterFactory, Error, Filter, FilterFactory, FilterInstance, ReadContext, - ReadResponse, WriteContext, WriteResponse, -}; use slog::{debug, o, Logger}; +use crate::filters::firewall::metrics::Metrics; +use crate::filters::prelude::*; + +use self::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; + +crate::include_proto!("quilkin.extensions.filters.firewall.v1alpha1"); + mod config; mod metrics; +pub use config::{Action, Config, PortRange, PortRangeError, Rule}; + pub const NAME: &str = "quilkin.extensions.filters.compress.v1alpha1.Firewall"; pub fn factory(base: &Logger) -> DynFilterFactory { @@ -87,12 +89,12 @@ impl Filter for Firewall { return match rule.action { Action::Allow => { debug!(self.log, "Allow"; "event" => "read", "from" => ctx.from.to_string()); - self.metrics.packets_allowed_on_read.inc(); + self.metrics.packets_allowed_read.inc(); Some(ctx.into()) } Action::Deny => { debug!(self.log, "Deny"; "event" => "read", "from" => ctx.from ); - self.metrics.packets_denied_on_read.inc(); + self.metrics.packets_denied_read.inc(); None } }; @@ -100,7 +102,7 @@ impl Filter for Firewall { } debug!(self.log, "default: Deny"; "event" => "read", "from" => ctx.from.to_string()); - self.metrics.packets_denied_on_read.inc(); + self.metrics.packets_denied_read.inc(); None } @@ -110,12 +112,12 @@ impl Filter for Firewall { return match rule.action { Action::Allow => { debug!(self.log, "Allow"; "event" => "write", "from" => ctx.from.to_string()); - self.metrics.packets_allowed_on_write.inc(); + self.metrics.packets_allowed_write.inc(); Some(ctx.into()) } Action::Deny => { debug!(self.log, "Deny"; "event" => "write", "from" => ctx.from ); - self.metrics.packets_denied_on_write.inc(); + self.metrics.packets_denied_write.inc(); None } }; @@ -123,16 +125,18 @@ impl Filter for Firewall { } debug!(self.log, "default: Deny"; "event" => "write", "from" => ctx.from.to_string()); - self.metrics.packets_denied_on_write.inc(); + self.metrics.packets_denied_write.inc(); None } } #[cfg(test)] mod tests { + use prometheus::Registry; + use std::net::Ipv4Addr; + use crate::endpoint::{Endpoint, Endpoints, UpstreamEndpoints}; use crate::filters::firewall::config::PortRange; use crate::test_utils::logger; - use prometheus::Registry; use super::*; @@ -144,35 +148,36 @@ mod tests { on_read: vec![Rule { action: Action::Allow, source: "192.168.75.0/24".parse().unwrap(), - ports: vec![PortRange { min: 10, max: 100 }], + ports: vec![PortRange::new(10, 100).unwrap()], }], on_write: vec![], }; + let local_ip = [192, 168, 75, 20]; let ctx = ReadContext::new( UpstreamEndpoints::from( - Endpoints::new(vec![Endpoint::new("127.0.0.1:8080".parse().unwrap())]).unwrap(), + Endpoints::new(vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())]).unwrap(), ), - "192.168.75.20:80".parse().unwrap(), + (local_ip, 80).into(), vec![], ); assert!(firewall.read(ctx).is_some()); - assert_eq!(1, firewall.metrics.packets_allowed_on_read.get()); - assert_eq!(0, firewall.metrics.packets_denied_on_read.get()); + assert_eq!(1, firewall.metrics.packets_allowed_read.get()); + assert_eq!(0, firewall.metrics.packets_denied_read.get()); let ctx = ReadContext::new( UpstreamEndpoints::from( - Endpoints::new(vec![Endpoint::new("127.0.0.1:8080".parse().unwrap())]).unwrap(), + Endpoints::new(vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())]).unwrap(), ), - "192.168.75.20:2000".parse().unwrap(), + (local_ip, 2000).into(), vec![], ); assert!(firewall.read(ctx).is_none()); - assert_eq!(1, firewall.metrics.packets_allowed_on_read.get()); - assert_eq!(1, firewall.metrics.packets_denied_on_read.get()); + assert_eq!(1, firewall.metrics.packets_allowed_read.get()); + assert_eq!(1, firewall.metrics.packets_denied_read.get()); - assert_eq!(0, firewall.metrics.packets_allowed_on_write.get()); - assert_eq!(0, firewall.metrics.packets_denied_on_write.get()); + assert_eq!(0, firewall.metrics.packets_allowed_write.get()); + assert_eq!(0, firewall.metrics.packets_denied_write.get()); } #[test] @@ -184,32 +189,34 @@ mod tests { on_write: vec![Rule { action: Action::Allow, source: "192.168.75.0/24".parse().unwrap(), - ports: vec![PortRange { min: 10, max: 100 }], + ports: vec![PortRange::new(10, 100).unwrap()], }], }; - let endpoint = Endpoint::new("127.0.0.1:80".parse().unwrap()); + let endpoint = Endpoint::new((Ipv4Addr::LOCALHOST, 80).into()); + let local_addr = (Ipv4Addr::LOCALHOST, 8081).into(); + let ctx = WriteContext::new( &endpoint, - "192.168.75.20:80".parse().unwrap(), - "127.0.0.1:8081".parse().unwrap(), + ([192, 168, 75, 20], 80).into(), + local_addr, vec![], ); assert!(firewall.write(ctx).is_some()); - assert_eq!(1, firewall.metrics.packets_allowed_on_write.get()); - assert_eq!(0, firewall.metrics.packets_denied_on_write.get()); + assert_eq!(1, firewall.metrics.packets_allowed_write.get()); + assert_eq!(0, firewall.metrics.packets_denied_write.get()); let ctx = WriteContext::new( &endpoint, - "192.168.77.20:80".parse().unwrap(), - "127.0.0.1:8081".parse().unwrap(), + ([192, 168, 77, 20], 80).into(), + local_addr, vec![], ); assert!(!firewall.write(ctx).is_some()); - assert_eq!(1, firewall.metrics.packets_allowed_on_write.get()); - assert_eq!(1, firewall.metrics.packets_denied_on_write.get()); + assert_eq!(1, firewall.metrics.packets_allowed_write.get()); + assert_eq!(1, firewall.metrics.packets_denied_write.get()); - assert_eq!(0, firewall.metrics.packets_allowed_on_read.get()); - assert_eq!(0, firewall.metrics.packets_denied_on_read.get()); + assert_eq!(0, firewall.metrics.packets_allowed_read.get()); + assert_eq!(0, firewall.metrics.packets_denied_read.get()); } } diff --git a/src/filters/firewall/config.rs b/src/filters/firewall/config.rs index ade82d44a8..b40f3d03d8 100644 --- a/src/filters/firewall/config.rs +++ b/src/filters/firewall/config.rs @@ -14,10 +14,7 @@ * limitations under the License. */ -use std::convert::TryFrom; -use std::fmt; -use std::fmt::Formatter; -use std::net::SocketAddr; +use std::{convert::TryFrom, fmt, fmt::Formatter, net::SocketAddr, ops::Range}; use ipnetwork::IpNetwork; use serde::de::{self, Visitor}; @@ -25,12 +22,12 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::{filters::ConvertProtoConfigError, map_proto_enum}; -use super::quilkin::extensions::filters::firewall::v1alpha1::firewall::{ - Action as ProtoAction, PortRange as ProtoPortRange, Rule as ProtoRule, +use super::quilkin::extensions::filters::firewall::v1alpha1::{ + firewall::{Action as ProtoAction, PortRange as ProtoPortRange, Rule as ProtoRule}, + Firewall as ProtoConfig, }; -use super::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; -/// Represents how a [Firewall] filter is configured for read and write +/// Represents how a Firewall filter is configured for read and write /// operations. #[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] #[non_exhaustive] @@ -39,16 +36,18 @@ pub struct Config { pub on_write: Vec, } +/// Whether or not a matching [Rule] should Allow or Deny access #[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] pub enum Action { - /// Matching details will allow packets through. + /// Matching rules will allow packets through. #[serde(rename = "ALLOW")] Allow, - /// Matching details will block packets. + /// Matching rules will block packets. #[serde(rename = "DENY")] Deny, } +/// Combination of CIDR range, port range and action to take. #[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] pub struct Rule { pub action: Action, @@ -58,14 +57,28 @@ pub struct Rule { } impl Rule { - /// Returns of the SocketAddress matches the provided CIDR address as well - /// as at least one of the port ranges in the Rule. - /// # Arguments + /// Returns `true` if `address` matches the provided CIDR address as well + /// as at least one of the port ranges in the [Rule]. /// - /// * `address`: An ipv4 or ipv6 address and port. + /// # Examples + /// ``` + /// use quilkin::filters::firewall::{Action, PortRange}; /// - /// returns: bool + /// let rule = quilkin::filters::firewall::Rule { + /// action: Action::Allow, + /// source: "192.168.75.0/24".parse().unwrap(), + /// ports: vec![PortRange::new(10, 100).unwrap()], + /// }; /// + /// let ip = [192, 168, 75, 10]; + /// assert!(rule.contains((ip, 50).into())); + /// assert!(rule.contains((ip, 99).into())); + /// assert!(rule.contains((ip, 10).into())); + /// + /// assert!(!rule.contains((ip, 5).into())); + /// assert!(!rule.contains((ip, 1000).into())); + /// assert!(!rule.contains(([192, 168, 76, 10], 40).into())); + /// ``` pub fn contains(&self, address: SocketAddr) -> bool { if !self.source.contains(address.ip()) { return false; @@ -73,26 +86,38 @@ impl Rule { self.ports .iter() - .any(|range| range.contains(address.port())) + .any(|range| range.contains(&address.port())) } } -#[derive(Clone, Debug, PartialEq)] -pub struct PortRange { - pub min: u16, - pub max: u16, +/// Invalid min and max values for a [PortRange]. +#[derive(Debug, thiserror::Error)] +pub enum PortRangeError { + #[error("invalid port range: min {min:?} is greater than or equal to max {max:?}")] + InvalidRange { min: u16, max: u16 }, } +/// Range of matching ports that are configured against a [Rule]. +#[derive(Clone, Debug, PartialEq)] +pub struct PortRange(Range); + impl PortRange { - /// Does this range contain a specific port value? - /// - /// # Arguments - /// - /// * `port`: - /// - /// returns: bool - pub fn contains(&self, port: u16) -> bool { - port >= self.min && port <= self.max + /// Creates a new [PortRange], where min is inclusive, max is exclusive. + /// [Result] will be a [PortRangeError] if `min >= max`. + pub fn new(min: u16, max: u16) -> Result { + if min >= max { + return Err(PortRangeError::InvalidRange { min, max }); + } + + Ok(Self(Range { + start: min, + end: max, + })) + } + + /// Returns true if the range contain the given `port`. + pub fn contains(&self, port: &u16) -> bool { + self.0.contains(port) } } @@ -103,11 +128,11 @@ impl Serialize for PortRange { where S: Serializer, { - if self.max == self.min { - return serializer.serialize_str(self.min.to_string().as_str()); + if self.0.start == (self.0.end - 1) { + return serializer.serialize_str(self.0.start.to_string().as_str()); } - let range = format!("{}-{}", self.min, self.max); + let range = format!("{}-{}", self.0.start, self.0.end); serializer.serialize_str(range.as_str()) } } @@ -136,23 +161,12 @@ impl<'de> Deserialize<'de> for PortRange { match v.split_once('-') { None => { let value = v.parse::().map_err(de::Error::custom)?; - Ok(PortRange { - min: value, - max: value, - }) + PortRange::new(value, value + 1).map_err(de::Error::custom) } Some(split) => { - let min = split.0.parse::().map_err(de::Error::custom)?; - let max = split.1.parse::().map_err(de::Error::custom)?; - - if min > max { - return Err(de::Error::custom(format!( - "min ({}) cannot be bigger than max ({})", - min, max - ))); - } - - Ok(PortRange { min, max }) + let start = split.0.parse::().map_err(de::Error::custom)?; + let end = split.1.parse::().map_err(de::Error::custom)?; + PortRange::new(start, end).map_err(de::Error::custom) } } } @@ -181,14 +195,9 @@ impl TryFrom for Config { ) })?; - if min > max { - return Err(ConvertProtoConfigError::new( - format!("min port ({}) is greater than the max port ({})", min, max), - Some("ports".into()), - )); - }; - - Ok(PortRange { min, max }) + PortRange::new(min, max).map_err(|err| { + ConvertProtoConfigError::new(format!("{}", err), Some("ports".into())) + }) } fn convert_rule(rule: &ProtoRule) -> Result { @@ -261,27 +270,32 @@ on_write: assert_eq!(rule1.action, Action::Allow); assert_eq!(rule1.source, "192.168.51.0/24".parse().unwrap()); assert_eq!(2, rule1.ports.len()); - assert_eq!(10, rule1.ports[0].min); - assert_eq!(10, rule1.ports[0].max); - assert_eq!(1000, rule1.ports[1].min); - assert_eq!(7000, rule1.ports[1].max); + assert_eq!(10, rule1.ports[0].0.start); + assert_eq!(11, rule1.ports[0].0.end); + assert_eq!(1000, rule1.ports[1].0.start); + assert_eq!(7000, rule1.ports[1].0.end); let rule2 = config.on_write[0].clone(); assert_eq!(rule2.action, Action::Deny); assert_eq!(rule2.source, "192.168.75.0/24".parse().unwrap()); assert_eq!(1, rule2.ports.len()); - assert_eq!(7000, rule2.ports[0].min); - assert_eq!(7000, rule2.ports[0].max); + assert_eq!(7000, rule2.ports[0].0.start); + assert_eq!(7001, rule2.ports[0].0.end); } #[test] fn portrange_contains() { - let range = PortRange { min: 10, max: 100 }; - assert!(range.contains(10)); - assert!(range.contains(100)); - assert!(range.contains(50)); - assert!(!range.contains(200)); - assert!(!range.contains(5)); + let range = PortRange::new(10, 100).unwrap(); + assert!(range.contains(&10)); + assert!(!range.contains(&100)); + assert!(range.contains(&50)); + assert!(!range.contains(&200)); + assert!(!range.contains(&5)); + + // single value + let single = PortRange::new(10, 11).unwrap(); + assert!(single.contains(&10)); + assert!(!single.contains(&11)); } #[test] @@ -295,7 +309,7 @@ on_write: on_write: vec![ProtoRule { action: ProtoAction::Deny as i32, source: "192.168.124.0/24".into(), - ports: vec![ProtoPortRange { min: 50, max: 50 }], + ports: vec![ProtoPortRange { min: 50, max: 51 }], }], }; @@ -305,15 +319,15 @@ on_write: assert_eq!(rule1.action, Action::Allow); assert_eq!(rule1.source, "192.168.75.0/24".parse().unwrap()); assert_eq!(1, rule1.ports.len()); - assert_eq!(10, rule1.ports[0].min); - assert_eq!(100, rule1.ports[0].max); + assert_eq!(10, rule1.ports[0].0.start); + assert_eq!(100, rule1.ports[0].0.end); let rule2 = config.on_write[0].clone(); assert_eq!(rule2.action, Action::Deny); assert_eq!(rule2.source, "192.168.124.0/24".parse().unwrap()); assert_eq!(1, rule2.ports.len()); - assert_eq!(50, rule2.ports[0].min); - assert_eq!(50, rule2.ports[0].max); + assert_eq!(50, rule2.ports[0].0.start); + assert_eq!(51, rule2.ports[0].0.end); } #[test] @@ -321,15 +335,16 @@ on_write: let rule = Rule { action: Action::Allow, source: "192.168.75.0/24".parse().unwrap(), - ports: vec![PortRange { min: 10, max: 100 }], + ports: vec![PortRange::new(10, 100).unwrap()], }; - assert!(rule.contains("192.168.75.10:50".parse().unwrap())); - assert!(rule.contains("192.168.75.10:100".parse().unwrap())); - assert!(rule.contains("192.168.75.10:10".parse().unwrap())); + let ip = [192, 168, 75, 10]; + assert!(rule.contains((ip, 50).into())); + assert!(rule.contains((ip, 99).into())); + assert!(rule.contains((ip, 10).into())); - assert!(!rule.contains("192.168.75.10:5".parse().unwrap())); - assert!(!rule.contains("192.168.75.10:1000".parse().unwrap())); - assert!(!rule.contains("192.168.76.10:40".parse().unwrap())); + assert!(!rule.contains((ip, 5).into())); + assert!(!rule.contains((ip, 1000).into())); + assert!(!rule.contains(([192, 168, 76, 10], 40).into())); } } diff --git a/src/filters/firewall/metrics.rs b/src/filters/firewall/metrics.rs index 496397a0eb..8f620c564a 100644 --- a/src/filters/firewall/metrics.rs +++ b/src/filters/firewall/metrics.rs @@ -14,22 +14,27 @@ * limitations under the License. */ -use prometheus::core::{AtomicU64, GenericCounter}; -use prometheus::{IntCounterVec, Registry, Result as MetricsResult}; +use prometheus::{ + core::{AtomicU64, GenericCounter}, + IntCounterVec, Registry, Result as MetricsResult, +}; use crate::metrics::{filter_opts, CollectorExt}; +const READ: &str = "read"; +const WRITE: &str = "write"; + /// Register and manage metrics for this filter pub(super) struct Metrics { - pub(super) packets_denied_on_read: GenericCounter, - pub(super) packets_denied_on_write: GenericCounter, - pub(super) packets_allowed_on_read: GenericCounter, - pub(super) packets_allowed_on_write: GenericCounter, + pub(super) packets_denied_read: GenericCounter, + pub(super) packets_denied_write: GenericCounter, + pub(super) packets_allowed_read: GenericCounter, + pub(super) packets_allowed_write: GenericCounter, } impl Metrics { pub(super) fn new(registry: &Registry) -> MetricsResult { - let event_labels = vec!["events"]; + let event_labels = &["event"]; let deny_metric = IntCounterVec::new( filter_opts( @@ -37,7 +42,7 @@ impl Metrics { "Firewall", "Total number of packets denied. Labels: event.", ), - &event_labels, + event_labels, )? .register_if_not_exists(registry)?; @@ -47,19 +52,15 @@ impl Metrics { "Firewall", "Total number of packets allowed. Labels: event.", ), - &event_labels, + event_labels, )? .register_if_not_exists(registry)?; Ok(Metrics { - packets_denied_on_read: deny_metric - .get_metric_with_label_values(vec!["on_read"].as_slice())?, - packets_denied_on_write: deny_metric - .get_metric_with_label_values(vec!["on_write"].as_slice())?, - packets_allowed_on_read: allow_metric - .get_metric_with_label_values(vec!["on_read"].as_slice())?, - packets_allowed_on_write: allow_metric - .get_metric_with_label_values(vec!["on_write"].as_slice())?, + packets_denied_read: deny_metric.get_metric_with_label_values(&[READ])?, + packets_denied_write: deny_metric.get_metric_with_label_values(&[WRITE])?, + packets_allowed_read: allow_metric.get_metric_with_label_values(&[READ])?, + packets_allowed_write: allow_metric.get_metric_with_label_values(&[WRITE])?, }) } } diff --git a/tests/firewall.rs b/tests/firewall.rs index e305d1ba7a..b66996e69b 100644 --- a/tests/firewall.rs +++ b/tests/firewall.rs @@ -113,7 +113,7 @@ async fn test(t: &mut TestHelper, server_port: u16, yaml: &str) -> Receiver client_addr, "address" => local_addr); recv.socket.send_to(b"hello", &local_addr).await.unwrap();