Skip to content

Commit

Permalink
Validate GET responses with Pydantic (#241)
Browse files Browse the repository at this point in the history
* Validate responses with Pydantic

* Specify part of statement `type: ignore` applies to

* Replace all `get_list` calls with `get_typed`

* Use `TypeAliasType` from typing-extensions

* Revert accidental find and replace change

* HACK: Use `cast()` in `get_list_unique()`

* Add source url comment to recursive JSON type

* Add helper functions for response validation

* Remove `ResponseLike`

* Raise MREG ValidationError on failed validation

* Validate response in `get_list_unique`

* Improve exception handling for failed validation

* Add `HistoryResource` from #232

Also directly validates results to `HistoryItem` instead of going through a dict.
  • Loading branch information
pederhan authored May 29, 2024
1 parent 7bbc3da commit e1c333e
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 170 deletions.
21 changes: 11 additions & 10 deletions mreg_cli/api/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Mapping, Self, cast
from typing import Any, Callable, Self, cast

from pydantic import AliasChoices, BaseModel
from pydantic.fields import FieldInfo
Expand All @@ -19,13 +19,13 @@
PatchError,
)
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import JSONMapping
from mreg_cli.types import JsonMapping
from mreg_cli.utilities.api import (
delete,
get,
get_item_by_key_value,
get_list,
get_list_unique,
get_typed,
patch,
post,
)
Expand Down Expand Up @@ -68,7 +68,7 @@ def validate_patched_model(model: BaseModel, fields: dict[str, Any]) -> None:
"""Validate that model fields were patched correctly."""
aliases = get_model_aliases(model)

