From e1c333ebb2f4d992e4601f29204010db69d33673 Mon Sep 17 00:00:00 2001 From: Peder Hovdan Andresen <107681714+pederhan@users.noreply.github.com> Date: Wed, 29 May 2024 14:02:51 +0200 Subject: [PATCH] Validate GET responses with Pydantic (#241) * Validate responses with Pydantic * Specify part of statement `type: ignore` applies to * Replace all `get_list` calls with `get_typed` * Use `TypeAliasType` from typing-extensions * Revert accidental find and replace change * HACK: Use `cast()` in `get_list_unique()` * Add source url comment to recursive JSON type * Add helper functions for response validation * Remove `ResponseLike` * Raise MREG ValidationError on failed validation * Validate response in `get_list_unique` * Improve exception handling for failed validation * Add `HistoryResource` from #232 Also directly validates results to `HistoryItem` instead of going through a dict. --- mreg_cli/api/abstracts.py | 21 ++--- mreg_cli/api/history.py | 71 ++++++++-------- mreg_cli/api/models.py | 54 ++++++------ mreg_cli/types.py | 66 +++++++-------- mreg_cli/utilities/api.py | 170 +++++++++++++++++++++++++------------- 5 files changed, 212 insertions(+), 170 deletions(-) diff --git a/mreg_cli/api/abstracts.py b/mreg_cli/api/abstracts.py index fb17337a..1dd7da74 100644 --- a/mreg_cli/api/abstracts.py +++ b/mreg_cli/api/abstracts.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Mapping, Self, cast +from typing import Any, Callable, Self, cast from pydantic import AliasChoices, BaseModel from pydantic.fields import FieldInfo @@ -19,13 +19,13 @@ PatchError, ) from mreg_cli.outputmanager import OutputManager -from mreg_cli.types import JSONMapping +from mreg_cli.types import JsonMapping from mreg_cli.utilities.api import ( delete, get, get_item_by_key_value, - get_list, get_list_unique, + get_typed, patch, post, ) @@ -68,7 +68,7 @@ def validate_patched_model(model: BaseModel, fields: dict[str, Any]) -> None: """Validate that model fields were patched correctly.""" aliases = get_model_aliases(model) - validators = { + validators: dict[type, Callable[[Any, Any], bool]] = { list: _validate_lists, dict: _validate_dicts, } @@ -83,7 +83,10 @@ def validate_patched_model(model: BaseModel, fields: dict[str, Any]) -> None: raise PatchError(f"Could not get value for {field_name} in patched object.") from e # Ensure patched value is the one we tried to set - validator = validators.get(type(nval), _validate_default) + validator = validators.get( + type(nval), # type: ignore # dict.get call with unknown type (Any) is fine + _validate_default, + ) if not validator(nval, value): raise PatchError( f"Patch failure! Tried to set {key} to {value}, but server returned {nval}." @@ -318,8 +321,7 @@ def get_list_by_field( if ordering: params["ordering"] = ordering - data = get_list(cls.endpoint(), params=params, limit=limit) - return [cls(**item) for item in data] + return get_typed(cls.endpoint(), list[cls], params=params, limit=limit) @classmethod def get_by_query( @@ -336,8 +338,7 @@ def get_by_query( if ordering: query["ordering"] = ordering - data = get_list(cls.endpoint().with_query(query), limit=limit) - return [cls(**item) for item in data] + return get_typed(cls.endpoint().with_query(query), list[cls], limit=limit) @classmethod def get_by_query_unique_or_raise( @@ -468,7 +469,7 @@ def delete(self) -> bool: return False @classmethod - def create(cls, params: JSONMapping, fetch_after_create: bool = True) -> Self | None: + def create(cls, params: JsonMapping, fetch_after_create: bool = True) -> Self | None: """Create the object. Note that several endpoints do not support location headers for created objects, diff --git a/mreg_cli/api/history.py b/mreg_cli/api/history.py index 8c6f14e2..b45792bc 100644 --- a/mreg_cli/api/history.py +++ b/mreg_cli/api/history.py @@ -5,36 +5,47 @@ import datetime import json from enum import Enum -from typing import Any +from typing import Any, Self from pydantic import BaseModel, Field, field_validator from mreg_cli.api.endpoints import Endpoint from mreg_cli.exceptions import EntityNotFound, InternalError from mreg_cli.outputmanager import OutputManager -from mreg_cli.utilities.api import get_list +from mreg_cli.utilities.api import get_typed class HistoryResource(str, Enum): - """History resources.""" + """History resources for the API. - Host = "host" - Group = "group" - HostPolicy_Role = "hostpolicy_role" - HostPolicy_Atom = "hostpolicy_atom" + Names represent resource names. + Values represent resource relations. + + Access resource names and relation with the `resource()` and `relation()` methods. + """ + + Host = "hosts" + Group = "groups" + HostPolicy_Role = "roles" + HostPolicy_Atom = "atoms" + + @classmethod + def _missing_(cls, value: Any) -> HistoryResource: + v = str(value).lower() + for resource in cls: + if resource.value == v: + return resource + elif resource.name.lower() == v: + return resource + raise ValueError(f"Unknown resource {value}") def relation(self) -> str: - """Provide the relation of the resource.""" - if self == HistoryResource.Host: - return "hosts" - if self == HistoryResource.Group: - return "groups" - if self == HistoryResource.HostPolicy_Role: - return "roles" - if self == HistoryResource.HostPolicy_Atom: - return "atoms" + """Get the resource relation.""" + return self.value - raise ValueError(f"Unknown resource {self}") + def resource(self) -> str: + """Get the resource name.""" + return self.name.lower() class HistoryItem(BaseModel): @@ -110,32 +121,24 @@ def output_multiple(cls, basename: str, items: list[HistoryItem]) -> None: item.output(basename) @classmethod - def get(cls, name: str, resource: HistoryResource) -> list[HistoryItem]: + def get(cls, name: str, resource: HistoryResource) -> list[Self]: """Get history items for a resource.""" - resource_value = resource.value - - params: dict[str, str | int] = {"resource": resource_value, "name": name} - - ret = get_list(Endpoint.History, params=params) - + params: dict[str, str] = {"resource": resource.resource(), "name": name} + ret = get_typed(Endpoint.History, list[cls], params=params) if len(ret) == 0: raise EntityNotFound(f"No history found for {name}") - model_ids = ",".join({str(i["model_id"]) for i in ret}) - + model_ids = ",".join({str(i.mid) for i in ret}) params = { - "resource": resource_value, + "resource": resource.resource(), "model_id__in": model_ids, } - - ret = get_list(Endpoint.History, params=params) - - data_relation = resource.relation() + ret = get_typed(Endpoint.History, list[cls], params=params) params = { - "data__relation": data_relation, + "data__relation": resource.relation(), "data__id__in": model_ids, } - ret.extend(get_list(Endpoint.History, params=params)) + ret.extend(get_typed(Endpoint.History, list[cls], params=params)) - return [cls(**i) for i in ret] + return ret diff --git a/mreg_cli/api/models.py b/mreg_cli/api/models.py index 04e6693d..69a8f0a9 100644 --- a/mreg_cli/api/models.py +++ b/mreg_cli/api/models.py @@ -47,7 +47,6 @@ delete, get, get_item_by_key_value, - get_list, get_list_in, get_list_unique, get_typed, @@ -196,7 +195,7 @@ def resolve_host(self) -> Host | None: if not data: return None - return Host(**data) + return Host.model_validate(data) class WithZone(BaseModel, APIMixin): @@ -218,7 +217,7 @@ def resolve_zone(self) -> ForwardZone | None: if not data: return None - return ForwardZone(**data) + return ForwardZone.model_validate(data) class WithTTL(BaseModel): @@ -331,8 +330,7 @@ def get_list_by_name_regex(cls, name: str) -> list[Self]: :returns: A list of resource objects. """ param, value = convert_wildcard_to_regex(cls.__name_field__, name, True) - data = get_list(cls.endpoint(), params={param: value}) - return [cls(**item) for item in data] + return get_typed(cls.endpoint(), list[cls], params={param: value}) def rename(self, new_name: str) -> Self: """Rename the resource. @@ -575,8 +573,7 @@ def get_list(cls) -> list[Self]: :returns: A list of all zones. """ - data = get_list(cls.endpoint()) - return [cls(**item) for item in data] + return get_typed(cls.endpoint(), list[cls]) def ensure_delegation_in_zone(self, name: str) -> None: """Ensure a delegation is in the zone. @@ -826,8 +823,7 @@ def get_delegations(self) -> list[ForwardZoneDelegation | ReverseZoneDelegation] :returns: The delegation object. """ cls = Delegation.type_by_zone(self) - data = get_list(cls.endpoint().with_params(self.name)) - return [cls.model_validate(d) for d in data] + return get_typed(cls.endpoint().with_params(self.name), list[cls]) def delete_delegation(self, name: str) -> bool: """Delete a delegation from the zone. @@ -907,10 +903,10 @@ def get_from_hostname(cls, hostname: HostT) -> ForwardZoneDelegation | ForwardZo zoneblob = data.json() if "delegate" in zoneblob: - return ForwardZoneDelegation(**zoneblob) + return ForwardZoneDelegation.model_validate(zoneblob) if "zone" in zoneblob: - return ForwardZone(**zoneblob["zone"]) + return ForwardZone.model_validate(zoneblob["zone"]) raise UnexpectedDataError(f"Unexpected response from server: {zoneblob}") @@ -1219,14 +1215,13 @@ class RoleTableRow(BaseModel): ) @classmethod - def get_roles_with_atom(cls, name: str) -> list[Role]: + def get_roles_with_atom(cls, name: str) -> list[Self]: """Get all roles with a specific atom. :param atom: Name of the atom to search for. :returns: A list of Role objects. """ - data = get_list(cls.endpoint(), params={"atoms__name__exact": name}) - return [cls(**item) for item in data] + return get_typed(cls.endpoint(), list[cls], params={"atoms__name__exact": name}) def add_atom(self, atom_name: str) -> bool: """Add an atom to the role. @@ -1392,16 +1387,15 @@ def endpoint(cls) -> Endpoint: return Endpoint.Labels @classmethod - def get_all(cls) -> list[Label]: + def get_all(cls) -> list[Self]: """Get all labels. :returns: A list of Label objects. """ - data = get_list(cls.endpoint(), params={"ordering": "name"}) - return [cls(**item) for item in data] + return get_typed(cls.endpoint(), list[cls], params={"ordering": "name"}) @classmethod - def get_by_id_or_raise(cls, _id: int) -> Label: + def get_by_id_or_raise(cls, _id: int) -> Self: """Get a Label by ID. :param _id: The Label ID to search for. @@ -1572,8 +1566,7 @@ def get_list(cls) -> list[Self]: :returns: A list of all networks. """ - data = get_list(cls.endpoint(), limit=None) - return [cls(**item) for item in data] + return get_typed(cls.endpoint(), list[cls], limit=None) @staticmethod def str_to_network(network: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: @@ -1906,7 +1899,7 @@ def is_ipv6(self) -> bool: def network(self) -> Network: """Return the network of the IP address.""" data = get(Endpoint.NetworksByIP.with_id(str(self.ip()))) - return Network(**data.json()) + return Network.model_validate(data.json()) def vlan(self) -> int | None: """Return the VLAN of the IP address.""" @@ -2039,7 +2032,7 @@ def get_by_name(cls, name: HostT) -> CNAME: data = get_item_by_key_value(Endpoint.Cnames, "name", name.hostname) if not data: raise EntityNotFound(f"CNAME record for {name} not found.") - return CNAME(**data) + return CNAME.model_validate(data) @classmethod def get_by_host_and_name(cls, host: HostT | int, name: HostT) -> CNAME: @@ -2149,7 +2142,7 @@ def get_by_all(cls, host: int, mx: str, priority: int) -> MX: ) if not data: raise EntityNotFound(f"MX record for {mx} not found.") - return MX(**data) + return MX.model_validate(data) def has_mx_with_priority(self, mx: str, priority: int) -> bool: """Return True if the MX record has the given MX and priority. @@ -2303,7 +2296,7 @@ def output_multiple(cls, srvs: list[Srv], padding: int = 14) -> None: host_ids = {srv.host for srv in srvs} host_data = get_list_in(Endpoint.Hosts, "id", list(host_ids)) - hosts = [Host(**host) for host in host_data] + hosts = [Host.model_validate(host) for host in host_data] host_id_name_map = {host.id: str(host.name) for host in hosts} @@ -2419,7 +2412,7 @@ def endpoint(cls) -> Endpoint: return Endpoint.BacnetID @classmethod - def get_in_range(cls, start: int, end: int) -> list[BacnetID]: + def get_in_range(cls, start: int, end: int) -> list[Self]: """Get Bacnet IDs in a range. :param start: The start of the range. @@ -2427,8 +2420,7 @@ def get_in_range(cls, start: int, end: int) -> list[BacnetID]: :returns: List of BacnetID objects in the range. """ params = {"id__range": f"{start},{end}"} - data = get_list(Endpoint.BacnetID, params=params) - return [BacnetID(**item) for item in data] + return get_typed(Endpoint.BacnetID, list[cls], params=params) @classmethod def output_multiple(cls, bacnetids: list[BacnetID]): @@ -2909,7 +2901,7 @@ def resolve_zone( data_as_dict = data.json() if data_as_dict["zone"]: - zone = ForwardZone(**data_as_dict["zone"]) + zone = ForwardZone.model_validate(data_as_dict["zone"]) if validate_zone_resolution and zone.id != self.zone: raise ValidationError(f"Expected zone ID {self.zone} but resovled as {zone.id}.") return zone @@ -2919,7 +2911,7 @@ def resolve_zone( raise EntityOwnershipMismatch( f"Host {self.name} is delegated to zone {data_as_dict['delegation']['name']}." ) - return ForwardZoneDelegation(**data_as_dict["delegation"]) + return ForwardZoneDelegation.model_validate(data_as_dict["delegation"]) raise EntityNotFound(f"Failed to resolve zone for host {self.name}.") @@ -3108,8 +3100,8 @@ def get(cls, params: dict[str, Any] | None = None) -> HostList: if "ordering" not in params: params["ordering"] = "name" - data = get_list(cls.endpoint(), params=params) - return cls(results=[Host(**host) for host in data]) + hosts = get_typed(cls.endpoint(), list[Host], params=params) + return cls(results=hosts) @classmethod def get_by_ip(cls, ip: IP_AddressT) -> HostList: diff --git a/mreg_cli/types.py b/mreg_cli/types.py index 3c717c46..f2eb6805 100644 --- a/mreg_cli/types.py +++ b/mreg_cli/types.py @@ -6,11 +6,11 @@ import ipaddress from collections.abc import Callable from typing import ( + Annotated, Any, Literal, Mapping, NamedTuple, - Protocol, Sequence, TypeAlias, TypedDict, @@ -18,7 +18,9 @@ Union, ) -from requests.structures import CaseInsensitiveDict +from pydantic import ValidationError, ValidationInfo, ValidatorFunctionWrapHandler, WrapValidator +from pydantic_core import PydanticCustomError +from typing_extensions import TypeAliasType CommandFunc = Callable[[argparse.Namespace], None] @@ -58,9 +60,30 @@ class RecordingEntry(TypedDict): NargsType = int | NargsStr -JSONPrimitive = Union[str, int, float, bool, None] -JSONValue = Union[JSONPrimitive, Mapping[str, "JSONValue"], Sequence["JSONValue"]] -JSONMapping = Mapping[str, JSONValue] +# Source: https://docs.pydantic.dev/2.7/concepts/types/#named-recursive-types +def json_custom_error_validator( + value: Any, handler: ValidatorFunctionWrapHandler, _info: ValidationInfo +) -> Any: + """Simplify the error message to avoid a gross error stemming from + exhaustive checking of all union options. + """ # noqa: D205 + try: + return handler(value) + except ValidationError: + raise PydanticCustomError( + "invalid_json", + "Input is not valid json", + ) from None + + +Json = TypeAliasType( + "Json", + Annotated[ + Union[Mapping[str, "Json"], Sequence["Json"], str, int, float, bool, None], + WrapValidator(json_custom_error_validator), + ], +) +JsonMapping = Mapping[str, Json] class Flag: @@ -104,36 +127,3 @@ class Command(NamedTuple): # Config DefaultType = TypeVar("DefaultType") - - -class ResponseLike(Protocol): - """Interface for objects that resemble a requests.Response object.""" - - @property - def ok(self) -> bool: - """Return True if the response was successful.""" - ... - - @property - def status_code(self) -> int: - """Return the HTTP status code.""" - ... - - @property - def reason(self) -> str: - """Return the HTTP status reason.""" - ... - - @property - def headers(self) -> CaseInsensitiveDict[str]: - """Return the dictionary of response headers.""" - ... - - @property - def text(self) -> str: - """Return the response body as text.""" - ... - - def json(self, **kwargs: Any) -> Any: - """Return the response body as JSON.""" - ... diff --git a/mreg_cli/utilities/api.py b/mreg_cli/utilities/api.py index 2be1720e..1a0a1784 100644 --- a/mreg_cli/utilities/api.py +++ b/mreg_cli/utilities/api.py @@ -11,20 +11,33 @@ import os import re import sys -from typing import Any, Literal, NoReturn, TypeVar, cast, get_origin, overload +from typing import ( + Any, + Literal, + NoReturn, + TypeVar, + cast, + get_origin, + overload, +) from urllib.parse import urljoin from uuid import uuid4 import requests from prompt_toolkit import prompt -from pydantic import TypeAdapter +from pydantic import ( + BaseModel, + TypeAdapter, + field_validator, +) +from requests import Response from mreg_cli.config import MregCliConfig -from mreg_cli.exceptions import CliError, LoginFailedError +from mreg_cli.exceptions import CliError, LoginFailedError, ValidationError from mreg_cli.log import cli_error, cli_warning from mreg_cli.outputmanager import OutputManager from mreg_cli.tokenfile import TokenFile -from mreg_cli.types import ResponseLike +from mreg_cli.types import Json, JsonMapping session = requests.Session() session.headers.update({"User-Agent": "mreg-cli"}) @@ -178,7 +191,7 @@ def auth_and_update_token(username: str, password: str) -> None: TokenFile.set_entry(username, MregCliConfig().get_url(), token) -def result_check(result: ResponseLike, operation_type: str, url: str) -> None: +def result_check(result: Response, operation_type: str, url: str) -> None: """Check the result of a request.""" if not result.ok: message = f'{operation_type} "{url}": {result.status_code}: {result.reason}' @@ -198,7 +211,7 @@ def _request_wrapper( ok404: bool = False, first: bool = True, **data: Any, -) -> ResponseLike | None: +) -> Response | None: """Wrap request calls to MREG for logging and token management.""" if params is None: params = {} @@ -225,24 +238,22 @@ def _request_wrapper( @overload -def get(path: str, params: dict[str, Any] | None, ok404: Literal[True]) -> ResponseLike | None: ... +def get(path: str, params: dict[str, Any] | None, ok404: Literal[True]) -> Response | None: ... @overload -def get(path: str, params: dict[str, Any] | None, ok404: Literal[False]) -> ResponseLike: ... +def get(path: str, params: dict[str, Any] | None, ok404: Literal[False]) -> Response: ... @overload -def get(path: str, params: dict[str, Any] | None = ..., *, ok404: bool) -> ResponseLike | None: ... +def get(path: str, params: dict[str, Any] | None = ..., *, ok404: bool) -> Response | None: ... @overload -def get(path: str, params: dict[str, Any] | None = ...) -> ResponseLike: ... +def get(path: str, params: dict[str, Any] | None = ...) -> Response: ... -def get( - path: str, params: dict[str, Any] | None = None, ok404: bool = False -) -> ResponseLike | None: +def get(path: str, params: dict[str, Any] | None = None, ok404: bool = False) -> Response | None: """Make a standard get request.""" if params is None: params = {} @@ -254,7 +265,7 @@ def get_list( params: dict[str, Any] | None = None, ok404: bool = False, limit: int | None = 500, -) -> list[dict[str, Any]]: +) -> list[Json]: """Make a get request that produces a list. Will iterate over paginated results and return result as list. If the number of hits is @@ -270,12 +281,7 @@ def get_list( :returns: A list of dictionaries. """ - ret = get_list_generic(path, params, ok404, limit, expect_one_result=False) - - if not isinstance(ret, list): - raise CliError(f"Expected a list of results, got {type(ret)}.") - - return ret + return get_list_generic(path, params, ok404, limit, expect_one_result=False) def get_list_in( @@ -283,7 +289,7 @@ def get_list_in( search_field: str, search_values: list[int], ok404: bool = False, -) -> list[dict[str, Any]]: +) -> list[Json]: """Get a list of items by a key value pair. :param path: The path to the API endpoint. @@ -305,7 +311,7 @@ def get_item_by_key_value( search_field: str, search_value: str, ok404: bool = False, -) -> None | dict[str, Any]: +) -> None | JsonMapping: """Get an item by a key value pair. :param path: The path to the API endpoint. @@ -324,7 +330,7 @@ def get_list_unique( path: str, params: dict[str, str], ok404: bool = False, -) -> None | dict[str, Any]: +) -> None | JsonMapping: """Do a get request that returns a single result from a search. :param path: The path to the API endpoint. @@ -336,14 +342,69 @@ def get_list_unique( :returns: A single dictionary, or None if no result was found and ok404 is True. """ ret = get_list_generic(path, params, ok404, expect_one_result=True) - - if not isinstance(ret, dict): - raise CliError(f"Expected a single result, got {type(ret)}.") - if not ret: return None - return ret + try: + validator = TypeAdapter(JsonMapping) + return validator.validate_python(ret) + except ValueError as e: + raise ValidationError(f"Failed to validate response from {path}: {e}") from e + + +class PaginatedResponse(BaseModel): + """Paginated response data from the API.""" + + count: int + next: str | None + previous: str | None + results: list[Json] + + @field_validator("count", mode="before") + @classmethod + def _none_count_is_0(cls, v: Any) -> Any: + """Ensure `count` is never `None`.""" + # Django count doesn't seem to be guaranteed to be an integer. + # Ensures here that None is treated as 0. + # https://github.com/django/django/blob/bcbc4b9b8a4a47c8e045b060a9860a5c038192de/django/core/paginator.py#L105-L111 + return v or 0 + + @classmethod + def from_response(cls, response: Response) -> PaginatedResponse: + """Create a PaginatedResponse from a Response.""" + return cls.model_validate_json(response.text) + + +ListResponse = TypeAdapter(list[Json]) +"""JSON list (array) response adapter.""" + + +# TODO: Provide better validation error introspection +def validate_list_response(response: Response) -> list[Json]: + """Parse and validate that a response contains a JSON array. + + :param response: The response to validate. + :raises ValidationError: If the response does not contain a valid JSON array. + :returns: Parsed response data as a list of Python objects. + """ + try: + return ListResponse.validate_json(response.text) + # NOTE: ValueError catches custom Pydantic errors too + except ValueError as e: + raise ValidationError(f"{response.url} did not return a valid JSON array") from e + + +def validate_paginated_response(response: Response) -> PaginatedResponse: + """Validate and parse that a response contains paginated JSON data. + + :param response: The response to validate. + :raises ValidationError: If the response does not contain valid paginated JSON. + :returns: Parsed response data as a PaginatedResponse object. + """ + try: + return PaginatedResponse.from_response(response) + except ValueError as e: + raise ValidationError(f"{response.url} did not return valid paginated JSON") from e @overload @@ -353,7 +414,7 @@ def get_list_generic( ok404: bool = False, limit: int | None = 500, expect_one_result: Literal[True] = True, -) -> dict[str, Any]: ... +) -> Json: ... @overload @@ -363,7 +424,7 @@ def get_list_generic( ok404: bool = False, limit: int | None = 500, expect_one_result: Literal[False] = False, -) -> list[dict[str, Any]]: ... +) -> list[Json]: ... def get_list_generic( @@ -372,7 +433,7 @@ def get_list_generic( ok404: bool = False, limit: int | None = 500, expect_one_result: bool | None = False, -) -> dict[str, Any] | list[dict[str, Any]]: +) -> Json | list[Json]: """Make a get request that produces a list. Will iterate over paginated results and return result as list. If the number of hits is @@ -394,8 +455,8 @@ def get_list_generic( """ def _check_expect_one_result( - ret: list[dict[str, Any]], - ) -> dict[str, Any] | list[dict[str, Any]]: + ret: list[Json], + ) -> Json | list[Json]: if expect_one_result: if len(ret) == 0: return {} @@ -409,39 +470,34 @@ def _check_expect_one_result( if params is None: params = {} - ret: list[dict[str, Any]] = [] + response = get(path, params) - # Get the first page to check the number of hits, and raise an exception if it is too high. - get_params = params.copy() - # get_params["page_size"] = 1 - resp = get(path, get_params).json() + # Non-paginated results, return them directly + if "count" not in response.text: + return validate_list_response(response) - if isinstance(resp, list): - # If list, assume it contains dicts - return cast(list[dict[str, Any]], resp) - elif not isinstance(resp, dict): - raise CliError(f"Expected a dict or list from {path!r}, got {type(resp)!r}.") - else: - resp = cast(dict[str, Any], resp) + resp = validate_paginated_response(response) - if limit and resp.get("count", 0) > abs(limit): - cli_warning(f"Too many hits ({resp['count']}), please refine your search criteria.") + if limit and resp.count > abs(limit): + cli_warning(f"Too many hits ({resp.count}), please refine your search criteria.") # Short circuit if there are no more pages. This means that there are no more results to # be had so we can return the results we already have. - if "next" in resp and not resp["next"]: - return _check_expect_one_result(resp["results"]) + if not resp.next: + return _check_expect_one_result(resp.results) + # Iterate over all pages and collect the results + ret: list[Json] = [] while True: resp = get(path, params=params, ok404=ok404) if resp is None: return _check_expect_one_result(ret) - result = resp.json() + result = validate_paginated_response(resp) - ret.extend(result["results"]) + ret.extend(result.results) - if "next" in result and result["next"]: - path = result["next"] + if result.next: + path = result.next else: return _check_expect_one_result(ret) @@ -462,7 +518,7 @@ def get_typed( :param params: The parameters to pass to the API endpoint. :param limit: The maximum number of hits to allow for paginated responses. - :raises ValidationError: If the response cannot be deserialized into the given type. + :raises pydantic.ValidationError: If the response cannot be deserialized into the given type. :returns: An instance of `type_` populated with data from the response. """ @@ -475,21 +531,21 @@ def get_typed( return adapter.validate_json(resp.text) -def post(path: str, params: dict[str, Any] | None = None, **kwargs: Any) -> ResponseLike | None: +def post(path: str, params: dict[str, Any] | None = None, **kwargs: Any) -> Response | None: """Use requests to make a post request. Assumes that all kwargs are data fields.""" if params is None: params = {} return _request_wrapper("post", path, params=params, **kwargs) -def patch(path: str, params: dict[str, Any] | None = None, **kwargs: Any) -> ResponseLike | None: +def patch(path: str, params: dict[str, Any] | None = None, **kwargs: Any) -> Response | None: """Use requests to make a patch request. Assumes that all kwargs are data fields.""" if params is None: params = {} return _request_wrapper("patch", path, params=params, **kwargs) -def delete(path: str, params: dict[str, Any] | None = None) -> ResponseLike | None: +def delete(path: str, params: dict[str, Any] | None = None) -> Response | None: """Use requests to make a delete request.""" if params is None: params = {}