Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Extract regular expression to support regexp unit tests #686

Merged
merged 4 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion anta/cli/exec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from click.exceptions import UsageError
from httpx import ConnectError, HTTPError

from anta.custom_types import REGEXP_PATH_MARKERS
from anta.device import AntaDevice, AsyncEOSDevice
from anta.models import AntaCommand

Expand Down Expand Up @@ -60,7 +61,7 @@ async def collect_commands(
async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "text"]) -> None:
outdir = Path() / root_dir / dev.name / outformat
outdir.mkdir(parents=True, exist_ok=True)
safe_command = re.sub(r"[\\\/\s]", "_", command)
safe_command = re.sub(rf"{REGEXP_PATH_MARKERS}", "_", command)
c = AntaCommand(command=command, ofmt=outformat)
await dev.collect(c)
if not c.collected:
Expand Down
43 changes: 34 additions & 9 deletions anta/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@
from pydantic import Field
from pydantic.functional_validators import AfterValidator, BeforeValidator

# Regular Expression definition
# TODO: make this configurable - with an env var maybe?
REGEXP_EOS_BLACKLIST_CMDS = [r"^reload.*", r"^conf\w*\s*(terminal|session)*", r"^wr\w*\s*\w+"]
"""List of regular expressions to blacklist from eos commands."""
REGEXP_PATH_MARKERS = r"[\\\/\s]"
"""Match directory path from string."""
REGEXP_INTERFACE_ID = r"\d+(\/\d+)*(\.\d+)?"
"""Match Interface ID lilke 1/1.1."""
REGEXP_TYPE_EOS_INTERFACE = r"^(Dps|Ethernet|Fabric|Loopback|Management|Port-Channel|Tunnel|Vlan|Vxlan)[0-9]+(\/[0-9]+)*(\.[0-9]+)?$"
"""Match EOS interface types like Ethernet1/1, Vlan1, Loopback1, etc."""
REGEXP_TYPE_VXLAN_SRC_INTERFACE = r"^(Loopback)([0-9]|[1-9][0-9]{1,2}|[1-7][0-9]{3}|8[01][0-9]{2}|819[01])$"
"""Match Vxlan source interface like Loopback10."""
REGEXP_TYPE_HOSTNAME = r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$"
"""Match hostname like `my-hostname`, `my-hostname-1`, `my-hostname-1-2`."""

# Regexp BGP AFI/SAFI
REGEXP_BGP_L2VPN_AFI = r"\b(l2[\s\-]?vpn[\s\-]?evpn)\b"
"""Match L2VPN EVPN AFI."""
REGEXP_BGP_IPV4_MPLS_LABELS = r"\b(ipv4[\s\-]?mpls[\s\-]?label(s)?)\b"
"""Match IPv4 MPLS Labels."""
REGEX_BGP_IPV4_MPLS_VPN = r"\b(ipv4[\s\-]?mpls[\s\-]?vpn)\b"
"""Match IPv4 MPLS VPN."""
REGEX_BGP_IPV4_UNICAST = r"\b(ipv4[\s\-]?uni[\s\-]?cast)\b"
"""Match IPv4 Unicast."""


def aaa_group_prefix(v: str) -> str:
"""Prefix the AAA method with 'group' if it is known."""
Expand All @@ -24,7 +49,7 @@ def interface_autocomplete(v: str) -> str:
- `po` will be changed to `Port-Channel`
- `lo` will be changed to `Loopback`
"""
intf_id_re = re.compile(r"\d+(\/\d+)*(\.\d+)?")
intf_id_re = re.compile(REGEXP_INTERFACE_ID)
m = intf_id_re.search(v)
if m is None:
msg = f"Could not parse interface ID in interface '{v}'"
Expand All @@ -46,7 +71,7 @@ def interface_case_sensitivity(v: str) -> str:
- loopback -> Loopback

"""
if isinstance(v, str) and len(v) > 0 and not v[0].isupper():
if isinstance(v, str) and v != "" and not v[0].isupper():
return f"{v[0].upper()}{v[1:]}"
return v

Expand All @@ -63,10 +88,10 @@ def bgp_multiprotocol_capabilities_abbreviations(value: str) -> str:

