Skip to content

Commit

Permalink
Do not consider bare TypeVar not overlapping with None for reachabili…
Browse files Browse the repository at this point in the history
…ty analysis (python#18138)

Fixes python#18126.

Simply allowing such intersection was insufficient: existing binder
logic widened the type to `T | None` after the `is None` check.

This PR extends the binder logic to prevent constructing a union type
when all conditional branches are reachable and contain no assignments:
checking `if isinstance(something, Something)` does not change the type
of `something` after the end of the `if` block.
  • Loading branch information
sterliakov authored Nov 21, 2024
1 parent e840275 commit 08340c2
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 42 deletions.
55 changes: 39 additions & 16 deletions mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import defaultdict
from contextlib import contextmanager
from typing import DefaultDict, Iterator, List, Optional, Tuple, Union, cast
from typing import DefaultDict, Iterator, List, NamedTuple, Optional, Tuple, Union
from typing_extensions import TypeAlias as _TypeAlias

from mypy.erasetype import remove_instance_last_known_values
Expand Down Expand Up @@ -30,6 +30,11 @@
BindableExpression: _TypeAlias = Union[IndexExpr, MemberExpr, NameExpr]


class CurrentType(NamedTuple):
type: Type
from_assignment: bool


class Frame:
"""A Frame represents a specific point in the execution of a program.
It carries information about the current types of expressions at
Expand All @@ -44,7 +49,7 @@ class Frame:

def __init__(self, id: int, conditional_frame: bool = False) -> None:
self.id = id
self.types: dict[Key, Type] = {}
self.types: dict[Key, CurrentType] = {}
self.unreachable = False
self.conditional_frame = conditional_frame
self.suppress_unreachable_warnings = False
Expand Down Expand Up @@ -132,18 +137,18 @@ def push_frame(self, conditional_frame: bool = False) -> Frame:
self.options_on_return.append([])
return f

def _put(self, key: Key, type: Type, index: int = -1) -> None:
self.frames[index].types[key] = type
def _put(self, key: Key, type: Type, from_assignment: bool, index: int = -1) -> None:
self.frames[index].types[key] = CurrentType(type, from_assignment)

def _get(self, key: Key, index: int = -1) -> Type | None:
def _get(self, key: Key, index: int = -1) -> CurrentType | None:
if index < 0:
index += len(self.frames)
for i in range(index, -1, -1):
if key in self.frames[i].types:
return self.frames[i].types[key]
return None

def put(self, expr: Expression, typ: Type) -> None:
def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> None:
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
return
if not literal(expr):
Expand All @@ -153,7 +158,7 @@ def put(self, expr: Expression, typ: Type) -> None:
if key not in self.declarations:
self.declarations[key] = get_declaration(expr)
self._add_dependencies(key)
self._put(key, typ)
self._put(key, typ, from_assignment)

def unreachable(self) -> None:
self.frames[-1].unreachable = True
Expand All @@ -164,7 +169,10 @@ def suppress_unreachable_warnings(self) -> None:
def get(self, expr: Expression) -> Type | None:
key = literal_hash(expr)
assert key is not None, "Internal error: binder tried to get non-literal"
return self._get(key)
found = self._get(key)
if found is None:
return None
return found.type

def is_unreachable(self) -> bool:
# TODO: Copy the value of unreachable into new frames to avoid
Expand Down Expand Up @@ -193,7 +201,7 @@ def update_from_options(self, frames: list[Frame]) -> bool:
If a key is declared as AnyType, only update it if all the
options are the same.
"""

all_reachable = all(not f.unreachable for f in frames)
frames = [f for f in frames if not f.unreachable]
changed = False
keys = {key for f in frames for key in f.types}
Expand All @@ -207,17 +215,30 @@ def update_from_options(self, frames: list[Frame]) -> bool:
# know anything about key in at least one possible frame.
continue

type = resulting_values[0]
assert type is not None
if all_reachable and all(
x is not None and not x.from_assignment for x in resulting_values
):
# Do not synthesize a new type if we encountered a conditional block
# (if, while or match-case) without assignments.
# See check-isinstance.test::testNoneCheckDoesNotMakeTypeVarOptional
# This is a safe assumption: the fact that we checked something with `is`
# or `isinstance` does not change the type of the value.
continue

current_type = resulting_values[0]
assert current_type is not None
type = current_type.type
declaration_type = get_proper_type(self.declarations.get(key))
if isinstance(declaration_type, AnyType):
# At this point resulting values can't contain None, see continue above
if not all(is_same_type(type, cast(Type, t)) for t in resulting_values[1:]):
if not all(
t is not None and is_same_type(type, t.type) for t in resulting_values[1:]
):
type = AnyType(TypeOfAny.from_another_any, source_any=declaration_type)
else:
for other in resulting_values[1:]:
assert other is not None
type = join_simple(self.declarations[key], type, other)
type = join_simple(self.declarations[key], type, other.type)
# Try simplifying resulting type for unions involving variadic tuples.
# Technically, everything is still valid without this step, but if we do
# not do this, this may create long unions after exiting an if check like:
Expand All @@ -236,8 +257,8 @@ def update_from_options(self, frames: list[Frame]) -> bool:
)
if simplified == self.declarations[key]:
type = simplified
if current_value is None or not is_same_type(type, current_value):
self._put(key, type)
if current_value is None or not is_same_type(type, current_value[0]):
self._put(key, type, from_assignment=True)
changed = True

