diff --git a/anta/cli/exec/utils.py b/anta/cli/exec/utils.py index 070fcc7b4..3e3ab511f 100644 --- a/anta/cli/exec/utils.py +++ b/anta/cli/exec/utils.py @@ -18,6 +18,7 @@ from click.exceptions import UsageError from httpx import ConnectError, HTTPError +from anta.custom_types import REGEXP_PATH_MARKERS from anta.device import AntaDevice, AsyncEOSDevice from anta.models import AntaCommand @@ -60,7 +61,7 @@ async def collect_commands( async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "text"]) -> None: outdir = Path() / root_dir / dev.name / outformat outdir.mkdir(parents=True, exist_ok=True) - safe_command = re.sub(r"[\\\/\s]", "_", command) + safe_command = re.sub(rf"{REGEXP_PATH_MARKERS}", "_", command) c = AntaCommand(command=command, ofmt=outformat) await dev.collect(c) if not c.collected: diff --git a/anta/custom_types.py b/anta/custom_types.py index 48711c2ce..a0a0631d0 100644 --- a/anta/custom_types.py +++ b/anta/custom_types.py @@ -9,6 +9,31 @@ from pydantic import Field from pydantic.functional_validators import AfterValidator, BeforeValidator +# Regular Expression definition +# TODO: make this configurable - with an env var maybe? +REGEXP_EOS_BLACKLIST_CMDS = [r"^reload.*", r"^conf\w*\s*(terminal|session)*", r"^wr\w*\s*\w+"] +"""List of regular expressions to blacklist from eos commands.""" +REGEXP_PATH_MARKERS = r"[\\\/\s]" +"""Match directory path from string.""" +REGEXP_INTERFACE_ID = r"\d+(\/\d+)*(\.\d+)?" +"""Match Interface ID lilke 1/1.1.""" +REGEXP_TYPE_EOS_INTERFACE = r"^(Dps|Ethernet|Fabric|Loopback|Management|Port-Channel|Tunnel|Vlan|Vxlan)[0-9]+(\/[0-9]+)*(\.[0-9]+)?$" +"""Match EOS interface types like Ethernet1/1, Vlan1, Loopback1, etc.""" +REGEXP_TYPE_VXLAN_SRC_INTERFACE = r"^(Loopback)([0-9]|[1-9][0-9]{1,2}|[1-7][0-9]{3}|8[01][0-9]{2}|819[01])$" +"""Match Vxlan source interface like Loopback10.""" +REGEXP_TYPE_HOSTNAME = r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$" +"""Match hostname like `my-hostname`, `my-hostname-1`, `my-hostname-1-2`.""" + +# Regexp BGP AFI/SAFI +REGEXP_BGP_L2VPN_AFI = r"\b(l2[\s\-]?vpn[\s\-]?evpn)\b" +"""Match L2VPN EVPN AFI.""" +REGEXP_BGP_IPV4_MPLS_LABELS = r"\b(ipv4[\s\-]?mpls[\s\-]?label(s)?)\b" +"""Match IPv4 MPLS Labels.""" +REGEX_BGP_IPV4_MPLS_VPN = r"\b(ipv4[\s\-]?mpls[\s\-]?vpn)\b" +"""Match IPv4 MPLS VPN.""" +REGEX_BGP_IPV4_UNICAST = r"\b(ipv4[\s\-]?uni[\s\-]?cast)\b" +"""Match IPv4 Unicast.""" + def aaa_group_prefix(v: str) -> str: """Prefix the AAA method with 'group' if it is known.""" @@ -24,7 +49,7 @@ def interface_autocomplete(v: str) -> str: - `po` will be changed to `Port-Channel` - `lo` will be changed to `Loopback` """ - intf_id_re = re.compile(r"\d+(\/\d+)*(\.\d+)?") + intf_id_re = re.compile(REGEXP_INTERFACE_ID) m = intf_id_re.search(v) if m is None: msg = f"Could not parse interface ID in interface '{v}'" @@ -46,7 +71,7 @@ def interface_case_sensitivity(v: str) -> str: - loopback -> Loopback """ - if isinstance(v, str) and len(v) > 0 and not v[0].isupper(): + if isinstance(v, str) and v != "" and not v[0].isupper(): return f"{v[0].upper()}{v[1:]}" return v @@ -63,10 +88,10 @@ def bgp_multiprotocol_capabilities_abbreviations(value: str) -> str: """ patterns = { - r"\b(l2[\s\-]?vpn[\s\-]?evpn)\b": "l2VpnEvpn", - r"\bipv4[\s_-]?mpls[\s_-]?label(s)?\b": "ipv4MplsLabels", - r"\bipv4[\s_-]?mpls[\s_-]?vpn\b": "ipv4MplsVpn", - r"\bipv4[\s_-]?uni[\s_-]?cast\b": "ipv4Unicast", + REGEXP_BGP_L2VPN_AFI: "l2VpnEvpn", + REGEXP_BGP_IPV4_MPLS_LABELS: "ipv4MplsLabels", + REGEX_BGP_IPV4_MPLS_VPN: "ipv4MplsVpn", + REGEX_BGP_IPV4_UNICAST: "ipv4Unicast", } for pattern, replacement in patterns.items(): @@ -97,7 +122,7 @@ def validate_regex(value: str) -> str: Vni = Annotated[int, Field(ge=1, le=16777215)] Interface = Annotated[ str, - Field(pattern=r"^(Dps|Ethernet|Fabric|Loopback|Management|Port-Channel|Tunnel|Vlan|Vxlan)[0-9]+(\/[0-9]+)*(\.[0-9]+)?$"), + Field(pattern=REGEXP_TYPE_EOS_INTERFACE), BeforeValidator(interface_autocomplete), BeforeValidator(interface_case_sensitivity), ] @@ -109,7 +134,7 @@ def validate_regex(value: str) -> str: ] VxlanSrcIntf = Annotated[ str, - Field(pattern=r"^(Loopback)([0-9]|[1-9][0-9]{1,2}|[1-7][0-9]{3}|8[01][0-9]{2}|819[01])$"), + Field(pattern=REGEXP_TYPE_VXLAN_SRC_INTERFACE), BeforeValidator(interface_autocomplete), BeforeValidator(interface_case_sensitivity), ] @@ -139,6 +164,6 @@ def validate_regex(value: str) -> str: Percent = Annotated[float, Field(ge=0.0, le=100.0)] PositiveInteger = Annotated[int, Field(ge=0)] Revision = Annotated[int, Field(ge=1, le=99)] -Hostname = Annotated[str, Field(pattern=r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$")] +Hostname = Annotated[str, Field(pattern=REGEXP_TYPE_HOSTNAME)] Port = Annotated[int, Field(ge=1, le=65535)] RegexString = Annotated[str, AfterValidator(validate_regex)] diff --git a/anta/models.py b/anta/models.py index 20338e7f6..b1a473439 100644 --- a/anta/models.py +++ b/anta/models.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, ConfigDict, ValidationError, create_model from anta import GITHUB_SUGGESTION -from anta.custom_types import Revision +from anta.custom_types import REGEXP_EOS_BLACKLIST_CMDS, Revision from anta.logger import anta_log_exception, exc_to_str from anta.result_manager.models import TestResult @@ -32,9 +32,6 @@ # This would imply overhead to define classes # https://stackoverflow.com/questions/74103528/type-hinting-an-instance-of-a-nested-class -# TODO: make this configurable - with an env var maybe? -BLACKLIST_REGEX = [r"^reload.*", r"^conf\w*\s*(terminal|session)*", r"^wr\w*\s*\w+"] - logger = logging.getLogger(__name__) @@ -515,12 +512,12 @@ def blocked(self) -> bool: """Check if CLI commands contain a blocked keyword.""" state = False for command in self.instance_commands: - for pattern in BLACKLIST_REGEX: + for pattern in REGEXP_EOS_BLACKLIST_CMDS: if re.match(pattern, command.command): self.logger.error( "Command <%s> is blocked for security reason matching %s", command.command, - BLACKLIST_REGEX, + REGEXP_EOS_BLACKLIST_CMDS, ) self.result.is_error(f"<{command.command}> is blocked for security reason") state = True diff --git a/tests/units/test_custom_types.py b/tests/units/test_custom_types.py index 7f6d17ce3..8119849a6 100644 --- a/tests/units/test_custom_types.py +++ b/tests/units/test_custom_types.py @@ -10,9 +10,203 @@ from __future__ import annotations +import re + import pytest -from anta.custom_types import bgp_multiprotocol_capabilities_abbreviations, interface_autocomplete +from anta.custom_types import ( + REGEX_BGP_IPV4_MPLS_VPN, + REGEX_BGP_IPV4_UNICAST, + REGEXP_BGP_IPV4_MPLS_LABELS, + REGEXP_BGP_L2VPN_AFI, + REGEXP_EOS_BLACKLIST_CMDS, + REGEXP_INTERFACE_ID, + REGEXP_PATH_MARKERS, + REGEXP_TYPE_EOS_INTERFACE, + REGEXP_TYPE_HOSTNAME, + REGEXP_TYPE_VXLAN_SRC_INTERFACE, + aaa_group_prefix, + bgp_multiprotocol_capabilities_abbreviations, + interface_autocomplete, + interface_case_sensitivity, +) + +# ------------------------------------------------------------------------------ +# TEST custom_types.py regular expressions +# ------------------------------------------------------------------------------ + + +def test_regexp_path_markers() -> None: + """Test REGEXP_PATH_MARKERS.""" + # Test strings that should match the pattern + assert re.search(REGEXP_PATH_MARKERS, "show/bgp/interfaces") is not None + assert re.search(REGEXP_PATH_MARKERS, "show\\bgp") is not None + assert re.search(REGEXP_PATH_MARKERS, "show bgp") is not None + + # Test strings that should not match the pattern + assert re.search(REGEXP_PATH_MARKERS, "aaaa") is None + assert re.search(REGEXP_PATH_MARKERS, "11111") is None + assert re.search(REGEXP_PATH_MARKERS, ".[]?<>") is None + + +def test_regexp_bgp_l2vpn_afi() -> None: + """Test REGEXP_BGP_L2VPN_AFI.""" + # Test strings that should match the pattern + assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn-evpn") is not None + assert re.search(REGEXP_BGP_L2VPN_AFI, "l2 vpn evpn") is not None + assert re.search(REGEXP_BGP_L2VPN_AFI, "l2-vpn evpn") is not None + assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn evpn") is not None + assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpnevpn") is not None + assert re.search(REGEXP_BGP_L2VPN_AFI, "l2 vpnevpn") is not None + + # Test strings that should not match the pattern + assert re.search(REGEXP_BGP_L2VPN_AFI, "al2vpn evpn") is None + assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn-evpna") is None + + +def test_regexp_bgp_ipv4_mpls_labels() -> None: + """Test REGEXP_BGP_IPV4_MPLS_LABELS.""" + assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4-mpls-label") is not None + assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4 mpls labels") is not None + assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4Mplslabel") is None + + +def test_regex_bgp_ipv4_mpls_vpn() -> None: + """Test REGEX_BGP_IPV4_MPLS_VPN.""" + assert re.search(REGEX_BGP_IPV4_MPLS_VPN, "ipv4-mpls-vpn") is not None + assert re.search(REGEX_BGP_IPV4_MPLS_VPN, "ipv4_mplsvpn") is None + + +def test_regex_bgp_ipv4_unicast() -> None: + """Test REGEX_BGP_IPV4_UNICAST.""" + assert re.search(REGEX_BGP_IPV4_UNICAST, "ipv4-uni-cast") is not None + assert re.search(REGEX_BGP_IPV4_UNICAST, "ipv4+unicast") is None + + +def test_regexp_type_interface_id() -> None: + """Test REGEXP_INTERFACE_ID.""" + intf_id_re = re.compile(f"{REGEXP_INTERFACE_ID}") + + # Test strings that should match the pattern + assert intf_id_re.search("123") is not None + assert intf_id_re.search("123/456") is not None + assert intf_id_re.search("123.456") is not None + assert intf_id_re.search("123/456.789") is not None + + +def test_regexp_type_eos_interface() -> None: + """Test REGEXP_TYPE_EOS_INTERFACE.""" + # Test strings that should match the pattern + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet0") is not None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vlan100") is not None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel1/0") is not None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback0.1") is not None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management0/0/0") is not None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Tunnel1") is not None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vxlan1") is not None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Fabric1") is not None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Dps1") is not None + + # Test strings that should not match the pattern + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vlan") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback.") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management/") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Tunnel") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vxlan") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Fabric") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Dps") is None + + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet1/a") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel-100") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback.10") is None + assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management/10") is None + + +def test_regexp_type_vxlan_src_interface() -> None: + """Test REGEXP_TYPE_VXLAN_SRC_INTERFACE.""" + # Test strings that should match the pattern + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback0") is not None + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback1") is not None + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback99") is not None + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback100") is not None + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback8190") is not None + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback8199") is not None + + # Test strings that should not match the pattern + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback") is None + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback9001") is None + assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback9000") is None + + +def test_regexp_type_hostname() -> None: + """Test REGEXP_TYPE_HOSTNAME.""" + # Test strings that should match the pattern + assert re.match(REGEXP_TYPE_HOSTNAME, "hostname") is not None + assert re.match(REGEXP_TYPE_HOSTNAME, "hostname.com") is not None + assert re.match(REGEXP_TYPE_HOSTNAME, "host-name.com") is not None + assert re.match(REGEXP_TYPE_HOSTNAME, "host.name.com") is not None + assert re.match(REGEXP_TYPE_HOSTNAME, "host-name1.com") is not None + + # Test strings that should not match the pattern + assert re.match(REGEXP_TYPE_HOSTNAME, "-hostname.com") is None + assert re.match(REGEXP_TYPE_HOSTNAME, ".hostname.com") is None + assert re.match(REGEXP_TYPE_HOSTNAME, "hostname-.com") is None + assert re.match(REGEXP_TYPE_HOSTNAME, "hostname..com") is None + + +@pytest.mark.parametrize( + ("test_string", "expected"), + [ + ("reload", True), # matches "^reload.*" + ("reload now", True), # matches "^reload.*" + ("configure terminal", True), # matches "^conf\w*\s*(terminal|session)*" + ("conf t", True), # matches "^conf\w*\s*(terminal|session)*" + ("write memory", True), # matches "^wr\w*\s*\w+" + ("wr mem", True), # matches "^wr\w*\s*\w+" + ("show running-config", False), # does not match any regex + ("no shutdown", False), # does not match any regex + ("", False), # empty string does not match any regex + ], +) +def test_regexp_eos_blacklist_cmds(test_string: str, expected: bool) -> None: + """Test REGEXP_EOS_BLACKLIST_CMDS.""" + + def matches_any_regex(string: str, regex_list: list[str]) -> bool: + """ + Check if a string matches at least one regular expression in a list. + + :param string: The string to check. + :param regex_list: A list of regular expressions. + :return: True if the string matches at least one regular expression, False otherwise. + """ + return any(re.match(regex, string) for regex in regex_list) + + assert matches_any_regex(test_string, REGEXP_EOS_BLACKLIST_CMDS) == expected + + +# ------------------------------------------------------------------------------ +# TEST custom_types.py functions +# ------------------------------------------------------------------------------ + + +def test_interface_autocomplete_success() -> None: + """Test interface_autocomplete with valid inputs.""" + assert interface_autocomplete("et1") == "Ethernet1" + assert interface_autocomplete("et1/1") == "Ethernet1/1" + assert interface_autocomplete("et1.1") == "Ethernet1.1" + assert interface_autocomplete("et1/1.1") == "Ethernet1/1.1" + assert interface_autocomplete("eth2") == "Ethernet2" + assert interface_autocomplete("po3") == "Port-Channel3" + assert interface_autocomplete("lo4") == "Loopback4" + + +def test_interface_autocomplete_no_alias() -> None: + """Test interface_autocomplete with inputs that don't have aliases.""" + assert interface_autocomplete("GigabitEthernet1") == "GigabitEthernet1" + assert interface_autocomplete("Vlan10") == "Vlan10" + assert interface_autocomplete("Tunnel100") == "Tunnel100" def test_interface_autocomplete_failure() -> None: @@ -34,3 +228,37 @@ def test_interface_autocomplete_failure() -> None: def test_bgp_multiprotocol_capabilities_abbreviationsh(str_input: str, expected_output: str) -> None: """Test bgp_multiprotocol_capabilities_abbreviations.""" assert bgp_multiprotocol_capabilities_abbreviations(str_input) == expected_output + + +def test_aaa_group_prefix_known_method() -> None: + """Test aaa_group_prefix with a known method.""" + assert aaa_group_prefix("local") == "local" + assert aaa_group_prefix("none") == "none" + assert aaa_group_prefix("logging") == "logging" + + +def test_aaa_group_prefix_unknown_method() -> None: + """Test aaa_group_prefix with an unknown method.""" + assert aaa_group_prefix("demo") == "group demo" + assert aaa_group_prefix("group1") == "group group1" + + +def test_interface_case_sensitivity_lowercase() -> None: + """Test interface_case_sensitivity with lowercase inputs.""" + assert interface_case_sensitivity("ethernet") == "Ethernet" + assert interface_case_sensitivity("vlan") == "Vlan" + assert interface_case_sensitivity("loopback") == "Loopback" + + +def test_interface_case_sensitivity_mixed_case() -> None: + """Test interface_case_sensitivity with mixed case inputs.""" + assert interface_case_sensitivity("Ethernet") == "Ethernet" + assert interface_case_sensitivity("Vlan") == "Vlan" + assert interface_case_sensitivity("Loopback") == "Loopback" + + +def test_interface_case_sensitivity_uppercase() -> None: + """Test interface_case_sensitivity with uppercase inputs.""" + assert interface_case_sensitivity("ETHERNET") == "ETHERNET" + assert interface_case_sensitivity("VLAN") == "VLAN" + assert interface_case_sensitivity("LOOPBACK") == "LOOPBACK"