Skip to content

Commit

Permalink
Improve generic typing further
Browse files Browse the repository at this point in the history
  • Loading branch information
Daverball committed Apr 14, 2024
1 parent 52890d7 commit a65bf26
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/itsdangerous/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class _CompactJSON:
"""Wrapper around json module that strips whitespace."""

@staticmethod
def loads(s: str | bytes) -> t.Any:
return _json.loads(s)
def loads(payload: str | bytes) -> t.Any:
return _json.loads(payload)

@staticmethod
def dumps(obj: t.Any, *args: t.Any, **kwargs: t.Any) -> str:
def dumps(obj: t.Any, **kwargs: t.Any) -> str:
kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("separators", (",", ":"))
return _json.dumps(obj, **kwargs)
101 changes: 86 additions & 15 deletions src/itsdangerous/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,39 @@
from .signer import _make_keys_list
from .signer import Signer


class _PDataSerializer(t.Protocol[t.AnyStr]):
def loads(self, s: t.AnyStr) -> t.Any: ...
def dumps(self, obj: t.Any, *args: t.Any, **kwargs: t.Any) -> t.AnyStr: ...


def is_text_serializer(serializer: _PDataSerializer[t.Any]) -> bool:
if t.TYPE_CHECKING:
import typing_extensions as te

# Usually we want this to either be str or bytes, but to avoid users having
# to manually set the bound type, we want to fall back to the old behavior
# of returning a union of both possibilities if structural matching fails
_SerializedT = te.TypeVar(
"_SerializedT", bound=t.Union[str, bytes], default=t.Union[str, bytes]
)
else:
# so this TypeVar is still available at runtime, albeit without the extra argument
_SerializedT = t.TypeVar("_SerializedT", bound=t.Union[str, bytes])


class _PDataSerializer(t.Protocol[_SerializedT]):
def loads(self, payload: _SerializedT, /) -> t.Any: ...
# we would like to use a gradual signature here, so we can support serializers
# with dumps methods that have additional required keyword arguments, in the
# meantime we provide a fallback overload for serializers that don't quite match
# this more strict Protocol
def dumps(self, obj: t.Any, /) -> _SerializedT: ...


# we want to replace this with te.TypeIs, as soon as it's well supported and part
# of typing_extensions.
def is_text_serializer(
serializer: _PDataSerializer[t.Any],
) -> te.TypeGuard[_PDataSerializer[str]]:
"""Checks whether a serializer generates text or binary."""
return isinstance(serializer.dumps({}), str)


class Serializer(t.Generic[t.AnyStr]):
class Serializer(t.Generic[_SerializedT]):
"""A serializer wraps a :class:`~itsdangerous.signer.Signer` to
enable serializing and securely signing data other than bytes. It
can unsign to verify that the data hasn't been changed.
Expand Down Expand Up @@ -76,7 +97,7 @@ class Serializer(t.Generic[t.AnyStr]):
#: The default serialization module to use to serialize data to a
#: string internally. The default is :mod:`json`, but can be changed
#: to any object that provides ``dumps`` and ``loads`` methods.
default_serializer: _PDataSerializer[t.Any] = json # pyright: ignore
default_serializer: _PDataSerializer[t.Any] = json

#: The default ``Signer`` class to instantiate when signing data.
#: The default is :class:`itsdangerous.signer.Signer`.
Expand All @@ -94,7 +115,56 @@ def __init__(
self: Serializer[str],
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None = b"itsdangerous",
serializer: None = None,
serializer: None | _PDataSerializer[str] = None,
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
fallback_signers: list[
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
]
| None = None,
): ...

@t.overload
def __init__(
self: Serializer[bytes],
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None,
serializer: _PDataSerializer[bytes],
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
fallback_signers: list[
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
]
| None = None,
): ...

@t.overload
def __init__(
self: Serializer[bytes],
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None = b"itsdangerous",
*,
serializer: _PDataSerializer[bytes],
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
fallback_signers: list[
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
]
| None = None,
): ...

# if structural matching for _PDataSerializer fails we're currently more lenient
# and rely on TypeVar defaults to fall us back to str | bytes. Eventually we may
# want to consider tightening this back up again
@t.overload
def __init__(
self,
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None,
serializer: t.Any,
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
Expand All @@ -106,10 +176,11 @@ def __init__(

@t.overload
def __init__(
self: Serializer[t.AnyStr],
self,
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None = b"itsdangerous",
serializer: _PDataSerializer[t.AnyStr] = ...,
*,
serializer: t.Any,
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
Expand All @@ -123,7 +194,7 @@ def __init__(
self,
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None = b"itsdangerous",
serializer: _PDataSerializer[t.AnyStr] | None = None,
serializer: t.Any | None = None,
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
Expand All @@ -148,7 +219,7 @@ def __init__(
if serializer is None:
serializer = self.default_serializer

self.serializer: _PDataSerializer[t.AnyStr] = serializer
self.serializer: _PDataSerializer[_SerializedT] = serializer
self.is_text_serializer: bool = is_text_serializer(serializer)

if signer is None:
Expand Down Expand Up @@ -238,7 +309,7 @@ def iter_unsigners(self, salt: str | bytes | None = None) -> cabc.Iterator[Signe
for secret_key in self.secret_keys:
yield fallback(secret_key, salt=salt, **kwargs)

def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> t.AnyStr:
def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> _SerializedT:
"""Returns a signed string serialized with the internal
serializer. The return value can be either a byte or unicode
string depending on the format of the internal serializer.
Expand Down
3 changes: 2 additions & 1 deletion src/itsdangerous/timed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .exc import BadSignature
from .exc import BadTimeSignature
from .exc import SignatureExpired
from .serializer import _SerializedT
from .serializer import Serializer
from .signer import Signer

Expand Down Expand Up @@ -166,7 +167,7 @@ def validate(self, signed_value: str | bytes, max_age: int | None = None) -> boo
return False


class TimedSerializer(Serializer[t.AnyStr]):
class TimedSerializer(Serializer[_SerializedT]):
"""Uses :class:`TimestampSigner` instead of the default
:class:`.Signer`.
"""
Expand Down

0 comments on commit a65bf26

Please sign in to comment.