diff --git a/pytype/abstract/function.py b/pytype/abstract/function.py index dcdffc941..3c9263c97 100644 --- a/pytype/abstract/function.py +++ b/pytype/abstract/function.py @@ -1170,20 +1170,26 @@ def handle_typeguard(node, ret: _ReturnType, first_arg, ctx, func_name=None): ) return None target = frame.lookup_name(target_name) - # Forward all the target's bindings to the current node, so we don't have - # visibility problems later. - target.PasteVariable(target, node) - old_data = set(target.data) + + # Forward all the target's visible bindings to the current node. We're going + # to add new bindings soon, which would otherwise hide the old bindings, kinda + # like assigning the variable to a new value. + for b in target.Bindings(node): + target.PasteBinding(b, node) + + # Add missing bindings to the target variable. + old_data = set(target.Data(node)) new_instance = ret.instantiate_parameter(node, abstract_utils.T) + new_data = set(new_instance.data) for b in new_instance.bindings: - if b.data not in target.data: + if b.data not in old_data: target.PasteBinding(b, node) # Create a boolean return variable with True bindings for values that # originate from the type guard type and False for the rest. typeguard_return = ctx.program.NewVariable() - for b in target.bindings: - boolvals = {b.data not in old_data} | {b.data in new_instance.data} + for b in target.Bindings(node): + boolvals = {b.data not in old_data} | {b.data in new_data} for v in boolvals: typeguard_return.AddBinding(ctx.convert.bool_values[v], {b}, node) return typeguard_return diff --git a/pytype/tests/test_typeguard.py b/pytype/tests/test_typeguard.py index e0a90d1d9..0126787ca 100644 --- a/pytype/tests/test_typeguard.py +++ b/pytype/tests/test_typeguard.py @@ -461,6 +461,76 @@ def f(x) -> TypeGuard[tuple[str, int]]: assert_type(e2, int) """) + def test_only_use_visible_bindings(self): + with self.DepTree([( + "foo.pyi", + """ + from typing import TypeGuard + class Foo: ... + def isfoo(x: object) -> TypeGuard[Foo]: ... + """, + )]): + self.Check(""" + import foo + + value = 1 + del value # add "Deleted" binding for `value` + value = 2 + if foo.isfoo(value): + print(value) # "Deleted" binding should not be visible here + """) + + def test_dont_hide_previous_bindings(self): + with self.DepTree([( + "foo.pyi", + """ + from typing import TypeGuard + class Foo: ... + class Bar: ... + class Baz: ... + def isbar(x: object) -> TypeGuard[Bar]: ... + def isbaz(x: object) -> TypeGuard[Baz]: ... + """, + )]): + errors = self.CheckWithErrors(""" + import foo + + def takes_foo(x: foo.Foo): + pass + + def test(x: foo.Foo): + is_bar = foo.isbar(x) + is_baz = foo.isbaz(x) + takes_foo(x) + reveal_type(x) # reveal-type[e] + """) + # This documents a slightly incorrect type inference. Arguably `x` should + # just be `Foo` here. + self.assertErrorSequences( + errors, {"e": "Union[foo.Bar, foo.Baz, foo.Foo]"} + ) + + def test_type_guard_matches_input_type(self): + with self.DepTree([( + "foo.pyi", + """ + from typing import TypeGuard + class Foo: ... + def isfoo(x: object) -> TypeGuard[Foo]: ... + """, + )]): + self.Check(""" + import foo + + class Test: + def __init__(self, value: foo.Foo): + if not foo.isfoo(value): + raise ValueError('"value" must be a Foo') + self.value = value + + print(Test(foo.Foo()).value) # .value must be defined here + """) + if __name__ == "__main__": test_base.main()