Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type safety of validators #286

Merged
merged 2 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mreg_cli/api/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MACAddressField(FrozenModel):

address: str

@field_validator("address", mode="before")
@field_validator("address", mode="after")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using "after", we let Pydantic ensure the value is a string before we proceed with validating and manipulating it.

@classmethod
def validate_and_format_mac(cls, v: str) -> str:
"""Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format.
Expand Down Expand Up @@ -49,8 +49,13 @@ class IPAddressField(FrozenModel):

@classmethod
def from_string(cls, address: str) -> IPAddressField:
"""Create an IPAddressField from a string."""
return cls(address=address) # type: ignore # validator handles this
"""Create an IPAddressField from a string.

Shortcut for creating an IPAddressField from a string,
without having to convince the type checker that we can
pass in a string to the address field each time.
"""
return cls(address=address) # pyright: ignore[reportArgumentType] # validator handles this

@field_validator("address", mode="before")
@classmethod
Expand All @@ -69,11 +74,13 @@ def is_ipv6(self) -> bool:
"""Check if the IP address is IPv6."""
return isinstance(self.address, ipaddress.IPv6Address)

@staticmethod
def is_valid(value: str) -> bool:
"""Check if the value is a valid IP address."""
try:
ipaddress.ip_address(value)
return True
except:
except ValueError:
return False

def __str__(self) -> str:
Expand Down
14 changes: 8 additions & 6 deletions mreg_cli/api/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ class HistoryItem(BaseModel):
data: dict[str, Any]

@field_validator("data", mode="before")
def parse_json_data(cls, v: Any) -> dict[str, Any]:
"""Ensure that data is always treated as a dictionary."""
def parse_json_data(cls, v: Any) -> Any:
"""Ensure that non-dict values are treated as JSON."""
if isinstance(v, dict):
return v # type: ignore
else:
return v # pyright: ignore[reportUnknownVariableType]
try:
return json.loads(v)
except json.JSONDecodeError as e:
raise ValueError("Failed to parse history data as JSON") from e
Comment on lines +69 to +73
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pydantic will not handle failed JSON parsing in the validator gracefully, so we need to guard it. We should probably log this somehow.


def clean_timestamp(self) -> str:
"""Clean up the timestamp for output."""
Expand All @@ -89,8 +91,8 @@ def msg(self, basename: str) -> str:
rel = self.data["relation"][:-1]
cls = str(self.resource)
if "." in cls:
cls = cls[cls.rindex(".")+1:]
cls = cls.replace("HostPolicy_","")
cls = cls[cls.rindex(".") + 1 :]
cls = cls.replace("HostPolicy_", "")
cls = cls.lower()
msg = f"{rel} {self.data['name']} {direction} {cls} {self.name}"
elif action == "create":
Expand Down
37 changes: 14 additions & 23 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class Permission(FrozenModelWithTimestamps, APIMixin):

@field_validator("range", mode="before")
@classmethod
def validate_ip_or_network(cls, value: str) -> IP_NetworkT:
def validate_ip_or_network(cls, value: Any) -> IP_NetworkT:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not guaranteed to be a string with mode="before"

"""Validate and convert the input to a network."""
try:
return ipaddress.ip_network(value)
Expand Down Expand Up @@ -1862,21 +1862,19 @@ class IPAddress(FrozenModelWithTimestamps, WithHost, APIMixin):

@field_validator("macaddress", mode="before")
@classmethod
def create_valid_macadress_or_none(cls, v: str) -> MACAddressField | None:
def create_valid_macadress_or_none(cls, v: Any) -> MACAddressField | None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not guaranteed to be a string with mode="before"

"""Create macaddress or convert empty strings to None."""
if v:
return MACAddressField(address=v)

return None

@model_validator(mode="before")
@field_validator("ipaddress", mode="before")
@classmethod
def convert_ip_address(cls, values: Any):
"""Convert ipaddress string to IPAddressField if necessary."""
ip_address = values.get("ipaddress")
if isinstance(ip_address, str):
values["ipaddress"] = {"address": ip_address}
return values
def create_valid_ipaddress(cls, v: Any) -> IPAddressField:
"""Create macaddress or convert empty strings to None."""
if isinstance(v, str):
return IPAddressField.from_string(v)
return v # let Pydantic handle it
Comment on lines 1869 to +1876
Copy link
Member Author

@pederhan pederhan Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A field validator was a better fit here, since we can delegate the validation to IPAddressField.from_string() directly instead of constructing an intermediate dict representation of the address.


@classmethod
def get_by_ip(cls, ip: IP_AddressT) -> list[Self]:
Expand Down Expand Up @@ -2039,7 +2037,7 @@ class CNAME(FrozenModelWithTimestamps, WithHost, WithZone, WithTTL, APIMixin):

@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: str) -> HostT:
def validate_name(cls, value: Any) -> HostT:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not guaranteed to be a string with mode="before"

"""Validate the hostname."""
return HostT(hostname=value)

Expand Down Expand Up @@ -2502,17 +2500,16 @@ class Host(FrozenModelWithTimestamps, WithTTL, WithHistory, APIMixin):

@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: str) -> HostT:
def validate_name(cls, value: Any) -> HostT:
"""Validate the hostname."""
return HostT(hostname=value)

@field_validator("bacnetid", mode="before")
@classmethod
def convert_bacnetid(cls, v: dict[str, int] | None) -> int | None:
"""Convert json id field to int or None."""
if v and "id" in v:
return v["id"]

def convert_bacnetid(cls, v: Any) -> Any:
"""Use nested ID value in bacnetid value."""
if isinstance(v, dict):
return v.get("id") # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
return None
Comment on lines 2505 to 2512
Copy link
Member Author

@pederhan pederhan Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not guaranteed to be a dict, check first. The validation of the value we retrieve from the dict is left to Pydantic.


@classmethod
Expand Down Expand Up @@ -3133,12 +3130,6 @@ def get_by_ip(cls, ip: IP_AddressT) -> HostList:
"""
return cls.get(params={"ipaddresses__ipaddress": str(ip), "ordering": "name"})

@field_validator("results", mode="before")
@classmethod
def check_results(cls, v: list[dict[str, str]]) -> list[dict[str, str]]:
"""Check that the results are valid."""
return v

Comment on lines -3136 to -3141
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did nothing.

def __len__(self):
"""Return the number of results."""
return len(self.results)
Expand Down
3 changes: 2 additions & 1 deletion mreg_cli/utilities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,9 @@ class PaginatedResponse(BaseModel):
def _none_count_is_0(cls, v: Any) -> Any:
"""Ensure `count` is never `None`."""
# Django count doesn't seem to be guaranteed to be an integer.
# Ensures here that None is treated as 0.
# https://github.com/django/django/blob/bcbc4b9b8a4a47c8e045b060a9860a5c038192de/django/core/paginator.py#L105-L111
# Theoretically any callable can be passed to the "count" attribute of the paginator.
# Ensures here that None (and any falsey value) is treated as 0.
return v or 0

@classmethod
Expand Down
Loading