Skip to content

Commit

Permalink
refactor: Witchcraft
Browse files Browse the repository at this point in the history
  • Loading branch information
gmuloc authored and mtache committed Apr 25, 2024
1 parent c2a74a1 commit f0ba2ec
Showing 1 changed file with 65 additions and 18 deletions.
83 changes: 65 additions & 18 deletions anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import hashlib
import inspect
import logging
import re
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -44,6 +45,7 @@ class AntaParamsBaseModel(BaseModel):

model_config = ConfigDict(extra="forbid")

# TODO: is this still needed?
if not TYPE_CHECKING:
# Following pydantic declaration and keeping __getattr__ only when TYPE_CHECKING is false.
# Disabling 1 Dynamically typed expressions (typing.Any) are disallowed in `__getattr__
Expand All @@ -56,7 +58,35 @@ def __getattr__(self, item: str) -> Any:
return None


class AntaTemplate(BaseModel):
class SingletonArgs(type):
"""SingletonArgs class.
Used as metaclass for AntaTemplates to create only one instance of each AntaTemplate with a given set of input arguments.
https://gist.github.com/wowkin2/3af15bfbf197a14a2b0b2488a1e8c787
"""

_instances: ClassVar[dict[str, SingletonArgs]] = {}
_init: ClassVar[dict[SingletonArgs, str]] = {}

def __init__(cls, name: str, bases: list[type], dct: dict[str, Any]) -> None: # noqa: ARG003
"""Initialize the singleton.
TODO
"""
# pylint: disable=unused-argument
cls._init[cls] = dct.get("__init__")

def __call__(cls, *args: Any, **kwargs: Any) -> SingletonArgs:
"""__call__ function."""
init = cls._init[cls]
key = (cls, inspect.Signature.bind(inspect.Signature(init), None, *args, **kwargs)) if init is not None else cls
if key not in cls.instances:
cls._instances[key] = super().__call__(*args, **kwargs)
return cls._instances[key]


class AntaTemplate:
"""Class to define a command template as Python f-string.
Can render a command from parameters.
Expand All @@ -71,11 +101,37 @@ class AntaTemplate(BaseModel):
"""

template: str
version: Literal[1, "latest"] = "latest"
revision: Revision | None = None
ofmt: Literal["json", "text"] = "json"
use_cache: bool = True
# pylint: disable=too-few-public-methods

__metaclass__ = SingletonArgs

def __init__( # noqa: PLR0913
self,
template: str,
version: Literal[1, "latest"] = "latest",
revision: Revision | None = None,
ofmt: Literal["json", "text"] = "json",
*,
use_cache: bool = True,
) -> None:
# pylint: disable=too-many-arguments
self.template = template
self.version = version
self.revision = revision
self.ofmt = ofmt
self.use_cache = use_cache

# Create the model only once per Template in the Singleton instance
field_names = [fname for _, fname, _, _ in Formatter().parse(self.template) if fname]
# Extracting the type from the params based on the expected field_names from the template
# All strings for now..
fields: dict[str, Any] = {key: (str | int | bool | Any, ...) for key in field_names}
# Accepting ParamsSchema as non lowercase variable
self.params_schema = create_model(
"ParamsSchema",
__base__=AntaParamsBaseModel,
**fields,
)

def render(self, **params: str | int | bool) -> AntaCommand:
"""Render an AntaCommand from an AntaTemplate instance.
Expand All @@ -93,25 +149,14 @@ def render(self, **params: str | int | bool) -> AntaCommand:
AntaTemplate instance.
"""
# Create params schema on the fly
field_names = [fname for _, fname, _, _ in Formatter().parse(self.template) if fname]
# Extracting the type from the params based on the expected field_names from the template
fields: dict[str, Any] = {key: (type(params.get(key)), ...) for key in field_names}
# Accepting ParamsSchema as non lowercase variable
ParamsSchema = create_model( # noqa: N806
"ParamsSchema",
__base__=AntaParamsBaseModel,
**fields,
)

try:
return AntaCommand(
command=self.template.format(**params),
ofmt=self.ofmt,
version=self.version,
revision=self.revision,
template=self,
params=ParamsSchema(**params),
params=self.params_schema(**params),
use_cache=self.use_cache,
)
except KeyError as e:
Expand Down Expand Up @@ -146,6 +191,8 @@ class AntaCommand(BaseModel):
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

command: str
version: Literal[1, "latest"] = "latest"
revision: Revision | None = None
Expand Down

0 comments on commit f0ba2ec

Please sign in to comment.