"""
patterns = {
r"\b(l2[\s\-]?vpn[\s\-]?evpn)\b": "l2VpnEvpn",
r"\bipv4[\s_-]?mpls[\s_-]?label(s)?\b": "ipv4MplsLabels",
r"\bipv4[\s_-]?mpls[\s_-]?vpn\b": "ipv4MplsVpn",
r"\bipv4[\s_-]?uni[\s_-]?cast\b": "ipv4Unicast",
REGEXP_BGP_L2VPN_AFI: "l2VpnEvpn",
REGEXP_BGP_IPV4_MPLS_LABELS: "ipv4MplsLabels",
REGEX_BGP_IPV4_MPLS_VPN: "ipv4MplsVpn",
REGEX_BGP_IPV4_UNICAST: "ipv4Unicast",
}

for pattern, replacement in patterns.items():
Expand Down Expand Up @@ -97,7 +122,7 @@ def validate_regex(value: str) -> str:
Vni = Annotated[int, Field(ge=1, le=16777215)]
Interface = Annotated[
str,
Field(pattern=r"^(Dps|Ethernet|Fabric|Loopback|Management|Port-Channel|Tunnel|Vlan|Vxlan)[0-9]+(\/[0-9]+)*(\.[0-9]+)?$"),
Field(pattern=REGEXP_TYPE_EOS_INTERFACE),
BeforeValidator(interface_autocomplete),
BeforeValidator(interface_case_sensitivity),
]
Expand All @@ -109,7 +134,7 @@ def validate_regex(value: str) -> str:
]
VxlanSrcIntf = Annotated[
str,
Field(pattern=r"^(Loopback)([0-9]|[1-9][0-9]{1,2}|[1-7][0-9]{3}|8[01][0-9]{2}|819[01])$"),
Field(pattern=REGEXP_TYPE_VXLAN_SRC_INTERFACE),
BeforeValidator(interface_autocomplete),
BeforeValidator(interface_case_sensitivity),
]
Expand Down Expand Up @@ -139,6 +164,6 @@ def validate_regex(value: str) -> str:
Percent = Annotated[float, Field(ge=0.0, le=100.0)]
PositiveInteger = Annotated[int, Field(ge=0)]
Revision = Annotated[int, Field(ge=1, le=99)]
Hostname = Annotated[str, Field(pattern=r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$")]
Hostname = Annotated[str, Field(pattern=REGEXP_TYPE_HOSTNAME)]
Port = Annotated[int, Field(ge=1, le=65535)]
RegexString = Annotated[str, AfterValidator(validate_regex)]
9 changes: 3 additions & 6 deletions anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import BaseModel, ConfigDict, ValidationError, create_model

from anta import GITHUB_SUGGESTION
from anta.custom_types import Revision
from anta.custom_types import REGEXP_EOS_BLACKLIST_CMDS, Revision
from anta.logger import anta_log_exception, exc_to_str
from anta.result_manager.models import TestResult

Expand All @@ -32,9 +32,6 @@
# This would imply overhead to define classes
# https://stackoverflow.com/questions/74103528/type-hinting-an-instance-of-a-nested-class

# TODO: make this configurable - with an env var maybe?
BLACKLIST_REGEX = [r"^reload.*", r"^conf\w*\s*(terminal|session)*", r"^wr\w*\s*\w+"]

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -515,12 +512,12 @@ def blocked(self) -> bool:
"""Check if CLI commands contain a blocked keyword."""
state = False
for command in self.instance_commands:
for pattern in BLACKLIST_REGEX:
for pattern in REGEXP_EOS_BLACKLIST_CMDS:
if re.match(pattern, command.command):
self.logger.error(
"Command <%s> is blocked for security reason matching %s",
command.command,
BLACKLIST_REGEX,
REGEXP_EOS_BLACKLIST_CMDS,
)
self.result.is_error(f"<{command.command}> is blocked for security reason")
state = True
Expand Down
230 changes: 229 additions & 1 deletion tests/units/test_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,203 @@

from __future__ import annotations

import re

import pytest

from anta.custom_types import bgp_multiprotocol_capabilities_abbreviations, interface_autocomplete
from anta.custom_types import (
REGEX_BGP_IPV4_MPLS_VPN,
REGEX_BGP_IPV4_UNICAST,
REGEXP_BGP_IPV4_MPLS_LABELS,
REGEXP_BGP_L2VPN_AFI,
REGEXP_EOS_BLACKLIST_CMDS,
REGEXP_INTERFACE_ID,
REGEXP_PATH_MARKERS,
REGEXP_TYPE_EOS_INTERFACE,
REGEXP_TYPE_HOSTNAME,
REGEXP_TYPE_VXLAN_SRC_INTERFACE,
aaa_group_prefix,
bgp_multiprotocol_capabilities_abbreviations,
interface_autocomplete,
interface_case_sensitivity,
)

# ------------------------------------------------------------------------------
# TEST custom_types.py regular expressions
# ------------------------------------------------------------------------------


def test_regexp_path_markers() -> None:
"""Test REGEXP_PATH_MARKERS."""
# Test strings that should match the pattern
assert re.search(REGEXP_PATH_MARKERS, "show/bgp/interfaces") is not None
assert re.search(REGEXP_PATH_MARKERS, "show\\bgp") is not None
assert re.search(REGEXP_PATH_MARKERS, "show bgp") is not None

# Test strings that should not match the pattern
assert re.search(REGEXP_PATH_MARKERS, "aaaa") is None
assert re.search(REGEXP_PATH_MARKERS, "11111") is None
assert re.search(REGEXP_PATH_MARKERS, ".[]?<>") is None


def test_regexp_bgp_l2vpn_afi() -> None:
"""Test REGEXP_BGP_L2VPN_AFI."""
# Test strings that should match the pattern
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn-evpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2 vpn evpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2-vpn evpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn evpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpnevpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2 vpnevpn") is not None

# Test strings that should not match the pattern
assert re.search(REGEXP_BGP_L2VPN_AFI, "al2vpn evpn") is None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn-evpna") is None


def test_regexp_bgp_ipv4_mpls_labels() -> None:
"""Test REGEXP_BGP_IPV4_MPLS_LABELS."""
assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4-mpls-label") is not None
assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4 mpls labels") is not None
assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4Mplslabel") is None


def test_regex_bgp_ipv4_mpls_vpn() -> None:
"""Test REGEX_BGP_IPV4_MPLS_VPN."""
assert re.search(REGEX_BGP_IPV4_MPLS_VPN, "ipv4-mpls-vpn") is not None
assert re.search(REGEX_BGP_IPV4_MPLS_VPN, "ipv4_mplsvpn") is None


def test_regex_bgp_ipv4_unicast() -> None:
"""Test REGEX_BGP_IPV4_UNICAST."""
assert re.search(REGEX_BGP_IPV4_UNICAST, "ipv4-uni-cast") is not None
assert re.search(REGEX_BGP_IPV4_UNICAST, "ipv4+unicast") is None


def test_regexp_type_interface_id() -> None:
"""Test REGEXP_INTERFACE_ID."""
intf_id_re = re.compile(f"{REGEXP_INTERFACE_ID}")

# Test strings that should match the pattern
assert intf_id_re.search("123") is not None
assert intf_id_re.search("123/456") is not None
assert intf_id_re.search("123.456") is not None
assert intf_id_re.search("123/456.789") is not None


def test_regexp_type_eos_interface() -> None:
"""Test REGEXP_TYPE_EOS_INTERFACE."""
# Test strings that should match the pattern
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet0") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vlan100") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel1/0") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback0.1") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management0/0/0") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Tunnel1") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vxlan1") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Fabric1") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Dps1") is not None

# Test strings that should not match the pattern
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vlan") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback.") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management/") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Tunnel") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vxlan") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Fabric") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Dps") is None

assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet1/a") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel-100") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback.10") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management/10") is None


def test_regexp_type_vxlan_src_interface() -> None:
"""Test REGEXP_TYPE_VXLAN_SRC_INTERFACE."""
# Test strings that should match the pattern
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback0") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback1") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback99") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback100") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback8190") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback8199") is not None

# Test strings that should not match the pattern
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback") is None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback9001") is None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback9000") is None


def test_regexp_type_hostname() -> None:
"""Test REGEXP_TYPE_HOSTNAME."""
# Test strings that should match the pattern
assert re.match(REGEXP_TYPE_HOSTNAME, "hostname") is not None
assert re.match(REGEXP_TYPE_HOSTNAME, "hostname.com") is not None
assert re.match(REGEXP_TYPE_HOSTNAME, "host-name.com") is not None
assert re.match(REGEXP_TYPE_HOSTNAME, "host.name.com") is not None
assert re.match(REGEXP_TYPE_HOSTNAME, "host-name1.com") is not None

# Test strings that should not match the pattern
assert re.match(REGEXP_TYPE_HOSTNAME, "-hostname.com") is None
assert re.match(REGEXP_TYPE_HOSTNAME, ".hostname.com") is None
assert re.match(REGEXP_TYPE_HOSTNAME, "hostname-.com") is None
assert re.match(REGEXP_TYPE_HOSTNAME, "hostname..com") is None


@pytest.mark.parametrize(
("test_string", "expected"),
[
("reload", True), # matches "^reload.*"
("reload now", True), # matches "^reload.*"
("configure terminal", True), # matches "^conf\w*\s*(terminal|session)*"
("conf t", True), # matches "^conf\w*\s*(terminal|session)*"
("write memory", True), # matches "^wr\w*\s*\w+"
("wr mem", True), # matches "^wr\w*\s*\w+"
("show running-config", False), # does not match any regex
("no shutdown", False), # does not match any regex
("", False), # empty string does not match any regex
],
)
def test_regexp_eos_blacklist_cmds(test_string: str, expected: bool) -> None:
"""Test REGEXP_EOS_BLACKLIST_CMDS."""

def matches_any_regex(string: str, regex_list: list[str]) -> bool:
"""
Check if a string matches at least one regular expression in a list.

