From f82a019e18d2810c72b9f98fb2181987ddc29fc9 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Sun, 15 Dec 2019 13:59:01 -0800 Subject: [PATCH 1/4] Add support for narrowing Literals using equality This pull request (finally) adds support for narrowing expressions using Literal types by equality, instead of just identity. For example, the following "tagged union" pattern is now supported: ```python class Foo(TypedDict): key: Literal["A"] blah: int class Bar(TypedDict): key: Literal["B"] something: str x: Union[Foo, Bar] if x.key == "A": reveal_type(x) # Revealed type is 'Foo' else: reveal_type(x) # Revealed type is 'Bar' ``` Previously, this was possible to do only with Enum Literals and the `is` operator, which is perhaps not very intuitive. The main limitation with this pull request is that it'll perform narrowing only if either the LHS or RHS contains an explicit Literal type somewhere. If this limitation is not present, we end up breaking a decent amount of real-world code -- mostly tests -- that do something like this: ```python def some_test_case() -> None: worker = Worker() # Without the limitation, we narrow 'worker.state' to # Literal['ready'] in this assert... assert worker.state == 'ready' worker.start() # ...which subsequently causes this second assert to narrow # worker.state to , causing the last line to be # unreachable. assert worker.state == 'running' worker.query() ``` I tried for several weeks to find a more intelligent way around this problem, but everything I tried ended up being either insufficient or super-hacky, so I gave up and went for this brute-force solution. The other main limitation is that we perform narrowing only if both the LHS and RHS do not define custom `__eq__` or `__ne__` methods, but this seems like a more reasonable one to me. Resolves https://github.com/python/mypy/issues/7944. --- mypy/checker.py | 157 +++++++++----- mypy/checkexpr.py | 47 +---- mypy/checkstrformat.py | 35 +--- mypy/typeops.py | 53 ++++- test-data/unit/check-narrowing.test | 305 ++++++++++++++++++++++++++++ test-data/unit/check-optional.test | 20 ++ 6 files changed, 491 insertions(+), 126 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ae829d1157c1..fa07ac4259c0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6,7 +6,7 @@ from typing import ( Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable, - Sequence, Mapping, Generic, AbstractSet + Sequence, Mapping, Generic, AbstractSet, Callable ) from typing_extensions import Final @@ -50,7 +50,8 @@ erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, try_getting_str_literals_from_type, try_getting_int_literals_from_type, tuple_fallback, is_singleton_type, try_expanding_enum_to_union, - true_only, false_only, function_type, TypeVarExtractor, + true_only, false_only, function_type, TypeVarExtractor, custom_special_method, + is_literal_type_like, ) from mypy import message_registry from mypy.subtypes import ( @@ -3890,20 +3891,59 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM partial_type_maps = [] for operator, expr_indices in simplified_operator_list: - if operator in {'is', 'is not'}: - if_map, else_map = self.refine_identity_comparison_expression( - operands, - operand_types, - expr_indices, - narrowable_operand_index_to_hash.keys(), - ) - elif operator in {'==', '!='}: - if_map, else_map = self.refine_equality_comparison_expression( - operands, - operand_types, - expr_indices, - narrowable_operand_index_to_hash.keys(), - ) + if operator in {'is', 'is not', '==', '!='}: + # is_valid_target: + # Controls which types we're allowed to narrow exprs to. Note that + # we cannot use 'is_literal_type_like' in both cases since doing + # 'x = 10000 + 1; x is 10001' is not always True in all Python impls. + # + # coerce_only_in_literal_context: + # If true, coerce types into literal types only if one or more of + # the provided exprs contains an explicit Literal type. This could + # technically be set to any arbitrary value, but it seems being liberal + # with narrowing when using 'is' and conservative when using '==' seems + # to break the least amount of real-world code. + # + # should_narrow_by_identity: + # Set to 'false' only if the user defines custom __eq__ or __ne__ methods + # that could cause identity-based narrowing to produce invalid results. + if operator in {'is', 'is not'}: + is_valid_target = is_singleton_type # type: Callable[[Type], bool] + coerce_only_in_literal_context = False + should_narrow_by_identity = True + else: + is_valid_target = is_exactly_literal_type + coerce_only_in_literal_context = True + + def has_no_custom_eq_checks(t: Type) -> bool: + return not custom_special_method(t, '__eq__', check_all=False) \ + and not custom_special_method(t, '__ne__', check_all=False) + expr_types = [operand_types[i] for i in expr_indices] + should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types)) + + if_map = {} # type: TypeMap + else_map = {} # type: TypeMap + if should_narrow_by_identity: + if_map, else_map = self.refine_identity_comparison_expression( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + is_valid_target, + coerce_only_in_literal_context, + ) + + # Strictly speaking, we should also skip this check if the objects in the expr + # chain have custom __eq__ or __ne__ methods. But we (maybe optimistically) + # assume nobody would actually create a custom objects that considers itself + # equal to None. + if if_map == {} and else_map == {}: + if_map, else_map = self.refine_away_none_in_comparison( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + ) elif operator in {'in', 'not in'}: assert len(expr_indices) == 2 left_index, right_index = expr_indices @@ -4146,8 +4186,10 @@ def refine_identity_comparison_expression(self, operand_types: List[Type], chain_indices: List[int], narrowable_operand_indices: AbstractSet[int], + is_valid_target: Callable[[ProperType], bool], + coerce_only_in_literal_context: bool, ) -> Tuple[TypeMap, TypeMap]: - """Produces conditional type maps refining expressions used in an identity comparison. + """Produces conditional type maps refining exprs used in an identity/equality comparison. The 'operands' and 'operand_types' lists should be the full list of operands used in the overall comparison expression. The 'chain_indices' list is the list of indices @@ -4163,30 +4205,45 @@ def refine_identity_comparison_expression(self, The 'narrowable_operand_indices' parameter is the set of all indices we are allowed to refine the types of: that is, all operands that will potentially be a part of the output TypeMaps. + + Although this function could theoretically try setting the types of the operands + in the chains to the meet, doing that causes too many issues in real-world code. + Instead, we use 'is_valid_target' to identify which of the given chain types + we could plausibly use as the refined type for the expressions in the chain. + + Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing + expressions in the chain to a Literal type. Performing this coercion is sometimes + too aggressive of a narrowing, depending on context. """ - singleton = None # type: Optional[ProperType] - possible_singleton_indices = [] + should_coerce = True + if coerce_only_in_literal_context: + should_coerce = any(is_literal_type_like(operand_types[i]) for i in chain_indices) + + target = None # type: Optional[Type] + possible_target_indices = [] for i in chain_indices: - coerced_type = coerce_to_literal(operand_types[i]) - if not is_singleton_type(coerced_type): + expr_type = operand_types[i] + if should_coerce: + expr_type = coerce_to_literal(expr_type) + if not is_valid_target(get_proper_type(expr_type)): continue - if singleton and not is_same_type(singleton, coerced_type): - # We have multiple disjoint singleton types. So the 'if' branch + if target and not is_same_type(target, expr_type): + # We have multiple disjoint target types. So the 'if' branch # must be unreachable. return None, {} - singleton = coerced_type - possible_singleton_indices.append(i) + target = expr_type + possible_target_indices.append(i) - # There's nothing we can currently infer if none of the operands are singleton types, + # There's nothing we can currently infer if none of the operands are valid targets, # so we end early and infer nothing. - if singleton is None: + if target is None: return {}, {} - # If possible, use an unassignable expression as the singleton. - # We skip refining the type of the singleton below, so ideally we'd + # If possible, use an unassignable expression as the target. + # We skip refining the type of the target below, so ideally we'd # want to pick an expression we were going to skip anyways. singleton_index = -1 - for i in possible_singleton_indices: + for i in possible_target_indices: if i not in narrowable_operand_indices: singleton_index = i @@ -4215,20 +4272,21 @@ def refine_identity_comparison_expression(self, # currently will just mark the whole branch as unreachable if either operand is # narrowed to . if singleton_index == -1: - singleton_index = possible_singleton_indices[-1] + singleton_index = possible_target_indices[-1] enum_name = None - if isinstance(singleton, LiteralType) and singleton.is_enum_literal(): - enum_name = singleton.fallback.type.fullname + target = get_proper_type(target) + if isinstance(target, LiteralType) and target.is_enum_literal(): + enum_name = target.fallback.type.fullname - target_type = [TypeRange(singleton, is_upper_bound=False)] + target_type = [TypeRange(target, is_upper_bound=False)] partial_type_maps = [] for i in chain_indices: - # If we try refining a singleton against itself, conditional_type_map + # If we try refining a type against itself, conditional_type_map # will end up assuming that the 'else' branch is unreachable. This is # typically not what we want: generally the user will intend for the - # singleton type to be some fixed 'sentinel' value and will want to refine + # target type to be some fixed 'sentinel' value and will want to refine # the other exprs against this one instead. if i == singleton_index: continue @@ -4246,17 +4304,16 @@ def refine_identity_comparison_expression(self, return reduce_partial_conditional_maps(partial_type_maps) - def refine_equality_comparison_expression(self, - operands: List[Expression], - operand_types: List[Type], - chain_indices: List[int], - narrowable_operand_indices: AbstractSet[int], - ) -> Tuple[TypeMap, TypeMap]: - """Produces conditional type maps refining expressions used in an equality comparison. + def refine_away_none_in_comparison(self, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + ) -> Tuple[TypeMap, TypeMap]: + """Produces conditional type maps refining away None in an identity/equality chain. - For more details, see the docstring of 'refine_equality_comparison' up above. - The only difference is that this function is for refining equality operations - (e.g. 'a == b == c') instead of identity ('a is b is c'). + For more details about what the different arguments mean, see the + docstring of 'refine_identity_comparison_expression' up above. """ non_optional_types = [] for i in chain_indices: @@ -4749,7 +4806,7 @@ class Foo(Enum): return False parent_type = get_proper_type(parent_type) - member_type = coerce_to_literal(member_type) + member_type = get_proper_type(coerce_to_literal(member_type)) if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType): return False @@ -5540,3 +5597,9 @@ def has_bool_item(typ: ProperType) -> bool: return any(is_named_instance(item, 'builtins.bool') for item in typ.items) return False + + +# TODO: why can't we define this as an inline function? +# Does mypyc not support them? +def is_exactly_literal_type(t: Type) -> bool: + return isinstance(get_proper_type(t), LiteralType) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7edaf7e2ad89..58dcd0105b53 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -32,7 +32,6 @@ YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode, ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, - SYMBOL_FUNCBASE_TYPES ) from mypy.literals import literal from mypy import nodes @@ -51,7 +50,7 @@ from mypy import erasetype from mypy.checkmember import analyze_member_access, type_object_type from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals -from mypy.checkstrformat import StringFormatterChecker, custom_special_method +from mypy.checkstrformat import StringFormatterChecker from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars from mypy.util import split_module_names from mypy.typevars import fill_typevars @@ -59,7 +58,8 @@ from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext from mypy.typeops import ( tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound, - function_type, callable_type, try_getting_str_literals + function_type, callable_type, try_getting_str_literals, custom_special_method, + is_literal_type_like, ) import mypy.errorcodes as codes @@ -4266,24 +4266,6 @@ def merge_typevars_in_callables_by_name( return output, variables -def is_literal_type_like(t: Optional[Type]) -> bool: - """Returns 'true' if the given type context is potentially either a LiteralType, - a Union of LiteralType, or something similar. - """ - t = get_proper_type(t) - if t is None: - return False - elif isinstance(t, LiteralType): - return True - elif isinstance(t, UnionType): - return any(is_literal_type_like(item) for item in t.items) - elif isinstance(t, TypeVarType): - return (is_literal_type_like(t.upper_bound) - or any(is_literal_type_like(item) for item in t.values)) - else: - return False - - def try_getting_literal(typ: Type) -> ProperType: """If possible, get a more precise literal type for a given type.""" typ = get_proper_type(typ) @@ -4305,29 +4287,6 @@ def is_expr_literal_type(node: Expression) -> bool: return False -def custom_equality_method(typ: Type) -> bool: - """Does this type have a custom __eq__() method?""" - typ = get_proper_type(typ) - if isinstance(typ, Instance): - method = typ.type.get('__eq__') - if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): - if method.node.info: - return not method.node.info.fullname.startswith('builtins.') - return False - if isinstance(typ, UnionType): - return any(custom_equality_method(t) for t in typ.items) - if isinstance(typ, TupleType): - return custom_equality_method(tuple_fallback(typ)) - if isinstance(typ, CallableType) and typ.is_type_obj(): - # Look up __eq__ on the metaclass for class objects. - return custom_equality_method(typ.fallback) - if isinstance(typ, AnyType): - # Avoid false positives in uncertain cases. - return True - # TODO: support other types (see ExpressionChecker.has_member())? - return False - - def has_bytes_component(typ: Type, py2: bool = False) -> bool: """Is this one of builtin byte types, or a union that contains it?""" typ = get_proper_type(typ) diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index 6f7647d98846..f89d5d0451b2 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -19,12 +19,12 @@ from mypy.types import ( Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType, - CallableType, LiteralType, get_proper_types + LiteralType, get_proper_types ) from mypy.nodes import ( StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr, IndexExpr, MemberExpr, TempNode, ARG_POS, ARG_STAR, ARG_NAMED, ARG_STAR2, - SYMBOL_FUNCBASE_TYPES, Decorator, Var, Node, MypyFile, ExpressionStmt, NameExpr, IntExpr + Node, MypyFile, ExpressionStmt, NameExpr, IntExpr ) import mypy.errorcodes as codes @@ -35,7 +35,7 @@ from mypy import message_registry from mypy.messages import MessageBuilder from mypy.maptype import map_instance_to_supertype -from mypy.typeops import tuple_fallback +from mypy.typeops import custom_special_method from mypy.subtypes import is_subtype from mypy.parse import parse @@ -961,32 +961,3 @@ def has_type_component(typ: Type, fullname: str) -> bool: elif isinstance(typ, UnionType): return any(has_type_component(t, fullname) for t in typ.relevant_items()) return False - - -def custom_special_method(typ: Type, name: str, - check_all: bool = False) -> bool: - """Does this type have a custom special method such as __format__() or __eq__()? - - If check_all is True ensure all items of a union have a custom method, not just some. - """ - typ = get_proper_type(typ) - if isinstance(typ, Instance): - method = typ.type.get(name) - if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): - if method.node.info: - return not method.node.info.fullname.startswith('builtins.') - return False - if isinstance(typ, UnionType): - if check_all: - return all(custom_special_method(t, name, check_all) for t in typ.items) - return any(custom_special_method(t, name) for t in typ.items) - if isinstance(typ, TupleType): - return custom_special_method(tuple_fallback(typ), name) - if isinstance(typ, CallableType) and typ.is_type_obj(): - # Look up __method__ on the metaclass for class objects. - return custom_special_method(typ.fallback, name) - if isinstance(typ, AnyType): - # Avoid false positives in uncertain cases. - return True - # TODO: support other types (see ExpressionChecker.has_member())? - return False diff --git a/mypy/typeops.py b/mypy/typeops.py index 266a0fa0bb88..828791333f36 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -17,7 +17,7 @@ ) from mypy.nodes import ( FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS, - Expression, StrExpr, Var + Expression, StrExpr, Var, Decorator, SYMBOL_FUNCBASE_TYPES ) from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance, expand_type @@ -564,6 +564,24 @@ def try_getting_literals_from_type(typ: Type, return literals +def is_literal_type_like(t: Optional[Type]) -> bool: + """Returns 'true' if the given type context is potentially either a LiteralType, + a Union of LiteralType, or something similar. + """ + t = get_proper_type(t) + if t is None: + return False + elif isinstance(t, LiteralType): + return True + elif isinstance(t, UnionType): + return any(is_literal_type_like(item) for item in t.items) + elif isinstance(t, TypeVarType): + return (is_literal_type_like(t.upper_bound) + or any(is_literal_type_like(item) for item in t.values)) + else: + return False + + def get_enum_values(typ: Instance) -> List[str]: """Return the list of values for an Enum.""" return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)] @@ -640,10 +658,11 @@ class Status(Enum): return typ -def coerce_to_literal(typ: Type) -> ProperType: +def coerce_to_literal(typ: Type) -> Type: """Recursively converts any Instances that have a last_known_value or are instances of enum types with a single value into the corresponding LiteralType. """ + original_type = typ typ = get_proper_type(typ) if isinstance(typ, UnionType): new_items = [coerce_to_literal(item) for item in typ.items] @@ -655,7 +674,7 @@ def coerce_to_literal(typ: Type) -> ProperType: enum_values = get_enum_values(typ) if len(enum_values) == 1: return LiteralType(value=enum_values[0], fallback=typ) - return typ + return original_type def get_type_vars(tp: Type) -> List[TypeVarType]: @@ -674,3 +693,31 @@ def _merge(self, iter: Iterable[List[TypeVarType]]) -> List[TypeVarType]: def visit_type_var(self, t: TypeVarType) -> List[TypeVarType]: return [t] + + +def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool: + """Does this type have a custom special method such as __format__() or __eq__()? + + If check_all is True ensure all items of a union have a custom method, not just some. + """ + typ = get_proper_type(typ) + if isinstance(typ, Instance): + method = typ.type.get(name) + if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): + if method.node.info: + return not method.node.info.fullname.startswith('builtins.') + return False + if isinstance(typ, UnionType): + if check_all: + return all(custom_special_method(t, name, check_all) for t in typ.items) + return any(custom_special_method(t, name) for t in typ.items) + if isinstance(typ, TupleType): + return custom_special_method(tuple_fallback(typ), name, check_all) + if isinstance(typ, CallableType) and typ.is_type_obj(): + # Look up __method__ on the metaclass for class objects. + return custom_special_method(typ.fallback, name, check_all) + if isinstance(typ, AnyType): + # Avoid false positives in uncertain cases. + return True + # TODO: support other types (see ExpressionChecker.has_member())? + return False diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index e0a0cb660c80..269938e8309f 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1,3 +1,86 @@ +[case testNarrowingParentWithStrsBasic] +from dataclasses import dataclass +from typing import NamedTuple, Tuple, Union +from typing_extensions import Literal, TypedDict + +class Object1: + key: Literal["A"] + foo: int +class Object2: + key: Literal["B"] + bar: str + +@dataclass +class Dataclass1: + key: Literal["A"] + foo: int +@dataclass +class Dataclass2: + key: Literal["B"] + foo: str + +class NamedTuple1(NamedTuple): + key: Literal["A"] + foo: int +class NamedTuple2(NamedTuple): + key: Literal["B"] + foo: str + +Tuple1 = Tuple[Literal["A"], int] +Tuple2 = Tuple[Literal["B"], str] + +class TypedDict1(TypedDict): + key: Literal["A"] + foo: int +class TypedDict2(TypedDict): + key: Literal["B"] + foo: str + +x1: Union[Object1, Object2] +if x1.key == "A": + reveal_type(x1) # N: Revealed type is '__main__.Object1' + reveal_type(x1.key) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x1) # N: Revealed type is '__main__.Object2' + reveal_type(x1.key) # N: Revealed type is 'Literal['B']' + +x2: Union[Dataclass1, Dataclass2] +if x2.key == "A": + reveal_type(x2) # N: Revealed type is '__main__.Dataclass1' + reveal_type(x2.key) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x2) # N: Revealed type is '__main__.Dataclass2' + reveal_type(x2.key) # N: Revealed type is 'Literal['B']' + +x3: Union[NamedTuple1, NamedTuple2] +if x3.key == "A": + reveal_type(x3) # N: Revealed type is 'Tuple[Literal['A'], builtins.int, fallback=__main__.NamedTuple1]' + reveal_type(x3.key) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x3) # N: Revealed type is 'Tuple[Literal['B'], builtins.str, fallback=__main__.NamedTuple2]' + reveal_type(x3.key) # N: Revealed type is 'Literal['B']' +if x3[0] == "A": + reveal_type(x3) # N: Revealed type is 'Tuple[Literal['A'], builtins.int, fallback=__main__.NamedTuple1]' + reveal_type(x3[0]) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x3) # N: Revealed type is 'Tuple[Literal['B'], builtins.str, fallback=__main__.NamedTuple2]' + reveal_type(x3[0]) # N: Revealed type is 'Literal['B']' + +x4: Union[Tuple1, Tuple2] +if x4[0] == "A": + reveal_type(x4) # N: Revealed type is 'Tuple[Literal['A'], builtins.int]' + reveal_type(x4[0]) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x4) # N: Revealed type is 'Tuple[Literal['B'], builtins.str]' + reveal_type(x4[0]) # N: Revealed type is 'Literal['B']' + +x5: Union[TypedDict1, TypedDict2] +if x5["key"] == "A": + reveal_type(x5) # N: Revealed type is 'TypedDict('__main__.TypedDict1', {'key': Literal['A'], 'foo': builtins.int})' +else: + reveal_type(x5) # N: Revealed type is 'TypedDict('__main__.TypedDict2', {'key': Literal['B'], 'foo': builtins.str})' +[builtins fixtures/primitives.pyi] + [case testNarrowingParentWithEnumsBasic] from enum import Enum from dataclasses import dataclass @@ -445,3 +528,225 @@ if y["model"]["key"] is Key.C: else: reveal_type(y) # N: Revealed type is 'Union[TypedDict('__main__.Parent1', {'model': TypedDict('__main__.Model1', {'key': Literal[__main__.Key.A]}), 'foo': builtins.int}), TypedDict('__main__.Parent2', {'model': TypedDict('__main__.Model2', {'key': Literal[__main__.Key.B]}), 'bar': builtins.str})]' reveal_type(y["model"]) # N: Revealed type is 'Union[TypedDict('__main__.Model1', {'key': Literal[__main__.Key.A]}), TypedDict('__main__.Model2', {'key': Literal[__main__.Key.B]})]' + +[case testNarrowingEqualityFlipFlop] +# flags: --warn-unreachable --strict-equality +from typing_extensions import Literal, Final +from enum import Enum + +class State(Enum): + A = 1 + B = 2 + +class FlipFlopEnum: + def __init__(self) -> None: + self.state = State.A + + def mutate(self) -> None: + self.state = State.B if self.state == State.A else State.A + +class FlipFlopStr: + def __init__(self) -> None: + self.state = "state-1" + + def mutate(self) -> None: + self.state = "state-2" if self.state == "state-1" else "state-1" + +def test1(switch: FlipFlopEnum) -> None: + # Naively, we might assume the 'assert' here would narrow the type to + # Literal[State.A]. However, doing this ends up breaking a fair number of real-world + # code (usually test cases) that looks similar to this function: e.g. checks + # to make sure a field was mutated to some particular value. + # + # And since mypy can't really reason about state mutation, we take a conservative + # approach and avoid narrowing anything here. + + assert switch.state == State.A + reveal_type(switch.state) # N: Revealed type is '__main__.State' + + switch.mutate() + + assert switch.state == State.B + reveal_type(switch.state) # N: Revealed type is '__main__.State' + +def test2(switch: FlipFlopEnum) -> None: + # So strictly speaking, we ought to do the same thing with 'is' comparisons + # for the same reasons as above. But in practice, not too many people seem to + # know that doing 'some_enum is MyEnum.Value' is idiomatic. So in practice, + # this is probably good enough for now. + + assert switch.state is State.A + reveal_type(switch.state) # N: Revealed type is 'Literal[__main__.State.A]' + + switch.mutate() + + assert switch.state is State.B # E: Non-overlapping identity check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]") + reveal_type(switch.state) # E: Statement is unreachable + +def test3(switch: FlipFlopStr) -> None: + # This is the same thing as 'test1', except we try using str literals. + + assert switch.state == "state-1" + reveal_type(switch.state) # N: Revealed type is 'builtins.str' + + switch.mutate() + + assert switch.state == "state-2" + reveal_type(switch.state) # N: Revealed type is 'builtins.str' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityRequiresExplicitStrLiteral] +# flags: --strict-optional +from typing_extensions import Literal, Final + +A_final: Final = "A" +A_literal: Literal["A"] + +# Neither the LHS nor the RHS are explicit literals, so regrettably nothing +# is narrowed here -- see 'testNarrowingEqualityFlipFlop' for an example of +# why more precise inference here is problematic. +x_str: str +if x_str == A_final: + reveal_type(x_str) # N: Revealed type is 'builtins.str' +else: + reveal_type(x_str) # N: Revealed type is 'builtins.str' +reveal_type(x_str) # N: Revealed type is 'builtins.str' + +# But the RHS is a literal, so we can at least narrow the 'if' case now. +if x_str == A_literal: + reveal_type(x_str) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x_str) # N: Revealed type is 'builtins.str' +reveal_type(x_str) # N: Revealed type is 'builtins.str' + +# But in these two cases, the LHS is a literal/literal-like type. So we +# assume the user *does* want literal-based narrowing and narrow accordingly +# regardless of whether the RHS is an explicit literal or not. +x_union: Literal["A", "B", None] +if x_union == A_final: + reveal_type(x_union) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x_union) # N: Revealed type is 'Union[Literal['B'], None]' +reveal_type(x_union) # N: Revealed type is 'Union[Literal['A'], Literal['B'], None]' + +if x_union == A_literal: + reveal_type(x_union) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x_union) # N: Revealed type is 'Union[Literal['B'], None]' +reveal_type(x_union) # N: Revealed type is 'Union[Literal['A'], Literal['B'], None]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityRequiresExplicitEnumLiteral] +# flags: --strict-optional +from typing_extensions import Literal, Final +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + +A_final: Final = Foo.A +A_literal: Literal[Foo.A] + +# See comments in testNarrowingEqualityRequiresExplicitStrLiteral and +# testNarrowingEqualityFlipFlop for more on why we can't narrow here. +x1: Foo +if x1 == Foo.A: + reveal_type(x1) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x1) # N: Revealed type is '__main__.Foo' + +x2: Foo +if x2 == A_final: + reveal_type(x2) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x2) # N: Revealed type is '__main__.Foo' + +# But we let this narrow since there's an explicit literal in the RHS. +x3: Foo +if x3 == A_literal: + reveal_type(x3) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x3) # N: Revealed type is 'Literal[__main__.Foo.B]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityDisabledForCustomEquality] +from typing import Union +from typing_extensions import Literal +from enum import Enum + +class Custom: + def __eq__(self, other: object) -> bool: return True + +class Default: pass + +x1: Union[Custom, Literal[1], Literal[2]] +if x1 == 1: + reveal_type(x1) # N: Revealed type is 'Union[__main__.Custom, Literal[1], Literal[2]]' +else: + reveal_type(x1) # N: Revealed type is 'Union[__main__.Custom, Literal[1], Literal[2]]' + +x2: Union[Default, Literal[1], Literal[2]] +if x2 == 1: + reveal_type(x2) # N: Revealed type is 'Literal[1]' +else: + reveal_type(x2) # N: Revealed type is 'Union[__main__.Default, Literal[2]]' + +class CustomEnum(Enum): + A = 1 + B = 2 + + def __eq__(self, other: object) -> bool: return True + +x3: CustomEnum +key: Literal[CustomEnum.A] +if x3 == key: + reveal_type(x3) # N: Revealed type is '__main__.CustomEnum' +else: + reveal_type(x3) # N: Revealed type is '__main__.CustomEnum' + +# For comparison, this narrows since we bypass __eq__ +if x3 is key: + reveal_type(x3) # N: Revealed type is 'Literal[__main__.CustomEnum.A]' +else: + reveal_type(x3) # N: Revealed type is 'Literal[__main__.CustomEnum.B]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityDisabledForCustomEqualityChain] +# flags: --strict-optional --strict-equality --warn-unreachable +from typing import Union +from typing_extensions import Literal + +class Custom: + def __eq__(self, other: object) -> bool: return True + +class Default: pass + +x: Literal[1, 2, None] +y: Custom +z: Default + +# We could maybe try doing something clever, but for simplicity we +# treat the whole chain as contaminated and mostly disable narrowing. +# +# The only exception is that we do at least strip away the 'None'. We +# (perhaps optimistically) assume no custom class would be pathological +# enough to declare itself to be equal to None and so permit this narrowing, +# since it's often convenient in practice. +if 1 == x == y: + reveal_type(x) # N: Revealed type is 'Union[Literal[1], Literal[2]]' + reveal_type(y) # N: Revealed type is '__main__.Custom' +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[1], Literal[2], None]' + reveal_type(y) # N: Revealed type is '__main__.Custom' + +# No contamination here +if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Union[Literal[1], Literal[2], None]", right operand type: "Default") + reveal_type(x) # E: Statement is unreachable + reveal_type(z) +else: + # TODO: Why do we narrow away 'Literal[1]' here? + # Even if the equality comparison is bogus, we should try and do better here. + reveal_type(x) # N: Revealed type is 'Union[Literal[2], None]' + reveal_type(z) # N: Revealed type is '__main__.Default' +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 9c40d550699e..ab11c42173d4 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -489,6 +489,10 @@ if x == '': reveal_type(x) # N: Revealed type is 'builtins.str' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' +if x is '': + reveal_type(x) # N: Revealed type is 'builtins.str' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithUnion] @@ -498,6 +502,10 @@ if x == '': reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is '': + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithOverlap] @@ -507,6 +515,10 @@ if x == object(): reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is object(): + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsStillOptionalWithNoOverlap] @@ -516,6 +528,10 @@ if x == 0: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' +if x is 0: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' [builtins fixtures/ops.pyi] [case testInferEqualsStillOptionalWithBothOptional] @@ -526,6 +542,10 @@ if x == y: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is y: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithMultipleArgs] From 96c0292ed8d223690ee6f94666a7380c93ba2975 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Tue, 7 Jan 2020 10:30:44 -0800 Subject: [PATCH 2/4] Respond to code review; remove concept of 'partial' TypeMaps --- mypy/checker.py | 69 +++------ test-data/unit/check-enum.test | 57 +++++--- test-data/unit/check-narrowing.test | 216 +++++++++++++++++++++++++++- 3 files changed, 268 insertions(+), 74 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index fa07ac4259c0..69f78817f34a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3895,7 +3895,8 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM # is_valid_target: # Controls which types we're allowed to narrow exprs to. Note that # we cannot use 'is_literal_type_like' in both cases since doing - # 'x = 10000 + 1; x is 10001' is not always True in all Python impls. + # 'x = 10000 + 1; x is 10001' is not always True in all Python + # implementations. # # coerce_only_in_literal_context: # If true, coerce types into literal types only if one or more of @@ -3916,12 +3917,12 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM coerce_only_in_literal_context = True def has_no_custom_eq_checks(t: Type) -> bool: - return not custom_special_method(t, '__eq__', check_all=False) \ - and not custom_special_method(t, '__ne__', check_all=False) + return (not custom_special_method(t, '__eq__', check_all=False) + and not custom_special_method(t, '__ne__', check_all=False)) expr_types = [operand_types[i] for i in expr_indices] should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types)) - if_map = {} # type: TypeMap + if_map = {} # type: TypeMap else_map = {} # type: TypeMap if should_narrow_by_identity: if_map, else_map = self.refine_identity_comparison_expression( @@ -3976,7 +3977,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: partial_type_maps.append((if_map, else_map)) - return reduce_partial_conditional_maps(partial_type_maps) + return reduce_conditional_maps(partial_type_maps) elif isinstance(node, RefExpr): # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively @@ -4189,7 +4190,7 @@ def refine_identity_comparison_expression(self, is_valid_target: Callable[[ProperType], bool], coerce_only_in_literal_context: bool, ) -> Tuple[TypeMap, TypeMap]: - """Produces conditional type maps refining exprs used in an identity/equality comparison. + """Produce conditional type maps refining expressions by an identity/equality comparison. The 'operands' and 'operand_types' lists should be the full list of operands used in the overall comparison expression. The 'chain_indices' list is the list of indices @@ -4302,7 +4303,7 @@ def refine_identity_comparison_expression(self, expr_type = try_expanding_enum_to_union(expr_type, enum_name) partial_type_maps.append(conditional_type_map(expr, expr_type, target_type)) - return reduce_partial_conditional_maps(partial_type_maps) + return reduce_conditional_maps(partial_type_maps) def refine_away_none_in_comparison(self, operands: List[Expression], @@ -4908,46 +4909,12 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: return result -def or_partial_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: - """Calculate what information we can learn from the truth of (e1 or e2) - in terms of the information that we can learn from the truth of e1 and - the truth of e2. +def reduce_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], + ) -> Tuple[TypeMap, TypeMap]: + """Reduces a list containing pairs of if/else TypeMaps into a single pair. - Unlike 'or_conditional_maps', we include an expression in the output even - if it exists in only one map: we're assuming both maps are "partial" and - contain information about only some expressions, and so we "or" together - expressions both maps have information on. - """ - - if m1 is None: - return m2 - if m2 is None: - return m1 - # The logic here is a blend between 'and_conditional_maps' - # and 'or_conditional_maps'. We use the high-level logic from the - # former to ensure all expressions make it in the output map, - # but resolve cases where both maps contain info on the same - # expr using the unioning strategy from the latter. - result = m2.copy() - m2_keys = {literal_hash(n2): n2 for n2 in m2} - for n1 in m1: - n2 = m2_keys.get(literal_hash(n1)) - if n2 is None: - result[n1] = m1[n1] - else: - result[n2] = make_simplified_union([m1[n1], result[n2]]) - - return result - - -def reduce_partial_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], - ) -> Tuple[TypeMap, TypeMap]: - """Reduces a list containing pairs of *partial* if/else TypeMaps into a single pair. - - That is, if a expression exists in only one map, we always include it in the output. - We only "and"/"or" together expressions that appear in multiple if/else maps. - - So for example, if we had the input: + We "and" together all of the if TypeMaps and "or" together the else TypeMaps. So + for example, if we had the input: [ ({x: TypeIfX, shared: TypeIfShared1}, {x: TypeElseX, shared: TypeElseShared1}), @@ -4958,11 +4925,14 @@ def reduce_partial_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], ( {x: TypeIfX, y: TypeIfY, shared: PseudoIntersection[TypeIfShared1, TypeIfShared2]}, - {x: TypeElseX, y: TypeElseY, shared: Union[TypeElseShared1, TypeElseShared2]}, + {shared: Union[TypeElseShared1, TypeElseShared2]}, ) ...where "PseudoIntersection[X, Y] == Y" because mypy actually doesn't understand intersections yet, so we settle for just arbitrarily picking the right expr's type. + + We only retain the shared expression in the 'else' case because we don't actually know + whether x was refined or y was refined -- only just that one of the two was refined. """ if len(type_maps) == 0: return {}, {} @@ -4971,10 +4941,9 @@ def reduce_partial_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], else: final_if_map, final_else_map = type_maps[0] for if_map, else_map in type_maps[1:]: - # 'and_conditional_maps' does the same thing for both global and partial type maps, - # which is why we don't need to have an 'and_partial_conditional_maps' function. final_if_map = and_conditional_maps(final_if_map, if_map) - final_else_map = or_partial_conditional_maps(final_else_map, else_map) + final_else_map = or_conditional_maps(final_else_map, else_map) + return final_if_map, final_else_map diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 9d027f47192f..6f4da262d02e 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -978,30 +978,41 @@ class Foo(Enum): x: Foo y: Foo +# We can't narrow anything in the else cases -- what if +# x is Foo.A and y is Foo.B or vice versa, for example? if x is y is Foo.A: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif x is y is Foo.B: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' reveal_type(x) # N: Revealed type is '__main__.Foo' reveal_type(y) # N: Revealed type is '__main__.Foo' if x is Foo.A is y: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif x is Foo.B is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' reveal_type(x) # N: Revealed type is '__main__.Foo' reveal_type(y) # N: Revealed type is '__main__.Foo' if Foo.A is x is y: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B is x is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' reveal_type(x) # N: Revealed type is '__main__.Foo' reveal_type(y) # N: Revealed type is '__main__.Foo' @@ -1026,8 +1037,10 @@ if x is Foo.A < y is Foo.B: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + # Note: we can't narrow in this case. What if both x and y + # are Foo.A, for example? + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' reveal_type(x) # N: Revealed type is '__main__.Foo' reveal_type(y) # N: Revealed type is '__main__.Foo' @@ -1109,11 +1122,13 @@ if x0 is x1 is Foo.A is x2 < x3 is Foo.B is x4 is x5: reveal_type(x4) # N: Revealed type is 'Literal[__main__.Foo.B]' reveal_type(x5) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x0) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' - reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' - reveal_type(x2) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' - - reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' - reveal_type(x4) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' - reveal_type(x5) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' + # We unfortunately can't narrow away anything. For example, + # what if x0 == Foo.A and x1 == Foo.B or vice versa? + reveal_type(x0) # N: Revealed type is '__main__.Foo' + reveal_type(x1) # N: Revealed type is '__main__.Foo' + reveal_type(x2) # N: Revealed type is '__main__.Foo' + + reveal_type(x3) # N: Revealed type is '__main__.Foo' + reveal_type(x4) # N: Revealed type is '__main__.Foo' + reveal_type(x5) # N: Revealed type is '__main__.Foo' [builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 269938e8309f..6c64b241eaaa 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -267,6 +267,88 @@ if x.key is Key.D: else: reveal_type(x) # N: Revealed type is 'Union[__main__.Object1, __main__.Object2]' +[case testNarrowingTypedDictParentMultipleKeys] +# flags: --warn-unreachable +from typing import Union +from typing_extensions import Literal, TypedDict + +class TypedDict1(TypedDict): + key: Literal['A', 'C'] +class TypedDict2(TypedDict): + key: Literal['B', 'C'] + +x: Union[TypedDict1, TypedDict2] +if x['key'] == 'A': + reveal_type(x) # N: Revealed type is 'TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]})' +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]' + +if x['key'] == 'C': + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]' +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]' + +if x['key'] == 'D': + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingPartialTypedDictParentMultipleKeys] +# flags: --warn-unreachable +from typing import Union +from typing_extensions import Literal, TypedDict + +class TypedDict1(TypedDict, total=False): + key: Literal['A', 'C'] +class TypedDict2(TypedDict, total=False): + key: Literal['B', 'C'] + +x: Union[TypedDict1, TypedDict2] +if x['key'] == 'A': + reveal_type(x) # N: Revealed type is 'TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]})' +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]' + +if x['key'] == 'C': + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]' +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]' + +if x['key'] == 'D': + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingNestedTypedDicts] +from typing import Union +from typing_extensions import TypedDict, Literal + +class A(TypedDict): + key: Literal['A'] +class B(TypedDict): + key: Literal['B'] +class C(TypedDict): + key: Literal['C'] + +class X(TypedDict): + inner: Union[A, B] +class Y(TypedDict): + inner: Union[B, C] + +unknown: Union[X, Y] +if unknown['inner']['key'] == 'A': + reveal_type(unknown) # N: Revealed type is 'TypedDict('__main__.X', {'inner': Union[TypedDict('__main__.A', {'key': Literal['A']}), TypedDict('__main__.B', {'key': Literal['B']})]})' + reveal_type(unknown['inner']) # N: Revealed type is 'TypedDict('__main__.A', {'key': Literal['A']})' +if unknown['inner']['key'] == 'B': + reveal_type(unknown) # N: Revealed type is 'Union[TypedDict('__main__.X', {'inner': Union[TypedDict('__main__.A', {'key': Literal['A']}), TypedDict('__main__.B', {'key': Literal['B']})]}), TypedDict('__main__.Y', {'inner': Union[TypedDict('__main__.B', {'key': Literal['B']}), TypedDict('__main__.C', {'key': Literal['C']})]})]' + reveal_type(unknown['inner']) # N: Revealed type is 'TypedDict('__main__.B', {'key': Literal['B']})' +if unknown['inner']['key'] == 'C': + reveal_type(unknown) # N: Revealed type is 'TypedDict('__main__.Y', {'inner': Union[TypedDict('__main__.B', {'key': Literal['B']}), TypedDict('__main__.C', {'key': Literal['C']})]})' + reveal_type(unknown['inner']) # N: Revealed type is 'TypedDict('__main__.C', {'key': Literal['C']})' +[builtins fixtures/primitives.pyi] + [case testNarrowingParentWithMultipleParents] from enum import Enum from typing import Union @@ -529,6 +611,42 @@ else: reveal_type(y) # N: Revealed type is 'Union[TypedDict('__main__.Parent1', {'model': TypedDict('__main__.Model1', {'key': Literal[__main__.Key.A]}), 'foo': builtins.int}), TypedDict('__main__.Parent2', {'model': TypedDict('__main__.Model2', {'key': Literal[__main__.Key.B]}), 'bar': builtins.str})]' reveal_type(y["model"]) # N: Revealed type is 'Union[TypedDict('__main__.Model1', {'key': Literal[__main__.Key.A]}), TypedDict('__main__.Model2', {'key': Literal[__main__.Key.B]})]' +[case testNarrowingParentsHierarchyTypedDictWithStr] +# flags: --warn-unreachable +from typing import Union +from typing_extensions import TypedDict, Literal + +class Parent1(TypedDict): + model: Model1 + foo: int + +class Parent2(TypedDict): + model: Model2 + bar: str + +class Model1(TypedDict): + key: Literal['A'] + +class Model2(TypedDict): + key: Literal['B'] + +x: Union[Parent1, Parent2] +if x["model"]["key"] == 'A': + reveal_type(x) # N: Revealed type is 'TypedDict('__main__.Parent1', {'model': TypedDict('__main__.Model1', {'key': Literal['A']}), 'foo': builtins.int})' + reveal_type(x["model"]) # N: Revealed type is 'TypedDict('__main__.Model1', {'key': Literal['A']})' +else: + reveal_type(x) # N: Revealed type is 'TypedDict('__main__.Parent2', {'model': TypedDict('__main__.Model2', {'key': Literal['B']}), 'bar': builtins.str})' + reveal_type(x["model"]) # N: Revealed type is 'TypedDict('__main__.Model2', {'key': Literal['B']})' + +y: Union[Parent1, Parent2] +if y["model"]["key"] == 'C': + reveal_type(y) # E: Statement is unreachable + reveal_type(y["model"]) +else: + reveal_type(y) # N: Revealed type is 'Union[TypedDict('__main__.Parent1', {'model': TypedDict('__main__.Model1', {'key': Literal['A']}), 'foo': builtins.int}), TypedDict('__main__.Parent2', {'model': TypedDict('__main__.Model2', {'key': Literal['B']}), 'bar': builtins.str})]' + reveal_type(y["model"]) # N: Revealed type is 'Union[TypedDict('__main__.Model1', {'key': Literal['A']}), TypedDict('__main__.Model2', {'key': Literal['B']})]' +[builtins fixtures/primitives.pyi] + [case testNarrowingEqualityFlipFlop] # flags: --warn-unreachable --strict-equality from typing_extensions import Literal, Final @@ -606,6 +724,12 @@ A_literal: Literal["A"] # is narrowed here -- see 'testNarrowingEqualityFlipFlop' for an example of # why more precise inference here is problematic. x_str: str +if x_str == "A": + reveal_type(x_str) # N: Revealed type is 'builtins.str' +else: + reveal_type(x_str) # N: Revealed type is 'builtins.str' +reveal_type(x_str) # N: Revealed type is 'builtins.str' + if x_str == A_final: reveal_type(x_str) # N: Revealed type is 'builtins.str' else: @@ -745,8 +869,94 @@ if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Un reveal_type(x) # E: Statement is unreachable reveal_type(z) else: - # TODO: Why do we narrow away 'Literal[1]' here? - # Even if the equality comparison is bogus, we should try and do better here. - reveal_type(x) # N: Revealed type is 'Union[Literal[2], None]' + reveal_type(x) # N: Revealed type is 'Union[Literal[1], Literal[2], None]' reveal_type(z) # N: Revealed type is '__main__.Default' [builtins fixtures/primitives.pyi] + +[case testNarrowingUnreachableCases] +# flags: --strict-optional --strict-equality --warn-unreachable +from typing import Union +from typing_extensions import Literal + +a: Literal[1] +b: Literal[1, 2] +c: Literal[2, 3] + +if a == b == c: + reveal_type(a) # E: Statement is unreachable + reveal_type(b) + reveal_type(c) +else: + reveal_type(a) # N: Revealed type is 'Literal[1]' + reveal_type(b) # N: Revealed type is 'Union[Literal[1], Literal[2]]' + reveal_type(c) # N: Revealed type is 'Union[Literal[2], Literal[3]]' + +if a == a == a: + reveal_type(a) # N: Revealed type is 'Literal[1]' +else: + reveal_type(a) # E: Statement is unreachable + +if a == a == b: + reveal_type(a) # N: Revealed type is 'Literal[1]' + reveal_type(b) # N: Revealed type is 'Literal[1]' +else: + reveal_type(a) # N: Revealed type is 'Literal[1]' + reveal_type(b) # N: Revealed type is 'Literal[2]' + +# In this case, it's ok for 'b' to narrow down to Literal[1] in the else case +# since that's the only way 'b == 2' can be false +if b == 2: + reveal_type(b) # N: Revealed type is 'Literal[2]' +else: + reveal_type(b) # N: Revealed type is 'Literal[1]' + +# But in this case, we can't conclude anything about the else case. This expression +# could end up being either '2 == 2 == 3' or '1 == 2 == 2', which means we can't +# conclude anything. +if b == 2 == c: + reveal_type(b) # N: Revealed type is 'Literal[2]' + reveal_type(c) # N: Revealed type is 'Literal[2]' +else: + reveal_type(b) # N: Revealed type is 'Union[Literal[1], Literal[2]]' + reveal_type(c) # N: Revealed type is 'Union[Literal[2], Literal[3]]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingUnreachableCases2] +# flags: --strict-optional --strict-equality --warn-unreachable +from typing import Union +from typing_extensions import Literal + +a: Literal[1, 2, 3, 4] +b: Literal[1, 2, 3, 4] + +if a == b == 1: + reveal_type(a) # N: Revealed type is 'Literal[1]' + reveal_type(b) # N: Revealed type is 'Literal[1]' +elif a == b == 2: + reveal_type(a) # N: Revealed type is 'Literal[2]' + reveal_type(b) # N: Revealed type is 'Literal[2]' +elif a == b == 3: + reveal_type(a) # N: Revealed type is 'Literal[3]' + reveal_type(b) # N: Revealed type is 'Literal[3]' +elif a == b == 4: + reveal_type(a) # N: Revealed type is 'Literal[4]' + reveal_type(b) # N: Revealed type is 'Literal[4]' +else: + # This branch is reachable if a == 1 and b == 2, for example. + reveal_type(a) # N: Revealed type is 'Union[Literal[1], Literal[2], Literal[3], Literal[4]]' + reveal_type(b) # N: Revealed type is 'Union[Literal[1], Literal[2], Literal[3], Literal[4]]' + +if a == a == 1: + reveal_type(a) # N: Revealed type is 'Literal[1]' +elif a == a == 2: + reveal_type(a) # N: Revealed type is 'Literal[2]' +elif a == a == 3: + reveal_type(a) # N: Revealed type is 'Literal[3]' +elif a == a == 4: + reveal_type(a) # N: Revealed type is 'Literal[4]' +else: + # In contrast, this branch must be unreachable: we assume (maybe naively) + # that 'a' won't be mutated in the middle of the expression. + reveal_type(a) # E: Statement is unreachable + reveal_type(b) +[builtins fixtures/primitives.pyi] From 7b08bd995c04bb5a8529ec50854f28e6f98e26b1 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 8 Jan 2020 00:27:58 -0800 Subject: [PATCH 3/4] Remove todo about inline functions --- mypy/checker.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 69f78817f34a..7dee8bfd3b9b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3913,12 +3913,16 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM coerce_only_in_literal_context = False should_narrow_by_identity = True else: - is_valid_target = is_exactly_literal_type - coerce_only_in_literal_context = True + def is_exactly_literal_type(t: Type) -> bool: + return isinstance(get_proper_type(t), LiteralType) def has_no_custom_eq_checks(t: Type) -> bool: return (not custom_special_method(t, '__eq__', check_all=False) - and not custom_special_method(t, '__ne__', check_all=False)) + and not custom_special_method(t, '__ne__', check_all=False)) + + is_valid_target = is_exactly_literal_type + coerce_only_in_literal_context = True + expr_types = [operand_types[i] for i in expr_indices] should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types)) @@ -5566,9 +5570,3 @@ def has_bool_item(typ: ProperType) -> bool: return any(is_named_instance(item, 'builtins.bool') for item in typ.items) return False - - -# TODO: why can't we define this as an inline function? -# Does mypyc not support them? -def is_exactly_literal_type(t: Type) -> bool: - return isinstance(get_proper_type(t), LiteralType) From e2ea84c9f71eb4cf8aed5cb2eb1aa52867c9ecf7 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 8 Jan 2020 07:50:40 -0800 Subject: [PATCH 4/4] Fix test formatting --- test-data/unit/check-enum.test | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 6f4da262d02e..18130d2d818c 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -987,10 +987,10 @@ elif x is y is Foo.B: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is '__main__.Foo' - reveal_type(y) # N: Revealed type is '__main__.Foo' -reveal_type(x) # N: Revealed type is '__main__.Foo' -reveal_type(y) # N: Revealed type is '__main__.Foo' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' if x is Foo.A is y: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -999,10 +999,10 @@ elif x is Foo.B is y: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is '__main__.Foo' - reveal_type(y) # N: Revealed type is '__main__.Foo' -reveal_type(x) # N: Revealed type is '__main__.Foo' -reveal_type(y) # N: Revealed type is '__main__.Foo' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' if Foo.A is x is y: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -1011,10 +1011,10 @@ elif Foo.B is x is y: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is '__main__.Foo' - reveal_type(y) # N: Revealed type is '__main__.Foo' -reveal_type(x) # N: Revealed type is '__main__.Foo' -reveal_type(y) # N: Revealed type is '__main__.Foo' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' [builtins fixtures/primitives.pyi]