-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split the models.py into smaller files.
- Loading branch information
Showing
3 changed files
with
356 additions
and
339 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
"""Abstract models for the API.""" | ||
|
||
|
||
from abc import ABC, abstractmethod | ||
from datetime import datetime | ||
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast | ||
|
||
from pydantic import BaseModel | ||
from pydantic.fields import AliasChoices, FieldInfo | ||
|
||
from mreg_cli.api.endpoints import Endpoint | ||
from mreg_cli.log import cli_warning | ||
from mreg_cli.outputmanager import OutputManager | ||
from mreg_cli.utilities.api import delete, get, get_item_by_key_value, get_list, patch, post | ||
|
||
BMT = TypeVar("BMT", bound="BaseModel") | ||
|
||
|
||
def get_field_aliases(field_info: FieldInfo) -> Set[str]: | ||
"""Get all aliases for a Pydantic field.""" | ||
aliases: set[str] = set() | ||
|
||
if field_info.alias: | ||
aliases.add(field_info.alias) | ||
|
||
if field_info.validation_alias: | ||
if isinstance(field_info.validation_alias, str): | ||
aliases.add(field_info.validation_alias) | ||
elif isinstance(field_info.validation_alias, AliasChoices): | ||
for choice in field_info.validation_alias.choices: | ||
if isinstance(choice, str): | ||
aliases.add(choice) | ||
return aliases | ||
|
||
|
||
def get_model_aliases(model: BaseModel) -> Dict[str, str]: | ||
"""Get a mapping of aliases to field names for a Pydantic model. | ||
Includes field names, alias, and validation alias(es). | ||
""" | ||
fields = {} # type: Dict[str, str] | ||
|
||
for field_name, field_info in model.model_fields.items(): | ||
aliases = get_field_aliases(field_info) | ||
if model.model_config.get("populate_by_name"): | ||
aliases.add(field_name) | ||
# Assign aliases to field name in mapping | ||
for alias in aliases: | ||
fields[alias] = field_name | ||
|
||
return fields | ||
|
||
|
||
class FrozenModel(BaseModel): | ||
"""Model for an immutable object.""" | ||
|
||
def __setattr__(self, name: str, value: Any): | ||
"""Raise an exception when trying to set an attribute.""" | ||
raise AttributeError("Cannot set attribute on a frozen object") | ||
|
||
def __delattr__(self, name: str): | ||
"""Raise an exception when trying to delete an attribute.""" | ||
raise AttributeError("Cannot delete attribute on a frozen object") | ||
|
||
class Config: | ||
"""Pydantic configuration. | ||
Set the class to frozen to make it immutable and thus hashable. | ||
""" | ||
|
||
frozen = True | ||
|
||
|
||
class FrozenModelWithTimestamps(FrozenModel): | ||
"""Model with created_at and updated_at fields.""" | ||
|
||
created_at: datetime | ||
updated_at: datetime | ||
|
||
def output_timestamps(self, padding: int = 14) -> None: | ||
"""Output the created and updated timestamps to the console.""" | ||
output_manager = OutputManager() | ||
output_manager.add_line(f"{'Created:':<{padding}}{self.created_at:%c}") | ||
output_manager.add_line(f"{'Updated:':<{padding}}{self.updated_at:%c}") | ||
|
||
|
||
class APIMixin(Generic[BMT], ABC): | ||
"""A mixin for API-related methods.""" | ||
|
||
id: int # noqa: A003 | ||
|
||
def id_for_endpoint(self) -> Union[int, str]: | ||
"""Return the appropriate id for the object for its endpoint. | ||
:returns: The correct identifier for the endpoint. | ||
""" | ||
field = self.endpoint().external_id_field() | ||
return getattr(self, field) | ||
|
||
@classmethod | ||
def field_for_endpoint(cls) -> str: | ||
"""Return the appropriate field for the object for its endpoint. | ||
:param field: The field to return. | ||
:returns: The correct field for the endpoint. | ||
""" | ||
return cls.endpoint().external_id_field() | ||
|
||
@classmethod | ||
@abstractmethod | ||
def endpoint(cls) -> Endpoint: | ||
"""Return the endpoint for the method.""" | ||
raise NotImplementedError("You must define an endpoint.") | ||
|
||
@classmethod | ||
def get(cls, _id: int) -> Optional[BMT]: | ||
"""Get an object. | ||
This function is at its base a wrapper around the get_by_id function, | ||
but it can be overridden to provide more specific functionality. | ||
:param _id: The ID of the object. | ||
:returns: The object if found, None otherwise. | ||
""" | ||
return cls.get_by_id(_id) | ||
|
||
@classmethod | ||
def get_by_id(cls, _id: int) -> Optional[BMT]: | ||
"""Get an object by its ID. | ||
Note that for Hosts, the ID is the name of the host. | ||
:param _id: The ID of the object. | ||
:returns: The object if found, None otherwise. | ||
""" | ||
endpoint = cls.endpoint() | ||
|
||
# Some endpoints do not use the ID field as the endpoint identifier, | ||
# and in these cases we need to search for the ID... Lovely. | ||
if endpoint.requires_search_for_id(): | ||
data = get_item_by_key_value(cls.endpoint(), "id", str(_id)) | ||
else: | ||
data = get(cls.endpoint().with_id(_id), ok404=True) | ||
if not data: | ||
return None | ||
data = data.json() | ||
|
||
if not data: | ||
return None | ||
|
||
return cast(BMT, cls(**data)) | ||
|
||
@classmethod | ||
def get_by_field(cls, field: str, value: str) -> Optional[BMT]: | ||
"""Get an object by a field. | ||
Note that some endpoints do not use the ID field for lookups. We do some | ||
magic mapping via endpoint introspection to perform the following mapping for | ||
classes and their endpoint "id" fields: | ||
- Hosts -> name | ||
- Networks -> network | ||
This implies that doing a get_by_field("name", value) on Hosts will *not* | ||
result in a search, but a direct lookup at ../endpoint/name which is what | ||
the mreg server expects for Hosts (and similar for Network). | ||
:param field: The field to search by. | ||
:param value: The value to search for. | ||
:returns: The object if found, None otherwise. | ||
""" | ||
endpoint = cls.endpoint() | ||
|
||
if endpoint.requires_search_for_id() and field == endpoint.external_id_field(): | ||
data = get(endpoint.with_id(value), ok404=True) | ||
if not data: | ||
return None | ||
data = data.json() | ||
else: | ||
data = get_item_by_key_value(cls.endpoint(), field, value, ok404=True) | ||
|
||
if not data: | ||
return None | ||
|
||
return cast(BMT, cls(**data)) | ||
|
||
@classmethod | ||
def get_list_by_field( | ||
cls, field: str, value: Union[str, int], ordering: Optional[str] = None | ||
) -> List[BMT]: | ||
"""Get a list of objects by a field. | ||
:param field: The field to search by. | ||
:param value: The value to search for. | ||
:param ordering: The ordering to use when fetching the list. | ||
:returns: A list of objects if found, an empty list otherwise. | ||
""" | ||
params = {field: value} | ||
if ordering: | ||
params["ordering"] = ordering | ||
|
||
data = get_list(cls.endpoint(), params=params) | ||
return [cast(BMT, cls(**item)) for item in data] | ||
|
||
def refetch(self) -> BMT: | ||
"""Fetch an updated version of the object. | ||
Note that the caller (self) of this method will remain unchanged and can contain | ||
outdated information. The returned object will be the updated version. | ||
:returns: The fetched object. | ||
""" | ||
obj = self.__class__.get_by_id(self.id) | ||
if not obj: | ||
cli_warning(f"Could not refresh {self.__class__.__name__} with ID {self.id}.") | ||
|
||
return obj | ||
|
||
def patch(self, fields: Dict[str, Any]) -> BMT: | ||
"""Patch the object with the given values. | ||
:param kwargs: The values to patch. | ||
:returns: The object refetched from the server. | ||
""" | ||
patch(self.endpoint().with_id(self.id), **fields) | ||
|
||
new_object = self.refetch() | ||
|
||
aliases = get_model_aliases(new_object) | ||
for key, value in fields.items(): | ||
field_name = aliases.get(key) | ||
if field_name is None: | ||
cli_warning(f"Unknown field {key} in patch request.") | ||
try: | ||
nval = getattr(new_object, field_name) | ||
except AttributeError: | ||
cli_warning(f"Could not get value for {field_name} in patched object.") | ||
if str(nval) != str(value): | ||
cli_warning( | ||
# Should this reference `field_name` instead of `key`? | ||
f"Patch failure! Tried to set {key} to {value}, but server returned {nval}." | ||
) | ||
|
||
return new_object | ||
|
||
def delete(self) -> bool: | ||
"""Delete the object. | ||
:returns: True if the object was deleted, False otherwise. | ||
""" | ||
response = delete(self.endpoint().with_id(self.id_for_endpoint())) | ||
|
||
if response and response.ok: | ||
return True | ||
|
||
return False | ||
|
||
@classmethod | ||
def create(cls, kwargs: Dict[str, Union[str, None]]) -> Union[None, BMT]: | ||
"""Create the object. | ||
:returns: The object if created, None otherwise. | ||
""" | ||
response = post(cls.endpoint(), params=None, **kwargs) | ||
|
||
if response and response.ok: | ||
location = response.headers.get("Location") | ||
if location: | ||
obj = None | ||
if cls.endpoint() is Endpoint.Hosts: | ||
obj = cls.get_by_field("name", location.split("/")[-1]) | ||
else: | ||
obj = cls.get_by_id(int(location.split("/")[-1])) | ||
|
||
if obj: | ||
return obj | ||
|
||
cli_warning(f"Could not fetch object from location {location}.") | ||
|
||
else: | ||
cli_warning("No location header in response.") | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""Fields for models of the API.""" | ||
|
||
import ipaddress | ||
import re | ||
|
||
from pydantic import validator | ||
|
||
from mreg_cli.api.abstracts import FrozenModel | ||
from mreg_cli.types import IP_AddressT | ||
|
||
_mac_regex = re.compile(r"^([0-9A-Fa-f]{2}[.:-]){5}([0-9A-Fa-f]{2})$") | ||
|
||
|
||
class MACAddressField(FrozenModel): | ||
"""Represents a MAC address.""" | ||
|
||
address: str | ||
|
||
@validator("address", pre=True) | ||
def validate_and_format_mac(cls, v: str) -> str: | ||
"""Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format. | ||
:param v: The input MAC address string. | ||
:raises ValueError: If the input does not match the expected MAC address pattern. | ||
:returns: The normalized MAC address. | ||
""" | ||
# Validate input format | ||
if not _mac_regex.match(v): | ||
raise ValueError("Invalid MAC address format") | ||
|
||
# Normalize MAC address | ||
v = re.sub(r"[.:-]", "", v).lower() | ||
return ":".join(v[i : i + 2] for i in range(0, 12, 2)) | ||
|
||
def __str__(self) -> str: | ||
"""Return the MAC address as a string.""" | ||
return self.address | ||
|
||
|
||
class IPAddressField(FrozenModel): | ||
"""Represents an IP address, automatically determines if it's IPv4 or IPv6.""" | ||
|
||
address: IP_AddressT | ||
|
||
@validator("address", pre=True) | ||
def parse_ip_address(cls, value: str) -> IP_AddressT: | ||
"""Parse and validate the IP address.""" | ||
try: | ||
return ipaddress.ip_address(value) | ||
except ValueError as e: | ||
raise ValueError(f"Invalid IP address '{value}'.") from e | ||
|
||
def is_ipv4(self) -> bool: | ||
"""Check if the IP address is IPv4.""" | ||
return isinstance(self.address, ipaddress.IPv4Address) | ||
|
||
def is_ipv6(self) -> bool: | ||
"""Check if the IP address is IPv6.""" | ||
return isinstance(self.address, ipaddress.IPv6Address) | ||
|
||
def __str__(self) -> str: | ||
"""Return the IP address as a string.""" | ||
return str(self.address) | ||
|
||
def __hash__(self): | ||
"""Return a hash of the IP address.""" | ||
return hash(self.address) |
Oops, something went wrong.