From 32893029892260ff5fe5a3d3dbf04890059f59d2 Mon Sep 17 00:00:00 2001 From: mdemello Date: Mon, 23 Oct 2023 19:52:15 -0700 Subject: [PATCH] When setting a type from assertIsInstance narrow the original type if possible. If we cannot narrow the type we continue to fall back to just instantiating the asserted types, rather than saying the assert will fail. PiperOrigin-RevId: 575998510 --- pytype/tests/test_test_code.py | 17 ++++++++++++++++- pytype/vm.py | 19 +++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/pytype/tests/test_test_code.py b/pytype/tests/test_test_code.py index 45cbd374a..894816712 100644 --- a/pytype/tests/test_test_code.py +++ b/pytype/tests/test_test_code.py @@ -62,8 +62,23 @@ def test_foo(self): assert_type(x, str) """) + def test_narrowed_type_from_assert_isinstance(self): + # assertIsInstance should narrow the original var's bindings if possible. + self.Check(""" + import unittest + from typing import Union + class A: + pass + class B(A): + pass + class FooTest(unittest.TestCase): + def test_foo(self, x: Union[A, B, int]): + self.assertIsInstance(x, A) + assert_type(x, Union[A, B]) + """) + def test_new_type_from_assert_isinstance(self): - # assertIsInstance should create a var with a new type even if it is not in + # assertIsInstance should create a var with a new type if it is not in # the original var's bindings. self.Check(""" import unittest diff --git a/pytype/vm.py b/pytype/vm.py index 4de0cf8f6..cc5f21fd7 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -3021,12 +3021,27 @@ def _narrow(self, state, var, pred): return self._store_new_var_in_local(state, var, out) def _set_type_from_assert_isinstance(self, state, var, class_spec): + """Set type of var from an assertIsInstance(var, class_spec) call.""" # TODO(mdemello): If we want to cast var to typ via an assertion, should - # we check that at least one binding of var is compatible with typ? + # we require that at least one binding of var is compatible with typ? classes = [] abstract_utils.flatten(class_spec, classes) + ret = [] + # First try to narrow `var` based on `classes`. + for c in classes: + m = self.ctx.matcher(state.node).compute_one_match( + var, c, keep_all_views=True, match_all_views=False) + if m.success: + for matched in m.good_matches: + d = matched.view[var] + if isinstance(d.data, abstract.Instance): + ret.append(d.data.cls) + + # If we don't have bindings from `classes` in `var`, instantiate the + # original class spec. + ret = ret or classes instance = self.init_class( - state.node, self.ctx.convert.merge_values(classes)) + state.node, self.ctx.convert.merge_values(ret)) return self._store_new_var_in_local(state, var, instance) def _check_test_assert(self, state, func, args):