diff --git a/pytype/tests/test_test_code.py b/pytype/tests/test_test_code.py index 45cbd374a..894816712 100644 --- a/pytype/tests/test_test_code.py +++ b/pytype/tests/test_test_code.py @@ -62,8 +62,23 @@ def test_foo(self): assert_type(x, str) """) + def test_narrowed_type_from_assert_isinstance(self): + # assertIsInstance should narrow the original var's bindings if possible. + self.Check(""" + import unittest + from typing import Union + class A: + pass + class B(A): + pass + class FooTest(unittest.TestCase): + def test_foo(self, x: Union[A, B, int]): + self.assertIsInstance(x, A) + assert_type(x, Union[A, B]) + """) + def test_new_type_from_assert_isinstance(self): - # assertIsInstance should create a var with a new type even if it is not in + # assertIsInstance should create a var with a new type if it is not in # the original var's bindings. self.Check(""" import unittest diff --git a/pytype/vm.py b/pytype/vm.py index 4de0cf8f6..cc5f21fd7 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -3021,12 +3021,27 @@ def _narrow(self, state, var, pred): return self._store_new_var_in_local(state, var, out) def _set_type_from_assert_isinstance(self, state, var, class_spec): + """Set type of var from an assertIsInstance(var, class_spec) call.""" # TODO(mdemello): If we want to cast var to typ via an assertion, should - # we check that at least one binding of var is compatible with typ? + # we require that at least one binding of var is compatible with typ? classes = [] abstract_utils.flatten(class_spec, classes) + ret = [] + # First try to narrow `var` based on `classes`. + for c in classes: + m = self.ctx.matcher(state.node).compute_one_match( + var, c, keep_all_views=True, match_all_views=False) + if m.success: + for matched in m.good_matches: + d = matched.view[var] + if isinstance(d.data, abstract.Instance): + ret.append(d.data.cls) + + # If we don't have bindings from `classes` in `var`, instantiate the + # original class spec. + ret = ret or classes instance = self.init_class( - state.node, self.ctx.convert.merge_values(classes)) + state.node, self.ctx.convert.merge_values(ret)) return self._store_new_var_in_local(state, var, instance) def _check_test_assert(self, state, func, args):