self.frames[-1].unreachable = not frames
Expand Down Expand Up @@ -374,7 +395,9 @@ def most_recent_enclosing_type(self, expr: BindableExpression, type: Type) -> Ty
key = literal_hash(expr)
assert key is not None
enclosers = [get_declaration(expr)] + [
f.types[key] for f in self.frames if key in f.types and is_subtype(type, f.types[key])
f.types[key].type
for f in self.frames
if key in f.types and is_subtype(type, f.types[key][0])
]
return enclosers[-1]

Expand Down
27 changes: 14 additions & 13 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4725,11 +4725,11 @@ def visit_if_stmt(self, s: IfStmt) -> None:

# XXX Issue a warning if condition is always False?
with self.binder.frame_context(can_skip=True, fall_through=2):
self.push_type_map(if_map)
self.push_type_map(if_map, from_assignment=False)
self.accept(b)

# XXX Issue a warning if condition is always True?
self.push_type_map(else_map)
self.push_type_map(else_map, from_assignment=False)

with self.binder.frame_context(can_skip=False, fall_through=2):
if s.else_body:
Expand Down Expand Up @@ -5310,18 +5310,21 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
if b.is_unreachable or isinstance(
get_proper_type(pattern_type.type), UninhabitedType
):
self.push_type_map(None)
self.push_type_map(None, from_assignment=False)
else_map: TypeMap = {}
else:
pattern_map, else_map = conditional_types_to_typemaps(
named_subject, pattern_type.type, pattern_type.rest_type
)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map)
self.push_type_map(pattern_map, from_assignment=False)
if pattern_map:
for expr, typ in pattern_map.items():
self.push_type_map(self._get_recursive_sub_patterns_map(expr, typ))
self.push_type_map(pattern_type.captures)
self.push_type_map(
self._get_recursive_sub_patterns_map(expr, typ),
from_assignment=False,
)
self.push_type_map(pattern_type.captures, from_assignment=False)
if g is not None:
with self.binder.frame_context(can_skip=False, fall_through=3):
gt = get_proper_type(self.expr_checker.accept(g))
Expand All @@ -5347,11 +5350,11 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
continue
type_map[named_subject] = type_map[expr]

self.push_type_map(guard_map)
self.push_type_map(guard_map, from_assignment=False)
self.accept(b)
else:
self.accept(b)
self.push_type_map(else_map)
self.push_type_map(else_map, from_assignment=False)

# This is needed due to a quirk in frame_context. Without it types will stay narrowed
# after the match.
Expand Down Expand Up @@ -7372,12 +7375,12 @@ def iterable_item_type(
def function_type(self, func: FuncBase) -> FunctionLike:
return function_type(func, self.named_type("builtins.function"))

def push_type_map(self, type_map: TypeMap) -> None:
def push_type_map(self, type_map: TypeMap, *, from_assignment: bool = True) -> None:
if type_map is None:
self.binder.unreachable()
else:
for expr, type in type_map.items():
self.binder.put(expr, type)
self.binder.put(expr, type, from_assignment=from_assignment)

def infer_issubclass_maps(self, node: CallExpr, expr: Expression) -> tuple[TypeMap, TypeMap]:
"""Infer type restrictions for an expression in issubclass call."""
Expand Down Expand Up @@ -7750,9 +7753,7 @@ def conditional_types(
) and is_proper_subtype(current_type, proposed_type, ignore_promotions=True):
# Expression is always of one of the types in proposed_type_ranges
return default, UninhabitedType()
elif not is_overlapping_types(
current_type, proposed_type, prohibit_none_typevar_overlap=True, ignore_promotions=True
):
elif not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
# Expression is never of any type in proposed_type_ranges
return UninhabitedType(), default
else:
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ elif x is Foo.C:
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.C]"
else:
reveal_type(x) # No output here: this branch is unreachable
reveal_type(x) # N: Revealed type is "__main__.Foo"
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[__main__.Foo.C]]"

