Skip to content

Commit

Permalink
Upgrade models.py to Pydantic V2 semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
pederhan committed Apr 23, 2024
1 parent eb37282 commit 62eb19a
Showing 1 changed file with 53 additions and 37 deletions.
90 changes: 53 additions & 37 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@
from datetime import datetime
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast

from pydantic import AliasChoices, BaseModel, Field, root_validator, validator
from pydantic import (
AliasChoices,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
field_validator,
model_validator,
)
from pydantic.fields import FieldInfo
from typing_extensions import Annotated

from mreg_cli.api.endpoints import Endpoint
from mreg_cli.config import MregCliConfig
Expand Down Expand Up @@ -69,7 +78,8 @@ class HostT(BaseModel):

hostname: str

@validator("hostname")
@field_validator("hostname")
@classmethod
def validate_hostname(cls, value: str) -> str:
"""Validate the hostname."""
value = value.lower()
Expand Down Expand Up @@ -112,13 +122,7 @@ def __delattr__(self, name: str):
"""Raise an exception when trying to delete an attribute."""
raise AttributeError("Cannot delete attribute on a frozen object")

class Config:
"""Pydantic configuration.
Set the class to frozen to make it immutable and thus hashable.
"""

frozen = True
model_config = ConfigDict(frozen=True)


class FrozenModelWithTimestamps(FrozenModel):
Expand Down Expand Up @@ -445,6 +449,18 @@ def is_delegated(self) -> bool:
return True


def _extract_name(value: Dict[str, Any]) -> str:
"""Extract the name from the dictionary.
:param v: Dictionary containing the name.
:returns: Extracted name as a string.
"""
return value["name"]


NameList = List[Annotated[str, BeforeValidator(_extract_name)]]


class Role(FrozenModelWithTimestamps, APIMixin["Role"]):
"""Model for a role.
Expand All @@ -454,13 +470,14 @@ class Role(FrozenModelWithTimestamps, APIMixin["Role"]):

id: int # noqa: A003
created_at: datetime = Field(..., validation_alias=AliasChoices("create_date", "created_at"))
hosts: List[str]
atoms: List[str]
hosts: NameList
atoms: NameList
description: str
name: str
labels: List[int]

@validator("created_at", pre=True)
@field_validator("created_at", mode="before")
@classmethod
def validate_created_at(cls, value: str) -> datetime:
"""Validate and convert the created_at field to datetime.
Expand All @@ -469,15 +486,6 @@ def validate_created_at(cls, value: str) -> datetime:
"""
return datetime.fromisoformat(value)

@validator("hosts", "atoms", pre=True, each_item=True)
def extract_name(cls, v: Dict[str, str]) -> str:
"""Extract the name from the dictionary.
:param v: Dictionary containing the name.
:returns: Extracted name as a string.
"""
return v["name"]

@classmethod
def endpoint(cls) -> Endpoint:
"""Return the endpoint for the class."""
Expand Down Expand Up @@ -512,7 +520,7 @@ class Network(FrozenModelWithTimestamps, APIMixin["Network"]):
excluded_ranges: List[str]
network: str # for now
description: str
vlan: Optional[int]
vlan: Optional[int] = None
dns_delegated: bool
category: str
location: str
Expand Down Expand Up @@ -559,7 +567,8 @@ class MACAddressField(FrozenModel):

address: str

@validator("address", pre=True)
@field_validator("address", mode="before")
@classmethod
def validate_and_format_mac(cls, v: str) -> str:
"""Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format.
Expand All @@ -585,7 +594,8 @@ class IPAddressField(FrozenModel):

address: IP_AddressT

@validator("address", pre=True)
@field_validator("address", mode="before")
@classmethod
def parse_ip_address(cls, value: str) -> IP_AddressT:
"""Parse and validate the IP address."""
try:
Expand Down Expand Up @@ -617,15 +627,17 @@ class IPAddress(FrozenModelWithTimestamps, WithHost, APIMixin["IPAddress"]):
macaddress: Optional[MACAddressField] = None
ipaddress: IPAddressField

@validator("macaddress", pre=True, allow_reuse=True)
def create_valid_macadress_or_none(cls, v: str):
@field_validator("macaddress", mode="before")
@classmethod
def create_valid_macadress_or_none(cls, v: str) -> MACAddressField | None:
"""Create macaddress or convert empty strings to None."""
if v:
return MACAddressField(address=v)

return None

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def convert_ip_address(cls, values: Any):
"""Convert ipaddress string to IPAddressField if necessary."""
ip_address = values.get("ipaddress")
Expand Down Expand Up @@ -753,7 +765,8 @@ class CNAME(FrozenModelWithTimestamps, WithHost, WithZone, APIMixin["CNAME"]):
name: HostT
ttl: Optional[int] = None

@validator("name", pre=True)
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: str) -> HostT:
"""Validate the hostname."""
return HostT(hostname=value)
Expand Down Expand Up @@ -846,9 +859,9 @@ class NAPTR(FrozenModelWithTimestamps, WithHost, APIMixin["NAPTR"]):
id: int # noqa: A003
preference: int
order: int
flag: Optional[str]
service: Optional[str]
regex: Optional[str]
flag: Optional[str] = None
service: Optional[str] = None
regex: Optional[str] = None
replacement: str

def output(self, padding: int = 14) -> None:
Expand Down Expand Up @@ -907,7 +920,7 @@ class Srv(FrozenModelWithTimestamps, WithHost, WithZone, APIMixin["Srv"]):
priority: int
weight: int
port: int
ttl: Optional[int]
ttl: Optional[int] = None

@classmethod
def endpoint(cls) -> Endpoint:
Expand Down Expand Up @@ -1072,13 +1085,15 @@ class Host(FrozenModelWithTimestamps, APIMixin["Host"]):
# Note, we do not use WithZone here as this is optional and we resolve it differently.
zone: Optional[int] = None

@validator("name", pre=True)
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: str) -> HostT:
"""Validate the hostname."""
return HostT(hostname=value)

@validator("comment", pre=True, allow_reuse=True)
def empty_string_to_none(cls, v: str):
@field_validator("comment", mode="before")
@classmethod
def empty_string_to_none(cls, v: str) -> str | None:
"""Convert empty strings to None."""
return v or None

Expand Down Expand Up @@ -1442,8 +1457,9 @@ def get(cls, params: Optional[Dict[str, Any]] = None) -> "HostList":
data = get_list(cls.endpoint(), params=params)
return cls(results=[Host(**host) for host in data])

@validator("results", pre=True)
def check_results(cls, v: List[Dict[str, str]]):
@field_validator("results", mode="before")
@classmethod
def check_results(cls, v: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Check that the results are valid."""
return v

Expand Down

0 comments on commit 62eb19a

Please sign in to comment.