diff --git a/.coveragerc b/.coveragerc index c9cf8c5..f621d2e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -17,7 +17,7 @@ omit = *migrations* exclude_lines = # Lines matching these regexes don't need to be covered # https://coverage.readthedocs.io/en/coverage-5.5/excluding.html?highlight=exclude_lines#advanced-exclusion - + # this is the default but must be explicitly specified since # we are overriding exclude_lines pragma: no cover diff --git a/dev-requirements.txt b/dev-requirements.txt index 375a12e..5240495 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -235,7 +235,9 @@ importlib-metadata==4.6.1 \ importlib-resources==5.2.2 \ --hash=sha256:2480d8e07d1890056cb53c96e3de44fead9c62f2ba949b0f2e4c4345f4afa977 \ --hash=sha256:a65882a4d0fe5fbf702273456ba2ce74fe44892c25e42e057aca526b702a6d4b - # via virtualenv + # via + # -r test-requirements.in + # virtualenv incremental==21.3.0 \ --hash=sha256:02f5de5aff48f6b9f665d99d48bfc7ec03b6e3943210de7cfc88856d755d6f57 \ --hash=sha256:92014aebc6a20b78a8084cdd5645eeaa7f74b8933f70fa3ada2cfbd1e3b54321 @@ -600,6 +602,10 @@ typed-ast==1.4.3 \ # astroid # black # mypy +typeguard==2.12.1 \ + --hash=sha256:c2af8b9bdd7657f4bd27b45336e7930171aead796711bc4cfc99b4731bb9d051 \ + --hash=sha256:cc15ef2704c9909ef9c80e19c62fb8468c01f75aad12f651922acf4dbe822e02 + # via -r requirements.in typing-extensions==3.10.0.0 \ --hash=sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497 \ --hash=sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342 \ diff --git a/docs/reference/fields.rst b/docs/reference/fields.rst new file mode 100644 index 0000000..b0b85de --- /dev/null +++ b/docs/reference/fields.rst @@ -0,0 +1,124 @@ +.. + TODO: figure out proper location for this stuff including making it public + if in desert + + +desert._fields module +===================== + +Tagged Unions +------------- + +Serializing and deserializing data of uncertain type can be tricky. +Cases where the type is hinted with :class:`typing.Union` create such cases. +Some solutions have a list of possible types and use the first one that works. +This can be used in some cases where the data for different types is sufficiently unique as to only work with a single type. +In more general cases you have difficulties where data of one type ends up getting processed by another type that is similar, but not the same. + +In an effort to reduce the heuristics involved in serialization and deserialization an explicit tag can be added to identify the type. +That is the basic feature of this group of utilities related to tagged unions. + +A tag indicating the object's type can be applied in various ways. +Presently three forms are implemented: adjacently tagged, internally tagged, and externally tagged. +Adjacently tagged is the most explicit form and is the recommended default. +You can write your own helper functions to implement your own tagging form if needed and still make use of the rest of the mechanisms implemented here. +A code example follows the forms below and provides related reference. + +- A class definition and bare serialized object for reference + + .. literalinclude:: ../../tests/test_fields.py + :language: python + :start-after: # start cat_class_example + :end-before: # end cat_class_example + + .. literalinclude:: ../../tests/example/untagged.json + :language: json + + +- Adjacently tagged + + .. literalinclude:: ../../tests/example/adjacent.json + :language: json + +- Internally tagged + + .. literalinclude:: ../../tests/example/internal.json + :language: json + +- Externally tagged + + .. literalinclude:: ../../tests/example/external.json + :language: json + +The code below is an actual test from the Desert test suite that provides an example usage of the tools that will be covered in detail below. + +.. literalinclude:: ../../tests/test_fields.py + :language: python + :start-after: # start tagged_union_example + :end-before: # end tagged_union_example + + +Fields +...... + +A :class:`marshmallow.fields.Field` is needed to describe the serialization. +This role is filled by :class:`desert._fields.TaggedUnionField`. +Several helpers at different levels are included to generate field instances that support each of the tagging schemes shown above. +:ref:`Registries ` are used to collect and hold the information needed to make the choices the field needs. +The helpers below create :class:`desert._fields.TaggedUnionField` instances that are backed by the passed registry. + +.. autofunction:: desert._fields.adjacently_tagged_union_from_registry +.. autofunction:: desert._fields.internally_tagged_union_from_registry +.. autofunction:: desert._fields.externally_tagged_union_from_registry + +.. autoclass:: desert._fields.TaggedUnionField + :members: + :undoc-members: + :show-inheritance: + +The fields can be created from :class:`desert._fields.FromObjectProtocol` and :class:`desert._fields.FromTagProtocol` instead of registries, if need. + +.. autofunction:: desert._fields.adjacently_tagged_union +.. autofunction:: desert._fields.internally_tagged_union +.. autofunction:: desert._fields.externally_tagged_union + +.. autoclass:: desert._fields.FromObjectProtocol + :members: __call__ + :undoc-members: + :show-inheritance: + +.. autoclass:: desert._fields.FromTagProtocol + :members: __call__ + :undoc-members: + :show-inheritance: + + +.. _tagged_union_registries: + +Registries +.......... + +Since unions are inherently about handling multiple types, fields that handle unions must be able to make decisions about multiple types. +Registries are not required to leverage other pieces of union support if you are developing their logic yourself. +If you are using the builtin mechanisms then a registry will be needed to define the relationships between tags, fields, and object types. + +.. + TODO: sure seems like the user shouldn't need to call Nested() themselves + +The registry's :meth:`desert._fields.FieldRegistryProtocol.register` method will primarily be used. +As an example, you might register a custom class ``Cat`` by providing a hint of ``Cat``, a tag of ``"cat"``, and a field such as ``marshmallow.fields.Nested(desert.schema(Cat))``. + +.. autoclass:: desert._fields.FieldRegistryProtocol + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: desert._fields.TypeAndHintFieldRegistry + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: desert._fields.HintTagField + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/reference/index.rst b/docs/reference/index.rst index a981732..9f266bc 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -5,3 +5,4 @@ Reference :glob: desert* + fields diff --git a/docs/requirements.txt b/docs/requirements.txt index 99b8ca6..0782025 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,7 @@ sphinx~=4.1 # >= 1.0.0rc1 for https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 sphinx-rtd-theme >= 1.0.0rc1 +# < 0.17 for /home/docs/checkouts/readthedocs.org/user_builds/desert/checkouts/94/src/desert/_fields.py:docstring of desert._fields.externally_tagged_union_from_registry:4: WARNING: circular inclusion in "include" directive: snippets/tag_forms/external.rst < snippets/tag_forms/internal.rst < snippets/tag_forms/adjacent.rst < snippets/tag_forms/external.rst < reference/fields.rst +docutils < 0.17 sphinx-autodoc-typehints -e .[dev] diff --git a/requirements.in b/requirements.in index db05712..11c75b2 100644 --- a/requirements.in +++ b/requirements.in @@ -2,3 +2,5 @@ marshmallow>=3.0 attrs typing_inspect dataclasses; python_version < "3.7" +#https://github.com/Stewori/pytypes/archive/b7271ec654d3553894febc6e0d8ad1b0e1ac570a.zip +typeguard diff --git a/requirements.txt b/requirements.txt index 5c9f441..f2c6f78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,10 @@ mypy-extensions==0.4.3 \ --hash=sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d \ --hash=sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8 # via typing-inspect +typeguard==2.12.1 \ + --hash=sha256:c2af8b9bdd7657f4bd27b45336e7930171aead796711bc4cfc99b4731bb9d051 \ + --hash=sha256:cc15ef2704c9909ef9c80e19c62fb8468c01f75aad12f651922acf4dbe822e02 + # via -r requirements.in typing-extensions==3.10.0.0 \ --hash=sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497 \ --hash=sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342 \ diff --git a/src/desert/_fields.py b/src/desert/_fields.py new file mode 100644 index 0000000..8bb59ae --- /dev/null +++ b/src/desert/_fields.py @@ -0,0 +1,445 @@ +import functools +import typing as t + +import attr +import marshmallow.fields +import typeguard +import typing_extensions +import typing_inspect + +import desert._util +import desert.exceptions + + +T = t.TypeVar("T") + + +# TODO: there must be a better name +@attr.s(frozen=True, auto_attribs=True) +class HintTagField: + """Serializing and deserializing a given piece of data requires a group of + information. A type hint that matches the data to be serialized, a Marshmallow + field that knows how to serialize and deserialize the data, and a string tag to + label the serialized data for deserialization. This is that group... There + must be a better name. + """ + + hint: t.Any + tag: str + field: marshmallow.fields.Field + + +class FieldRegistryProtocol(typing_extensions.Protocol): + """This protocol encourages registries to provide a common interface. The actual + implementation of the mapping from objects to be serialized to their Marshmallow + fields, and likewise from the serialized data, can take any form. + """ + + def register( + self, + hint: t.Any, + tag: str, + field: marshmallow.fields.Field, + ) -> None: + """Inform the registry of the relationship between the passed hint, tag, and + field. + """ + ... + + @property + def from_object(self) -> "FromObjectProtocol": + """This is a funny way of writing that the registry's `.from_object()` method + should satisfy :class:`FromObjectProtocol`. + """ + ... + + @property + def from_tag(self) -> "FromTagProtocol": + """This is a funny way of writing that the registry's `.from_tag()` method + should satisfy :class:`FromTagProtocol`. + """ + ... + + +check_field_registry_protocol = desert._util.ProtocolChecker[FieldRegistryProtocol]() + + +# @attr.s(auto_attribs=True) +# class TypeDictFieldRegistry: +# the_dict: t.Dict[ +# t.Union[type, str], +# HintTagField, +# ] = attr.ib(factory=dict) +# +# def register( +# self, +# hint: t.Any, +# tag: str, +# field: marshmallow.fields.Field, +# ) -> None: +# # TODO: just disabling for now to show more interesting test results +# # if any(key in self.the_dict for key in [cls, tag]): +# # raise Exception() +# +# type_tag_field = HintTagField(hint=hint, tag=tag, field=field) +# +# self.the_dict[hint] = type_tag_field +# self.the_dict[tag] = type_tag_field +# +# # # TODO: this type hinting... doesn't help much as it could return +# # # another cls +# # def __call__(self, tag: str, field: marshmallow.fields) -> t.Callable[[T], T]: +# # return lambda cls: self.register(cls=cls, tag=tag, field=field) +# +# def from_object(self, value: t.Any) -> HintTagField: +# return self.the_dict[type(value)] +# +# def from_tag(self, tag: str) -> HintTagField: +# return self.the_dict[tag] + + +@check_field_registry_protocol +@attr.s(auto_attribs=True) +class TypeAndHintFieldRegistry: + """This registry uses type and type hint checks to decide what field to use for + serialization. The deserialization field is chosen directly from the tag. + """ + + by_tag: t.Dict[str, HintTagField] = attr.ib(factory=dict) + + # TODO: but type bans from-scratch metatypes... and protocols + def register( + self, + hint: t.Any, + tag: str, + field: marshmallow.fields.Field, + ) -> None: + if tag in self.by_tag: + raise desert.exceptions.TagAlreadyRegistered(tag=tag) + + type_tag_field = HintTagField(hint=hint, tag=tag, field=field) + + self.by_tag[tag] = type_tag_field + + def from_object(self, value: object) -> HintTagField: + scores = {} + + # for type_tag_field in self.the_list: + for type_tag_field in self.by_tag.values(): + score = 0 + + # if pytypes.is_of_type(value, type_tag_field.hint): + try: + typeguard.check_type( + argname="", + value=value, + expected_type=type_tag_field.hint, + ) + except TypeError: + pass + else: + score += 2 + + try: + if isinstance(value, type_tag_field.hint): + score += 3 + except TypeError: + pass + + if score > 0: + # Only use this to disambiguate between already selected options such + # as ["a", "b"] matching both t.List[str] and t.Sequence[str]. + # This only works properly on 3.7+. + if type(value) == typing_inspect.get_origin(type_tag_field.hint): + score += 1 + + scores[type_tag_field] = score + + high_score = max(scores.values()) + + if high_score == 0: + raise desert.exceptions.NoMatchingHintFound( + hints=[ttf.hint for ttf in self.by_tag.values()], value=value + ) + + potential = [ttf for ttf, score in scores.items() if score == high_score] + + if len(potential) != 1: + raise desert.exceptions.MultipleMatchingHintsFound( + hints=[ttf.hint for ttf in potential], value=value + ) + + [type_tag_field] = potential + + return type_tag_field + + def from_tag(self, tag: str) -> HintTagField: + return self.by_tag[tag] + + +@attr.s(auto_attribs=True) +class TaggedValue: + tag: str + value: object + + +class FromObjectProtocol(typing_extensions.Protocol): + def __call__(self, value: object) -> HintTagField: + ... + + +class FromTagProtocol(typing_extensions.Protocol): + def __call__(self, tag: str) -> HintTagField: + ... + + +class FromTaggedProtocol(typing_extensions.Protocol): + def __call__(self, item: t.Any) -> TaggedValue: + ... + + +class ToTaggedProtocol(typing_extensions.Protocol): + def __call__(self, tag: str, value: t.Any) -> object: + ... + + +class TaggedUnionField(marshmallow.fields.Field): + """A Marshmallow field to handle unions where the data may not always be of a + single type. Usually this field would not be created directly but rather by + using helper functions to fill out the needed functions in a consistent manner. + + Helpers are provided both to directly create various forms of this field as well + as to create the same from a :class:`FieldRegistryProtocol`. + + - From a registry + + - :func:`adjacently_tagged_union_from_registry` + - :func:`internally_tagged_union_from_registry` + - :func:`externally_tagged_union_from_registry` + + - Direct + + - :func:`adjacently_tagged_union` + - :func:`internally_tagged_union` + - :func:`externally_tagged_union` + """ + + def __init__( + self, + *, + from_object: FromObjectProtocol, + from_tag: FromTagProtocol, + from_tagged: FromTaggedProtocol, + to_tagged: ToTaggedProtocol, + # object results in the super() call complaining about types + # https://github.com/python/mypy/issues/5382 + **kwargs: t.Any, + ) -> None: + super().__init__(**kwargs) + + self.from_object = from_object + self.from_tag = from_tag + self.from_tagged = from_tagged + self.to_tagged = to_tagged + + def _deserialize( + self, + value: object, + attr: t.Optional[str], + data: t.Optional[t.Mapping[str, object]], + # object results in the super() call complaining about types + # https://github.com/python/mypy/issues/5382 + **kwargs: t.Any, + ) -> object: + tagged_value = self.from_tagged(item=value) + + type_tag_field = self.from_tag(tagged_value.tag) + field = type_tag_field.field + + return field.deserialize(tagged_value.value) + + def _serialize( + self, + value: object, + attr: str, + obj: object, + # object results in the super() call complaining about types + # https://github.com/python/mypy/issues/5382 + **kwargs: t.Any, + ) -> object: + type_tag_field = self.from_object(value) + field = type_tag_field.field + tag = type_tag_field.tag + serialized_value = field.serialize(attr, obj) + + return self.to_tagged(tag=tag, value=serialized_value) + + +default_tagged_type_key = "#type" +default_tagged_value_key = "#value" + + +def from_externally_tagged(item: t.Mapping[str, object]) -> TaggedValue: + """Process externally tagged data into a :class:`TaggedValue`.""" + + [[tag, serialized_value]] = item.items() + + return TaggedValue(tag=tag, value=serialized_value) + + +def to_externally_tagged(tag: str, value: object) -> t.Dict[str, object]: + """Process untagged data to the externally tagged form.""" + + return {tag: value} + + +def externally_tagged_union( + from_object: FromObjectProtocol, + from_tag: FromTagProtocol, +) -> TaggedUnionField: + """Create a :class:`TaggedUnionField` that supports the externally tagged form.""" + + # TODO: allow the pass through kwargs to the field + + return TaggedUnionField( + from_object=from_object, + from_tag=from_tag, + from_tagged=from_externally_tagged, + to_tagged=to_externally_tagged, + ) + + +def externally_tagged_union_from_registry( + registry: FieldRegistryProtocol, +) -> TaggedUnionField: + """Use a :class:`FieldRegistryProtocol` to create a :class:`TaggedUnionField` that supports + the externally tagged form. Externally tagged data has the following form. + + .. literalinclude:: ../../tests/example/external.json + :language: json + """ + + return externally_tagged_union( + from_object=registry.from_object, + from_tag=registry.from_tag, + ) + + +def from_internally_tagged(item: t.Mapping[str, object], type_key: str) -> TaggedValue: + """Process internally tagged data into a :class:`TaggedValue`.""" + + # it just kind of has to be a string... + type_string: str = item[type_key] # type: ignore[assignment] + + return TaggedValue( + tag=type_string, + value={k: v for k, v in item.items() if k != type_key}, + ) + + +def to_internally_tagged( + tag: str, + value: t.Mapping[str, object], + type_key: str, +) -> t.Mapping[str, object]: + """Process untagged data to the internally tagged form.""" + + if type_key in value: + raise desert.exceptions.TypeKeyCollision(type_key=type_key, value=value) + + return {type_key: tag, **value} + + +def internally_tagged_union( + from_object: FromObjectProtocol, + from_tag: FromTagProtocol, + type_key: str = default_tagged_type_key, +) -> TaggedUnionField: + """Create a :class:`TaggedUnionField` that supports the internally tagged form.""" + + return TaggedUnionField( + from_object=from_object, + from_tag=from_tag, + from_tagged=functools.partial(from_internally_tagged, type_key=type_key), + to_tagged=functools.partial(to_internally_tagged, type_key=type_key), + ) + + +def internally_tagged_union_from_registry( + registry: FieldRegistryProtocol, + type_key: str = default_tagged_type_key, +) -> TaggedUnionField: + """Use a :class:`FieldRegistryProtocol` to create a :class:`TaggedUnionField` that supports + the internally tagged form. Internally tagged data has the following form. + + .. literalinclude:: ../../tests/example/internal.json + :language: json + """ + + return internally_tagged_union( + from_object=registry.from_object, + from_tag=registry.from_tag, + type_key=type_key, + ) + + +def from_adjacently_tagged( + item: t.Dict[str, object], type_key: str, value_key: str +) -> TaggedValue: + """Process adjacently tagged data into a :class:`TaggedValue`.""" + + tag: str = item.pop(type_key) # type: ignore[assignment] + serialized_value = item.pop(value_key) + + if len(item) > 0: + raise Exception() + + return TaggedValue(tag=tag, value=serialized_value) + + +def to_adjacently_tagged( + tag: str, value: object, type_key: str, value_key: str +) -> t.Dict[str, object]: + """Process untagged data to the adjacently tagged form.""" + + return {type_key: tag, value_key: value} + + +def adjacently_tagged_union( + from_object: FromObjectProtocol, + from_tag: FromTagProtocol, + type_key: str = default_tagged_type_key, + value_key: str = default_tagged_value_key, +) -> TaggedUnionField: + """Create a :class:`TaggedUnionField` that supports the adjacently tagged form.""" + + return TaggedUnionField( + from_object=from_object, + from_tag=from_tag, + from_tagged=functools.partial( + from_adjacently_tagged, type_key=type_key, value_key=value_key + ), + to_tagged=functools.partial( + to_adjacently_tagged, type_key=type_key, value_key=value_key + ), + ) + + +def adjacently_tagged_union_from_registry( + registry: FieldRegistryProtocol, + type_key: str = default_tagged_type_key, + value_key: str = default_tagged_value_key, +) -> TaggedUnionField: + """Use a :class:`FieldRegistryProtocol` to create a :class:`TaggedUnionField` that supports + the adjacently tagged form. Adjacently tagged data has the following form. + + .. literalinclude:: ../../tests/example/adjacent.json + :language: json + """ + + return adjacently_tagged_union( + from_object=registry.from_object, + from_tag=registry.from_tag, + type_key=type_key, + value_key=value_key, + ) diff --git a/src/desert/_util.py b/src/desert/_util.py new file mode 100644 index 0000000..73370b4 --- /dev/null +++ b/src/desert/_util.py @@ -0,0 +1,26 @@ +import typing + + +T = typing.TypeVar("T") + + +class ProtocolChecker(typing.Generic[T]): + """Instances of this class can be used as decorators that will result in type hint + checks to verifying that other classes implement a given protocol. Generally you + would create a single instance where you define each protocol and then use that + instance as the decorator. Note that this usage is, at least in part, due to + Python not supporting type parameter specification in the ``@`` decorator + expression. + .. code-block:: python + import typing + class MyProtocol(typing.Protocol): + def a_method(self): ... + check_my_protocol = qtrio._util.ProtocolChecker[MyProtocol]() + @check_my_protocol + class AClass: + def a_method(self): + return 42092 + """ + + def __call__(self, cls: typing.Type[T]) -> typing.Type[T]: + return cls diff --git a/src/desert/exceptions.py b/src/desert/exceptions.py index 3aac255..6682c9e 100644 --- a/src/desert/exceptions.py +++ b/src/desert/exceptions.py @@ -1,4 +1,5 @@ import dataclasses +import typing as t import attr @@ -7,9 +8,47 @@ class DesertException(Exception): """Top-level exception for desert.""" +class MultipleMatchingHintsFound(DesertException): + """Raised when a union finds multiple hints that equally match the data to be + serialized. + """ + + def __init__(self, hints: t.Any, value: object): + hint_list = ", ".join(str(hint) for hint in hints) + super().__init__( + f"Multiple matching type hints found in union for {value!r}. Candidates: {hint_list}" + ) + + +class NoMatchingHintFound(DesertException): + """Raised when a union is unable to find a valid hint for the data to be + serialized. + """ + + def __init__(self, hints: t.Any, value: object): + hint_list = ", ".join(str(hint) for hint in hints) + super().__init__( + f"No matching type hints found in union for {value!r}. Considered: {hint_list}" + ) + + class NotAnAttrsClassOrDataclass(DesertException): """Raised for dataclass operations on non-dataclasses.""" +class TagAlreadyRegistered(DesertException): + """Raised when registering a tag that has already been registered.""" + + def __init__(self, tag: str): + super().__init__(f"Tag already registered: {tag!r}") + + +class TypeKeyCollision(DesertException): + """Raised when a tag key collides with a data value.""" + + def __init__(self, type_key: str, value: object): + super().__init__(f"Type key {type_key!r} collided with attribute in: {value!r}") + + class UnknownType(DesertException): """Raised for a type with unknown serialization equivalent.""" diff --git a/test-requirements.in b/test-requirements.in index 1b5ab61..2213c53 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -1,5 +1,6 @@ coverage cuvner +importlib_resources marshmallow-enum marshmallow-union pytest diff --git a/test-requirements.txt b/test-requirements.txt index b787a98..33afc4a 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -100,7 +100,6 @@ importlib-metadata==4.6.1 \ --hash=sha256:9f55f560e116f8643ecf2922d9cd3e1c7e8d52e683178fecd9d08f6aa357e11e # via # -r test-requirements.in - # backports.entry-points-selectable # click # pluggy # pytest @@ -109,7 +108,9 @@ importlib-metadata==4.6.1 \ importlib-resources==5.2.2 \ --hash=sha256:2480d8e07d1890056cb53c96e3de44fead9c62f2ba949b0f2e4c4345f4afa977 \ --hash=sha256:a65882a4d0fe5fbf702273456ba2ce74fe44892c25e42e057aca526b702a6d4b - # via virtualenv + # via + # -r test-requirements.in + # virtualenv incremental==21.3.0 \ --hash=sha256:02f5de5aff48f6b9f665d99d48bfc7ec03b6e3943210de7cfc88856d755d6f57 \ --hash=sha256:92014aebc6a20b78a8084cdd5645eeaa7f74b8933f70fa3ada2cfbd1e3b54321 @@ -204,6 +205,10 @@ tox==3.24.0 \ --hash=sha256:67636634df6569e450c4bc18fdfd8b84d7903b3902d5c65416eb6735f3d4afb8 \ --hash=sha256:c990028355f0d0b681e3db9baa89dd9f839a6e999c320029339f6a6b36160591 # via -r test-requirements.in +typeguard==2.12.1 \ + --hash=sha256:c2af8b9bdd7657f4bd27b45336e7930171aead796711bc4cfc99b4731bb9d051 \ + --hash=sha256:cc15ef2704c9909ef9c80e19c62fb8468c01f75aad12f651922acf4dbe822e02 + # via -r requirements.in typing-extensions==3.10.0.0 \ --hash=sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497 \ --hash=sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342 \ diff --git a/tests/example/__init__.py b/tests/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/example/adjacent.json b/tests/example/adjacent.json new file mode 100644 index 0000000..a67c923 --- /dev/null +++ b/tests/example/adjacent.json @@ -0,0 +1,7 @@ +{ + "#type": "cat", + "#value": { + "name": "Max", + "color": "tuxedo" + } +} diff --git a/tests/example/external.json b/tests/example/external.json new file mode 100644 index 0000000..bcf8c5d --- /dev/null +++ b/tests/example/external.json @@ -0,0 +1,6 @@ +{ + "cat": { + "name": "Max", + "color": "tuxedo" + } +} diff --git a/tests/example/internal.json b/tests/example/internal.json new file mode 100644 index 0000000..38d3cd3 --- /dev/null +++ b/tests/example/internal.json @@ -0,0 +1,5 @@ +{ + "#type": "cat", + "name": "Max", + "color": "tuxedo" +} diff --git a/tests/example/untagged.json b/tests/example/untagged.json new file mode 100644 index 0000000..a0707b5 --- /dev/null +++ b/tests/example/untagged.json @@ -0,0 +1,4 @@ +{ + "name": "Max", + "color": "tuxedo" +} diff --git a/tests/test_fields.py b/tests/test_fields.py new file mode 100644 index 0000000..de1e4a7 --- /dev/null +++ b/tests/test_fields.py @@ -0,0 +1,598 @@ +import abc +import collections.abc +import dataclasses +import decimal +import json +import sys +import typing as t + +# https://github.com/pytest-dev/pytest/issues/7469 +import _pytest.fixtures +import attr +import importlib_resources +import marshmallow +import pytest +import typing_extensions + +import desert._fields +import desert.exceptions +import tests.example + + +# TODO: test that field constructor doesn't tromple Field parameters + +_NOTHING = object() + + +@attr.frozen +class ExampleData: + to_serialize: t.Any + serialized: t.Any + deserialized: t.Any + tag: str + field: marshmallow.fields.Field + # TODO: can we be more specific? + hint: t.Any + requires_origin: bool = False + + @classmethod + def build( + cls, + hint: object, + to_serialize: object, + tag: str, + field: marshmallow.fields.Field, + requires_origin: bool = False, + serialized: object = _NOTHING, + deserialized: object = _NOTHING, + ) -> "ExampleData": + if serialized is _NOTHING: + serialized = to_serialize + + if deserialized is _NOTHING: + deserialized = to_serialize + + return cls( + hint=hint, + to_serialize=to_serialize, + serialized=serialized, + deserialized=deserialized, + tag=tag, + field=field, + requires_origin=requires_origin, + ) + + +basic_example_data_list = [ + ExampleData.build( + hint=float, + to_serialize=3.7, + tag="float_tag", + field=marshmallow.fields.Float(), + ), + ExampleData.build( + hint=str, + to_serialize="29", + tag="str_tag", + field=marshmallow.fields.String(), + ), + ExampleData.build( + hint=decimal.Decimal, + to_serialize=decimal.Decimal("4.2"), + serialized="4.2", + tag="decimal_tag", + field=marshmallow.fields.Decimal(as_string=True), + ), + ExampleData.build( + hint=t.List[int], + to_serialize=[1, 2, 3], + tag="integer_list_tag", + field=marshmallow.fields.List(marshmallow.fields.Integer()), + ), + ExampleData.build( + hint=t.List[str], + to_serialize=["abc", "2", "mno"], + tag="string_list_tag", + field=marshmallow.fields.List(marshmallow.fields.String()), + requires_origin=True, + ), + ExampleData.build( + hint=t.Sequence[str], + to_serialize=("def", "13"), + serialized=["def", "13"], + deserialized=["def", "13"], + tag="string_sequence_tag", + field=marshmallow.fields.List(marshmallow.fields.String()), + ), +] + + +@attr.frozen +class CustomExampleClass: + a: int + b: str + + +custom_example_data_list = [ + ExampleData.build( + hint=CustomExampleClass, + to_serialize=CustomExampleClass(a=1, b="b"), + serialized={"a": 1, "b": "b"}, + tag="custom_example_class", + field=marshmallow.fields.Nested(desert.schema(CustomExampleClass)), + ), +] + + +all_example_data_list = basic_example_data_list + custom_example_data_list + + +@pytest.fixture( + name="example_data", + params=all_example_data_list, + ids=[str(example) for example in all_example_data_list], +) +def _example_data(request: _pytest.fixtures.SubRequest) -> ExampleData: + return request.param # type: ignore[no-any-return] + + +@pytest.fixture( + name="custom_example_data", + params=custom_example_data_list, + ids=[str(example) for example in custom_example_data_list], +) +def _custom_example_data(request: _pytest.fixtures.SubRequest) -> ExampleData: + return request.param # type: ignore[no-any-return] + + +# def build_type_dict_registry(examples): +# registry = desert._fields.TypeDictFieldRegistry() +# +# for example in examples: +# registry.register( +# hint=example.hint, +# tag=example.tag, +# field=example.field, +# ) +# +# return registry + + +# class NonStringSequence(abc.ABC): +# @classmethod +# def __subclasshook__(cls, maybe_subclass): +# return isinstance(maybe_subclass, collections.abc.Sequence) and not isinstance( +# maybe_subclass, str +# ) + + +def build_order_isinstance_registry( + examples: t.List[ExampleData], +) -> desert._fields.TypeAndHintFieldRegistry: + registry = desert._fields.TypeAndHintFieldRegistry() + + # registry.register( + # hint=t.List[], + # tag="sequence_abc", + # field=marshmallow.fields.List(marshmallow.fields.String()), + # ) + + for example in examples: + registry.register( + hint=example.hint, + tag=example.tag, + field=example.field, + ) + + return registry + + +registries = [ + # build_type_dict_registry(example_data_list), + build_order_isinstance_registry(all_example_data_list), +] +registry_ids = [type(registry).__name__ for registry in registries] + + +@pytest.fixture( + name="registry", + params=registries, + ids=registry_ids, +) +def _registry( + request: _pytest.fixtures.SubRequest, +) -> desert._fields.TypeAndHintFieldRegistry: + return request.param # type: ignore[no-any-return] + + +def test_registry_raises_for_no_match( + registry: desert._fields.FieldRegistryProtocol, +) -> None: + class C: + pass + + c = C() + + with pytest.raises(desert.exceptions.NoMatchingHintFound): + registry.from_object(value=c) + + +def test_registry_raises_for_multiple_matches() -> None: + registry = desert._fields.TypeAndHintFieldRegistry() + + registry.register( + hint=t.Sequence, + tag="sequence", + field=marshmallow.fields.List(marshmallow.fields.Field()), + ) + + registry.register( + hint=t.Collection, + tag="collection", + field=marshmallow.fields.List(marshmallow.fields.Field()), + ) + + with pytest.raises(desert.exceptions.MultipleMatchingHintsFound): + registry.from_object(value=[]) + + +@pytest.fixture(name="externally_tagged_field", params=[False, True]) +def _externally_tagged_field( + request: _pytest.fixtures.SubRequest, + registry: desert._fields.FieldRegistryProtocol, +) -> desert._fields.TaggedUnionField: + field: desert._fields.TaggedUnionField + + if request.param: + field = desert._fields.externally_tagged_union_from_registry(registry=registry) + else: + field = desert._fields.externally_tagged_union( + from_object=registry.from_object, + from_tag=registry.from_tag, + ) + + return field + + +def test_externally_tagged_deserialize( + example_data: ExampleData, externally_tagged_field: desert._fields.TaggedUnionField +) -> None: + serialized = {example_data.tag: example_data.serialized} + + deserialized = externally_tagged_field.deserialize(serialized) + + expected = example_data.deserialized + + assert (type(deserialized) == type(expected)) and (deserialized == expected) + + +def test_externally_tagged_deserialize_extra_key_raises( + example_data: ExampleData, + externally_tagged_field: desert._fields.TaggedUnionField, +) -> None: + serialized = { + example_data.tag: { + "#value": example_data.serialized, + "extra": 29, + }, + } + + with pytest.raises(expected_exception=Exception): + externally_tagged_field.deserialize(serialized) + + +def test_externally_tagged_serialize( + example_data: ExampleData, + externally_tagged_field: desert._fields.TaggedUnionField, +) -> None: + if example_data.requires_origin and sys.version_info < (3, 7): + pytest.xfail() + + obj = {"key": example_data.to_serialize} + + serialized = externally_tagged_field.serialize("key", obj) + + assert serialized == {example_data.tag: example_data.serialized} + + +@pytest.fixture(name="internally_tagged_field", params=[False, True]) +def _internally_tagged_field( + request: _pytest.fixtures.SubRequest, + registry: desert._fields.FieldRegistryProtocol, +) -> desert._fields.TaggedUnionField: + field: desert._fields.TaggedUnionField + + if request.param: + field = desert._fields.internally_tagged_union_from_registry(registry=registry) + else: + field = desert._fields.internally_tagged_union( + from_object=registry.from_object, + from_tag=registry.from_tag, + ) + + return field + + +def test_to_internally_tagged_raises_for_tag_collision() -> None: + with pytest.raises(desert.exceptions.TypeKeyCollision): + desert._fields.to_internally_tagged( + tag="C", value={"collide": True}, type_key="collide" + ) + + +def test_internally_tagged_deserialize( + custom_example_data: ExampleData, + internally_tagged_field: desert._fields.TaggedUnionField, +) -> None: + serialized = {"#type": custom_example_data.tag, **custom_example_data.serialized} + + deserialized = internally_tagged_field.deserialize(serialized) + + expected = custom_example_data.deserialized + + assert (type(deserialized) == type(expected)) and (deserialized == expected) + + +def test_internally_tagged_serialize( + custom_example_data: ExampleData, + internally_tagged_field: desert._fields.TaggedUnionField, +) -> None: + obj = {"key": custom_example_data.to_serialize} + + serialized = internally_tagged_field.serialize("key", obj) + + assert serialized == { + "#type": custom_example_data.tag, + **custom_example_data.serialized, + } + + +@pytest.fixture(name="adjacently_tagged_field", params=[False, True]) +def _adjacently_tagged_field( + request: _pytest.fixtures.SubRequest, + registry: desert._fields.FieldRegistryProtocol, +) -> desert._fields.TaggedUnionField: + field: desert._fields.TaggedUnionField + + if request.param: + field = desert._fields.adjacently_tagged_union_from_registry(registry=registry) + else: + field = desert._fields.adjacently_tagged_union( + from_object=registry.from_object, + from_tag=registry.from_tag, + ) + + return field + + +def test_adjacently_tagged_deserialize( + example_data: ExampleData, + adjacently_tagged_field: desert._fields.TaggedUnionField, +) -> None: + serialized = {"#type": example_data.tag, "#value": example_data.serialized} + + deserialized = adjacently_tagged_field.deserialize(serialized) + + expected = example_data.deserialized + + assert (type(deserialized) == type(expected)) and (deserialized == expected) + + +def test_adjacently_tagged_deserialize_extra_key_raises( + example_data: ExampleData, + adjacently_tagged_field: desert._fields.TaggedUnionField, +) -> None: + serialized = { + "#type": example_data.tag, + "#value": example_data.serialized, + "extra": 29, + } + + with pytest.raises(expected_exception=Exception): + adjacently_tagged_field.deserialize(serialized) + + +def test_adjacently_tagged_serialize( + example_data: ExampleData, + adjacently_tagged_field: desert._fields.TaggedUnionField, +) -> None: + if example_data.requires_origin and sys.version_info < (3, 7): + pytest.xfail() + + obj = {"key": example_data.to_serialize} + + serialized = adjacently_tagged_field.serialize("key", obj) + + assert serialized == {"#type": example_data.tag, "#value": example_data.serialized} + + +@pytest.mark.parametrize( + argnames=["type_string", "value"], argvalues=[["str", "3"], ["int", 7]] +) +def test_actual_example(type_string: str, value: t.Union[int, str]) -> None: + registry = desert._fields.TypeAndHintFieldRegistry() + registry.register(hint=str, tag="str", field=marshmallow.fields.String()) + registry.register(hint=int, tag="int", field=marshmallow.fields.Integer()) + + field = desert._fields.adjacently_tagged_union_from_registry(registry=registry) + + @attr.frozen + class C: + # TODO: desert.ib() shouldn't be needed for many cases + union: t.Union[str, int] = desert.ib(marshmallow_field=field) + + schema = desert.schema(C) + + objects = C(union=value) + marshalled = {"union": {"#type": type_string, "#value": value}} + serialized = json.dumps(marshalled) + + assert schema.dumps(objects) == serialized + assert schema.loads(serialized) == objects + + +def test_raises_for_tag_reregistration() -> None: + registry = desert._fields.TypeAndHintFieldRegistry() + registry.register(hint=str, tag="duplicate_tag", field=marshmallow.fields.String()) + + with pytest.raises(desert.exceptions.TagAlreadyRegistered): + registry.register( + hint=int, tag="duplicate_tag", field=marshmallow.fields.Integer() + ) + + +# start cat_class_example +@dataclasses.dataclass +class Cat: + name: str + color: str + # end cat_class_example + + +def test_untagged_serializes_like_snippet() -> None: + cat = Cat(name="Max", color="tuxedo") + + reference = importlib_resources.read_text(tests.example, "untagged.json").strip() + + schema = desert.schema(Cat, meta={"ordered": True}) + dumped = schema.dumps(cat, indent=4) + + assert dumped == reference + + +# Marshmallow fields expect to serialize an attribute, not an object directly. +# This class gives us somewhere to stick the object of interest to make the field +# happy. +@attr.frozen +class CatCarrier: + an_object: Cat + + +class FromRegistryProtocol(typing_extensions.Protocol): + def __call__( + self, registry: desert._fields.FieldRegistryProtocol + ) -> desert._fields.TaggedUnionField: + ... + + +@attr.frozen +class ResourceAndRegistryFunction: + resource_name: str + from_registry_function: FromRegistryProtocol + + +@pytest.fixture( + name="resource_and_registry_function", + params=[ + ResourceAndRegistryFunction( + resource_name="adjacent.json", + from_registry_function=desert._fields.adjacently_tagged_union_from_registry, + ), + ResourceAndRegistryFunction( + resource_name="internal.json", + from_registry_function=desert._fields.internally_tagged_union_from_registry, + ), + ResourceAndRegistryFunction( + resource_name="external.json", + from_registry_function=desert._fields.externally_tagged_union_from_registry, + ), + ], +) +def resource_and_registry_function_fixture( + request: _pytest.fixtures.SubRequest, +) -> ResourceAndRegistryFunction: + return request.param # type: ignore[no-any-return] + + +def test_tagged_serializes_like_snippet( + resource_and_registry_function: ResourceAndRegistryFunction, +) -> None: + cat = Cat(name="Max", color="tuxedo") + + registry = desert._fields.TypeAndHintFieldRegistry() + registry.register( + hint=Cat, + tag="cat", + field=marshmallow.fields.Nested(desert.schema(Cat, meta={"ordered": True})), + ) + + reference = importlib_resources.read_text( + tests.example, resource_and_registry_function.resource_name + ).strip() + + field = resource_and_registry_function.from_registry_function(registry=registry) + marshalled = field.serialize(attr="an_object", obj=CatCarrier(an_object=cat)) + dumped = json.dumps(marshalled, indent=4) + + assert dumped == reference + + +def test_tagged_deserializes_from_snippet( + resource_and_registry_function: ResourceAndRegistryFunction, +) -> None: + registry = desert._fields.TypeAndHintFieldRegistry() + registry.register( + hint=Cat, + tag="cat", + field=marshmallow.fields.Nested(desert.schema(Cat, meta={"ordered": True})), + ) + + reference = importlib_resources.read_text( + tests.example, resource_and_registry_function.resource_name + ).strip() + + field = resource_and_registry_function.from_registry_function(registry=registry) + deserialized_cat = field.deserialize(value=json.loads(reference)) + + assert deserialized_cat == Cat(name="Max", color="tuxedo") + + +# start tagged_union_example +def test_tagged_union_example() -> None: + @dataclasses.dataclass + class Dog: + name: str + color: str + + registry = desert._fields.TypeAndHintFieldRegistry() + registry.register( + hint=Cat, + tag="cat", + field=marshmallow.fields.Nested(desert.schema(Cat, meta={"ordered": True})), + ) + registry.register( + hint=Dog, + tag="dog", + field=marshmallow.fields.Nested(desert.schema(Dog, meta={"ordered": True})), + ) + + field = desert._fields.adjacently_tagged_union_from_registry(registry=registry) + + @dataclasses.dataclass + class CatOrDog: + union: t.Union[Cat, Dog] = desert.field(marshmallow_field=field) + + schema = desert.schema(CatOrDog) + + with_a_cat = CatOrDog(union=Cat(name="Max", color="tuxedo")) + with_a_dog = CatOrDog(union=Dog(name="Bubbles", color="black spots on white")) + + marshalled_cat = { + "union": {"#type": "cat", "#value": {"name": "Max", "color": "tuxedo"}} + } + marshalled_dog = { + "union": { + "#type": "dog", + "#value": {"name": "Bubbles", "color": "black spots on white"}, + } + } + + dumped_cat = json.dumps(marshalled_cat) + dumped_dog = json.dumps(marshalled_dog) + + assert dumped_cat == schema.dumps(with_a_cat) + assert dumped_dog == schema.dumps(with_a_dog) + + assert with_a_cat == schema.loads(dumped_cat) + assert with_a_dog == schema.loads(dumped_dog) + # end tagged_union_example