validators = {
validators: dict[type, Callable[[Any, Any], bool]] = {
list: _validate_lists,
dict: _validate_dicts,
}
Expand All @@ -83,7 +83,10 @@ def validate_patched_model(model: BaseModel, fields: dict[str, Any]) -> None:
raise PatchError(f"Could not get value for {field_name} in patched object.") from e

# Ensure patched value is the one we tried to set
validator = validators.get(type(nval), _validate_default)
validator = validators.get(
type(nval), # type: ignore # dict.get call with unknown type (Any) is fine
_validate_default,
)
if not validator(nval, value):
raise PatchError(
f"Patch failure! Tried to set {key} to {value}, but server returned {nval}."
Expand Down Expand Up @@ -318,8 +321,7 @@ def get_list_by_field(
if ordering:
params["ordering"] = ordering

data = get_list(cls.endpoint(), params=params, limit=limit)
return [cls(**item) for item in data]
return get_typed(cls.endpoint(), list[cls], params=params, limit=limit)

@classmethod
def get_by_query(
Expand All @@ -336,8 +338,7 @@ def get_by_query(
if ordering:
query["ordering"] = ordering

data = get_list(cls.endpoint().with_query(query), limit=limit)
return [cls(**item) for item in data]
return get_typed(cls.endpoint().with_query(query), list[cls], limit=limit)

@classmethod
def get_by_query_unique_or_raise(
Expand Down Expand Up @@ -468,7 +469,7 @@ def delete(self) -> bool:
return False

@classmethod
def create(cls, params: JSONMapping, fetch_after_create: bool = True) -> Self | None:
def create(cls, params: JsonMapping, fetch_after_create: bool = True) -> Self | None:
"""Create the object.
Note that several endpoints do not support location headers for created objects,
Expand Down
71 changes: 37 additions & 34 deletions mreg_cli/api/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,47 @@
import datetime
import json
from enum import Enum
from typing import Any
from typing import Any, Self

from pydantic import BaseModel, Field, field_validator

from mreg_cli.api.endpoints import Endpoint
from mreg_cli.exceptions import EntityNotFound, InternalError
from mreg_cli.outputmanager import OutputManager
from mreg_cli.utilities.api import get_list
from mreg_cli.utilities.api import get_typed


class HistoryResource(str, Enum):
"""History resources."""
"""History resources for the API.
Host = "host"
Group = "group"
HostPolicy_Role = "hostpolicy_role"
HostPolicy_Atom = "hostpolicy_atom"
Names represent resource names.
Values represent resource relations.
Access resource names and relation with the `resource()` and `relation()` methods.
"""

Host = "hosts"
Group = "groups"
HostPolicy_Role = "roles"
HostPolicy_Atom = "atoms"

@classmethod
def _missing_(cls, value: Any) -> HistoryResource:
v = str(value).lower()
for resource in cls:
if resource.value == v:
return resource
elif resource.name.lower() == v:
return resource
raise ValueError(f"Unknown resource {value}")

def relation(self) -> str:
"""Provide the relation of the resource."""
if self == HistoryResource.Host:
return "hosts"
if self == HistoryResource.Group:
return "groups"
if self == HistoryResource.HostPolicy_Role:
return "roles"
if self == HistoryResource.HostPolicy_Atom:
return "atoms"
"""Get the resource relation."""
return self.value

raise ValueError(f"Unknown resource {self}")
def resource(self) -> str:
"""Get the resource name."""
return self.name.lower()


class HistoryItem(BaseModel):
Expand Down Expand Up @@ -110,32 +121,24 @@ def output_multiple(cls, basename: str, items: list[HistoryItem]) -> None:
item.output(basename)

@classmethod
def get(cls, name: str, resource: HistoryResource) -> list[HistoryItem]:
def get(cls, name: str, resource: HistoryResource) -> list[Self]:
"""Get history items for a resource."""
resource_value = resource.value

params: dict[str, str | int] = {"resource": resource_value, "name": name}

ret = get_list(Endpoint.History, params=params)

params: dict[str, str] = {"resource": resource.resource(), "name": name}
ret = get_typed(Endpoint.History, list[cls], params=params)
if len(ret) == 0:
raise EntityNotFound(f"No history found for {name}")

model_ids = ",".join({str(i["model_id"]) for i in ret})

model_ids = ",".join({str(i.mid) for i in ret})
params = {
"resource": resource_value,
"resource": resource.resource(),
"model_id__in": model_ids,
}

ret = get_list(Endpoint.History, params=params)

data_relation = resource.relation()
ret = get_typed(Endpoint.History, list[cls], params=params)

params = {
"data__relation": data_relation,
"data__relation": resource.relation(),
"data__id__in": model_ids,
}
ret.extend(get_list(Endpoint.History, params=params))
ret.extend(get_typed(Endpoint.History, list[cls], params=params))

return [cls(**i) for i in ret]
return ret
54 changes: 23 additions & 31 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
delete,
get,
get_item_by_key_value,
get_list,
get_list_in,
get_list_unique,
get_typed,
Expand Down Expand Up @@ -196,7 +195,7 @@ def resolve_host(self) -> Host | None:
if not data:
return None

return Host(**data)
return Host.model_validate(data)


class WithZone(BaseModel, APIMixin):
Expand All @@ -218,7 +217,7 @@ def resolve_zone(self) -> ForwardZone | None:
if not data:
return None

return ForwardZone(**data)
return ForwardZone.model_validate(data)


class WithTTL(BaseModel):
Expand Down Expand Up @@ -331,8 +330,7 @@ def get_list_by_name_regex(cls, name: str) -> list[Self]:
:returns: A list of resource objects.
"""
param, value = convert_wildcard_to_regex(cls.__name_field__, name, True)
data = get_list(cls.endpoint(), params={param: value})
return [cls(**item) for item in data]
return get_typed(cls.endpoint(), list[cls], params={param: value})

def rename(self, new_name: str) -> Self:
"""Rename the resource.
Expand Down Expand Up @@ -575,8 +573,7 @@ def get_list(cls) -> list[Self]:
:returns: A list of all zones.
"""
data = get_list(cls.endpoint())
return [cls(**item) for item in data]
return get_typed(cls.endpoint(), list[cls])

def ensure_delegation_in_zone(self, name: str) -> None:
"""Ensure a delegation is in the zone.
Expand Down Expand Up @@ -826,8 +823,7 @@ def get_delegations(self) -> list[ForwardZoneDelegation | ReverseZoneDelegation]
:returns: The delegation object.
"""
cls = Delegation.type_by_zone(self)
data = get_list(cls.endpoint().with_params(self.name))
return [cls.model_validate(d) for d in data]
return get_typed(cls.endpoint().with_params(self.name), list[cls])

def delete_delegation(self, name: str) -> bool:
"""Delete a delegation from the zone.
Expand Down Expand Up @@ -907,10 +903,10 @@ def get_from_hostname(cls, hostname: HostT) -> ForwardZoneDelegation | ForwardZo
zoneblob = data.json()

if "delegate" in zoneblob:
return ForwardZoneDelegation(**zoneblob)
return ForwardZoneDelegation.model_validate(zoneblob)

if "zone" in zoneblob:
return ForwardZone(**zoneblob["zone"])
return ForwardZone.model_validate(zoneblob["zone"])

raise UnexpectedDataError(f"Unexpected response from server: {zoneblob}")

Expand Down Expand Up @@ -1219,14 +1215,13 @@ class RoleTableRow(BaseModel):
)

@classmethod
def get_roles_with_atom(cls, name: str) -> list[Role]:
def get_roles_with_atom(cls, name: str) -> list[Self]:
"""Get all roles with a specific atom.
:param atom: Name of the atom to search for.
:returns: A list of Role objects.
"""
data = get_list(cls.endpoint(), params={"atoms__name__exact": name})
return [cls(**item) for item in data]
return get_typed(cls.endpoint(), list[cls], params={"atoms__name__exact": name})

def add_atom(self, atom_name: str) -> bool:
"""Add an atom to the role.
Expand Down Expand Up @@ -1392,16 +1387,15 @@ def endpoint(cls) -> Endpoint:
return Endpoint.Labels

@classmethod
def get_all(cls) -> list[Label]:
def get_all(cls) -> list[Self]:
"""Get all labels.
:returns: A list of Label objects.
"""
data = get_list(cls.endpoint(), params={"ordering": "name"})
return [cls(**item) for item in data]
return get_typed(cls.endpoint(), list[cls], params={"ordering": "name"})

@classmethod
def get_by_id_or_raise(cls, _id: int) -> Label:
def get_by_id_or_raise(cls, _id: int) -> Self:
"""Get a Label by ID.
:param _id: The Label ID to search for.
Expand Down Expand Up @@ -1572,8 +1566,7 @@ def get_list(cls) -> list[Self]:
:returns: A list of all networks.
"""
data = get_list(cls.endpoint(), limit=None)
return [cls(**item) for item in data]
return get_typed(cls.endpoint(), list[cls], limit=None)

@staticmethod
def str_to_network(network: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network:
Expand Down Expand Up @@ -1906,7 +1899,7 @@ def is_ipv6(self) -> bool:
def network(self) -> Network:
"""Return the network of the IP address."""
data = get(Endpoint.NetworksByIP.with_id(str(self.ip())))
return Network(**data.json())
return Network.model_validate(data.json())

def vlan(self) -> int | None:
"""Return the VLAN of the IP address."""
Expand Down Expand Up @@ -2039,7 +2032,7 @@ def get_by_name(cls, name: HostT) -> CNAME:
data = get_item_by_key_value(Endpoint.Cnames, "name", name.hostname)
if not data:
raise EntityNotFound(f"CNAME record for {name} not found.")
return CNAME(**data)
return CNAME.model_validate(data)

@classmethod
def get_by_host_and_name(cls, host: HostT | int, name: HostT) -> CNAME:
Expand Down Expand Up @@ -2149,7 +2142,7 @@ def get_by_all(cls, host: int, mx: str, priority: int) -> MX:
)
if not data:
raise EntityNotFound(f"MX record for {mx} not found.")
return MX(**data)
return MX.model_validate(data)

def has_mx_with_priority(self, mx: str, priority: int) -> bool:
"""Return True if the MX record has the given MX and priority.
Expand Down Expand Up @@ -2303,7 +2296,7 @@ def output_multiple(cls, srvs: list[Srv], padding: int = 14) -> None:
host_ids = {srv.host for srv in srvs}

host_data = get_list_in(Endpoint.Hosts, "id", list(host_ids))
hosts = [Host(**host) for host in host_data]
hosts = [Host.model_validate(host) for host in host_data]

host_id_name_map = {host.id: str(host.name) for host in hosts}

Expand Down Expand Up @@ -2419,16 +2412,15 @@ def endpoint(cls) -> Endpoint:
return Endpoint.BacnetID

@classmethod
def get_in_range(cls, start: int, end: int) -> list[BacnetID]:
def get_in_range(cls, start: int, end: int) -> list[Self]:
"""Get Bacnet IDs in a range.
:param start: The start of the range.
:param end: The end of the range.
:returns: List of BacnetID objects in the range.
"""
params = {"id__range": f"{start},{end}"}
data = get_list(Endpoint.BacnetID, params=params)
return [BacnetID(**item) for item in data]
return get_typed(Endpoint.BacnetID, list[cls], params=params)

@classmethod
def output_multiple(cls, bacnetids: list[BacnetID]):
Expand Down Expand Up @@ -2909,7 +2901,7 @@ def resolve_zone(
data_as_dict = data.json()

if data_as_dict["zone"]:
zone = ForwardZone(**data_as_dict["zone"])
zone = ForwardZone.model_validate(data_as_dict["zone"])
if validate_zone_resolution and zone.id != self.zone:
raise ValidationError(f"Expected zone ID {self.zone} but resovled as {zone.id}.")
return zone
Expand All @@ -2919,7 +2911,7 @@ def resolve_zone(
raise EntityOwnershipMismatch(
f"Host {self.name} is delegated to zone {data_as_dict['delegation']['name']}."
)
return ForwardZoneDelegation(**data_as_dict["delegation"])
return ForwardZoneDelegation.model_validate(data_as_dict["delegation"])

raise EntityNotFound(f"Failed to resolve zone for host {self.name}.")

Expand Down Expand Up @@ -3108,8 +3100,8 @@ def get(cls, params: dict[str, Any] | None = None) -> HostList:
if "ordering" not in params:
params["ordering"] = "name"

data = get_list(cls.endpoint(), params=params)
return cls(results=[Host(**host) for host in data])
hosts = get_typed(cls.endpoint(), list[Host], params=params)
return cls(results=hosts)

@classmethod
def get_by_ip(cls, ip: IP_AddressT) -> HostList:
Expand Down
Loading

0 comments on commit e1c333e

Please sign in to comment.