From 9923d81c30da1352613f7d07f1de1f33cc0aba9d Mon Sep 17 00:00:00 2001 From: pederhan Date: Thu, 25 Jul 2024 15:23:03 +0200 Subject: [PATCH 1/2] Improve type safety of validators --- mreg_cli/api/fields.py | 15 +++++++++++---- mreg_cli/api/history.py | 14 ++++++++------ mreg_cli/api/models.py | 37 ++++++++++++++----------------------- mreg_cli/utilities/api.py | 3 ++- 4 files changed, 35 insertions(+), 34 deletions(-) diff --git a/mreg_cli/api/fields.py b/mreg_cli/api/fields.py index 35cadf99..4f91763a 100644 --- a/mreg_cli/api/fields.py +++ b/mreg_cli/api/fields.py @@ -20,7 +20,7 @@ class MACAddressField(FrozenModel): address: str - @field_validator("address", mode="before") + @field_validator("address", mode="after") @classmethod def validate_and_format_mac(cls, v: str) -> str: """Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format. @@ -49,8 +49,13 @@ class IPAddressField(FrozenModel): @classmethod def from_string(cls, address: str) -> IPAddressField: - """Create an IPAddressField from a string.""" - return cls(address=address) # type: ignore # validator handles this + """Create an IPAddressField from a string. + + Shortcut for creating an IPAddressField from a string, + without having to convince the type checker that we can + pass in a string to the address field each time. + """ + return cls(address=address) # pyright: ignore[reportArgumentType] # validator handles this @field_validator("address", mode="before") @classmethod @@ -69,11 +74,13 @@ def is_ipv6(self) -> bool: """Check if the IP address is IPv6.""" return isinstance(self.address, ipaddress.IPv6Address) + @staticmethod def is_valid(value: str) -> bool: + """Check if the value is a valid IP address.""" try: ipaddress.ip_address(value) return True - except: + except ValueError: return False def __str__(self) -> str: diff --git a/mreg_cli/api/history.py b/mreg_cli/api/history.py index 52aa23a7..9df97044 100644 --- a/mreg_cli/api/history.py +++ b/mreg_cli/api/history.py @@ -63,12 +63,14 @@ class HistoryItem(BaseModel): data: dict[str, Any] @field_validator("data", mode="before") - def parse_json_data(cls, v: Any) -> dict[str, Any]: - """Ensure that data is always treated as a dictionary.""" + def parse_json_data(cls, v: Any) -> Any: + """Ensure that non-dict values are treated as JSON.""" if isinstance(v, dict): - return v # type: ignore - else: + return v # pyright: ignore[reportUnknownVariableType] + try: return json.loads(v) + except json.JSONDecodeError as e: + raise ValueError("Failed to parse history data as JSON") from e def clean_timestamp(self) -> str: """Clean up the timestamp for output.""" @@ -89,8 +91,8 @@ def msg(self, basename: str) -> str: rel = self.data["relation"][:-1] cls = str(self.resource) if "." in cls: - cls = cls[cls.rindex(".")+1:] - cls = cls.replace("HostPolicy_","") + cls = cls[cls.rindex(".") + 1 :] + cls = cls.replace("HostPolicy_", "") cls = cls.lower() msg = f"{rel} {self.data['name']} {direction} {cls} {self.name}" elif action == "create": diff --git a/mreg_cli/api/models.py b/mreg_cli/api/models.py index 1c2c79ea..04dcf3c6 100644 --- a/mreg_cli/api/models.py +++ b/mreg_cli/api/models.py @@ -397,7 +397,7 @@ class Permission(FrozenModelWithTimestamps, APIMixin): @field_validator("range", mode="before") @classmethod - def validate_ip_or_network(cls, value: str) -> IP_NetworkT: + def validate_ip_or_network(cls, value: Any) -> IP_NetworkT: """Validate and convert the input to a network.""" try: return ipaddress.ip_network(value) @@ -1862,21 +1862,19 @@ class IPAddress(FrozenModelWithTimestamps, WithHost, APIMixin): @field_validator("macaddress", mode="before") @classmethod - def create_valid_macadress_or_none(cls, v: str) -> MACAddressField | None: + def create_valid_macadress_or_none(cls, v: Any) -> MACAddressField | None: """Create macaddress or convert empty strings to None.""" if v: return MACAddressField(address=v) - return None - @model_validator(mode="before") + @field_validator("ipaddress", mode="before") @classmethod - def convert_ip_address(cls, values: Any): - """Convert ipaddress string to IPAddressField if necessary.""" - ip_address = values.get("ipaddress") - if isinstance(ip_address, str): - values["ipaddress"] = {"address": ip_address} - return values + def create_valid_ipaddress(cls, v: Any) -> IPAddressField: + """Create macaddress or convert empty strings to None.""" + if isinstance(v, str): + return IPAddressField.from_string(v) + return v # let Pydantic handle it @classmethod def get_by_ip(cls, ip: IP_AddressT) -> list[Self]: @@ -2039,7 +2037,7 @@ class CNAME(FrozenModelWithTimestamps, WithHost, WithZone, WithTTL, APIMixin): @field_validator("name", mode="before") @classmethod - def validate_name(cls, value: str) -> HostT: + def validate_name(cls, value: Any) -> HostT: """Validate the hostname.""" return HostT(hostname=value) @@ -2502,17 +2500,16 @@ class Host(FrozenModelWithTimestamps, WithTTL, WithHistory, APIMixin): @field_validator("name", mode="before") @classmethod - def validate_name(cls, value: str) -> HostT: + def validate_name(cls, value: Any) -> HostT: """Validate the hostname.""" return HostT(hostname=value) @field_validator("bacnetid", mode="before") @classmethod - def convert_bacnetid(cls, v: dict[str, int] | None) -> int | None: - """Convert json id field to int or None.""" - if v and "id" in v: - return v["id"] - + def convert_bacnetid(cls, v: Any) -> Any: + """Use nested ID value in bacnetid value.""" + if isinstance(v, dict): + return v.get("id") # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] return None @classmethod @@ -3133,12 +3130,6 @@ def get_by_ip(cls, ip: IP_AddressT) -> HostList: """ return cls.get(params={"ipaddresses__ipaddress": str(ip), "ordering": "name"}) - @field_validator("results", mode="before") - @classmethod - def check_results(cls, v: list[dict[str, str]]) -> list[dict[str, str]]: - """Check that the results are valid.""" - return v - def __len__(self): """Return the number of results.""" return len(self.results) diff --git a/mreg_cli/utilities/api.py b/mreg_cli/utilities/api.py index c0bb324d..842d5a24 100644 --- a/mreg_cli/utilities/api.py +++ b/mreg_cli/utilities/api.py @@ -428,8 +428,9 @@ class PaginatedResponse(BaseModel): def _none_count_is_0(cls, v: Any) -> Any: """Ensure `count` is never `None`.""" # Django count doesn't seem to be guaranteed to be an integer. - # Ensures here that None is treated as 0. # https://github.com/django/django/blob/bcbc4b9b8a4a47c8e045b060a9860a5c038192de/django/core/paginator.py#L105-L111 + # Theoretically any callable can be passed to the "count" attribute of the paginator. + # Ensures here that None (and any falsey value) is treated as 0. return v or 0 @classmethod From dcfdfd880710104e005db4d891edfc0ece42689b Mon Sep 17 00:00:00 2001 From: pederhan Date: Fri, 26 Jul 2024 10:32:16 +0200 Subject: [PATCH 2/2] Remove unused import --- mreg_cli/api/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mreg_cli/api/models.py b/mreg_cli/api/models.py index 04dcf3c6..bd9113ba 100644 --- a/mreg_cli/api/models.py +++ b/mreg_cli/api/models.py @@ -14,7 +14,6 @@ Field, computed_field, field_validator, - model_validator, ) from typing_extensions import Unpack