From f8d2c1aec2ac72272063d2e547ddeb7e584b5a46 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 8 Dec 2023 11:17:26 -0600 Subject: [PATCH] Add support for Python 3.12's `type` aliases This adds support for the new syntactic type aliases added in Python 3.12. A few examples: ``` type NullableStr = str | None type Pair[T] = tuple[T, T] type NullableStrPair = Pair[NullableStr] ``` msgspec now supports these type aliases, *except* in cases where the type alias is recursive. For example, the following type isn't supported: ``` type Link[T] = tuple[T, Link[T] | None] ``` The internal datastructure we use to store type information was not designed to handle recursive types like these; supporting them will require a larger refactor. --- docs/source/supported-types.rst | 32 ++++++++++ msgspec/_core.c | 107 +++++++++++++++++++++++++------- msgspec/inspect.py | 9 +++ tests/test_common.py | 97 +++++++++++++++++++++++++++++ tests/test_convert.py | 26 ++++++-- tests/test_inspect.py | 17 +++++ 6 files changed, 258 insertions(+), 30 deletions(-) diff --git a/docs/source/supported-types.rst b/docs/source/supported-types.rst index 834cc1a7..7abe0af3 100644 --- a/docs/source/supported-types.rst +++ b/docs/source/supported-types.rst @@ -49,6 +49,8 @@ Most combinations of the following types are supported (with a few restrictions) - `typing.Literal` - `typing.NewType` - `typing.Final` +- `typing.TypeAliasType` +- `typing.TypeAlias` - `typing.NamedTuple` / `collections.namedtuple` - `typing.TypedDict` - `typing.Generic` @@ -1170,6 +1172,36 @@ support here is purely to aid static analysis tools like mypy_ or pyright_. File "", line 1, in msgspec.ValidationError: Expected `int`, got `str` +Type Aliases +------------ + +For complex types, sometimes it can be nice to write the type once so you can +reuse it later. + +.. code-block:: python + + Point = tuple[float, float] + +Here ``Point`` is a "type alias" for ``tuple[float, float]`` - ``msgspec`` +will substitute in ``tuple[float, float]`` whenever the ``Point`` type +is used in an annotation. + +``msgspec`` supports the following equivalent forms: + +.. code-block:: python + + # Using variable assignment + Point = tuple[float, float] + + # Using variable assignment, annotated as a `TypeAlias` + Point: TypeAlias = tuple[float, float] + + # Using Python 3.12's new `type` statement. This only works on Python 3.12+ + type Point = tuple[float, float] + +To learn more about Type Aliases, see Python's `Type Alias docs here +`__. + Generic Types ------------- diff --git a/msgspec/_core.c b/msgspec/_core.c index aeacbe26..5e9804c7 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -15,6 +15,12 @@ #include "ryu.h" #include "atof.h" +/* Python version checks */ +#define PY39_PLUS (PY_VERSION_HEX >= 0x03090000) +#define PY310_PLUS (PY_VERSION_HEX >= 0x030a0000) +#define PY311_PLUS (PY_VERSION_HEX >= 0x030b0000) +#define PY312_PLUS (PY_VERSION_HEX >= 0x030c0000) + /* Hint to the compiler not to store `x` in a register since it is likely to * change. Results in much higher performance on GCC, with smaller benefits on * clang */ @@ -36,18 +42,18 @@ ms_popcount(uint64_t i) { \ } #endif -#if PY_VERSION_HEX < 0x03090000 -#define CALL_ONE_ARG(f, a) PyObject_CallFunctionObjArgs(f, a, NULL) -#define CALL_NO_ARGS(f) PyObject_CallFunctionObjArgs(f, NULL) -#define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodObjArgs(o, n, a, NULL) -#define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodObjArgs(o, n, NULL) -#define SET_SIZE(obj, size) (((PyVarObject *)obj)->ob_size = size) -#else +#if PY39_PLUS #define CALL_ONE_ARG(f, a) PyObject_CallOneArg(f, a) #define CALL_NO_ARGS(f) PyObject_CallNoArgs(f) #define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodOneArg(o, n, a) #define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodNoArgs(o, n) #define SET_SIZE(obj, size) Py_SET_SIZE(obj, size) +#else +#define CALL_ONE_ARG(f, a) PyObject_CallFunctionObjArgs(f, a, NULL) +#define CALL_NO_ARGS(f) PyObject_CallFunctionObjArgs(f, NULL) +#define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodObjArgs(o, n, a, NULL) +#define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodObjArgs(o, n, NULL) +#define SET_SIZE(obj, size) (((PyVarObject *)obj)->ob_size = size) #endif #define DIV_ROUND_CLOSEST(n, d) ((((n) < 0) == ((d) < 0)) ? (((n) + (d)/2)/(d)) : (((n) - (d)/2)/(d))) @@ -157,7 +163,7 @@ fast_long_extract_parts(PyObject *vv, bool *neg, uint64_t *scale) { uint64_t prev, x = 0; bool negative; -#if PY_VERSION_HEX >= 0x030c0000 +#if PY312_PLUS /* CPython 3.12 changed int internal representation */ int sign = 1 - (v->long_value.lv_tag & _PyLong_SIGN_MASK); negative = sign == -1; @@ -405,6 +411,9 @@ typedef struct { PyObject *str___dataclass_fields__; PyObject *str___attrs_attrs__; PyObject *str___supertype__; +#if PY312_PLUS + PyObject *str___value__; +#endif PyObject *str___bound__; PyObject *str___constraints__; PyObject *str_int; @@ -427,8 +436,11 @@ typedef struct { PyObject *get_typeddict_info; PyObject *get_dataclass_info; PyObject *rebuild; -#if PY_VERSION_HEX >= 0x030a00f0 +#if PY310_PLUS PyObject *types_uniontype; +#endif +#if PY312_PLUS + PyObject *typing_typealiastype; #endif PyObject *astimezone; PyObject *re_compile; @@ -2122,7 +2134,7 @@ PyTypeObject NoDefault_Type = { .tp_basicsize = 0 }; -#if PY_VERSION_HEX >= 0x030c0000 +#if PY312_PLUS PyObject _NoDefault_Object = { _PyObject_EXTRA_INIT { _Py_IMMORTAL_REFCNT }, @@ -2226,7 +2238,7 @@ PyTypeObject Unset_Type = { .tp_basicsize = 0 }; -#if PY_VERSION_HEX >= 0x030c0000 +#if PY312_PLUS PyObject _Unset_Object = { _PyObject_EXTRA_INIT { _Py_IMMORTAL_REFCNT }, @@ -4459,6 +4471,21 @@ typenode_origin_args_metadata( t = temp; continue; } + /* Check for parametrized TypeAliasType if Python 3.12+ */ + #if PY312_PLUS + if (Py_TYPE(origin) == (PyTypeObject *)(state->mod->typing_typealiastype)) { + PyObject *value = PyObject_GetAttr(origin, state->mod->str___value__); + if (value == NULL) goto error; + PyObject *temp = PyObject_GetItem(value, args); + Py_DECREF(value); + if (temp == NULL) goto error; + Py_CLEAR(args); + Py_CLEAR(origin); + Py_DECREF(t); + t = temp; + continue; + } + #endif } else { /* Custom non-parametrized generics won't have __args__ @@ -4487,14 +4514,23 @@ typenode_origin_args_metadata( t = supertype; continue; } - else { - PyErr_Clear(); - break; + PyErr_Clear(); + + /* Check for TypeAliasType if Python 3.12+ */ + #if PY312_PLUS + if (Py_TYPE(t) == (PyTypeObject *)(state->mod->typing_typealiastype)) { + PyObject *value = PyObject_GetAttr(t, state->mod->str___value__); + if (value == NULL) goto error; + Py_DECREF(t); + t = value; + continue; } + #endif + break; } } - #if PY_VERSION_HEX >= 0x030a00f0 + #if PY310_PLUS if (Py_TYPE(t) == (PyTypeObject *)(state->mod->types_uniontype)) { /* Handle types.UnionType unions (`int | float | ...`) */ args = PyObject_GetAttr(t, state->mod->str___args__); @@ -4692,6 +4728,10 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) { } } else if (origin == state->mod->typing_union) { + if (Py_EnterRecursiveCall(" while analyzing a type")) { + out = -1; + goto done; + } for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(args); i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); /* Ignore UnsetType in unions */ @@ -4699,6 +4739,7 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) { out = typenode_collect_type(state, arg); if (out < 0) break; } + Py_LeaveRecursiveCall(); } else if (origin == state->mod->typing_literal) { if (state->literals == NULL) { @@ -4761,6 +4802,8 @@ TypeNode_Convert(PyObject *obj) { state.mod = msgspec_get_global_state(); state.context = obj; + if (Py_EnterRecursiveCall(" while analyzing a type")) return NULL; + /* Traverse `obj` to collect all type annotations at this level */ if (typenode_collect_type(&state, obj) < 0) goto done; /* Handle structs in a second pass */ @@ -4773,6 +4816,7 @@ TypeNode_Convert(PyObject *obj) { out = typenode_from_collect_state(&state); done: typenode_collect_clear_state(&state); + Py_LeaveRecursiveCall(); return out; } @@ -9717,14 +9761,14 @@ ms_encode_err_type_unsupported(PyTypeObject *type) { *************************************************************************/ #define MS_HAS_TZINFO(o) (((_PyDateTime_BaseTZInfo *)(o))->hastzinfo) -#if PY_VERSION_HEX < 0x030a00f0 +#if PY310_PLUS +#define MS_DATE_GET_TZINFO(o) PyDateTime_DATE_GET_TZINFO(o) +#define MS_TIME_GET_TZINFO(o) PyDateTime_TIME_GET_TZINFO(o) +#else #define MS_DATE_GET_TZINFO(o) (MS_HAS_TZINFO(o) ? \ ((PyDateTime_DateTime *)(o))->tzinfo : Py_None) #define MS_TIME_GET_TZINFO(o) (MS_HAS_TZINFO(o) ? \ ((PyDateTime_Time *)(o))->tzinfo : Py_None) -#else -#define MS_DATE_GET_TZINFO(o) PyDateTime_DATE_GET_TZINFO(o) -#define MS_TIME_GET_TZINFO(o) PyDateTime_TIME_GET_TZINFO(o) #endif #ifndef TIMEZONE_CACHE_SIZE @@ -15472,7 +15516,7 @@ static struct PyMethodDef Decoder_methods[] = { "decode", (PyCFunction) Decoder_decode, METH_FASTCALL, Decoder_decode__doc__, }, -#if PY_VERSION_HEX >= 0x03090000 +#if PY39_PLUS {"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS}, #endif {NULL, NULL} /* sentinel */ @@ -18512,7 +18556,7 @@ static struct PyMethodDef JSONDecoder_methods[] = { "decode_lines", (PyCFunction) JSONDecoder_decode_lines, METH_FASTCALL, JSONDecoder_decode_lines__doc__, }, -#if PY_VERSION_HEX >= 0x03090000 +#if PY39_PLUS {"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS}, #endif {NULL, NULL} /* sentinel */ @@ -21029,6 +21073,9 @@ msgspec_clear(PyObject *m) Py_CLEAR(st->str___dataclass_fields__); Py_CLEAR(st->str___attrs_attrs__); Py_CLEAR(st->str___supertype__); +#if PY312_PLUS + Py_CLEAR(st->str___value__); +#endif Py_CLEAR(st->str___bound__); Py_CLEAR(st->str___constraints__); Py_CLEAR(st->str_int); @@ -21051,8 +21098,11 @@ msgspec_clear(PyObject *m) Py_CLEAR(st->get_typeddict_info); Py_CLEAR(st->get_dataclass_info); Py_CLEAR(st->rebuild); -#if PY_VERSION_HEX >= 0x030a00f0 +#if PY310_PLUS Py_CLEAR(st->types_uniontype); +#endif +#if PY312_PLUS + Py_CLEAR(st->typing_typealiastype); #endif Py_CLEAR(st->astimezone); Py_CLEAR(st->re_compile); @@ -21118,8 +21168,11 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg) Py_VISIT(st->get_typeddict_info); Py_VISIT(st->get_dataclass_info); Py_VISIT(st->rebuild); -#if PY_VERSION_HEX >= 0x030a00f0 +#if PY310_PLUS Py_VISIT(st->types_uniontype); +#endif +#if PY312_PLUS + Py_VISIT(st->typing_typealiastype); #endif Py_VISIT(st->astimezone); Py_VISIT(st->re_compile); @@ -21315,6 +21368,9 @@ PyInit__core(void) SET_REF(typing_final, "Final"); SET_REF(typing_generic, "Generic"); SET_REF(typing_generic_alias, "_GenericAlias"); +#if PY312_PLUS + SET_REF(typing_typealiastype, "TypeAliasType"); +#endif Py_DECREF(temp_module); temp_module = PyImport_ImportModule("msgspec._utils"); @@ -21328,7 +21384,7 @@ PyInit__core(void) SET_REF(rebuild, "rebuild"); Py_DECREF(temp_module); -#if PY_VERSION_HEX >= 0x030a00f0 +#if PY310_PLUS temp_module = PyImport_ImportModule("types"); if (temp_module == NULL) return NULL; SET_REF(types_uniontype, "UnionType"); @@ -21411,6 +21467,9 @@ PyInit__core(void) CACHED_STRING(str___dataclass_fields__, "__dataclass_fields__"); CACHED_STRING(str___attrs_attrs__, "__attrs_attrs__"); CACHED_STRING(str___supertype__, "__supertype__"); +#if PY312_PLUS + CACHED_STRING(str___value__, "__value__"); +#endif CACHED_STRING(str___bound__, "__bound__"); CACHED_STRING(str___constraints__, "__constraints__"); CACHED_STRING(str_int, "int"); diff --git a/msgspec/inspect.py b/msgspec/inspect.py index 664f24c6..3838a106 100644 --- a/msgspec/inspect.py +++ b/msgspec/inspect.py @@ -20,6 +20,11 @@ except Exception: _types_UnionType = type("UnionType", (), {}) # type: ignore +try: + from typing import TypeAliasType as _TypeAliasType # type: ignore +except Exception: + _TypeAliasType = type("TypeAliasType", (), {}) # type: ignore + import msgspec from msgspec import NODEFAULT, UNSET, UnsetType as _UnsetType @@ -628,6 +633,8 @@ def _origin_args_metadata(t): t = origin elif origin == Final: t = t.__args__[0] + elif type(origin) is _TypeAliasType: + t = origin.__value__[t.__args__] else: args = getattr(t, "__args__", None) origin = _CONCRETE_TYPES.get(origin, origin) @@ -636,6 +643,8 @@ def _origin_args_metadata(t): supertype = getattr(t, "__supertype__", None) if supertype is not None: t = supertype + elif type(t) is _TypeAliasType: + t = t.__value__ else: origin = t args = None diff --git a/tests/test_common.py b/tests/test_common.py index 9886da0d..2f4731dc 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -47,10 +47,12 @@ PY39 = sys.version_info[:2] >= (3, 9) PY310 = sys.version_info[:2] >= (3, 10) PY311 = sys.version_info[:2] >= (3, 11) +PY312 = sys.version_info[:2] >= (3, 12) py39_plus = pytest.mark.skipif(not PY39, reason="3.9+ only") py310_plus = pytest.mark.skipif(not PY310, reason="3.10+ only") py311_plus = pytest.mark.skipif(not PY311, reason="3.11+ only") +py312_plus = pytest.mark.skipif(not PY312, reason="3.12+ only") T = TypeVar("T") @@ -3697,6 +3699,101 @@ def test_decode_annotated_newtype_annotated(self, proto, Annotated): dec.decode(proto.encode(bad)) +class TestTypeAlias: + @py312_plus + def test_simple(self, proto): + with temp_module("type Ex = str | None") as mod: + dec = proto.Decoder(mod.Ex) + assert dec.decode(proto.encode("test")) == "test" + assert dec.decode(proto.encode(None)) is None + with pytest.raises(ValidationError): + dec.decode(proto.encode(1)) + + @py312_plus + def test_generic(self, proto): + with temp_module("type Pair[T] = tuple[T, T]") as mod: + dec = proto.Decoder(mod.Pair) + assert dec.decode(proto.encode((1, 2))) == (1, 2) + for bad in [1, [1, 2, 3]]: + with pytest.raises(ValidationError): + dec.decode(proto.encode(bad)) + + @py312_plus + def test_parametrized_generic(self, proto): + with temp_module("type Pair[T] = tuple[T, T]") as mod: + dec = proto.Decoder(mod.Pair[int]) + assert dec.decode(proto.encode((1, 2))) == (1, 2) + for bad in [1, [1, 2, 3], [1, "a"]]: + with pytest.raises(ValidationError): + dec.decode(proto.encode(bad)) + + @py312_plus + def test_typealias_wrapping_typealias(self, proto): + src = """ + type Pair[T] = tuple[T, T] + type Pairs[T] = list[Pair[T]] + """ + with temp_module(src) as mod: + dec = proto.Decoder(mod.Pairs) + for good in [[], [(1, 2), (3, 4)]]: + assert dec.decode(proto.encode(good)) == good + for bad in [1, [1], [(1, 2, 3)]]: + with pytest.raises(ValidationError): + dec.decode(proto.encode(bad)) + + dec = proto.Decoder(mod.Pairs[int]) + for good in [[], [(1, 2)], [(1, 2), (3, 4)]]: + assert dec.decode(proto.encode(good)) == good + for bad in [1, [1], [(1, "a")]]: + with pytest.raises(ValidationError): + dec.decode(proto.encode(bad)) + + @py312_plus + def test_typealias_with_constraints(self, proto): + src = """ + import msgspec + from typing import Annotated + type Key = Annotated[str, msgspec.Meta(max_length=4)] + """ + with temp_module(src) as mod: + dec = proto.Decoder(mod.Key) + for good in ["", "abc", "abcd"]: + assert dec.decode(proto.encode(good)) == good + for bad in [1, "abcde"]: + with pytest.raises(ValidationError): + dec.decode(proto.encode(bad)) + + @py312_plus + def test_typealias_parametrized_generic_too_many_parameters(self): + with temp_module("type Pair[T] = tuple[T, T]") as mod: + with pytest.raises(TypeError): + msgspec.json.Decoder(mod.Pair[int, int]) + + @py312_plus + @pytest.mark.parametrize( + "src", + [ + "type Ex = Ex | None", + "type Ex = tuple[Ex, int]", + "type Ex[T] = tuple[T, Ex[T]]", + "type Temp[T] = tuple[T, Temp[T]]; Ex = Temp[int]", + "type Temp[T] = tuple[T, Ex[T]]; type Ex[T] = tuple[Temp[T], T];", + ], + ) + def test_recursive_typealias_errors(self, src): + """Eventually we should support this, but for now just test that it + errors cleanly""" + with temp_module(src) as mod: + with pytest.raises(RecursionError): + msgspec.json.Decoder(mod.Ex) + + @py312_plus + def test_typealias_invalid_type(self): + with temp_module("type Ex = int | complex") as mod: + with pytest.raises(TypeError): + msgspec.json.Decoder(mod.Ex) + + class TestDecimal: def test_encoder_decimal_format(self, proto): assert proto.Encoder().decimal_format == "string" diff --git a/tests/test_convert.py b/tests/test_convert.py index 833fe6b0..02d09143 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -1032,11 +1032,17 @@ def test_sequence_cyclic_recursion(self, kind): for _ in range(depth): typ = FrozenSet[typ] - msg = [] - msg.append(msg) + class Cache(Struct): + value: typ + + msgspec.json.Decoder(Cache) + + arr = [] + arr.append(arr) + msg = {"value": arr} with pytest.raises(RecursionError): with max_call_depth(5): - assert convert(msg, typ) + convert(msg, Cache) @pytest.mark.parametrize("out_type", [list, tuple, set, frozenset]) @uses_annotated @@ -1246,11 +1252,19 @@ def test_dict_cyclic_recursion(self, dictcls): typ = Dict[str, int] for _ in range(depth): typ = Dict[str, typ] - msg = dictcls() - msg["x"] = msg + + class Cache(Struct): + value: typ + + msgspec.json.Decoder(Cache) + + map = dictcls() + map["x"] = map + msg = {"value": map} + with pytest.raises(RecursionError): with max_call_depth(5): - assert convert(msg, typ) + convert(msg, Cache) @uses_annotated def test_dict_constrs(self, dictcls): diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 7e3d00cc..f1a77cae 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -43,6 +43,9 @@ PY39 = sys.version_info[:2] >= (3, 9) +PY312 = sys.version_info[:2] >= (3, 12) + +py312_plus = pytest.mark.skipif(not PY312, reason="3.12+ only") T = TypeVar("T") @@ -204,6 +207,20 @@ def test_newtype(): ) +@py312_plus +@pytest.mark.parametrize( + "src, typ", + [ + ("type Ex = str | None", Union[str, None]), + ("type Ex[T] = tuple[T, int]", Tuple[Any, int]), + ("type Temp[T] = tuple[T, int]; Ex = Temp[str]", Tuple[str, int]), + ], +) +def test_typealias(src, typ): + with temp_module(src) as mod: + assert mi.type_info(mod.Ex) == mi.type_info(typ) + + def test_final(Annotated): cases = [ (int, mi.IntType()),