From 9300105d9fef48304a21462c15971e6ca67c009b Mon Sep 17 00:00:00 2001 From: thatmattlove Date: Thu, 23 Dec 2021 00:00:35 -0700 Subject: [PATCH] #142: Start multiple query target implementation --- hyperglass/models/api/query.py | 4 ++- hyperglass/models/directive.py | 58 +++++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/hyperglass/models/api/query.py b/hyperglass/models/api/query.py index fd8e880a..eb33913b 100644 --- a/hyperglass/models/api/query.py +++ b/hyperglass/models/api/query.py @@ -22,13 +22,15 @@ (TEXT := use_state("params").web.text) +QueryTarget = constr(strip_whitespace=True, min_length=1) + class Query(BaseModel): """Validation model for input query parameters.""" query_location: StrictStr # Device `name` field query_type: StrictStr # Directive `id` field - query_target: constr(strip_whitespace=True, min_length=1) + query_target: t.Union[t.List[QueryTarget], QueryTarget] class Config: """Pydantic model configuration.""" diff --git a/hyperglass/models/directive.py b/hyperglass/models/directive.py index e3a967a1..1ab3b8ef 100644 --- a/hyperglass/models/directive.py +++ b/hyperglass/models/directive.py @@ -18,6 +18,10 @@ from .main import MultiModel, HyperglassModel, HyperglassUniqueModel from .fields import Action +if t.TYPE_CHECKING: + # Project + from hyperglass.models.api.query import QueryTarget + IPv4PrefixLength = conint(ge=0, le=32) IPv6PrefixLength = conint(ge=0, le=128) IPNetwork = t.Union[IPv4Network, IPv6Network] @@ -82,7 +86,7 @@ def validate_commands(cls, value: t.Union[str, t.List[str]]) -> t.List[str]: return [value] return value - def validate_target(self, target: str) -> bool: + def validate_target(self, target: str, *, multiple: bool) -> bool: """Validate a query target (Placeholder signature).""" raise NotImplementedError( f"{self._validation} rule does not implement a 'validate_target()' method" @@ -119,8 +123,11 @@ def in_range(self, target: IPNetwork) -> bool: return False - def validate_target(self, target: str) -> bool: + def validate_target(self, target: "QueryTarget", *, multiple: bool) -> bool: """Validate an IP address target against this rule's conditions.""" + if isinstance(target, t.List): + self._passed = False + raise InputValidationError("Target must be a single value") try: # Attempt to use IP object factory to create an IP address object valid_target = ip_network(target) @@ -181,23 +188,43 @@ class RuleWithPattern(Rule): _validation: RuleValidation = PrivateAttr("pattern") condition: StrictStr - def validate_target(self, target: str) -> str: + def validate_target(self, target: "QueryTarget", *, multiple: bool) -> str: # noqa: C901 """Validate a string target against configured regex patterns.""" - if self.condition == "*": - pattern = re.compile(".+", re.IGNORECASE) - else: - pattern = re.compile(self.condition, re.IGNORECASE) + def validate_single_value(value: str) -> t.Union[bool, BaseException]: + if self.condition == "*": + pattern = re.compile(".+", re.IGNORECASE) + else: + pattern = re.compile(self.condition, re.IGNORECASE) + is_match = pattern.match(value) - is_match = pattern.match(target) - if is_match and self.action == "permit": + if is_match and self.action == "permit": + return True + elif is_match and self.action == "deny": + return InputValidationError(target=value, error="Denied") + return False + + if isinstance(target, t.List) and multiple: + for result in (validate_single_value(v) for v in target): + if isinstance(result, BaseException): + self._passed = False + raise result + elif result is False: + self._passed = False + return result self._passed = True return True - elif is_match and self.action == "deny": - self._passed = False - raise InputValidationError(target=target, error="Denied") - return False + elif isinstance(target, t.List) and not multiple: + raise InputValidationError("Target must be a single value") + + result = validate_single_value(target) + + if isinstance(result, BaseException): + self._passed = False + raise result + self._passed = result + return result class RuleWithoutValidation(Rule): @@ -206,7 +233,7 @@ class RuleWithoutValidation(Rule): _validation: RuleValidation = PrivateAttr(None) condition: None - def validate_target(self, target: str) -> t.Literal[True]: + def validate_target(self, target: str, *, multiple: bool) -> t.Literal[True]: """Don't validate a target. Always returns `True`.""" self._passed = True return True @@ -229,11 +256,12 @@ class Directive(HyperglassUniqueModel, unique_by=("id", "table_output")): disable_builtins: StrictBool = False table_output: t.Optional[StrictStr] groups: t.List[StrictStr] = [] + multiple: StrictBool = False def validate_target(self, target: str) -> bool: """Validate a target against all configured rules.""" for rule in self.rules: - valid = rule.validate_target(target) + valid = rule.validate_target(target, multiple=self.multiple) if valid is True: return True continue