From cbc680e1f1f666b2fc4b3313bfa37c5d2b4cc829 Mon Sep 17 00:00:00 2001 From: Mark Mandel Date: Mon, 18 Oct 2021 16:42:28 -0700 Subject: [PATCH] Review updates, including refactoring with PortRange::new() --- src/filters/firewall.rs | 84 ++++++++++--------- src/filters/firewall/config.rs | 140 ++++++++++++++++---------------- src/filters/firewall/metrics.rs | 35 ++++---- 3 files changed, 134 insertions(+), 125 deletions(-) diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs index 2c800c4bf6..3a32c3d241 100644 --- a/src/filters/firewall.rs +++ b/src/filters/firewall.rs @@ -14,17 +14,20 @@ * limitations under the License. */ -crate::include_proto!("quilkin.extensions.filters.firewall.v1alpha1"); +//! Filter for allowing/blocking traffic by IP and port. -use self::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; -use crate::filters::firewall::config::{Action, Config, Rule}; -use crate::filters::firewall::metrics::Metrics; -use crate::filters::{ - CreateFilterArgs, DynFilterFactory, Error, Filter, FilterFactory, FilterInstance, ReadContext, - ReadResponse, WriteContext, WriteResponse, -}; use slog::{debug, o, Logger}; +use crate::filters::firewall::{ + config::{Action, Config, Rule}, + metrics::Metrics, +}; +use crate::filters::prelude::*; + +use self::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; + +crate::include_proto!("quilkin.extensions.filters.firewall.v1alpha1"); + mod config; mod metrics; @@ -87,12 +90,12 @@ impl Filter for Firewall { return match rule.action { Action::Allow => { debug!(self.log, "Allow"; "event" => "read", "from" => ctx.from.to_string()); - self.metrics.packets_allowed_on_read.inc(); + self.metrics.packets_allowed_read.inc(); Some(ctx.into()) } Action::Deny => { debug!(self.log, "Deny"; "event" => "read", "from" => ctx.from ); - self.metrics.packets_denied_on_read.inc(); + self.metrics.packets_denied_read.inc(); None } }; @@ -100,7 +103,7 @@ impl Filter for Firewall { } debug!(self.log, "default: Deny"; "event" => "read", "from" => ctx.from.to_string()); - self.metrics.packets_denied_on_read.inc(); + self.metrics.packets_denied_read.inc(); None } @@ -110,12 +113,12 @@ impl Filter for Firewall { return match rule.action { Action::Allow => { debug!(self.log, "Allow"; "event" => "write", "from" => ctx.from.to_string()); - self.metrics.packets_allowed_on_write.inc(); + self.metrics.packets_allowed_write.inc(); Some(ctx.into()) } Action::Deny => { debug!(self.log, "Deny"; "event" => "write", "from" => ctx.from ); - self.metrics.packets_denied_on_write.inc(); + self.metrics.packets_denied_write.inc(); None } }; @@ -123,16 +126,18 @@ impl Filter for Firewall { } debug!(self.log, "default: Deny"; "event" => "write", "from" => ctx.from.to_string()); - self.metrics.packets_denied_on_write.inc(); + self.metrics.packets_denied_write.inc(); None } } #[cfg(test)] mod tests { + use prometheus::Registry; + use std::net::Ipv4Addr; + use crate::endpoint::{Endpoint, Endpoints, UpstreamEndpoints}; use crate::filters::firewall::config::PortRange; use crate::test_utils::logger; - use prometheus::Registry; use super::*; @@ -144,35 +149,36 @@ mod tests { on_read: vec![Rule { action: Action::Allow, source: "192.168.75.0/24".parse().unwrap(), - ports: vec![PortRange { min: 10, max: 100 }], + ports: vec![PortRange::new(10, 100).unwrap()], }], on_write: vec![], }; + let local_ip = [192, 168, 75, 20]; let ctx = ReadContext::new( UpstreamEndpoints::from( - Endpoints::new(vec![Endpoint::new("127.0.0.1:8080".parse().unwrap())]).unwrap(), + Endpoints::new(vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())]).unwrap(), ), - "192.168.75.20:80".parse().unwrap(), + (local_ip, 80).into(), vec![], ); assert!(firewall.read(ctx).is_some()); - assert_eq!(1, firewall.metrics.packets_allowed_on_read.get()); - assert_eq!(0, firewall.metrics.packets_denied_on_read.get()); + assert_eq!(1, firewall.metrics.packets_allowed_read.get()); + assert_eq!(0, firewall.metrics.packets_denied_read.get()); let ctx = ReadContext::new( UpstreamEndpoints::from( - Endpoints::new(vec![Endpoint::new("127.0.0.1:8080".parse().unwrap())]).unwrap(), + Endpoints::new(vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())]).unwrap(), ), - "192.168.75.20:2000".parse().unwrap(), + (local_ip, 2000).into(), vec![], ); assert!(firewall.read(ctx).is_none()); - assert_eq!(1, firewall.metrics.packets_allowed_on_read.get()); - assert_eq!(1, firewall.metrics.packets_denied_on_read.get()); + assert_eq!(1, firewall.metrics.packets_allowed_read.get()); + assert_eq!(1, firewall.metrics.packets_denied_read.get()); - assert_eq!(0, firewall.metrics.packets_allowed_on_write.get()); - assert_eq!(0, firewall.metrics.packets_denied_on_write.get()); + assert_eq!(0, firewall.metrics.packets_allowed_write.get()); + assert_eq!(0, firewall.metrics.packets_denied_write.get()); } #[test] @@ -184,32 +190,34 @@ mod tests { on_write: vec![Rule { action: Action::Allow, source: "192.168.75.0/24".parse().unwrap(), - ports: vec![PortRange { min: 10, max: 100 }], + ports: vec![PortRange::new(10, 100).unwrap()], }], }; - let endpoint = Endpoint::new("127.0.0.1:80".parse().unwrap()); + let endpoint = Endpoint::new((Ipv4Addr::LOCALHOST, 80).into()); + let local_addr = (Ipv4Addr::LOCALHOST, 8081).into(); + let ctx = WriteContext::new( &endpoint, - "192.168.75.20:80".parse().unwrap(), - "127.0.0.1:8081".parse().unwrap(), + ([192, 168, 75, 20], 80).into(), + local_addr, vec![], ); assert!(firewall.write(ctx).is_some()); - assert_eq!(1, firewall.metrics.packets_allowed_on_write.get()); - assert_eq!(0, firewall.metrics.packets_denied_on_write.get()); + assert_eq!(1, firewall.metrics.packets_allowed_write.get()); + assert_eq!(0, firewall.metrics.packets_denied_write.get()); let ctx = WriteContext::new( &endpoint, - "192.168.77.20:80".parse().unwrap(), - "127.0.0.1:8081".parse().unwrap(), + ([192, 168, 77, 20], 80).into(), + local_addr, vec![], ); assert!(!firewall.write(ctx).is_some()); - assert_eq!(1, firewall.metrics.packets_allowed_on_write.get()); - assert_eq!(1, firewall.metrics.packets_denied_on_write.get()); + assert_eq!(1, firewall.metrics.packets_allowed_write.get()); + assert_eq!(1, firewall.metrics.packets_denied_write.get()); - assert_eq!(0, firewall.metrics.packets_allowed_on_read.get()); - assert_eq!(0, firewall.metrics.packets_denied_on_read.get()); + assert_eq!(0, firewall.metrics.packets_allowed_read.get()); + assert_eq!(0, firewall.metrics.packets_denied_read.get()); } } diff --git a/src/filters/firewall/config.rs b/src/filters/firewall/config.rs index ade82d44a8..f15ace0257 100644 --- a/src/filters/firewall/config.rs +++ b/src/filters/firewall/config.rs @@ -14,10 +14,7 @@ * limitations under the License. */ -use std::convert::TryFrom; -use std::fmt; -use std::fmt::Formatter; -use std::net::SocketAddr; +use std::{convert::TryFrom, fmt, fmt::Formatter, net::SocketAddr, ops::Range}; use ipnetwork::IpNetwork; use serde::de::{self, Visitor}; @@ -25,10 +22,10 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::{filters::ConvertProtoConfigError, map_proto_enum}; -use super::quilkin::extensions::filters::firewall::v1alpha1::firewall::{ - Action as ProtoAction, PortRange as ProtoPortRange, Rule as ProtoRule, +use super::quilkin::extensions::filters::firewall::v1alpha1::{ + firewall::{Action as ProtoAction, PortRange as ProtoPortRange, Rule as ProtoRule}, + Firewall as ProtoConfig, }; -use super::quilkin::extensions::filters::firewall::v1alpha1::Firewall as ProtoConfig; /// Represents how a [Firewall] filter is configured for read and write /// operations. @@ -41,10 +38,10 @@ pub struct Config { #[derive(Clone, Deserialize, Debug, PartialEq, Serialize)] pub enum Action { - /// Matching details will allow packets through. + /// Matching rules will allow packets through. #[serde(rename = "ALLOW")] Allow, - /// Matching details will block packets. + /// Matching rules will block packets. #[serde(rename = "DENY")] Deny, } @@ -73,26 +70,39 @@ impl Rule { self.ports .iter() - .any(|range| range.contains(address.port())) + .any(|range| range.contains(&address.port())) } } -#[derive(Clone, Debug, PartialEq)] -pub struct PortRange { - pub min: u16, - pub max: u16, +/// InitializeError is returned with an error message if the +/// [`ClusterManager`] fails to initialize properly. +#[derive(Debug, thiserror::Error)] +pub enum PortRangeError { + #[error("invalid port range: min {min:?} is greater than or equal to max {max:?}")] + InvalidRange { min: u16, max: u16 }, } +/// The range of ports that are configured against a [Rule]. +#[derive(Clone, Debug, PartialEq)] +pub struct PortRange(Range); + impl PortRange { + pub fn new(min: u16, max: u16) -> Result { + if min >= max { + return Err(PortRangeError::InvalidRange { min, max }); + } + + Ok(Self(Range { + start: min, + end: max, + })) + } /// Does this range contain a specific port value? - /// - /// # Arguments - /// - /// * `port`: + /// min is inclusive, max is exclusive. /// /// returns: bool - pub fn contains(&self, port: u16) -> bool { - port >= self.min && port <= self.max + pub fn contains(&self, port: &u16) -> bool { + self.0.contains(port) } } @@ -103,11 +113,11 @@ impl Serialize for PortRange { where S: Serializer, { - if self.max == self.min { - return serializer.serialize_str(self.min.to_string().as_str()); + if self.0.start == (self.0.end - 1) { + return serializer.serialize_str(self.0.start.to_string().as_str()); } - let range = format!("{}-{}", self.min, self.max); + let range = format!("{}-{}", self.0.start, self.0.end); serializer.serialize_str(range.as_str()) } } @@ -136,23 +146,12 @@ impl<'de> Deserialize<'de> for PortRange { match v.split_once('-') { None => { let value = v.parse::().map_err(de::Error::custom)?; - Ok(PortRange { - min: value, - max: value, - }) + PortRange::new(value, value + 1).map_err(de::Error::custom) } Some(split) => { - let min = split.0.parse::().map_err(de::Error::custom)?; - let max = split.1.parse::().map_err(de::Error::custom)?; - - if min > max { - return Err(de::Error::custom(format!( - "min ({}) cannot be bigger than max ({})", - min, max - ))); - } - - Ok(PortRange { min, max }) + let start = split.0.parse::().map_err(de::Error::custom)?; + let end = split.1.parse::().map_err(de::Error::custom)?; + PortRange::new(start, end).map_err(de::Error::custom) } } } @@ -181,14 +180,9 @@ impl TryFrom for Config { ) })?; - if min > max { - return Err(ConvertProtoConfigError::new( - format!("min port ({}) is greater than the max port ({})", min, max), - Some("ports".into()), - )); - }; - - Ok(PortRange { min, max }) + PortRange::new(min, max).map_err(|err| { + ConvertProtoConfigError::new(format!("{}", err), Some("ports".into())) + }) } fn convert_rule(rule: &ProtoRule) -> Result { @@ -261,27 +255,32 @@ on_write: assert_eq!(rule1.action, Action::Allow); assert_eq!(rule1.source, "192.168.51.0/24".parse().unwrap()); assert_eq!(2, rule1.ports.len()); - assert_eq!(10, rule1.ports[0].min); - assert_eq!(10, rule1.ports[0].max); - assert_eq!(1000, rule1.ports[1].min); - assert_eq!(7000, rule1.ports[1].max); + assert_eq!(10, rule1.ports[0].0.start); + assert_eq!(11, rule1.ports[0].0.end); + assert_eq!(1000, rule1.ports[1].0.start); + assert_eq!(7000, rule1.ports[1].0.end); let rule2 = config.on_write[0].clone(); assert_eq!(rule2.action, Action::Deny); assert_eq!(rule2.source, "192.168.75.0/24".parse().unwrap()); assert_eq!(1, rule2.ports.len()); - assert_eq!(7000, rule2.ports[0].min); - assert_eq!(7000, rule2.ports[0].max); + assert_eq!(7000, rule2.ports[0].0.start); + assert_eq!(7001, rule2.ports[0].0.end); } #[test] fn portrange_contains() { - let range = PortRange { min: 10, max: 100 }; - assert!(range.contains(10)); - assert!(range.contains(100)); - assert!(range.contains(50)); - assert!(!range.contains(200)); - assert!(!range.contains(5)); + let range = PortRange::new(10, 100).unwrap(); + assert!(range.contains(&10)); + assert!(!range.contains(&100)); + assert!(range.contains(&50)); + assert!(!range.contains(&200)); + assert!(!range.contains(&5)); + + // single value + let single = PortRange::new(10, 11).unwrap(); + assert!(single.contains(&10)); + assert!(!single.contains(&11)); } #[test] @@ -295,7 +294,7 @@ on_write: on_write: vec![ProtoRule { action: ProtoAction::Deny as i32, source: "192.168.124.0/24".into(), - ports: vec![ProtoPortRange { min: 50, max: 50 }], + ports: vec![ProtoPortRange { min: 50, max: 51 }], }], }; @@ -305,15 +304,15 @@ on_write: assert_eq!(rule1.action, Action::Allow); assert_eq!(rule1.source, "192.168.75.0/24".parse().unwrap()); assert_eq!(1, rule1.ports.len()); - assert_eq!(10, rule1.ports[0].min); - assert_eq!(100, rule1.ports[0].max); + assert_eq!(10, rule1.ports[0].0.start); + assert_eq!(100, rule1.ports[0].0.end); let rule2 = config.on_write[0].clone(); assert_eq!(rule2.action, Action::Deny); assert_eq!(rule2.source, "192.168.124.0/24".parse().unwrap()); assert_eq!(1, rule2.ports.len()); - assert_eq!(50, rule2.ports[0].min); - assert_eq!(50, rule2.ports[0].max); + assert_eq!(50, rule2.ports[0].0.start); + assert_eq!(51, rule2.ports[0].0.end); } #[test] @@ -321,15 +320,16 @@ on_write: let rule = Rule { action: Action::Allow, source: "192.168.75.0/24".parse().unwrap(), - ports: vec![PortRange { min: 10, max: 100 }], + ports: vec![PortRange::new(10, 100).unwrap()], }; - assert!(rule.contains("192.168.75.10:50".parse().unwrap())); - assert!(rule.contains("192.168.75.10:100".parse().unwrap())); - assert!(rule.contains("192.168.75.10:10".parse().unwrap())); + let ip = [192, 168, 75, 10]; + assert!(rule.contains((ip, 50).into())); + assert!(rule.contains((ip, 99).into())); + assert!(rule.contains((ip, 10).into())); - assert!(!rule.contains("192.168.75.10:5".parse().unwrap())); - assert!(!rule.contains("192.168.75.10:1000".parse().unwrap())); - assert!(!rule.contains("192.168.76.10:40".parse().unwrap())); + assert!(!rule.contains((ip, 5).into())); + assert!(!rule.contains((ip, 1000).into())); + assert!(!rule.contains(([192, 168, 76, 10], 40).into())); } } diff --git a/src/filters/firewall/metrics.rs b/src/filters/firewall/metrics.rs index 496397a0eb..8f620c564a 100644 --- a/src/filters/firewall/metrics.rs +++ b/src/filters/firewall/metrics.rs @@ -14,22 +14,27 @@ * limitations under the License. */ -use prometheus::core::{AtomicU64, GenericCounter}; -use prometheus::{IntCounterVec, Registry, Result as MetricsResult}; +use prometheus::{ + core::{AtomicU64, GenericCounter}, + IntCounterVec, Registry, Result as MetricsResult, +}; use crate::metrics::{filter_opts, CollectorExt}; +const READ: &str = "read"; +const WRITE: &str = "write"; + /// Register and manage metrics for this filter pub(super) struct Metrics { - pub(super) packets_denied_on_read: GenericCounter, - pub(super) packets_denied_on_write: GenericCounter, - pub(super) packets_allowed_on_read: GenericCounter, - pub(super) packets_allowed_on_write: GenericCounter, + pub(super) packets_denied_read: GenericCounter, + pub(super) packets_denied_write: GenericCounter, + pub(super) packets_allowed_read: GenericCounter, + pub(super) packets_allowed_write: GenericCounter, } impl Metrics { pub(super) fn new(registry: &Registry) -> MetricsResult { - let event_labels = vec!["events"]; + let event_labels = &["event"]; let deny_metric = IntCounterVec::new( filter_opts( @@ -37,7 +42,7 @@ impl Metrics { "Firewall", "Total number of packets denied. Labels: event.", ), - &event_labels, + event_labels, )? .register_if_not_exists(registry)?; @@ -47,19 +52,15 @@ impl Metrics { "Firewall", "Total number of packets allowed. Labels: event.", ), - &event_labels, + event_labels, )? .register_if_not_exists(registry)?; Ok(Metrics { - packets_denied_on_read: deny_metric - .get_metric_with_label_values(vec!["on_read"].as_slice())?, - packets_denied_on_write: deny_metric - .get_metric_with_label_values(vec!["on_write"].as_slice())?, - packets_allowed_on_read: allow_metric - .get_metric_with_label_values(vec!["on_read"].as_slice())?, - packets_allowed_on_write: allow_metric - .get_metric_with_label_values(vec!["on_write"].as_slice())?, + packets_denied_read: deny_metric.get_metric_with_label_values(&[READ])?, + packets_denied_write: deny_metric.get_metric_with_label_values(&[WRITE])?, + packets_allowed_read: allow_metric.get_metric_with_label_values(&[READ])?, + packets_allowed_write: allow_metric.get_metric_with_label_values(&[WRITE])?, }) } }