:param string: The string to check.
:param regex_list: A list of regular expressions.
:return: True if the string matches at least one regular expression, False otherwise.
"""
return any(re.match(regex, string) for regex in regex_list)

assert matches_any_regex(test_string, REGEXP_EOS_BLACKLIST_CMDS) == expected


# ------------------------------------------------------------------------------
# TEST custom_types.py functions
# ------------------------------------------------------------------------------


def test_interface_autocomplete_success() -> None:
"""Test interface_autocomplete with valid inputs."""
assert interface_autocomplete("et1") == "Ethernet1"
assert interface_autocomplete("et1/1") == "Ethernet1/1"
assert interface_autocomplete("et1.1") == "Ethernet1.1"
assert interface_autocomplete("et1/1.1") == "Ethernet1/1.1"
assert interface_autocomplete("eth2") == "Ethernet2"
assert interface_autocomplete("po3") == "Port-Channel3"
assert interface_autocomplete("lo4") == "Loopback4"


def test_interface_autocomplete_no_alias() -> None:
"""Test interface_autocomplete with inputs that don't have aliases."""
assert interface_autocomplete("GigabitEthernet1") == "GigabitEthernet1"
assert interface_autocomplete("Vlan10") == "Vlan10"
assert interface_autocomplete("Tunnel100") == "Tunnel100"


