Skip to content

Commit

Permalink
gh-92261: Disallow iteration of Union (and other special forms) (GH-9…
Browse files Browse the repository at this point in the history
  • Loading branch information
mrahtz authored May 8, 2022
1 parent 788ef54 commit 4739997
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 5 deletions.
20 changes: 20 additions & 0 deletions Lib/test/test_genericalias.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,5 +487,25 @@ def test_del_iter(self):
del iter_x


class TypeIterationTests(unittest.TestCase):
_UNITERABLE_TYPES = (list, tuple)

def test_cannot_iterate(self):
for test_type in self._UNITERABLE_TYPES:
with self.subTest(type=test_type):
expected_error_regex = "object is not iterable"
with self.assertRaisesRegex(TypeError, expected_error_regex):
iter(test_type)
with self.assertRaisesRegex(TypeError, expected_error_regex):
list(test_type)
with self.assertRaisesRegex(TypeError, expected_error_regex):
for _ in test_type:
pass

def test_is_not_instance_of_iterable(self):
for type_to_test in self._UNITERABLE_TYPES:
self.assertNotIsInstance(type_to_test, Iterable)


if __name__ == "__main__":
unittest.main()
31 changes: 31 additions & 0 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7348,6 +7348,37 @@ def test_all_exported_names(self):
self.assertSetEqual(computed_all, actual_all)


class TypeIterationTests(BaseTestCase):
_UNITERABLE_TYPES = (
Any,
Union,
Union[str, int],
Union[str, T],
List,
Tuple,
Callable,
Callable[..., T],
Callable[[T], str],
Annotated,
Annotated[T, ''],
)

def test_cannot_iterate(self):
expected_error_regex = "object is not iterable"
for test_type in self._UNITERABLE_TYPES:
with self.subTest(type=test_type):
with self.assertRaisesRegex(TypeError, expected_error_regex):
iter(test_type)
with self.assertRaisesRegex(TypeError, expected_error_regex):
list(test_type)
with self.assertRaisesRegex(TypeError, expected_error_regex):
for _ in test_type:
pass

def test_is_not_instance_of_iterable(self):
for type_to_test in self._UNITERABLE_TYPES:
self.assertNotIsInstance(type_to_test, collections.abc.Iterable)


if __name__ == '__main__':
main()
25 changes: 20 additions & 5 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,24 @@ def __deepcopy__(self, memo):
return self


class _NotIterable:
"""Mixin to prevent iteration, without being compatible with Iterable.
That is, we could do:
def __iter__(self): raise TypeError()
But this would make users of this mixin duck type-compatible with
collections.abc.Iterable - isinstance(foo, Iterable) would be True.
Luckily, we can instead prevent iteration by setting __iter__ to None, which
is treated specially.
"""

__iter__ = None


# Internal indicator of special typing constructs.
# See __doc__ instance attribute for specific docs.
class _SpecialForm(_Final, _root=True):
class _SpecialForm(_Final, _NotIterable, _root=True):
__slots__ = ('_name', '__doc__', '_getitem')

def __init__(self, getitem):
Expand Down Expand Up @@ -1498,7 +1513,7 @@ def __iter__(self):
# 1 for List and 2 for Dict. It may be -1 if variable number of
# parameters are accepted (needs custom __getitem__).

class _SpecialGenericAlias(_BaseGenericAlias, _root=True):
class _SpecialGenericAlias(_NotIterable, _BaseGenericAlias, _root=True):
def __init__(self, origin, nparams, *, inst=True, name=None):
if name is None:
name = origin.__name__
Expand Down Expand Up @@ -1541,7 +1556,7 @@ def __or__(self, right):
def __ror__(self, left):
return Union[left, self]

class _CallableGenericAlias(_GenericAlias, _root=True):
class _CallableGenericAlias(_NotIterable, _GenericAlias, _root=True):
def __repr__(self):
assert self._name == 'Callable'
args = self.__args__
Expand Down Expand Up @@ -1606,7 +1621,7 @@ def __getitem__(self, params):
return self.copy_with(params)


class _UnionGenericAlias(_GenericAlias, _root=True):
class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
def copy_with(self, params):
return Union[params]

Expand Down Expand Up @@ -2046,7 +2061,7 @@ def _proto_hook(other):
cls.__init__ = _no_init_or_replace_init


class _AnnotatedAlias(_GenericAlias, _root=True):
class _AnnotatedAlias(_NotIterable, _GenericAlias, _root=True):
"""Runtime representation of an annotated type.
At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix hang when trying to iterate over a ``typing.Union``.

0 comments on commit 4739997

Please sign in to comment.