if Foo.A is x:
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.A]"
Expand All @@ -825,7 +825,7 @@ elif Foo.C is x:
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.C]"
else:
reveal_type(x) # No output here: this branch is unreachable
reveal_type(x) # N: Revealed type is "__main__.Foo"
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[__main__.Foo.C]]"

y: Foo
if y is Foo.A:
Expand Down
17 changes: 9 additions & 8 deletions test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -2207,23 +2207,24 @@ def foo2(x: Optional[str]) -> None:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/isinstance.pyi]

[case testNoneCheckDoesNotNarrowWhenUsingTypeVars]

# Note: this test (and the following one) are testing checker.conditional_type_map:
# if you set the 'prohibit_none_typevar_overlap' keyword argument to False when calling
# 'is_overlapping_types', the binder will incorrectly infer that 'out' has a type of
# Union[T, None] after the if statement.

[case testNoneCheckDoesNotMakeTypeVarOptional]
from typing import TypeVar

T = TypeVar('T')

def foo(x: T) -> T:
def foo_if(x: T) -> T:
out = None
out = x
if out is None:
pass
return out

def foo_while(x: T) -> T:
out = None
out = x
while out is None:
pass
return out
[builtins fixtures/isinstance.pyi]

[case testNoneCheckDoesNotNarrowWhenUsingTypeVarsNoStrictOptional]
Expand Down
19 changes: 19 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2333,3 +2333,22 @@ def f(x: C) -> None:

f(C(5))
[builtins fixtures/primitives.pyi]

[case testNarrowingTypeVarNone]
# flags: --warn-unreachable

# https://github.com/python/mypy/issues/18126
from typing import TypeVar

T = TypeVar("T")

def fn_if(arg: T) -> None:
if arg is None:
return None
return None

def fn_while(arg: T) -> None:
while arg is None:
return None
return None
[builtins fixtures/primitives.pyi]
30 changes: 30 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -2409,3 +2409,33 @@ def f(x: T) -> None:
case _:
accept_seq_int(x) # E: Argument 1 to "accept_seq_int" has incompatible type "T"; expected "Sequence[int]"
[builtins fixtures/tuple.pyi]

[case testNarrowingTypeVarMatch]
# flags: --warn-unreachable

# https://github.com/python/mypy/issues/18126
from typing import TypeVar

T = TypeVar("T")

def fn_case(arg: T) -> None:
match arg:
case None:
return None
return None
[builtins fixtures/primitives.pyi]

[case testNoneCheckDoesNotMakeTypeVarOptionalMatch]
from typing import TypeVar

T = TypeVar('T')

def foo(x: T) -> T:
out = None
out = x
match out:
case None:
pass
return out

[builtins fixtures/isinstance.pyi]
6 changes: 3 additions & 3 deletions test-data/unit/check-type-promotion.test
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ else:
reveal_type(x) # N: Revealed type is "builtins.complex"

# Note we make type precise, since type promotions are involved
reveal_type(x) # N: Revealed type is "Union[builtins.complex, builtins.int, builtins.float]"
reveal_type(x) # N: Revealed type is "builtins.complex"
[builtins fixtures/primitives.pyi]

[case testIntersectionUsingPromotion3]
Expand Down Expand Up @@ -127,7 +127,7 @@ if isinstance(x, int):
reveal_type(x) # N: Revealed type is "builtins.int"
else:
reveal_type(x) # N: Revealed type is "Union[builtins.float, builtins.complex]"
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.complex]"
reveal_type(x) # N: Revealed type is "Union[builtins.float, builtins.complex]"
[builtins fixtures/primitives.pyi]

[case testIntersectionUsingPromotion6]
Expand All @@ -139,7 +139,7 @@ if isinstance(x, int):
reveal_type(x) # N: Revealed type is "builtins.int"
else:
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.complex]"
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.int, builtins.complex]"
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.complex]"
[builtins fixtures/primitives.pyi]

[case testIntersectionUsingPromotion7]
Expand Down

0 comments on commit 08340c2

Please sign in to comment.