Skip to content

Commit

Permalink
When setting a type from assertIsInstance narrow the original type if…
Browse files Browse the repository at this point in the history
… possible.

If we cannot narrow the type we continue to fall back to just instantiating the
asserted types, rather than saying the assert will fail.

PiperOrigin-RevId: 575998510
  • Loading branch information
martindemello authored and rchen152 committed Oct 24, 2023
1 parent de0b10c commit 3289302
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
17 changes: 16 additions & 1 deletion pytype/tests/test_test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,23 @@ def test_foo(self):
assert_type(x, str)
""")

def test_narrowed_type_from_assert_isinstance(self):
# assertIsInstance should narrow the original var's bindings if possible.
self.Check("""
import unittest
from typing import Union
class A:
pass
class B(A):
pass
class FooTest(unittest.TestCase):
def test_foo(self, x: Union[A, B, int]):
self.assertIsInstance(x, A)
assert_type(x, Union[A, B])
""")

def test_new_type_from_assert_isinstance(self):
# assertIsInstance should create a var with a new type even if it is not in
# assertIsInstance should create a var with a new type if it is not in
# the original var's bindings.
self.Check("""
import unittest
Expand Down
19 changes: 17 additions & 2 deletions pytype/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,12 +3021,27 @@ def _narrow(self, state, var, pred):
return self._store_new_var_in_local(state, var, out)

def _set_type_from_assert_isinstance(self, state, var, class_spec):
"""Set type of var from an assertIsInstance(var, class_spec) call."""
# TODO(mdemello): If we want to cast var to typ via an assertion, should
# we check that at least one binding of var is compatible with typ?
# we require that at least one binding of var is compatible with typ?
classes = []
abstract_utils.flatten(class_spec, classes)
ret = []
# First try to narrow `var` based on `classes`.
for c in classes:
m = self.ctx.matcher(state.node).compute_one_match(
var, c, keep_all_views=True, match_all_views=False)
if m.success:
for matched in m.good_matches:
d = matched.view[var]
if isinstance(d.data, abstract.Instance):
ret.append(d.data.cls)

# If we don't have bindings from `classes` in `var`, instantiate the
# original class spec.
ret = ret or classes
instance = self.init_class(
state.node, self.ctx.convert.merge_values(classes))
state.node, self.ctx.convert.merge_values(ret))
return self._store_new_var_in_local(state, var, instance)

def _check_test_assert(self, state, func, args):
Expand Down

0 comments on commit 3289302

Please sign in to comment.