def test_interface_autocomplete_failure() -> None:
Expand All @@ -34,3 +228,37 @@ def test_interface_autocomplete_failure() -> None:
def test_bgp_multiprotocol_capabilities_abbreviationsh(str_input: str, expected_output: str) -> None:
"""Test bgp_multiprotocol_capabilities_abbreviations."""
assert bgp_multiprotocol_capabilities_abbreviations(str_input) == expected_output


def test_aaa_group_prefix_known_method() -> None:
"""Test aaa_group_prefix with a known method."""
assert aaa_group_prefix("local") == "local"
assert aaa_group_prefix("none") == "none"
assert aaa_group_prefix("logging") == "logging"


def test_aaa_group_prefix_unknown_method() -> None:
"""Test aaa_group_prefix with an unknown method."""
assert aaa_group_prefix("demo") == "group demo"
assert aaa_group_prefix("group1") == "group group1"


def test_interface_case_sensitivity_lowercase() -> None:
"""Test interface_case_sensitivity with lowercase inputs."""
assert interface_case_sensitivity("ethernet") == "Ethernet"
assert interface_case_sensitivity("vlan") == "Vlan"
assert interface_case_sensitivity("loopback") == "Loopback"


def test_interface_case_sensitivity_mixed_case() -> None:
"""Test interface_case_sensitivity with mixed case inputs."""
assert interface_case_sensitivity("Ethernet") == "Ethernet"
assert interface_case_sensitivity("Vlan") == "Vlan"
assert interface_case_sensitivity("Loopback") == "Loopback"


def test_interface_case_sensitivity_uppercase() -> None:
"""Test interface_case_sensitivity with uppercase inputs."""
assert interface_case_sensitivity("ETHERNET") == "ETHERNET"
assert interface_case_sensitivity("VLAN") == "VLAN"
assert interface_case_sensitivity("LOOPBACK") == "LOOPBACK"
Loading