From 0855efeadb6ad199c31fcc5b0706c884b6fccfb5 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sun, 18 Jul 2021 12:10:19 +0300 Subject: [PATCH] bpo-44633: Fix parameter substitution of the union type with wrong types. (GH-27218) A TypeError is now raised instead of returning NotImplemented. (cherry picked from commit 3ea5332a4365bdd771286b3e9692495116e9ceef) Co-authored-by: Serhiy Storchaka --- Lib/test/test_types.py | 6 +++ .../2021-07-17-21-04-04.bpo-44633.5-zKeI.rst | 2 + Objects/unionobject.c | 44 ++++++++++++------- 3 files changed, 36 insertions(+), 16 deletions(-) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2021-07-17-21-04-04.bpo-44633.5-zKeI.rst diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index b2e11308f3e80d..6ed74714dc578e 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -772,6 +772,12 @@ def test_union_parameter_chaining(self): self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T]) self.assertEqual((list[T] | list[S])[int, int], list[int]) + def test_union_parameter_substitution_errors(self): + T = typing.TypeVar("T") + x = int | T + with self.assertRaises(TypeError): + x[42] + def test_or_type_operator_with_forward(self): T = typing.TypeVar('T') ForwardAfter = T | 'Forward' diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-07-17-21-04-04.bpo-44633.5-zKeI.rst b/Misc/NEWS.d/next/Core and Builtins/2021-07-17-21-04-04.bpo-44633.5-zKeI.rst new file mode 100644 index 00000000000000..507a68b65d4c37 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2021-07-17-21-04-04.bpo-44633.5-zKeI.rst @@ -0,0 +1,2 @@ +Parameter substitution of the union type with wrong types now raises +``TypeError`` instead of returning ``NotImplemented``. diff --git a/Objects/unionobject.c b/Objects/unionobject.c index c744c8746cb4d7..c0c9a24bcc204a 100644 --- a/Objects/unionobject.c +++ b/Objects/unionobject.c @@ -302,10 +302,22 @@ is_unionable(PyObject *obj) PyObject * _Py_union_type_or(PyObject* self, PyObject* other) { + int r = is_unionable(self); + if (r > 0) { + r = is_unionable(other); + } + if (r < 0) { + return NULL; + } + if (!r) { + Py_RETURN_NOTIMPLEMENTED; + } + PyObject *tuple = PyTuple_Pack(2, self, other); if (tuple == NULL) { return NULL; } + PyObject *new_union = make_union(tuple); Py_DECREF(tuple); return new_union; @@ -434,6 +446,21 @@ union_getitem(PyObject *self, PyObject *item) return NULL; } + // Check arguments are unionable. + Py_ssize_t nargs = PyTuple_GET_SIZE(newargs); + for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) { + PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); + int is_arg_unionable = is_unionable(arg); + if (is_arg_unionable <= 0) { + Py_DECREF(newargs); + if (is_arg_unionable == 0) { + PyErr_Format(PyExc_TypeError, + "Each union argument must be a type, got %.100R", arg); + } + return NULL; + } + } + PyObject *res = make_union(newargs); Py_DECREF(newargs); @@ -495,21 +522,6 @@ make_union(PyObject *args) { assert(PyTuple_CheckExact(args)); - unionobject* result = NULL; - - // Check arguments are unionable. - Py_ssize_t nargs = PyTuple_GET_SIZE(args); - for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) { - PyObject *arg = PyTuple_GET_ITEM(args, iarg); - int is_arg_unionable = is_unionable(arg); - if (is_arg_unionable < 0) { - return NULL; - } - if (!is_arg_unionable) { - Py_RETURN_NOTIMPLEMENTED; - } - } - args = dedup_and_flatten_args(args); if (args == NULL) { return NULL; @@ -521,7 +533,7 @@ make_union(PyObject *args) return result1; } - result = PyObject_GC_New(unionobject, &_PyUnion_Type); + unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type); if (result == NULL) { Py_DECREF(args); return NULL;