Skip to content

Commit

Permalink
Validate responses with Pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
pederhan committed May 27, 2024
1 parent 248a19e commit ad21eaa
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 69 deletions.
10 changes: 5 additions & 5 deletions mreg_cli/api/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Mapping, Self, cast
from typing import Any, Callable, Self, cast

from pydantic import AliasChoices, BaseModel
from pydantic.fields import FieldInfo
Expand All @@ -19,7 +19,7 @@
PatchError,
)
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import JSONMapping
from mreg_cli.types import JsonMapping
from mreg_cli.utilities.api import (
delete,
get,
Expand Down Expand Up @@ -68,7 +68,7 @@ def validate_patched_model(model: BaseModel, fields: dict[str, Any]) -> None:
"""Validate that model fields were patched correctly."""
aliases = get_model_aliases(model)

validators = {
validators: dict[type, Callable[[Any, Any], bool]] = {
list: _validate_lists,
dict: _validate_dicts,
}
Expand All @@ -83,7 +83,7 @@ def validate_patched_model(model: BaseModel, fields: dict[str, Any]) -> None:
raise PatchError(f"Could not get value for {field_name} in patched object.") from e

# Ensure patched value is the one we tried to set
validator = validators.get(type(nval), _validate_default)
validator = validators.get(type(nval), _validate_default) # type: ignore # dict.get type checking is whatever
if not validator(nval, value):
raise PatchError(
f"Patch failure! Tried to set {key} to {value}, but server returned {nval}."
Expand Down Expand Up @@ -468,7 +468,7 @@ def delete(self) -> bool:
return False

@classmethod
def create(cls, params: JSONMapping, fetch_after_create: bool = True) -> Self | None:
def create(cls, params: JsonMapping, fetch_after_create: bool = True) -> Self | None:
"""Create the object.
Note that several endpoints do not support location headers for created objects,
Expand Down
10 changes: 5 additions & 5 deletions mreg_cli/api/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mreg_cli.api.endpoints import Endpoint
from mreg_cli.exceptions import EntityNotFound, InternalError
from mreg_cli.outputmanager import OutputManager
from mreg_cli.utilities.api import get_list
from mreg_cli.utilities.api import get_list, get_typed


class HistoryResource(str, Enum):
Expand Down Expand Up @@ -116,7 +116,7 @@ def get(cls, name: str, resource: HistoryResource) -> list[HistoryItem]:

params: dict[str, str | int] = {"resource": resource_value, "name": name}

ret = get_list(Endpoint.History, params=params)
ret = get_typed(Endpoint.History, list[dict[str, Any]], params=params)

if len(ret) == 0:
raise EntityNotFound(f"No history found for {name}")
Expand All @@ -128,14 +128,14 @@ def get(cls, name: str, resource: HistoryResource) -> list[HistoryItem]:
"model_id__in": model_ids,
}

ret = get_list(Endpoint.History, params=params)
ret = get_typed(Endpoint.History, list[dict[str, Any]], params=params)

data_relation = resource.relation()

params = {
"data__relation": data_relation,
"data__id__in": model_ids,
}
ret.extend(get_list(Endpoint.History, params=params))
ret.extend(get_typed(Endpoint.History, list[dict[str, Any]], params=params))

return [cls(**i) for i in ret]
return [cls.model_validate(i) for i in ret]
34 changes: 17 additions & 17 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def resolve_host(self) -> Host | None:
if not data:
return None

return Host(**data)
return Host.model_validate(data)


class WithZone(BaseModel, APIMixin):
Expand All @@ -218,7 +218,7 @@ def resolve_zone(self) -> ForwardZone | None:
if not data:
return None

return ForwardZone(**data)
return ForwardZone.model_validate(data)


class WithTTL(BaseModel):
Expand Down Expand Up @@ -332,7 +332,7 @@ def get_list_by_name_regex(cls, name: str) -> list[Self]:
"""
param, value = convert_wildcard_to_regex(cls.__name_field__, name, True)
data = get_list(cls.endpoint(), params={param: value})
return [cls(**item) for item in data]
return [cls.model_validate(item) for item in data]

def rename(self, new_name: str) -> Self:
"""Rename the resource.
Expand Down Expand Up @@ -576,7 +576,7 @@ def get_list(cls) -> list[Self]:
:returns: A list of all zones.
"""
data = get_list(cls.endpoint())
return [cls(**item) for item in data]
return [cls.model_validate(item) for item in data]

def ensure_delegation_in_zone(self, name: str) -> None:
"""Ensure a delegation is in the zone.
Expand Down Expand Up @@ -907,10 +907,10 @@ def get_from_hostname(cls, hostname: HostT) -> ForwardZoneDelegation | ForwardZo
zoneblob = data.json()

if "delegate" in zoneblob:
return ForwardZoneDelegation(**zoneblob)
return ForwardZoneDelegation.model_validate(zoneblob)

if "zone" in zoneblob:
return ForwardZone(**zoneblob["zone"])
return ForwardZone.model_validate(zoneblob["zone"])

raise UnexpectedDataError(f"Unexpected response from server: {zoneblob}")

Expand Down Expand Up @@ -1226,7 +1226,7 @@ def get_roles_with_atom(cls, name: str) -> list[Role]:
:returns: A list of Role objects.
"""
data = get_list(cls.endpoint(), params={"atoms__name__exact": name})
return [cls(**item) for item in data]
return [cls.model_validate(item) for item in data]

def add_atom(self, atom_name: str) -> bool:
"""Add an atom to the role.
Expand Down Expand Up @@ -1398,7 +1398,7 @@ def get_all(cls) -> list[Label]:
:returns: A list of Label objects.
"""
data = get_list(cls.endpoint(), params={"ordering": "name"})
return [cls(**item) for item in data]
return [cls.model_validate(item) for item in data]

@classmethod
def get_by_id_or_raise(cls, _id: int) -> Label:
Expand Down Expand Up @@ -1573,7 +1573,7 @@ def get_list(cls) -> list[Self]:
:returns: A list of all networks.
"""
data = get_list(cls.endpoint(), limit=None)
return [cls(**item) for item in data]
return [cls.model_validate(item) for item in data]

@staticmethod
def str_to_network(network: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network:
Expand Down Expand Up @@ -1906,7 +1906,7 @@ def is_ipv6(self) -> bool:
def network(self) -> Network:
"""Return the network of the IP address."""
data = get(Endpoint.NetworksByIP.with_id(str(self.ip())))
return Network(**data.json())
return Network.model_validate(data.json())

def vlan(self) -> int | None:
"""Return the VLAN of the IP address."""
Expand Down Expand Up @@ -2039,7 +2039,7 @@ def get_by_name(cls, name: HostT) -> CNAME:
data = get_item_by_key_value(Endpoint.Cnames, "name", name.hostname)
if not data:
raise EntityNotFound(f"CNAME record for {name} not found.")
return CNAME(**data)
return CNAME.model_validate(data)

@classmethod
def get_by_host_and_name(cls, host: HostT | int, name: HostT) -> CNAME:
Expand Down Expand Up @@ -2149,7 +2149,7 @@ def get_by_all(cls, host: int, mx: str, priority: int) -> MX:
)
if not data:
raise EntityNotFound(f"MX record for {mx} not found.")
return MX(**data)
return MX.model_validate(data)

def has_mx_with_priority(self, mx: str, priority: int) -> bool:
"""Return True if the MX record has the given MX and priority.
Expand Down Expand Up @@ -2303,7 +2303,7 @@ def output_multiple(cls, srvs: list[Srv], padding: int = 14) -> None:
host_ids = {srv.host for srv in srvs}

host_data = get_list_in(Endpoint.Hosts, "id", list(host_ids))
hosts = [Host(**host) for host in host_data]
hosts = [Host.model_validate(host) for host in host_data]

host_id_name_map = {host.id: str(host.name) for host in hosts}

Expand Down Expand Up @@ -2428,7 +2428,7 @@ def get_in_range(cls, start: int, end: int) -> list[BacnetID]:
"""
params = {"id__range": f"{start},{end}"}
data = get_list(Endpoint.BacnetID, params=params)
return [BacnetID(**item) for item in data]
return [BacnetID.model_validate(item) for item in data]

@classmethod
def output_multiple(cls, bacnetids: list[BacnetID]):
Expand Down Expand Up @@ -2909,7 +2909,7 @@ def resolve_zone(
data_as_dict = data.json()

if data_as_dict["zone"]:
zone = ForwardZone(**data_as_dict["zone"])
zone = ForwardZone.model_validate(data_as_dict["zone"])
if validate_zone_resolution and zone.id != self.zone:
raise ValidationError(f"Expected zone ID {self.zone} but resovled as {zone.id}.")
return zone
Expand All @@ -2919,7 +2919,7 @@ def resolve_zone(
raise EntityOwnershipMismatch(
f"Host {self.name} is delegated to zone {data_as_dict['delegation']['name']}."
)
return ForwardZoneDelegation(**data_as_dict["delegation"])
return ForwardZoneDelegation.model_validate(data_as_dict["delegation"])

raise EntityNotFound(f"Failed to resolve zone for host {self.name}.")

Expand Down Expand Up @@ -3109,7 +3109,7 @@ def get(cls, params: dict[str, Any] | None = None) -> HostList:
params["ordering"] = "name"

data = get_list(cls.endpoint(), params=params)
return cls(results=[Host(**host) for host in data])
return cls(results=[Host.model_validate(host) for host in data])

@classmethod
def get_by_ip(cls, ip: IP_AddressT) -> HostList:
Expand Down
2 changes: 1 addition & 1 deletion mreg_cli/tokenfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TokenFile:

def __init__(self, tokens: Optional[list[dict[str, str]]] = None):
"""Initialize the TokenFile instance."""
self.tokens = [Token(**token) for token in tokens] if tokens else []
self.tokens = [Token.model_validate(token) for token in tokens] if tokens else []

@classmethod
def _load_tokens(cls) -> "TokenFile":
Expand Down
26 changes: 23 additions & 3 deletions mreg_cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ipaddress
from collections.abc import Callable
from typing import (
Annotated,
Any,
Literal,
Mapping,
Expand All @@ -18,6 +19,8 @@
Union,
)

from pydantic import ValidationError, ValidationInfo, ValidatorFunctionWrapHandler, WrapValidator
from pydantic_core import PydanticCustomError
from requests.structures import CaseInsensitiveDict

CommandFunc = Callable[[argparse.Namespace], None]
Expand Down Expand Up @@ -58,9 +61,26 @@ class RecordingEntry(TypedDict):
NargsType = int | NargsStr


JSONPrimitive = Union[str, int, float, bool, None]
JSONValue = Union[JSONPrimitive, Mapping[str, "JSONValue"], Sequence["JSONValue"]]
JSONMapping = Mapping[str, JSONValue]
def json_custom_error_validator(
value: Any, handler: ValidatorFunctionWrapHandler, _info: ValidationInfo
) -> Any:
"""Simplify the error message to avoid a gross error stemming from
exhaustive checking of all union options.
""" # noqa: D205
try:
return handler(value)
except ValidationError:
raise PydanticCustomError(
"invalid_json",
"Input is not valid json",
) from None


type Json = Annotated[
Union[Mapping[str, "Json"], Sequence["Json"], str, int, float, bool, None],
WrapValidator(json_custom_error_validator),
]
JsonMapping = Mapping[str, Json]


class Flag:
Expand Down
Loading

0 comments on commit ad21eaa

Please sign in to comment.