From e4ecb91ea868c20c022daf1edbd76426bfce32a2 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 14:00:44 -0600 Subject: [PATCH 01/29] Add test case --- test-data/unit/check-optional.test | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 754c6b52ff19..07bbca04ff22 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1040,3 +1040,21 @@ x: Optional[List[int]] if 3 in x: pass +[case testNestedFunction] +from typing import Optional + +def f1(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return reveal_type(x) + nested() + +def f2(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + if int(): + x = None + nested() From 05e395857c6bd0c38e69fc326a307f76b2acb501 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 14:13:33 -0600 Subject: [PATCH 02/29] WIP debugging stuff --- mypy/binder.py | 5 ++++- mypy/checker.py | 5 +++++ test-data/unit/check-optional.test | 4 +++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index d822aecec2f3..1a9db453ed20 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -51,6 +51,9 @@ def __init__(self, id: int, conditional_frame: bool = False) -> None: # need this field. self.suppress_unreachable_warnings = False + def __repr__(self) -> str: + return f"Frame({self.id}, {self.types}, {self.unreachable}, {self.conditional_frame})" + Assigns = DefaultDict[Expression, List[Tuple[Type, Optional[Type]]]] @@ -63,7 +66,7 @@ class ConditionalTypeBinder: ``` class A: - a = None # type: Union[int, str] + a: Union[int, str] = None x = A() lst = [x] reveal_type(x.a) # Union[int, str] diff --git a/mypy/checker.py b/mypy/checker.py index 07dfb7de08f1..7de07f79a91c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1060,7 +1060,12 @@ def check_func_def( expanded = self.expand_typevars(defn, typ) for item, typ in expanded: old_binder = self.binder + #print([1], old_binder.frames) self.binder = ConditionalTypeBinder() + for f in old_binder.frames: + for k, v in f.types.items(): + if len(k) == 2 and k[0] == 'Var': + print(k[1], v) with self.binder.top_frame_context(): defn.expanded.append(item) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 07bbca04ff22..e13479af148a 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1040,7 +1040,7 @@ x: Optional[List[int]] if 3 in x: pass -[case testNestedFunction] +[case testNarrowedVariableInNestedFunction] from typing import Optional def f1(x: Optional[str]) -> None: @@ -1050,6 +1050,7 @@ def f1(x: Optional[str]) -> None: return reveal_type(x) nested() +""" def f2(x: Optional[str]) -> None: if x is None: x = "a" @@ -1058,3 +1059,4 @@ def f2(x: Optional[str]) -> None: if int(): x = None nested() +""" From f3c87cacca3a9161662b4dff91e5aea7304f04a2 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 15:11:20 -0600 Subject: [PATCH 03/29] Propagate narrowed types to nested functions in some cases --- mypy/checker.py | 48 ++++++++++++++++++++++++++---- mypy/scope.py | 8 +++++ test-data/unit/check-optional.test | 4 +-- 3 files changed, 52 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 7de07f79a91c..89425fc10aca 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -45,6 +45,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_erased_types, is_overlapping_types from mypy.message_registry import ErrorMessage +from mypy.traverser import TraverserVisitor from mypy.messages import ( SUGGESTED_TEST_FIXTURES, MessageBuilder, @@ -1060,12 +1061,7 @@ def check_func_def( expanded = self.expand_typevars(defn, typ) for item, typ in expanded: old_binder = self.binder - #print([1], old_binder.frames) self.binder = ConditionalTypeBinder() - for f in old_binder.frames: - for k, v in f.types.items(): - if len(k) == 2 and k[0] == 'Var': - print(k[1], v) with self.binder.top_frame_context(): defn.expanded.append(item) @@ -1212,6 +1208,15 @@ def check_func_def( # Type check body in a new scope. with self.binder.top_frame_context(): + for f in old_binder.frames: + for k, v in f.types.items(): + if ( + len(k) == 2 + and k[0] == "Var" + and not self.is_var_redefined_in_outer_context(k[1], defn.line) + ): + f = self.binder.push_frame() + f.types[k] = v with self.scope.push_function(defn): # We suppress reachability warnings when we use TypeVars with value # restrictions: we only want to report a warning if a certain statement is @@ -1315,6 +1320,16 @@ def check_func_def( self.binder = old_binder + def is_var_redefined_in_outer_context(self, v: Var, after_line: int) -> bool: + outer = self.tscope.outer_function() + if outer is None: + # Top-level function -- outer context is top level, and we can't reason about + # globals + return True + if isinstance(outer, FuncDef): + return find_last_var_assignment_line(outer.body, v) >= after_line + return False + def check_unbound_return_typevar(self, typ: CallableType) -> None: """Fails when the return typevar is not defined in arguments.""" if isinstance(typ.ret_type, TypeVarType) and typ.ret_type in typ.variables: @@ -7634,3 +7649,26 @@ def collapse_walrus(e: Expression) -> Expression: if isinstance(e, AssignmentExpr): return e.target return e + + +def find_last_var_assignment_line(n: Node, v: Var) -> int: + v = VarAssignVisitor(v) + n.accept(v) + return v.last_line + + +class VarAssignVisitor(TraverserVisitor): + def __init__(self, v: Var) -> None: + self.last_line = -1 + self.lvalue = False + self.var_node = v + + def visit_assignment_stmt(self, s: AssignmentStmt) -> None: + self.lvalue = True + for lv in s.lvalues: + lv.accept(self) + self.lvalue = False + + def visit_name_expr(self, e: NameExpr) -> None: + if self.lvalue and e.node is self.var_node: + self.last_line = max(self.last_line, e.line) diff --git a/mypy/scope.py b/mypy/scope.py index 19a690df8220..e797e095c826 100644 --- a/mypy/scope.py +++ b/mypy/scope.py @@ -21,6 +21,7 @@ def __init__(self) -> None: self.module: str | None = None self.classes: list[TypeInfo] = [] self.function: FuncBase | None = None + self.functions: list[FuncBase] = [] # Number of nested scopes ignored (that don't get their own separate targets) self.ignored = 0 @@ -65,12 +66,14 @@ def module_scope(self, prefix: str) -> Iterator[None]: @contextmanager def function_scope(self, fdef: FuncBase) -> Iterator[None]: + self.functions.append(fdef) if not self.function: self.function = fdef else: # Nested functions are part of the topmost function target. self.ignored += 1 yield + self.functions.pop() if self.ignored: # Leave a scope that's included in the enclosing target. self.ignored -= 1 @@ -78,6 +81,11 @@ def function_scope(self, fdef: FuncBase) -> Iterator[None]: assert self.function self.function = None + def outer_function(self) -> FuncBase | None: + if len(self.functions) > 1: + return self.functions[-2] + return None + def enter_class(self, info: TypeInfo) -> None: """Enter a class target scope.""" if not self.function: diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index e13479af148a..7a7b69cceba6 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1047,10 +1047,9 @@ def f1(x: Optional[str]) -> None: if x is None: x = "a" def nested() -> str: - return reveal_type(x) + return reveal_type(x) # N: Revealed type is "builtins.str" nested() -""" def f2(x: Optional[str]) -> None: if x is None: x = "a" @@ -1059,4 +1058,3 @@ def f2(x: Optional[str]) -> None: if int(): x = None nested() -""" From 01e044d25da273b3e80378c87f1c78857e369923 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 15:14:01 -0600 Subject: [PATCH 04/29] Add test cases --- test-data/unit/check-optional.test | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 7a7b69cceba6..0cbb694aee50 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1043,14 +1043,20 @@ if 3 in x: [case testNarrowedVariableInNestedFunction] from typing import Optional -def f1(x: Optional[str]) -> None: +def can_narrow(x: Optional[str]) -> None: if x is None: x = "a" def nested() -> str: return reveal_type(x) # N: Revealed type is "builtins.str" nested() -def f2(x: Optional[str]) -> None: +def can_narrow_lambda(x: Optional[str]) -> None: + if x is None: + x = "a" + nested = lambda: x + reveal_type(nested()) # N: Revealed type is "builtins.str" + +def cannot_narrow_if_reassigned(x: Optional[str]) -> None: if x is None: x = "a" def nested() -> str: @@ -1058,3 +1064,8 @@ def f2(x: Optional[str]) -> None: if int(): x = None nested() + +x: Optional[str] = "x" + +def cannot_narrow_top_level() -> None: + reveal_type(x) # N: Revealed type is "Union[builtins.str, None]" From eecc8775c07639dae4e3966c685d07e76bfc16d3 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 17:38:17 -0600 Subject: [PATCH 05/29] Support for and with statements --- mypy/checker.py | 16 +++++++++++++++ test-data/unit/check-optional.test | 33 +++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 89425fc10aca..91b38223e069 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7672,3 +7672,19 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: def visit_name_expr(self, e: NameExpr) -> None: if self.lvalue and e.node is self.var_node: self.last_line = max(self.last_line, e.line) + + def visit_with_stmt(self, s: WithStmt) -> None: + self.lvalue = True + for lv in s.target: + if lv is not None: + lv.accept(self) + self.lvalue = False + s.body.accept(self) + + def visit_for_stmt(self, s: ForStmt) -> None: + self.lvalue = True + s.index.accept(self) + self.lvalue = False + s.body.accept(self) + if s.else_body: + s.else_body.accept(self) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 0cbb694aee50..d0446f64957f 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1040,7 +1040,7 @@ x: Optional[List[int]] if 3 in x: pass -[case testNarrowedVariableInNestedFunction] +[case testNarrowedVariableInNestedFunctionBasic] from typing import Optional def can_narrow(x: Optional[str]) -> None: @@ -1069,3 +1069,34 @@ x: Optional[str] = "x" def cannot_narrow_top_level() -> None: reveal_type(x) # N: Revealed type is "Union[builtins.str, None]" + +[case testNarrowedVariableInNestedFunction2] +from typing import Optional + +class C: + a: Optional[str] + +def attribute_narrowing(c: C) -> None: + # This case is not supported at the moment. + c.a = "x" + def nested() -> str: + return c.a # E: Incompatible return value type (got "Optional[str]", expected "str") + nested() + +def assignment_in_for(x: Optional[str]) -> None: + if x is None: + x = "e" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + for x in ["x"]: + pass + +def foo(): pass + +def assignment_in_with(x: Optional[str]) -> None: + if x is None: + x = "e" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + with foo() as x: + pass From 030457e0b996fbfe8dc5ac4f74e5acb8b7e1ae8e Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 17:44:12 -0600 Subject: [PATCH 06/29] Test more --- test-data/unit/check-optional.test | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index d0446f64957f..1012d796fe23 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1100,3 +1100,13 @@ def assignment_in_with(x: Optional[str]) -> None: return x # E: Incompatible return value type (got "Optional[str]", expected "str") with foo() as x: pass + +g: Optional[str] + +def assign_to_global() -> None: + global g + g = "x" + # This is unsafe, but we don't generate an error, for convenience. Besides, + # this is probably a very rare case. + def nested() -> str: + return g From 7968859055545cdf51d74cde1c795778e3599fb6 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 18:00:29 -0600 Subject: [PATCH 07/29] Fix nested functions --- mypy/checker.py | 10 ++++++---- mypy/scope.py | 6 ++---- test-data/unit/check-optional.test | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 91b38223e069..a202911827fd 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1321,13 +1321,15 @@ def check_func_def( self.binder = old_binder def is_var_redefined_in_outer_context(self, v: Var, after_line: int) -> bool: - outer = self.tscope.outer_function() - if outer is None: + outers = self.tscope.outer_functions() + if not outers: # Top-level function -- outer context is top level, and we can't reason about # globals return True - if isinstance(outer, FuncDef): - return find_last_var_assignment_line(outer.body, v) >= after_line + for outer in outers: + if isinstance(outer, FuncDef): + if find_last_var_assignment_line(outer.body, v) >= after_line: + return True return False def check_unbound_return_typevar(self, typ: CallableType) -> None: diff --git a/mypy/scope.py b/mypy/scope.py index e797e095c826..021dd9a7d8a5 100644 --- a/mypy/scope.py +++ b/mypy/scope.py @@ -81,10 +81,8 @@ def function_scope(self, fdef: FuncBase) -> Iterator[None]: assert self.function self.function = None - def outer_function(self) -> FuncBase | None: - if len(self.functions) > 1: - return self.functions[-2] - return None + def outer_functions(self) -> list[FuncBase]: + return self.functions[:-1] def enter_class(self, info: TypeInfo) -> None: """Enter a class target scope.""" diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 1012d796fe23..594d31e09654 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1110,3 +1110,17 @@ def assign_to_global() -> None: # this is probably a very rare case. def nested() -> str: return g + +def assign_to_nonlocal(x: Optional[str]) -> None: + def nested() -> str: + nonlocal x + + if x is None: + x = "a" + + def nested2() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + + return nested2() + nested() + x = None From 7603f2435e2f3a5c54a03827ddf0a120c2a4106a Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 18:12:05 -0600 Subject: [PATCH 08/29] Add unit tests --- test-data/unit/check-optional.test | 37 ++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 594d31e09654..3cb6a2e93c36 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1124,3 +1124,40 @@ def assign_to_nonlocal(x: Optional[str]) -> None: return nested2() nested() x = None + +def dec(f): + return f + +@dec +def decorated_outer(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x + nested() + +@dec +def decorated_outer_bad(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + x = None + nested() + +def decorated_inner(x: Optional[str]) -> None: + if x is None: + x = "a" + @dec + def nested() -> str: + return x + nested() + +def decorated_inner_bad(x: Optional[str]) -> None: + if x is None: + x = "a" + @dec + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + x = None + nested() From a81856ac64c62313703109c84f0141625c47be8c Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 24 Apr 2023 18:15:02 -0600 Subject: [PATCH 09/29] More testing --- test-data/unit/check-optional.test | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 3cb6a2e93c36..b1a4eff79534 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1071,7 +1071,7 @@ def cannot_narrow_top_level() -> None: reveal_type(x) # N: Revealed type is "Union[builtins.str, None]" [case testNarrowedVariableInNestedFunction2] -from typing import Optional +from typing import Optional, overload class C: a: Optional[str] @@ -1161,3 +1161,26 @@ def decorated_inner_bad(x: Optional[str]) -> None: return x # E: Incompatible return value type (got "Optional[str]", expected "str") x = None nested() + +@overload +def overloaded_outer(x: None) -> None: ... +@overload +def overloaded_outer(x: str) -> None: ... +def overloaded_outer(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x + nested() + +@overload +def overloaded_outer_bad(x: None) -> None: ... +@overload +def overloaded_outer_bad(x: str) -> None: ... +def overloaded_outer_bad(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + x = None + nested() From c503b3c4b415b4d7e37b2095aa9bbf2e8209500f Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 08:50:55 -0600 Subject: [PATCH 10/29] Check if a match statement assigns to a variable --- mypy/checker.py | 12 +++++ mypy/fastparse.py | 3 +- test-data/unit/check-python310.test | 68 +++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index a202911827fd..a015077b6019 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7690,3 +7690,15 @@ def visit_for_stmt(self, s: ForStmt) -> None: s.body.accept(self) if s.else_body: s.else_body.accept(self) + + def visit_as_pattern(self, p: AsPattern) -> None: + if p.name is not None: + self.lvalue = True + p.name.accept(self) + self.lvalue = False + + def visit_starred_pattern(self, p: StarredPattern) -> None: + if p.capture is not None: + self.lvalue = True + p.capture.accept(self) + self.lvalue = False diff --git a/mypy/fastparse.py b/mypy/fastparse.py index b619bbe1368c..2c7e16e963e5 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1767,7 +1767,8 @@ def visit_MatchStar(self, n: MatchStar) -> StarredPattern: if n.name is None: node = StarredPattern(None) else: - node = StarredPattern(NameExpr(n.name)) + name = self.set_line(NameExpr(n.name), n) + node = StarredPattern(name) return self.set_line(node, n) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index fb83dda7ffab..65e46b604862 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1863,3 +1863,71 @@ def f() -> None: match y: case A(): reveal_type(y.a) # N: Revealed type is "builtins.int" + +[case testNarrowedVariableInNestedModifiedInMatch] +# flags: --strict-optional +from typing import Optional + +def match_stmt_error1(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + match object(): + case str(x): + pass + nested() + +def match_stmt_ok1(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x + match object(): + case str(y): + pass + nested() + +def match_stmt_error2(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + match [None]: + case [x]: + pass + nested() + +def match_stmt_error3(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + match {'a': None}: + case {'a': x}: + pass + nested() + +def match_stmt_error4(x: Optional[list[str]]) -> None: + if x is None: + x = ["a"] + def nested() -> list[str]: + return x # E: Incompatible return value type (got "Optional[List[str]]", expected "List[str]") + match ["a"]: + case [*x]: + pass + nested() + +class C: + a: str + +def match_stmt_error5(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + match C(): + case C(a=x): + pass + nested() +[builtins fixtures/tuple.pyi] From a601098c02a35bb885e348e0019b16b6291c2f61 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 09:31:17 -0600 Subject: [PATCH 11/29] Support walrus expression --- mypy/checker.py | 8 ++++++++ test-data/unit/check-python38.test | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index a015077b6019..ac158fc35a03 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7691,7 +7691,15 @@ def visit_for_stmt(self, s: ForStmt) -> None: if s.else_body: s.else_body.accept(self) + def visit_assignment_expr(self, e: AssignmentExpr) -> None: + self.lvalue = True + e.target.accept(self) + self.lvalue = False + e.value.accept(self) + def visit_as_pattern(self, p: AsPattern) -> None: + if p.pattern is not None: + p.pattern.accept(self) if p.name is not None: self.lvalue = True p.name.accept(self) diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index b9f9f2173ae1..79e67d3c999c 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -775,3 +775,25 @@ main:9: note: Revealed type is "builtins.int" class C: [(j := i) for i in [1, 2, 3]] # E: Assignment expression within a comprehension cannot be used in a class body [builtins fixtures/list.pyi] + +[case testNarrowedVariableInNestedModifiedInWalrus] +# flags: --strict-optional +from typing import Optional + +def walrus_with_nested_error(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + if x := None: + pass + nested() + +def walrus_with_nested_ok(x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return x + if y := x: + pass + nested() From fe051312a90de3c86e46139d46822620fd738b40 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 09:35:43 -0600 Subject: [PATCH 12/29] Fix self check and isort --- mypy/checker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ac158fc35a03..9f6c54f6ffd6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -45,7 +45,6 @@ from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_erased_types, is_overlapping_types from mypy.message_registry import ErrorMessage -from mypy.traverser import TraverserVisitor from mypy.messages import ( SUGGESTED_TEST_FIXTURES, MessageBuilder, @@ -135,6 +134,7 @@ is_final_node, ) from mypy.options import Options +from mypy.patterns import AsPattern, StarredPattern from mypy.plugin import CheckerPluginInterface, Plugin from mypy.scope import Scope from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name @@ -152,7 +152,7 @@ restrict_subtype_away, unify_generic_callable, ) -from mypy.traverser import all_return_statements, has_return_statement +from mypy.traverser import TraverserVisitor, all_return_statements, has_return_statement from mypy.treetransform import TransformVisitor from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type from mypy.typeops import ( @@ -7654,9 +7654,9 @@ def collapse_walrus(e: Expression) -> Expression: def find_last_var_assignment_line(n: Node, v: Var) -> int: - v = VarAssignVisitor(v) - n.accept(v) - return v.last_line + visitor = VarAssignVisitor(v) + n.accept(visitor) + return visitor.last_line class VarAssignVisitor(TraverserVisitor): From 9d0b478d26e0c15f7d47c33bb1ccdd420b297a22 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 10:01:17 -0600 Subject: [PATCH 13/29] Add docstrings --- mypy/checker.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 9f6c54f6ffd6..ce24c334f6fe 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1321,6 +1321,11 @@ def check_func_def( self.binder = old_binder def is_var_redefined_in_outer_context(self, v: Var, after_line: int) -> bool: + """Can the variable be assigned to at module top level or outer function? + + Note that this doesn't do a full CFG analysis but uses a line number based + heuristic that isn't correct in some (rare) cases. + """ outers = self.tscope.outer_functions() if not outers: # Top-level function -- outer context is top level, and we can't reason about @@ -7654,6 +7659,12 @@ def collapse_walrus(e: Expression) -> Expression: def find_last_var_assignment_line(n: Node, v: Var) -> int: + """Find the highest line number of a potential assignment to variable within node. + + This supports local and global variables. + + Return -1 if no assignment was found. + """ visitor = VarAssignVisitor(v) n.accept(visitor) return visitor.last_line From 639a9a125fee6bdc9f66c0de99975639ac37f979 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 10:01:57 -0600 Subject: [PATCH 14/29] Minor tweak to test --- test-data/unit/check-optional.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index b1a4eff79534..5aa708ccdeb9 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1070,7 +1070,7 @@ x: Optional[str] = "x" def cannot_narrow_top_level() -> None: reveal_type(x) # N: Revealed type is "Union[builtins.str, None]" -[case testNarrowedVariableInNestedFunction2] +[case testNarrowedVariableInNestedFunctionMore] from typing import Optional, overload class C: From 6dbf9f13911b851f99b7da3e0daf4654656604c4 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 10:34:15 -0600 Subject: [PATCH 15/29] Refactoring and comment updates --- mypy/checker.py | 15 ++++++++------- test-data/unit/check-optional.test | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ce24c334f6fe..3a83d2bf5a1c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1208,15 +1208,16 @@ def check_func_def( # Type check body in a new scope. with self.binder.top_frame_context(): - for f in old_binder.frames: - for k, v in f.types.items(): + # Copy type narrowings from an outer function when it seems safe. + for frame in old_binder.frames: + for key, narrowed_type in frame.types.items(): if ( - len(k) == 2 - and k[0] == "Var" - and not self.is_var_redefined_in_outer_context(k[1], defn.line) + len(key) == 2 + and key[0] == "Var" + and not self.is_var_redefined_in_outer_context(key[1], defn.line) ): - f = self.binder.push_frame() - f.types[k] = v + new_frame = self.binder.push_frame() + new_frame.types[key] = narrowed_type with self.scope.push_function(defn): # We suppress reachability warnings when we use TypeVars with value # restrictions: we only want to report a warning if a certain statement is diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 5aa708ccdeb9..0b07870ac4f9 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1077,7 +1077,7 @@ class C: a: Optional[str] def attribute_narrowing(c: C) -> None: - # This case is not supported at the moment. + # This case is not supported, since we can't keep track of assignments to attributes. c.a = "x" def nested() -> str: return c.a # E: Incompatible return value type (got "Optional[str]", expected "str") From d463240398d1091cd4e5db9ae43836cb60c10e15 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 10:38:05 -0600 Subject: [PATCH 16/29] More refactoring --- mypy/checker.py | 9 ++++----- mypy/literals.py | 12 +++++++++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 3a83d2bf5a1c..e240555723ca 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -41,7 +41,7 @@ from mypy.errors import Errors, ErrorWatcher, report_internal_error from mypy.expandtype import expand_self_type, expand_type, expand_type_by_instance from mypy.join import join_types -from mypy.literals import Key, literal, literal_hash +from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_erased_types, is_overlapping_types from mypy.message_registry import ErrorMessage @@ -1211,10 +1211,9 @@ def check_func_def( # Copy type narrowings from an outer function when it seems safe. for frame in old_binder.frames: for key, narrowed_type in frame.types.items(): - if ( - len(key) == 2 - and key[0] == "Var" - and not self.is_var_redefined_in_outer_context(key[1], defn.line) + key_var = extract_var_from_literal_hash(key) + if key_var is not None and not self.is_var_redefined_in_outer_context( + key_var, defn.line ): new_frame = self.binder.push_frame() new_frame.types[key] = narrowed_type diff --git a/mypy/literals.py b/mypy/literals.py index 9d91cf728b06..dcf9bccdd01d 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Iterable, Optional, Tuple +from typing import Any, Iterable, Optional, Tuple, cast from typing_extensions import Final, TypeAlias as _TypeAlias from mypy.nodes import ( @@ -139,6 +139,16 @@ def literal_hash(e: Expression) -> Key | None: return e.accept(_hasher) +def extract_var_from_literal_hash(key: Key) -> Var | None: + """If key refers to a Var node, return it. + + Return None otherwise. + """ + if len(key) == 2 and key[0] == "Var": + return cast(Var, key[1]) + return None + + class _Hasher(ExpressionVisitor[Optional[Key]]): def visit_int_expr(self, e: IntExpr) -> Key: return ("Literal", e.value) From 77b67dfca03dc480b514d17e94cb708f93aa1068 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 10:42:26 -0600 Subject: [PATCH 17/29] Test narrowing multiple variables --- test-data/unit/check-optional.test | 36 +++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 0b07870ac4f9..2197ac8a2f59 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1070,7 +1070,7 @@ x: Optional[str] = "x" def cannot_narrow_top_level() -> None: reveal_type(x) # N: Revealed type is "Union[builtins.str, None]" -[case testNarrowedVariableInNestedFunctionMore] +[case testNarrowedVariableInNestedFunctionMore1] from typing import Optional, overload class C: @@ -1184,3 +1184,37 @@ def overloaded_outer_bad(x: Optional[str]) -> None: return x # E: Incompatible return value type (got "Optional[str]", expected "str") x = None nested() + +[case testNarrowedVariableInNestedFunctionMore2] +from typing import Optional + +def narrow_multiple(x: Optional[str], y: Optional[int]) -> None: + z: Optional[str] = x + if x is None: + x = "" + if y is None: + y = 1 + if int(): + if z is None: + z = "" + def nested() -> None: + a: str = x + b: int = y + c: str = z + nested() + +def narrow_multiple_partial(x: Optional[str], y: Optional[int]) -> None: + z: Optional[str] = x + if x is None: + x = "" + if isinstance(y, int): + if z is None: + z = "" + def nested() -> None: + a: str = x + b: int = y + c: str = z # E: Incompatible types in assignment (expression has type "Optional[str]", variable has type "str") + z = None + nested() + +[builtins fixtures/isinstance.pyi] From 339ae64a2580b3c445d18a1bf4b2bbef513b561e Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 10:58:04 -0600 Subject: [PATCH 18/29] Add another test case --- test-data/unit/check-optional.test | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 2197ac8a2f59..4df350cc4064 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1217,4 +1217,15 @@ def narrow_multiple_partial(x: Optional[str], y: Optional[int]) -> None: z = None nested() +def multiple_nested_functions(x: Optional[str], y: Optional[str]) -> None: + if x is None: + x = "" + def nested1() -> str: + return x + if y is None: + y = "" + def nested2() -> str: + a: str = y + return x + [builtins fixtures/isinstance.pyi] From 56c3f3a399117bef04397a46f5c14a68a245ecc3 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 11:05:02 -0600 Subject: [PATCH 19/29] Don't leak frames + add assert to find frame leaks --- mypy/binder.py | 1 + mypy/checker.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index 1a9db453ed20..37c0b6bb9006 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -449,6 +449,7 @@ def top_frame_context(self) -> Iterator[Frame]: assert len(self.frames) == 1 yield self.push_frame() self.pop_frame(True, 0) + assert len(self.frames) == 1 def get_declaration(expr: BindableExpression) -> Type | None: diff --git a/mypy/checker.py b/mypy/checker.py index e240555723ca..231523bd71a6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1208,14 +1208,19 @@ def check_func_def( # Type check body in a new scope. with self.binder.top_frame_context(): - # Copy type narrowings from an outer function when it seems safe. + # Copy some type narrowings from an outer function when it seems safe enough + # (i.e. we can't find an assignment that might change the type of the + # variable afterwards). + new_frame = None for frame in old_binder.frames: for key, narrowed_type in frame.types.items(): key_var = extract_var_from_literal_hash(key) if key_var is not None and not self.is_var_redefined_in_outer_context( key_var, defn.line ): - new_frame = self.binder.push_frame() + # It seems safe to propagate the type narrowing to a nested scope. + if new_frame is None: + new_frame = self.binder.push_frame() new_frame.types[key] = narrowed_type with self.scope.push_function(defn): # We suppress reachability warnings when we use TypeVars with value @@ -1228,6 +1233,8 @@ def check_func_def( self.binder.suppress_unreachable_warnings() self.accept(item.body) unreachable = self.binder.is_unreachable() + if new_frame is not None: + self.binder.pop_frame(True, 0) if not unreachable: if defn.is_generator or is_named_instance( From 24faaa53468a8a1c39ec746de9f34009f359ca34 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 11:11:08 -0600 Subject: [PATCH 20/29] Test method --- test-data/unit/check-optional.test | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 4df350cc4064..563ad01f71a8 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1050,6 +1050,14 @@ def can_narrow(x: Optional[str]) -> None: return reveal_type(x) # N: Revealed type is "builtins.str" nested() +class C: + def can_narrow_in_method(self, x: Optional[str]) -> None: + if x is None: + x = "a" + def nested() -> str: + return reveal_type(x) # N: Revealed type is "builtins.str" + nested() + def can_narrow_lambda(x: Optional[str]) -> None: if x is None: x = "a" From ecc2d348bd86a5f66edeaf2533bf20c8cac1c957 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 11:49:18 -0600 Subject: [PATCH 21/29] Fix mypyc build --- mypy/checker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 231523bd71a6..931b5316ac80 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -26,7 +26,7 @@ import mypy.checkexpr from mypy import errorcodes as codes, message_registry, nodes, operators -from mypy.binder import ConditionalTypeBinder, get_declaration +from mypy.binder import ConditionalTypeBinder, Frame, get_declaration from mypy.checkmember import ( MemberContext, analyze_decorator_or_funcbase_access, @@ -1211,7 +1211,7 @@ def check_func_def( # Copy some type narrowings from an outer function when it seems safe enough # (i.e. we can't find an assignment that might change the type of the # variable afterwards). - new_frame = None + new_frame: Frame | None = None for frame in old_binder.frames: for key, narrowed_type in frame.types.items(): key_var = extract_var_from_literal_hash(key) From 67098aa33f3ab0f8c6f84708629e69a826cc53df Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 12:37:58 -0600 Subject: [PATCH 22/29] Update mypy/literals.py Co-authored-by: Alex Waygood --- mypy/literals.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mypy/literals.py b/mypy/literals.py index dcf9bccdd01d..ce46e4702310 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -145,7 +145,8 @@ def extract_var_from_literal_hash(key: Key) -> Var | None: Return None otherwise. """ if len(key) == 2 and key[0] == "Var": - return cast(Var, key[1]) + assert isinstance(key[1], Var) + return key[1] return None From 25cafbd13f01573c0475e3dc9c739c63e0129824 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 12:38:49 -0600 Subject: [PATCH 23/29] Fix tests when using compiled mypy --- mypy/literals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/literals.py b/mypy/literals.py index ce46e4702310..9c2cda96d1a7 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -144,7 +144,7 @@ def extract_var_from_literal_hash(key: Key) -> Var | None: Return None otherwise. """ - if len(key) == 2 and key[0] == "Var": + if len(key) == 2 and key[0] == "Var" and key[1] is not None: assert isinstance(key[1], Var) return key[1] return None From 40f3b53248415c18945f75c062eda9ccd03c6535 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 12:43:01 -0600 Subject: [PATCH 24/29] Fix unused import --- mypy/literals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/literals.py b/mypy/literals.py index 9c2cda96d1a7..a3e0415fbda0 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Iterable, Optional, Tuple, cast +from typing import Any, Iterable, Optional, Tuple from typing_extensions import Final, TypeAlias as _TypeAlias from mypy.nodes import ( From 4bfb5f1bf3b8e6e9f8d15c84d985e6f6563ecefe Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 25 Apr 2023 12:45:58 -0600 Subject: [PATCH 25/29] Actually fix compiled mypy --- mypy/literals.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mypy/literals.py b/mypy/literals.py index a3e0415fbda0..53ba559c56bb 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -144,8 +144,7 @@ def extract_var_from_literal_hash(key: Key) -> Var | None: Return None otherwise. """ - if len(key) == 2 and key[0] == "Var" and key[1] is not None: - assert isinstance(key[1], Var) + if len(key) == 2 and key[0] == "Var" and isinstance(key[1], Var): return key[1] return None From 0c4dce4871e88101333ce3bb4173768c9d6b5cac Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 26 Apr 2023 12:08:53 -0600 Subject: [PATCH 26/29] Add more tests for reading narrowed variable after nested function --- test-data/unit/check-optional.test | 8 ++++++++ test-data/unit/check-python310.test | 6 ++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 563ad01f71a8..b233846ab678 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1050,12 +1050,20 @@ def can_narrow(x: Optional[str]) -> None: return reveal_type(x) # N: Revealed type is "builtins.str" nested() +def foo(a): pass + class C: def can_narrow_in_method(self, x: Optional[str]) -> None: if x is None: x = "a" def nested() -> str: return reveal_type(x) # N: Revealed type is "builtins.str" + # Reading the variable is fine + y = x + with foo(x): + foo(x) + for a in foo(x): + foo(x) nested() def can_narrow_lambda(x: Optional[str]) -> None: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 65e46b604862..15454fc3e216 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1878,14 +1878,16 @@ def match_stmt_error1(x: Optional[str]) -> None: pass nested() +def foo(x): pass + def match_stmt_ok1(x: Optional[str]) -> None: if x is None: x = "a" def nested() -> str: return x - match object(): + match foo(x): case str(y): - pass + z = x nested() def match_stmt_error2(x: Optional[str]) -> None: From 4fe08c225d37aad755d5edafedb8a45d9e8f875e Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 26 Apr 2023 12:16:03 -0600 Subject: [PATCH 27/29] Fix dealing with member and index expressions after nested function --- mypy/checker.py | 12 ++++++++++++ test-data/unit/check-optional.test | 15 +++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 931b5316ac80..25661ae75277 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7693,6 +7693,18 @@ def visit_name_expr(self, e: NameExpr) -> None: if self.lvalue and e.node is self.var_node: self.last_line = max(self.last_line, e.line) + def visit_member_expr(self, e: MemberExpr) -> None: + old_lvalue = self.lvalue + self.lvalue = False + super().visit_member_expr(e) + self.lvalue = old_lvalue + + def visit_index_expr(self, e: IndexExpr) -> None: + old_lvalue = self.lvalue + self.lvalue = False + super().visit_index_expr(e) + self.lvalue = old_lvalue + def visit_with_stmt(self, s: WithStmt) -> None: self.lvalue = True for lv in s.target: diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index b233846ab678..3da776ab0cfe 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1244,4 +1244,19 @@ def multiple_nested_functions(x: Optional[str], y: Optional[str]) -> None: a: str = y return x +class C: + a: str + def __setitem__(self, key, value): pass + +def narrowed_variable_used_in_lvalue_but_not_assigned(c: Optional[C]) -> None: + if c is None: + c = C() + def nested() -> C: + return c + c.a = "x" + c[1] = 2 + cc = C() + cc[c] = 3 + nested() + [builtins fixtures/isinstance.pyi] From c323e063d5ce38d7aedd0a57f60f215091411d8f Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 2 May 2023 17:24:28 +0100 Subject: [PATCH 28/29] Fix test case --- test-data/unit/check-optional.test | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 3da776ab0cfe..0c406d24692a 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1084,7 +1084,15 @@ def cannot_narrow_if_reassigned(x: Optional[str]) -> None: x: Optional[str] = "x" def cannot_narrow_top_level() -> None: - reveal_type(x) # N: Revealed type is "Union[builtins.str, None]" + global x + if x is None: + x = "a" + def nested() -> str: + # This should perhaps not be narrowed, since the nested function could outlive + # the outer function, and since other functions could also assign to x, but + # this seems like a minor issue. + return x + nested() [case testNarrowedVariableInNestedFunctionMore1] from typing import Optional, overload From c35090529a26a368ed628e53d9eb182ce6c3893e Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 3 May 2023 10:41:14 +0100 Subject: [PATCH 29/29] More tests --- test-data/unit/check-optional.test | 63 +++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 0c406d24692a..8ccce5115aba 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -1083,7 +1083,7 @@ def cannot_narrow_if_reassigned(x: Optional[str]) -> None: x: Optional[str] = "x" -def cannot_narrow_top_level() -> None: +def narrow_global_in_func() -> None: global x if x is None: x = "a" @@ -1094,6 +1094,13 @@ def cannot_narrow_top_level() -> None: return x nested() +x = "y" + +def narrowing_global_at_top_level_not_propagated() -> str: + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + [case testNarrowedVariableInNestedFunctionMore1] from typing import Optional, overload @@ -1267,4 +1274,58 @@ def narrowed_variable_used_in_lvalue_but_not_assigned(c: Optional[C]) -> None: cc[c] = 3 nested() +def narrow_with_multi_lvalues_1(x: Optional[str]) -> None: + if x is None: + x = "" + + def nested() -> str: + return x + + y = z = None + +def narrow_with_multi_lvalue_2(x: Optional[str]) -> None: + if x is None: + x = "" + + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + + x = y = None + +def narrow_with_multi_lvalue_3(x: Optional[str]) -> None: + if x is None: + x = "" + + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + + y = x = None + +def narrow_with_multi_assign_1(x: Optional[str]) -> None: + if x is None: + x = "" + + def nested() -> str: + return x + + y, z = None, None + +def narrow_with_multi_assign_2(x: Optional[str]) -> None: + if x is None: + x = "" + + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + + x, y = None, None + +def narrow_with_multi_assign_3(x: Optional[str]) -> None: + if x is None: + x = "" + + def nested() -> str: + return x # E: Incompatible return value type (got "Optional[str]", expected "str") + + y, x = None, None + [builtins fixtures/isinstance.pyi]