From 6aed48e054b83bb39cfe277314fc85f3948415c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthieu=20T=C3=A2che?= Date: Fri, 27 Oct 2023 14:56:51 +0200 Subject: [PATCH] feat: update CLI to support AntaCatalog --- anta/catalog.py | 15 ++++++--------- anta/cli/__init__.py | 3 ++- anta/cli/check/commands.py | 14 +++----------- anta/cli/nrfu/utils.py | 2 +- anta/cli/utils.py | 12 ++++++++++++ anta/models.py | 23 ++++++++++++++--------- anta/runner.py | 11 ++++------- 7 files changed, 42 insertions(+), 38 deletions(-) diff --git a/anta/catalog.py b/anta/catalog.py index f83d12330..293242960 100644 --- a/anta/catalog.py +++ b/anta/catalog.py @@ -14,8 +14,8 @@ from pydantic import BaseModel, RootModel, model_serializer, model_validator from pydantic.types import ImportString from yaml import safe_load -from anta.device import AntaDevice +from anta.device import AntaDevice from anta.models import AntaTest logger = logging.getLogger(__name__) @@ -89,6 +89,7 @@ def check_tests(cls, data: Any) -> Any: are actually defined in their respective Python module and instantiate Input instances with provided value to validate test inputs. """ + def flatten_modules(data: dict[str, Any], package: str | None = None) -> dict[ModuleType, list[Any]]: """ Allow the user to provide a data structure with nested Python modules. @@ -194,15 +195,11 @@ def check(self: AntaCatalog) -> None: """ if self._data is not None: self.file = AntaCatalogFile(**self._data) - else: - logger.critical("Catalog file has not been parsed thus cannot be checked") - # TODO: custom exception - raise Exception() - if self._tests: - logger.warning(f'Overriding AntaCatalog data from file {self.filename}') + if self._tests: + logger.warning(f"Overriding AntaCatalog data from file {self.filename}") self._tests = [] - for tests in self.file.root.values(): - self._tests.extend(tests) + for tests in self.file.root.values(): + self._tests.extend(tests) def get_tests_by_tags(self, tags: list[str], strict: bool = False) -> list[AntaTestDefinition]: """ diff --git a/anta/cli/__init__.py b/anta/cli/__init__.py index 6cf076ea4..f2e259470 100644 --- a/anta/cli/__init__.py +++ b/anta/cli/__init__.py @@ -21,7 +21,7 @@ from anta.cli.exec import commands as exec_commands from anta.cli.get import commands as get_commands from anta.cli.nrfu import commands as nrfu_commands -from anta.cli.utils import AliasedGroup, IgnoreRequiredWithHelp, parse_catalog, parse_inventory +from anta.cli.utils import AliasedGroup, IgnoreRequiredWithHelp, check_catalog, parse_catalog, parse_inventory from anta.logger import setup_logging from anta.result_manager import ResultManager @@ -150,6 +150,7 @@ def anta( ) def _nrfu(ctx: click.Context, catalog: AntaCatalog) -> None: """Run NRFU against inventory devices""" + check_catalog(ctx, catalog) ctx.obj["catalog"] = catalog ctx.obj["result_manager"] = ResultManager() diff --git a/anta/cli/check/commands.py b/anta/cli/check/commands.py index 2ce0bab3f..ca4339625 100644 --- a/anta/cli/check/commands.py +++ b/anta/cli/check/commands.py @@ -11,13 +11,11 @@ import logging import click -from pydantic import ValidationError from rich.pretty import pretty_repr from anta.catalog import AntaCatalog from anta.cli.console import console -from anta.cli.utils import parse_catalog -from anta.tools.misc import anta_log_exception +from anta.cli.utils import check_catalog, parse_catalog logger = logging.getLogger(__name__) @@ -38,11 +36,5 @@ def catalog(ctx: click.Context, catalog: AntaCatalog) -> None: Check that the catalog is valid """ logger.info(f"Checking syntax of catalog {ctx.obj['catalog_path']}") - try: - catalog.check() - console.print(f"[bold][green]Catalog {ctx.obj['catalog_path']} is valid") - console.print(pretty_repr(catalog.file)) - except ValidationError as e: - console.print(f"[bold][red]Catalog {ctx.obj['catalog_path']} is invalid") - anta_log_exception(e) - ctx.exit(1) + check_catalog(ctx, catalog) + console.print(pretty_repr(catalog.file)) diff --git a/anta/cli/nrfu/utils.py b/anta/cli/nrfu/utils.py index 7275530d0..493651d40 100644 --- a/anta/cli/nrfu/utils.py +++ b/anta/cli/nrfu/utils.py @@ -26,7 +26,7 @@ def print_settings(context: click.Context, report_template: pathlib.Path | None = None, report_output: pathlib.Path | None = None) -> None: """Print ANTA settings before running tests""" - message = f"Running ANTA tests:\n- {context.obj['inventory']}\n- Tests catalog contains {len(context.obj['catalog'])} tests" + message = f"Running ANTA tests:\n- {context.obj['inventory']}\n- Tests catalog contains {len(context.obj['catalog'].tests)} tests" if report_template: message += f"\n- Report template: {report_template}" if report_output: diff --git a/anta/cli/utils.py b/anta/cli/utils.py index 2af3fad66..120f8df34 100644 --- a/anta/cli/utils.py +++ b/anta/cli/utils.py @@ -14,8 +14,10 @@ from typing import TYPE_CHECKING, Any import click +from pydantic import ValidationError from anta.catalog import AntaCatalog +from anta.cli.console import console from anta.inventory import AntaInventory from anta.tools.misc import anta_log_exception @@ -43,6 +45,16 @@ class ExitCode(enum.IntEnum): USAGE_ERROR = 4 +def check_catalog(ctx: click.Context, catalog: AntaCatalog) -> None: + try: + catalog.check() + console.print(f"[bold][green]Catalog {catalog.filename} is valid") + except ValidationError as e: + console.print(f"[bold][red]Catalog {catalog.filename} is invalid") + anta_log_exception(e) + ctx.exit(1) + + def parse_inventory(ctx: click.Context, path: Path) -> AntaInventory: """ Helper function parse an ANTA inventory YAML file diff --git a/anta/models.py b/anta/models.py index 13823496a..3f7f172a1 100644 --- a/anta/models.py +++ b/anta/models.py @@ -317,7 +317,7 @@ class Filters(BaseModel): def __init__( self, device: AntaDevice, - inputs: Optional[dict[str, Any]], + inputs: dict[str, Any] | AntaTest.Input | None, eos_data: Optional[list[dict[Any, Any] | str]] = None, ): """AntaTest Constructor @@ -337,19 +337,24 @@ def __init__( if self.result.result == "unset": self._init_commands(eos_data) - def _init_inputs(self, inputs: Optional[dict[str, Any]]) -> None: + def _init_inputs(self, inputs: dict[str, Any] | AntaTest.Input | None) -> None: """Instantiate the `inputs` instance attribute with an `AntaTest.Input` instance to validate test inputs from defined model. Overwrite result fields based on `ResultOverwrite` input definition. Any input validation error will set this test result status as 'error'.""" - try: - self.inputs = self.Input(**inputs) if inputs is not None else self.Input() - except ValidationError as e: - message = f"{self.__module__}.{self.__class__.__name__}: Inputs are not valid\n{e}" - self.logger.error(message) - self.result.is_error(message=message, exception=e) - return + if inputs is None: + self.inputs = self.Input() + elif isinstance(inputs, AntaTest.Input): + self.inputs = inputs + elif isinstance(inputs, dict): + try: + self.inputs = self.Input(**inputs) + except ValidationError as e: + message = f"{self.__module__}.{self.__class__.__name__}: Inputs are not valid\n{e}" + self.logger.error(message) + self.result.is_error(message=message, exception=e) + return if res_ow := self.inputs.result_overwrite: if res_ow.categories: self.result.categories = res_ow.categories diff --git a/anta/runner.py b/anta/runner.py index cd5918a8d..4ad236385 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -48,6 +48,7 @@ async def main( any: ResultManager object gets updated with the test results. """ + catalog.check() if not catalog.tests: logger.info("The list of tests is empty, exiting") return @@ -71,15 +72,11 @@ async def main( coros = [] - for device, test in itertools.product(devices, catalog.tests): - test_class = test[0] - test_inputs = test[1] - test_filters = test[1].get("filters", None) if test[1] is not None else None - test_tags = test_filters.get("tags", []) if test_filters is not None else [] - if len(test_tags) == 0 or filter_tags(tags_cli=tags, tags_device=device.tags, tags_test=test_tags): + for device, test_definition in itertools.product(devices, catalog.tests): + if len(test_definition.inputs.filters.tags) == 0 or filter_tags(tags_cli=tags, tags_device=device.tags, tags_test=test_definition.inputs.filters.tags): try: # Instantiate AntaTest object - test_instance = test_class(device=device, inputs=test_inputs) + test_instance = test_definition.test(device=device, inputs=test_definition.inputs) coros.append(test_instance.test(eos_data=None)) except Exception as e: # pylint: disable=broad-exception-caught message = "Error when creating ANTA tests"