Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(anta): Optimize AntaTemplate render performance. #654

Merged
merged 23 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 65 additions & 21 deletions anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import hashlib
import inspect
import logging
import re
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -44,6 +45,7 @@ class AntaParamsBaseModel(BaseModel):

model_config = ConfigDict(extra="forbid")

# TODO: is this still needed?
if not TYPE_CHECKING:
# Following pydantic declaration and keeping __getattr__ only when TYPE_CHECKING is false.
# Disabling 1 Dynamically typed expressions (typing.Any) are disallowed in `__getattr__
Expand All @@ -56,7 +58,35 @@ def __getattr__(self, item: str) -> Any:
return None


class AntaTemplate(BaseModel):
class SingletonArgs(type):
"""SingletonArgs class.

Used as metaclass for AntaTemplates to create only one instance of each AntaTemplate with a given set of input arguments.

https://gist.github.com/wowkin2/3af15bfbf197a14a2b0b2488a1e8c787
"""

_instances: ClassVar[dict[str, SingletonArgs]] = {}
_init: ClassVar[dict[SingletonArgs, str]] = {}

def __init__(cls, name: str, bases: list[type], dct: dict[str, Any]) -> None: # noqa: ARG003
"""Initialize the singleton.

TODO
"""
# pylint: disable=unused-argument
cls._init[cls] = dct.get("__init__")

def __call__(cls, *args: Any, **kwargs: Any) -> SingletonArgs:
"""__call__ function."""
init = cls._init[cls]
key = (cls, inspect.Signature.bind(inspect.Signature(init), None, *args, **kwargs)) if init is not None else cls
if key not in cls.instances:
cls._instances[key] = super().__call__(*args, **kwargs)
return cls._instances[key]


class AntaTemplate:
"""Class to define a command template as Python f-string.

Can render a command from parameters.
Expand All @@ -71,11 +101,37 @@ class AntaTemplate(BaseModel):

"""

template: str
version: Literal[1, "latest"] = "latest"
revision: Revision | None = None
ofmt: Literal["json", "text"] = "json"
use_cache: bool = True
# pylint: disable=too-few-public-methods

__metaclass__ = SingletonArgs
gmuloc marked this conversation as resolved.
Show resolved Hide resolved

def __init__( # noqa: PLR0913
self,
template: str,
version: Literal[1, "latest"] = "latest",
revision: Revision | None = None,
ofmt: Literal["json", "text"] = "json",
*,
use_cache: bool = True,
) -> None:
# pylint: disable=too-many-arguments
self.template = template
self.version = version
self.revision = revision
self.ofmt = ofmt
self.use_cache = use_cache

# Create the model only once per Template in the Singleton instance
field_names = [fname for _, fname, _, _ in Formatter().parse(self.template) if fname]
# Extracting the type from the params based on the expected field_names from the template
# All strings for now..
fields: dict[str, Any] = {key: (str | int | bool | Any, ...) for key in field_names}
# Accepting ParamsSchema as non lowercase variable
self.params_schema = create_model(
"ParamsSchema",
__base__=AntaParamsBaseModel,
**fields,
)

def render(self, **params: str | int | bool) -> AntaCommand:
"""Render an AntaCommand from an AntaTemplate instance.
Expand All @@ -93,25 +149,14 @@ def render(self, **params: str | int | bool) -> AntaCommand:
AntaTemplate instance.

"""
# Create params schema on the fly
field_names = [fname for _, fname, _, _ in Formatter().parse(self.template) if fname]
# Extracting the type from the params based on the expected field_names from the template
fields: dict[str, Any] = {key: (type(params.get(key)), ...) for key in field_names}
# Accepting ParamsSchema as non lowercase variable
ParamsSchema = create_model( # noqa: N806
"ParamsSchema",
__base__=AntaParamsBaseModel,
**fields,
)

try:
return AntaCommand(
command=self.template.format(**params),
ofmt=self.ofmt,
version=self.version,
revision=self.revision,
template=self,
params=ParamsSchema(**params),
params=self.params_schema(**params),
use_cache=self.use_cache,
)
except KeyError as e:
Expand Down Expand Up @@ -146,6 +191,8 @@ class AntaCommand(BaseModel):

"""

model_config = ConfigDict(arbitrary_types_allowed=True)

command: str
version: Literal[1, "latest"] = "latest"
revision: Revision | None = None
Expand Down Expand Up @@ -590,9 +637,6 @@ async def wrapper(
self.result.is_error(message=exc_to_str(e))

# TODO: find a correct way to time test execution
# msg = f"Executing test {self.name} on device {self.device.name} took {t.time}" # noqa: ERA001
# self.logger.debug(msg) # noqa: ERA001

AntaTest.update_progress()
return self.result

Expand Down
6 changes: 3 additions & 3 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from anta import GITHUB_SUGGESTION
from anta.logger import anta_log_exception, exc_to_str
from anta.models import AntaTest
from anta.tools import Catchtime
from anta.tools import Catchtime, cprofile

if TYPE_CHECKING:
from collections.abc import Coroutine
Expand Down Expand Up @@ -125,7 +125,7 @@ def prepare_tests(

Returns
-------
defaultdict[AntaDevice, set[AntaTestDefinition]] | None: A mapping of devices to the tests to run or None if there are no tests to run.
A mapping of devices to the tests to run or None if there are no tests to run.
"""
# Build indexes for the catalog. If `tests` is set, filter the indexes based on these tests
catalog.build_indexes(filtered_tests=tests)
Expand Down Expand Up @@ -186,7 +186,7 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio
anta_log_exception(e, message, logger)
return coros


@cprofile
async def main( # noqa: PLR0913
manager: ResultManager,
inventory: AntaInventory,
Expand Down
60 changes: 59 additions & 1 deletion anta/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@

from __future__ import annotations

import cProfile
import os
import pstats
from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast

from anta.logger import format_td

Expand All @@ -20,6 +24,8 @@
else:
from typing_extensions import Self

F = TypeVar("F", bound=Callable[..., Any])


def get_failed_logs(expected_output: dict[Any, Any], actual_output: dict[Any, Any]) -> str:
"""Get the failed log for a test.
Expand Down Expand Up @@ -288,3 +294,55 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException
self.time = format_td(self.raw_time, 3)
if self.logger and self.message:
self.logger.info("%s completed in: %s.", self.message, self.time)


def cprofile(sort_by: str = "cumtime") -> Callable[[F], F]:
"""Profile a function with cProfile.

profile is conditionally enabled based on the presence of ANTA_CPROFILE environment variable.
Expect to decorate an async function.

Args:
----
sort_by (str): The criterion to sort the profiling results. Default is 'cumtime'.

Returns
-------
Callable: The decorated function with conditional profiling.
"""

def decorator(func: F) -> F:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Enable cProfile or not.

If `ANTA_CPROFILE` is set, cProfile is enabled and dumps the stats to the file.

Args:
----
*args: Arbitrary positional arguments.
**kwargs: Arbitrary keyword arguments.

Returns
-------
The result of the function call.
"""
cprofile_file = os.environ.get("ANTA_CPROFILE")

if cprofile_file is not None:
profiler = cProfile.Profile()
profiler.enable()

try:
result = await func(*args, **kwargs)
finally:
if cprofile_file is not None:
profiler.disable()
stats = pstats.Stats(profiler).sort_stats(sort_by)
stats.dump_stats(cprofile_file)

return result

return cast(F, wrapper)

return decorator
Loading