diff --git a/docs/source/supported-types.rst b/docs/source/supported-types.rst index 8ea6740f..75394921 100644 --- a/docs/source/supported-types.rst +++ b/docs/source/supported-types.rst @@ -44,6 +44,7 @@ Most combinations of the following types are supported (with a few restrictions) - `typing.Union` - `typing.Literal` - `typing.NewType` +- `typing.Final` - `typing.NamedTuple` / `collections.namedtuple` - `typing.TypedDict` diff --git a/msgspec/_core.c b/msgspec/_core.c index 02dafc17..dfd5463c 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -372,6 +372,7 @@ typedef struct { PyObject *typing_any; PyObject *typing_literal; PyObject *typing_classvar; + PyObject *typing_final; PyObject *typing_generic_alias; PyObject *typing_annotated_alias; PyObject *concrete_types; @@ -4123,44 +4124,92 @@ typenode_origin_args_metadata( PyObject *t = obj; Py_INCREF(t); - /* First strip out meta "wrapper" types (Annotated, NewType) */ + /* First strip out meta "wrapper" types (Annotated, NewType, Final) */ while (true) { - if (Py_TYPE(t) == (PyTypeObject *)(state->mod->typing_annotated_alias)) { - /* Handle Annotated */ - PyObject *origin = PyObject_GetAttr(t, state->mod->str___origin__); - if (origin == NULL) { - Py_CLEAR(t); - return NULL; - } + assert(t != NULL && origin == NULL && args == NULL); - PyObject *metadata = PyObject_GetAttr(t, state->mod->str___metadata__); - if (metadata == NULL) { - Py_DECREF(origin); + /* Before inspecting attributes, try looking up the object in the + * abstract -> concrete mapping. If present, this is an unparametrized + * collection of some form. This helps avoid compatibility issues in + * Python 3.8, where unparametrized collections still have __args__. */ + origin = PyDict_GetItem(state->mod->concrete_types, t); + if (origin != NULL) { + Py_INCREF(origin); + break; + } + + /* If `t` is a type instance, no need to inspect further */ + if (PyType_CheckExact(t)) { + /* t is a concrete type object. */ + break; + } + + origin = PyObject_GetAttr(t, state->mod->str___origin__); + if (origin != NULL) { + if (Py_TYPE(t) == (PyTypeObject *)(state->mod->typing_annotated_alias)) { + /* Handle typing.Annotated[...] */ + PyObject *metadata = PyObject_GetAttr(t, state->mod->str___metadata__); + if (metadata == NULL) goto error; + for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(metadata); i++) { + PyObject *annot = PyTuple_GET_ITEM(metadata, i); + if (Py_TYPE(annot) == &Meta_Type) { + if (constraints_update(constraints, (Meta *)annot, obj) < 0) { + Py_DECREF(metadata); + goto error; + } + } + } + Py_DECREF(metadata); Py_DECREF(t); - return NULL; + t = origin; + origin = NULL; + continue; } - - for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(metadata); i++) { - PyObject *annot = PyTuple_GET_ITEM(metadata, i); - if (Py_TYPE(annot) == &Meta_Type) { - if (constraints_update(constraints, (Meta *)annot, obj) < 0) { - Py_DECREF(metadata); - Py_DECREF(origin); + else { + args = PyObject_GetAttr(t, state->mod->str___args__); + if (args != NULL) { + if (!PyTuple_Check(args)) { + PyErr_SetString(PyExc_TypeError, "__args__ must be a tuple"); + goto error; + } + if (origin == state->mod->typing_final) { + /* Handle typing.Final[...] */ + PyObject *temp = PyTuple_GetItem(args, 0); + if (temp == NULL) goto error; + Py_CLEAR(args); + Py_CLEAR(origin); Py_DECREF(t); - return NULL; + Py_INCREF(temp); + t = temp; + continue; } } + else { + /* Custom non-parametrized generics won't have __args__ + * set. Ignore __args__ error */ + PyErr_Clear(); + } + /* Lookup __origin__ in the mapping, in case it's a supported + * abstract type. Equal to `origin = mapping.get(origin, origin)` */ + PyObject *temp = PyDict_GetItem(state->mod->concrete_types, origin); + if (temp != NULL) { + Py_DECREF(origin); + Py_INCREF(temp); + origin = temp; + } + break; } - Py_DECREF(metadata); - Py_DECREF(t); - t = origin; } else { - /* Handle NewType */ + PyErr_Clear(); + + /* Check for NewType */ PyObject *supertype = PyObject_GetAttr(t, state->mod->str___supertype__); if (supertype != NULL) { + /* It's a newtype, use the wrapped type and loop again */ Py_DECREF(t); t = supertype; + continue; } else { PyErr_Clear(); @@ -4169,59 +4218,25 @@ typenode_origin_args_metadata( } } - /* At this point `t` is a concrete type. Next check for generic types, - * extracting `__origin__` and `__args__`. This lets us normalize how - * we check for collection types later */ - if ((origin = PyDict_GetItem(state->mod->concrete_types, t)) != NULL) { - Py_INCREF(origin); - } #if PY_VERSION_HEX >= 0x030a00f0 - else if (Py_TYPE(t) == (PyTypeObject *)(state->mod->types_uniontype)) { + if (Py_TYPE(t) == (PyTypeObject *)(state->mod->types_uniontype)) { + /* Handle types.UnionType unions (`int | float | ...`) */ args = PyObject_GetAttr(t, state->mod->str___args__); - if (args == NULL) { - Py_DECREF(t); - return NULL; - } + if (args == NULL) goto error; origin = state->mod->typing_union; Py_INCREF(origin); } #endif - else { - origin = PyObject_GetAttr(t, state->mod->str___origin__); - if (origin == NULL) { - /* Not a generic */ - PyErr_Clear(); - } - else { - /* Lookup __origin__ in the mapping, in case it's a supported - * abstract type */ - PyObject *temp = PyDict_GetItem(state->mod->concrete_types, origin); - if (temp != NULL) { - Py_DECREF(origin); - Py_INCREF(temp); - origin = temp; - } - args = PyObject_GetAttr(t, state->mod->str___args__); - if (args == NULL) { - /* Custom non-parametrized generics won't have __args__ set. - * Ignore __args__ error */ - PyErr_Clear(); - } - else { - if (!PyTuple_Check(args)) { - PyErr_SetString(PyExc_TypeError, "__args__ must be a tuple"); - Py_DECREF(t); - Py_DECREF(origin); - Py_DECREF(args); - return NULL; - } - } - } - } *out_origin = origin; *out_args = args; return t; + +error: + Py_XDECREF(t); + Py_XDECREF(origin); + Py_XDECREF(args); + return NULL; } static int @@ -10443,7 +10458,7 @@ mpack_encode_struct(EncoderState *self, PyObject *obj) actual_len--; } else { - if (mpack_encode_str(self, key) < 0) goto cleanup; + if (mpack_encode_str(self, key) < 0) goto cleanup; if (mpack_encode(self, val) < 0) goto cleanup; } } @@ -10458,7 +10473,7 @@ mpack_encode_struct(EncoderState *self, PyObject *obj) actual_len--; } else { - if (mpack_encode_str(self, key) < 0) goto cleanup; + if (mpack_encode_str(self, key) < 0) goto cleanup; if (mpack_encode(self, val) < 0) goto cleanup; } } @@ -18606,6 +18621,7 @@ msgspec_clear(PyObject *m) Py_CLEAR(st->typing_any); Py_CLEAR(st->typing_literal); Py_CLEAR(st->typing_classvar); + Py_CLEAR(st->typing_final); Py_CLEAR(st->typing_generic_alias); Py_CLEAR(st->typing_annotated_alias); Py_CLEAR(st->concrete_types); @@ -18686,6 +18702,7 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg) Py_VISIT(st->typing_any); Py_VISIT(st->typing_literal); Py_VISIT(st->typing_classvar); + Py_VISIT(st->typing_final); Py_VISIT(st->typing_generic_alias); Py_VISIT(st->typing_annotated_alias); Py_VISIT(st->concrete_types); @@ -18893,6 +18910,7 @@ PyInit__core(void) SET_REF(typing_any, "Any"); SET_REF(typing_literal, "Literal"); SET_REF(typing_classvar, "ClassVar"); + SET_REF(typing_final, "Final"); SET_REF(typing_generic_alias, "_GenericAlias"); Py_DECREF(temp_module); diff --git a/msgspec/_utils.py b/msgspec/_utils.py index d4c41c05..ea63394b 100644 --- a/msgspec/_utils.py +++ b/msgspec/_utils.py @@ -34,50 +34,34 @@ def get_type_hints(obj): # A mapping from a type annotation (or annotation __origin__) to the concrete -# python type that msgspec will use when decoding. Note that non-collection -# types don't strict need to be in this mapping. Common ones are added to avoid -# an unnecessary `getattr(t, "__origin__", None)` call on them. -# THIS IS PRIVATE FOR A REASON. DON'T MUCK WITH THIS. +# python type that msgspec will use when decoding. THIS IS PRIVATE FOR A +# REASON. DON'T MUCK WITH THIS. _CONCRETE_TYPES = { - t: t - for t in [ - None, - bool, - int, - float, - str, - bytes, - bytearray, - list, - tuple, - set, - frozenset, - dict, - ] + list: list, + tuple: tuple, + set: set, + frozenset: frozenset, + dict: dict, + typing.List: list, + typing.Tuple: tuple, + typing.Set: set, + typing.FrozenSet: frozenset, + typing.Dict: dict, + typing.Collection: list, + typing.MutableSequence: list, + typing.Sequence: list, + typing.MutableMapping: dict, + typing.Mapping: dict, + typing.MutableSet: set, + typing.AbstractSet: set, + collections.abc.Collection: list, + collections.abc.MutableSequence: list, + collections.abc.Sequence: list, + collections.abc.MutableSet: set, + collections.abc.Set: set, + collections.abc.MutableMapping: dict, + collections.abc.Mapping: dict, } -_CONCRETE_TYPES.update( - { - typing.List: list, - typing.Tuple: tuple, - typing.Set: set, - typing.FrozenSet: frozenset, - typing.Dict: dict, - typing.Collection: list, - typing.MutableSequence: list, - typing.Sequence: list, - typing.MutableMapping: dict, - typing.Mapping: dict, - typing.MutableSet: set, - typing.AbstractSet: set, - collections.abc.Collection: list, - collections.abc.MutableSequence: list, - collections.abc.Sequence: list, - collections.abc.MutableSet: set, - collections.abc.Set: set, - collections.abc.MutableMapping: dict, - collections.abc.Mapping: dict, - } -) def get_typeddict_hints(obj): diff --git a/msgspec/inspect.py b/msgspec/inspect.py index 9ad6c100..683ed350 100644 --- a/msgspec/inspect.py +++ b/msgspec/inspect.py @@ -5,7 +5,7 @@ import enum import uuid from collections.abc import Iterable -from typing import Any, Literal, Tuple, Type as typing_Type, Union +from typing import Any, Final, Literal, Tuple, Type as typing_Type, Union try: from types import UnionType as _types_UnionType @@ -611,34 +611,38 @@ def type_info(type: Any, *, protocol: Literal[None, "msgpack", "json"] = None) - # Implementation details def _origin_args_metadata(t): - # Strip Annotated and NewType wrappers until we hit a concrete base type + # Strip wrappers (Annotated, NewType, Final) until we hit a concrete type metadata = [] while True: - supertype = getattr(t, "__supertype__", None) - if supertype is not None: - t = supertype - elif type(t) is _AnnotatedAlias: - metadata.extend(m for m in t.__metadata__ if type(m) is msgspec.Meta) - t = t.__origin__ - else: + origin = _CONCRETE_TYPES.get(t) + if origin is not None: + args = None break - if type(t) is _types_UnionType: - args = t.__args__ - t = Union - else: - try: - t = _CONCRETE_TYPES[t] - args = None - except Exception: - try: - origin = t.__origin__ - except AttributeError: - args = None + origin = getattr(t, "__origin__", None) + if origin is not None: + if type(t) is _AnnotatedAlias: + metadata.extend(m for m in t.__metadata__ if type(m) is msgspec.Meta) + t = origin + elif origin == Final: + t = t.__args__[0] else: args = getattr(t, "__args__", None) - t = _CONCRETE_TYPES.get(origin, origin) - return t, args, tuple(metadata) + origin = _CONCRETE_TYPES.get(origin, origin) + break + else: + supertype = getattr(t, "__supertype__", None) + if supertype is not None: + t = supertype + else: + origin = t + args = None + break + + if type(origin) is _types_UnionType: + args = origin.__args__ + origin = Union + return origin, args, tuple(metadata) def _is_struct(t): diff --git a/tests/basic_typing_examples.py b/tests/basic_typing_examples.py index 09079372..cf5d9bfe 100644 --- a/tests/basic_typing_examples.py +++ b/tests/basic_typing_examples.py @@ -3,7 +3,7 @@ import datetime import pickle -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, Final, List, Type, Union import msgspec @@ -100,6 +100,18 @@ class Test(Base, kw_only=True): Test(b"foo", "test", a=1, b=[1, 2, 3]) +def check_struct_final_fields() -> None: + """Test that type checkers support `Final` fields for + dataclass_transform""" + class Test(msgspec.Struct): + x: Final[int] = 0 + + t = Test() + t2 = Test(x=1) + reveal_type(t.x) # assert "int" in typ + reveal_type(t2.x) # assert "int" in typ + + def check_struct_repr_omit_defaults() -> None: class Test(msgspec.Struct, repr_omit_defaults=True): x: int diff --git a/tests/test_common.py b/tests/test_common.py index 5feb98ff..95b42b02 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -15,6 +15,7 @@ from typing import ( Deque, Dict, + Final, List, Literal, NamedTuple, @@ -2817,3 +2818,27 @@ class Ex(msgspec.Struct, omit_defaults=True): res = proto.encode(x) sol = proto.encode(y) assert res == sol + + +class TestFinal: + def test_decode_final(self, proto): + dec = proto.Decoder(Final[int]) + + assert dec.decode(proto.encode(1)) == 1 + with pytest.raises(msgspec.ValidationError): + dec.decode(proto.encode("bad")) + + def test_decode_final_annotated(self, proto, Annotated): + dec = proto.Decoder(Final[Annotated[int, msgspec.Meta(ge=0)]]) + + assert dec.decode(proto.encode(1)) == 1 + with pytest.raises(msgspec.ValidationError): + dec.decode(proto.encode(-1)) + + def test_decode_final_newtype(self, proto): + UserId = NewType("UserId", int) + dec = proto.Decoder(Final[UserId]) + + assert dec.decode(proto.encode(1)) == 1 + with pytest.raises(msgspec.ValidationError): + dec.decode(proto.encode("bad")) diff --git a/tests/test_inspect.py b/tests/test_inspect.py index f41a23f4..faa4ff6c 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -12,6 +12,7 @@ from typing import ( Any, Dict, + Final, FrozenSet, List, Literal, @@ -186,6 +187,22 @@ def test_newtype(): ) +@pytest.mark.parametrize( + "typ, sol", + [ + (int, mi.IntType()), + (Annotated[int, Meta(ge=0)], mi.IntType(ge=0)), + (NewType("UserId", Annotated[int, Meta(ge=0)]), mi.IntType(ge=0)), + ], +) +def test_final(typ, sol): + class Ex(msgspec.Struct): + x: Final[typ] + + info = mi.type_info(Ex) + assert info.fields[0].type == sol + + def test_custom(): assert mi.type_info(complex) == mi.CustomType(complex)