diff --git a/mypy/checker.py b/mypy/checker.py index ae829d1157c1..7dee8bfd3b9b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6,7 +6,7 @@ from typing import ( Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable, - Sequence, Mapping, Generic, AbstractSet + Sequence, Mapping, Generic, AbstractSet, Callable ) from typing_extensions import Final @@ -50,7 +50,8 @@ erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, try_getting_str_literals_from_type, try_getting_int_literals_from_type, tuple_fallback, is_singleton_type, try_expanding_enum_to_union, - true_only, false_only, function_type, TypeVarExtractor, + true_only, false_only, function_type, TypeVarExtractor, custom_special_method, + is_literal_type_like, ) from mypy import message_registry from mypy.subtypes import ( @@ -3890,20 +3891,64 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM partial_type_maps = [] for operator, expr_indices in simplified_operator_list: - if operator in {'is', 'is not'}: - if_map, else_map = self.refine_identity_comparison_expression( - operands, - operand_types, - expr_indices, - narrowable_operand_index_to_hash.keys(), - ) - elif operator in {'==', '!='}: - if_map, else_map = self.refine_equality_comparison_expression( - operands, - operand_types, - expr_indices, - narrowable_operand_index_to_hash.keys(), - ) + if operator in {'is', 'is not', '==', '!='}: + # is_valid_target: + # Controls which types we're allowed to narrow exprs to. Note that + # we cannot use 'is_literal_type_like' in both cases since doing + # 'x = 10000 + 1; x is 10001' is not always True in all Python + # implementations. + # + # coerce_only_in_literal_context: + # If true, coerce types into literal types only if one or more of + # the provided exprs contains an explicit Literal type. This could + # technically be set to any arbitrary value, but it seems being liberal + # with narrowing when using 'is' and conservative when using '==' seems + # to break the least amount of real-world code. + # + # should_narrow_by_identity: + # Set to 'false' only if the user defines custom __eq__ or __ne__ methods + # that could cause identity-based narrowing to produce invalid results. + if operator in {'is', 'is not'}: + is_valid_target = is_singleton_type # type: Callable[[Type], bool] + coerce_only_in_literal_context = False + should_narrow_by_identity = True + else: + def is_exactly_literal_type(t: Type) -> bool: + return isinstance(get_proper_type(t), LiteralType) + + def has_no_custom_eq_checks(t: Type) -> bool: + return (not custom_special_method(t, '__eq__', check_all=False) + and not custom_special_method(t, '__ne__', check_all=False)) + + is_valid_target = is_exactly_literal_type + coerce_only_in_literal_context = True + + expr_types = [operand_types[i] for i in expr_indices] + should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types)) + + if_map = {} # type: TypeMap + else_map = {} # type: TypeMap + if should_narrow_by_identity: + if_map, else_map = self.refine_identity_comparison_expression( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + is_valid_target, + coerce_only_in_literal_context, + ) + + # Strictly speaking, we should also skip this check if the objects in the expr + # chain have custom __eq__ or __ne__ methods. But we (maybe optimistically) + # assume nobody would actually create a custom objects that considers itself + # equal to None. + if if_map == {} and else_map == {}: + if_map, else_map = self.refine_away_none_in_comparison( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + ) elif operator in {'in', 'not in'}: assert len(expr_indices) == 2 left_index, right_index = expr_indices @@ -3936,7 +3981,7 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM partial_type_maps.append((if_map, else_map)) - return reduce_partial_conditional_maps(partial_type_maps) + return reduce_conditional_maps(partial_type_maps) elif isinstance(node, RefExpr): # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively @@ -4146,8 +4191,10 @@ def refine_identity_comparison_expression(self, operand_types: List[Type], chain_indices: List[int], narrowable_operand_indices: AbstractSet[int], + is_valid_target: Callable[[ProperType], bool], + coerce_only_in_literal_context: bool, ) -> Tuple[TypeMap, TypeMap]: - """Produces conditional type maps refining expressions used in an identity comparison. + """Produce conditional type maps refining expressions by an identity/equality comparison. The 'operands' and 'operand_types' lists should be the full list of operands used in the overall comparison expression. The 'chain_indices' list is the list of indices @@ -4163,30 +4210,45 @@ def refine_identity_comparison_expression(self, The 'narrowable_operand_indices' parameter is the set of all indices we are allowed to refine the types of: that is, all operands that will potentially be a part of the output TypeMaps. + + Although this function could theoretically try setting the types of the operands + in the chains to the meet, doing that causes too many issues in real-world code. + Instead, we use 'is_valid_target' to identify which of the given chain types + we could plausibly use as the refined type for the expressions in the chain. + + Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing + expressions in the chain to a Literal type. Performing this coercion is sometimes + too aggressive of a narrowing, depending on context. """ - singleton = None # type: Optional[ProperType] - possible_singleton_indices = [] + should_coerce = True + if coerce_only_in_literal_context: + should_coerce = any(is_literal_type_like(operand_types[i]) for i in chain_indices) + + target = None # type: Optional[Type] + possible_target_indices = [] for i in chain_indices: - coerced_type = coerce_to_literal(operand_types[i]) - if not is_singleton_type(coerced_type): + expr_type = operand_types[i] + if should_coerce: + expr_type = coerce_to_literal(expr_type) + if not is_valid_target(get_proper_type(expr_type)): continue - if singleton and not is_same_type(singleton, coerced_type): - # We have multiple disjoint singleton types. So the 'if' branch + if target and not is_same_type(target, expr_type): + # We have multiple disjoint target types. So the 'if' branch # must be unreachable. return None, {} - singleton = coerced_type - possible_singleton_indices.append(i) + target = expr_type + possible_target_indices.append(i) - # There's nothing we can currently infer if none of the operands are singleton types, + # There's nothing we can currently infer if none of the operands are valid targets, # so we end early and infer nothing. - if singleton is None: + if target is None: return {}, {} - # If possible, use an unassignable expression as the singleton. - # We skip refining the type of the singleton below, so ideally we'd + # If possible, use an unassignable expression as the target. + # We skip refining the type of the target below, so ideally we'd # want to pick an expression we were going to skip anyways. singleton_index = -1 - for i in possible_singleton_indices: + for i in possible_target_indices: if i not in narrowable_operand_indices: singleton_index = i @@ -4215,20 +4277,21 @@ def refine_identity_comparison_expression(self, # currently will just mark the whole branch as unreachable if either operand is # narrowed to . if singleton_index == -1: - singleton_index = possible_singleton_indices[-1] + singleton_index = possible_target_indices[-1] enum_name = None - if isinstance(singleton, LiteralType) and singleton.is_enum_literal(): - enum_name = singleton.fallback.type.fullname + target = get_proper_type(target) + if isinstance(target, LiteralType) and target.is_enum_literal(): + enum_name = target.fallback.type.fullname - target_type = [TypeRange(singleton, is_upper_bound=False)] + target_type = [TypeRange(target, is_upper_bound=False)] partial_type_maps = [] for i in chain_indices: - # If we try refining a singleton against itself, conditional_type_map + # If we try refining a type against itself, conditional_type_map # will end up assuming that the 'else' branch is unreachable. This is # typically not what we want: generally the user will intend for the - # singleton type to be some fixed 'sentinel' value and will want to refine + # target type to be some fixed 'sentinel' value and will want to refine # the other exprs against this one instead. if i == singleton_index: continue @@ -4244,19 +4307,18 @@ def refine_identity_comparison_expression(self, expr_type = try_expanding_enum_to_union(expr_type, enum_name) partial_type_maps.append(conditional_type_map(expr, expr_type, target_type)) - return reduce_partial_conditional_maps(partial_type_maps) + return reduce_conditional_maps(partial_type_maps) - def refine_equality_comparison_expression(self, - operands: List[Expression], - operand_types: List[Type], - chain_indices: List[int], - narrowable_operand_indices: AbstractSet[int], - ) -> Tuple[TypeMap, TypeMap]: - """Produces conditional type maps refining expressions used in an equality comparison. + def refine_away_none_in_comparison(self, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + ) -> Tuple[TypeMap, TypeMap]: + """Produces conditional type maps refining away None in an identity/equality chain. - For more details, see the docstring of 'refine_equality_comparison' up above. - The only difference is that this function is for refining equality operations - (e.g. 'a == b == c') instead of identity ('a is b is c'). + For more details about what the different arguments mean, see the + docstring of 'refine_identity_comparison_expression' up above. """ non_optional_types = [] for i in chain_indices: @@ -4749,7 +4811,7 @@ class Foo(Enum): return False parent_type = get_proper_type(parent_type) - member_type = coerce_to_literal(member_type) + member_type = get_proper_type(coerce_to_literal(member_type)) if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType): return False @@ -4851,46 +4913,12 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: return result -def or_partial_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: - """Calculate what information we can learn from the truth of (e1 or e2) - in terms of the information that we can learn from the truth of e1 and - the truth of e2. - - Unlike 'or_conditional_maps', we include an expression in the output even - if it exists in only one map: we're assuming both maps are "partial" and - contain information about only some expressions, and so we "or" together - expressions both maps have information on. - """ - - if m1 is None: - return m2 - if m2 is None: - return m1 - # The logic here is a blend between 'and_conditional_maps' - # and 'or_conditional_maps'. We use the high-level logic from the - # former to ensure all expressions make it in the output map, - # but resolve cases where both maps contain info on the same - # expr using the unioning strategy from the latter. - result = m2.copy() - m2_keys = {literal_hash(n2): n2 for n2 in m2} - for n1 in m1: - n2 = m2_keys.get(literal_hash(n1)) - if n2 is None: - result[n1] = m1[n1] - else: - result[n2] = make_simplified_union([m1[n1], result[n2]]) - - return result - - -def reduce_partial_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], - ) -> Tuple[TypeMap, TypeMap]: - """Reduces a list containing pairs of *partial* if/else TypeMaps into a single pair. - - That is, if a expression exists in only one map, we always include it in the output. - We only "and"/"or" together expressions that appear in multiple if/else maps. +def reduce_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], + ) -> Tuple[TypeMap, TypeMap]: + """Reduces a list containing pairs of if/else TypeMaps into a single pair. - So for example, if we had the input: + We "and" together all of the if TypeMaps and "or" together the else TypeMaps. So + for example, if we had the input: [ ({x: TypeIfX, shared: TypeIfShared1}, {x: TypeElseX, shared: TypeElseShared1}), @@ -4901,11 +4929,14 @@ def reduce_partial_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], ( {x: TypeIfX, y: TypeIfY, shared: PseudoIntersection[TypeIfShared1, TypeIfShared2]}, - {x: TypeElseX, y: TypeElseY, shared: Union[TypeElseShared1, TypeElseShared2]}, + {shared: Union[TypeElseShared1, TypeElseShared2]}, ) ...where "PseudoIntersection[X, Y] == Y" because mypy actually doesn't understand intersections yet, so we settle for just arbitrarily picking the right expr's type. + + We only retain the shared expression in the 'else' case because we don't actually know + whether x was refined or y was refined -- only just that one of the two was refined. """ if len(type_maps) == 0: return {}, {} @@ -4914,10 +4945,9 @@ def reduce_partial_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], else: final_if_map, final_else_map = type_maps[0] for if_map, else_map in type_maps[1:]: - # 'and_conditional_maps' does the same thing for both global and partial type maps, - # which is why we don't need to have an 'and_partial_conditional_maps' function. final_if_map = and_conditional_maps(final_if_map, if_map) - final_else_map = or_partial_conditional_maps(final_else_map, else_map) + final_else_map = or_conditional_maps(final_else_map, else_map) + return final_if_map, final_else_map diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7edaf7e2ad89..58dcd0105b53 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -32,7 +32,6 @@ YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode, ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, - SYMBOL_FUNCBASE_TYPES ) from mypy.literals import literal from mypy import nodes @@ -51,7 +50,7 @@ from mypy import erasetype from mypy.checkmember import analyze_member_access, type_object_type from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals -from mypy.checkstrformat import StringFormatterChecker, custom_special_method +from mypy.checkstrformat import StringFormatterChecker from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars from mypy.util import split_module_names from mypy.typevars import fill_typevars @@ -59,7 +58,8 @@ from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext from mypy.typeops import ( tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound, - function_type, callable_type, try_getting_str_literals + function_type, callable_type, try_getting_str_literals, custom_special_method, + is_literal_type_like, ) import mypy.errorcodes as codes @@ -4266,24 +4266,6 @@ def merge_typevars_in_callables_by_name( return output, variables -def is_literal_type_like(t: Optional[Type]) -> bool: - """Returns 'true' if the given type context is potentially either a LiteralType, - a Union of LiteralType, or something similar. - """ - t = get_proper_type(t) - if t is None: - return False - elif isinstance(t, LiteralType): - return True - elif isinstance(t, UnionType): - return any(is_literal_type_like(item) for item in t.items) - elif isinstance(t, TypeVarType): - return (is_literal_type_like(t.upper_bound) - or any(is_literal_type_like(item) for item in t.values)) - else: - return False - - def try_getting_literal(typ: Type) -> ProperType: """If possible, get a more precise literal type for a given type.""" typ = get_proper_type(typ) @@ -4305,29 +4287,6 @@ def is_expr_literal_type(node: Expression) -> bool: return False -def custom_equality_method(typ: Type) -> bool: - """Does this type have a custom __eq__() method?""" - typ = get_proper_type(typ) - if isinstance(typ, Instance): - method = typ.type.get('__eq__') - if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): - if method.node.info: - return not method.node.info.fullname.startswith('builtins.') - return False - if isinstance(typ, UnionType): - return any(custom_equality_method(t) for t in typ.items) - if isinstance(typ, TupleType): - return custom_equality_method(tuple_fallback(typ)) - if isinstance(typ, CallableType) and typ.is_type_obj(): - # Look up __eq__ on the metaclass for class objects. - return custom_equality_method(typ.fallback) - if isinstance(typ, AnyType): - # Avoid false positives in uncertain cases. - return True - # TODO: support other types (see ExpressionChecker.has_member())? - return False - - def has_bytes_component(typ: Type, py2: bool = False) -> bool: """Is this one of builtin byte types, or a union that contains it?""" typ = get_proper_type(typ) diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index 6f7647d98846..f89d5d0451b2 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -19,12 +19,12 @@ from mypy.types import ( Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType, - CallableType, LiteralType, get_proper_types + LiteralType, get_proper_types ) from mypy.nodes import ( StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr, IndexExpr, MemberExpr, TempNode, ARG_POS, ARG_STAR, ARG_NAMED, ARG_STAR2, - SYMBOL_FUNCBASE_TYPES, Decorator, Var, Node, MypyFile, ExpressionStmt, NameExpr, IntExpr + Node, MypyFile, ExpressionStmt, NameExpr, IntExpr ) import mypy.errorcodes as codes @@ -35,7 +35,7 @@ from mypy import message_registry from mypy.messages import MessageBuilder from mypy.maptype import map_instance_to_supertype -from mypy.typeops import tuple_fallback +from mypy.typeops import custom_special_method from mypy.subtypes import is_subtype from mypy.parse import parse @@ -961,32 +961,3 @@ def has_type_component(typ: Type, fullname: str) -> bool: elif isinstance(typ, UnionType): return any(has_type_component(t, fullname) for t in typ.relevant_items()) return False - - -def custom_special_method(typ: Type, name: str, - check_all: bool = False) -> bool: - """Does this type have a custom special method such as __format__() or __eq__()? - - If check_all is True ensure all items of a union have a custom method, not just some. - """ - typ = get_proper_type(typ) - if isinstance(typ, Instance): - method = typ.type.get(name) - if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): - if method.node.info: - return not method.node.info.fullname.startswith('builtins.') - return False - if isinstance(typ, UnionType): - if check_all: - return all(custom_special_method(t, name, check_all) for t in typ.items) - return any(custom_special_method(t, name) for t in typ.items) - if isinstance(typ, TupleType): - return custom_special_method(tuple_fallback(typ), name) - if isinstance(typ, CallableType) and typ.is_type_obj(): - # Look up __method__ on the metaclass for class objects. - return custom_special_method(typ.fallback, name) - if isinstance(typ, AnyType): - # Avoid false positives in uncertain cases. - return True - # TODO: support other types (see ExpressionChecker.has_member())? - return False diff --git a/mypy/typeops.py b/mypy/typeops.py index 266a0fa0bb88..828791333f36 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -17,7 +17,7 @@ ) from mypy.nodes import ( FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS, - Expression, StrExpr, Var + Expression, StrExpr, Var, Decorator, SYMBOL_FUNCBASE_TYPES ) from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance, expand_type @@ -564,6 +564,24 @@ def try_getting_literals_from_type(typ: Type, return literals +def is_literal_type_like(t: Optional[Type]) -> bool: + """Returns 'true' if the given type context is potentially either a LiteralType, + a Union of LiteralType, or something similar. + """ + t = get_proper_type(t) + if t is None: + return False + elif isinstance(t, LiteralType): + return True + elif isinstance(t, UnionType): + return any(is_literal_type_like(item) for item in t.items) + elif isinstance(t, TypeVarType): + return (is_literal_type_like(t.upper_bound) + or any(is_literal_type_like(item) for item in t.values)) + else: + return False + + def get_enum_values(typ: Instance) -> List[str]: """Return the list of values for an Enum.""" return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)] @@ -640,10 +658,11 @@ class Status(Enum): return typ -def coerce_to_literal(typ: Type) -> ProperType: +def coerce_to_literal(typ: Type) -> Type: """Recursively converts any Instances that have a last_known_value or are instances of enum types with a single value into the corresponding LiteralType. """ + original_type = typ typ = get_proper_type(typ) if isinstance(typ, UnionType): new_items = [coerce_to_literal(item) for item in typ.items] @@ -655,7 +674,7 @@ def coerce_to_literal(typ: Type) -> ProperType: enum_values = get_enum_values(typ) if len(enum_values) == 1: return LiteralType(value=enum_values[0], fallback=typ) - return typ + return original_type def get_type_vars(tp: Type) -> List[TypeVarType]: @@ -674,3 +693,31 @@ def _merge(self, iter: Iterable[List[TypeVarType]]) -> List[TypeVarType]: def visit_type_var(self, t: TypeVarType) -> List[TypeVarType]: return [t] + + +def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool: + """Does this type have a custom special method such as __format__() or __eq__()? + + If check_all is True ensure all items of a union have a custom method, not just some. + """ + typ = get_proper_type(typ) + if isinstance(typ, Instance): + method = typ.type.get(name) + if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): + if method.node.info: + return not method.node.info.fullname.startswith('builtins.') + return False + if isinstance(typ, UnionType): + if check_all: + return all(custom_special_method(t, name, check_all) for t in typ.items) + return any(custom_special_method(t, name) for t in typ.items) + if isinstance(typ, TupleType): + return custom_special_method(tuple_fallback(typ), name, check_all) + if isinstance(typ, CallableType) and typ.is_type_obj(): + # Look up __method__ on the metaclass for class objects. + return custom_special_method(typ.fallback, name, check_all) + if isinstance(typ, AnyType): + # Avoid false positives in uncertain cases. + return True + # TODO: support other types (see ExpressionChecker.has_member())? + return False diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 9d027f47192f..18130d2d818c 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -978,32 +978,43 @@ class Foo(Enum): x: Foo y: Foo +# We can't narrow anything in the else cases -- what if +# x is Foo.A and y is Foo.B or vice versa, for example? if x is y is Foo.A: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif x is y is Foo.B: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' -reveal_type(x) # N: Revealed type is '__main__.Foo' -reveal_type(y) # N: Revealed type is '__main__.Foo' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' if x is Foo.A is y: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif x is Foo.B is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' -reveal_type(x) # N: Revealed type is '__main__.Foo' -reveal_type(y) # N: Revealed type is '__main__.Foo' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' if Foo.A is x is y: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B is x is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' -reveal_type(x) # N: Revealed type is '__main__.Foo' -reveal_type(y) # N: Revealed type is '__main__.Foo' + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' [builtins fixtures/primitives.pyi] @@ -1026,8 +1037,10 @@ if x is Foo.A < y is Foo.B: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' - reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + # Note: we can't narrow in this case. What if both x and y + # are Foo.A, for example? + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' reveal_type(x) # N: Revealed type is '__main__.Foo' reveal_type(y) # N: Revealed type is '__main__.Foo' @@ -1109,11 +1122,13 @@ if x0 is x1 is Foo.A is x2 < x3 is Foo.B is x4 is x5: reveal_type(x4) # N: Revealed type is 'Literal[__main__.Foo.B]' reveal_type(x5) # N: Revealed type is 'Literal[__main__.Foo.B]' else: - reveal_type(x0) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' - reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' - reveal_type(x2) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' - - reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' - reveal_type(x4) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' - reveal_type(x5) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' + # We unfortunately can't narrow away anything. For example, + # what if x0 == Foo.A and x1 == Foo.B or vice versa? + reveal_type(x0) # N: Revealed type is '__main__.Foo' + reveal_type(x1) # N: Revealed type is '__main__.Foo' + reveal_type(x2) # N: Revealed type is '__main__.Foo' + + reveal_type(x3) # N: Revealed type is '__main__.Foo' + reveal_type(x4) # N: Revealed type is '__main__.Foo' + reveal_type(x5) # N: Revealed type is '__main__.Foo' [builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index e0a0cb660c80..6c64b241eaaa 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1,3 +1,86 @@ +[case testNarrowingParentWithStrsBasic] +from dataclasses import dataclass +from typing import NamedTuple, Tuple, Union +from typing_extensions import Literal, TypedDict + +class Object1: + key: Literal["A"] + foo: int +class Object2: + key: Literal["B"] + bar: str + +@dataclass +class Dataclass1: + key: Literal["A"] + foo: int +@dataclass +class Dataclass2: + key: Literal["B"] + foo: str + +class NamedTuple1(NamedTuple): + key: Literal["A"] + foo: int +class NamedTuple2(NamedTuple): + key: Literal["B"] + foo: str + +Tuple1 = Tuple[Literal["A"], int] +Tuple2 = Tuple[Literal["B"], str] + +class TypedDict1(TypedDict): + key: Literal["A"] + foo: int +class TypedDict2(TypedDict): + key: Literal["B"] + foo: str + +x1: Union[Object1, Object2] +if x1.key == "A": + reveal_type(x1) # N: Revealed type is '__main__.Object1' + reveal_type(x1.key) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x1) # N: Revealed type is '__main__.Object2' + reveal_type(x1.key) # N: Revealed type is 'Literal['B']' + +x2: Union[Dataclass1, Dataclass2] +if x2.key == "A": + reveal_type(x2) # N: Revealed type is '__main__.Dataclass1' + reveal_type(x2.key) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x2) # N: Revealed type is '__main__.Dataclass2' + reveal_type(x2.key) # N: Revealed type is 'Literal['B']' + +x3: Union[NamedTuple1, NamedTuple2] +if x3.key == "A": + reveal_type(x3) # N: Revealed type is 'Tuple[Literal['A'], builtins.int, fallback=__main__.NamedTuple1]' + reveal_type(x3.key) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x3) # N: Revealed type is 'Tuple[Literal['B'], builtins.str, fallback=__main__.NamedTuple2]' + reveal_type(x3.key) # N: Revealed type is 'Literal['B']' +if x3[0] == "A": + reveal_type(x3) # N: Revealed type is 'Tuple[Literal['A'], builtins.int, fallback=__main__.NamedTuple1]' + reveal_type(x3[0]) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x3) # N: Revealed type is 'Tuple[Literal['B'], builtins.str, fallback=__main__.NamedTuple2]' + reveal_type(x3[0]) # N: Revealed type is 'Literal['B']' + +x4: Union[Tuple1, Tuple2] +if x4[0] == "A": + reveal_type(x4) # N: Revealed type is 'Tuple[Literal['A'], builtins.int]' + reveal_type(x4[0]) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x4) # N: Revealed type is 'Tuple[Literal['B'], builtins.str]' + reveal_type(x4[0]) # N: Revealed type is 'Literal['B']' + +x5: Union[TypedDict1, TypedDict2] +if x5["key"] == "A": + reveal_type(x5) # N: Revealed type is 'TypedDict('__main__.TypedDict1', {'key': Literal['A'], 'foo': builtins.int})' +else: + reveal_type(x5) # N: Revealed type is 'TypedDict('__main__.TypedDict2', {'key': Literal['B'], 'foo': builtins.str})' +[builtins fixtures/primitives.pyi] + [case testNarrowingParentWithEnumsBasic] from enum import Enum from dataclasses import dataclass @@ -184,6 +267,88 @@ if x.key is Key.D: else: reveal_type(x) # N: Revealed type is 'Union[__main__.Object1, __main__.Object2]' +[case testNarrowingTypedDictParentMultipleKeys] +# flags: --warn-unreachable +from typing import Union +from typing_extensions import Literal, TypedDict + +class TypedDict1(TypedDict): + key: Literal['A', 'C'] +class TypedDict2(TypedDict): + key: Literal['B', 'C'] + +x: Union[TypedDict1, TypedDict2] +if x['key'] == 'A': + reveal_type(x) # N: Revealed type is 'TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]})' +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]' + +if x['key'] == 'C': + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]' +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]' + +if x['key'] == 'D': + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingPartialTypedDictParentMultipleKeys] +# flags: --warn-unreachable +from typing import Union +from typing_extensions import Literal, TypedDict + +class TypedDict1(TypedDict, total=False): + key: Literal['A', 'C'] +class TypedDict2(TypedDict, total=False): + key: Literal['B', 'C'] + +x: Union[TypedDict1, TypedDict2] +if x['key'] == 'A': + reveal_type(x) # N: Revealed type is 'TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]})' +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]' + +if x['key'] == 'C': + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]' +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]' + +if x['key'] == 'D': + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is 'Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingNestedTypedDicts] +from typing import Union +from typing_extensions import TypedDict, Literal + +class A(TypedDict): + key: Literal['A'] +class B(TypedDict): + key: Literal['B'] +class C(TypedDict): + key: Literal['C'] + +class X(TypedDict): + inner: Union[A, B] +class Y(TypedDict): + inner: Union[B, C] + +unknown: Union[X, Y] +if unknown['inner']['key'] == 'A': + reveal_type(unknown) # N: Revealed type is 'TypedDict('__main__.X', {'inner': Union[TypedDict('__main__.A', {'key': Literal['A']}), TypedDict('__main__.B', {'key': Literal['B']})]})' + reveal_type(unknown['inner']) # N: Revealed type is 'TypedDict('__main__.A', {'key': Literal['A']})' +if unknown['inner']['key'] == 'B': + reveal_type(unknown) # N: Revealed type is 'Union[TypedDict('__main__.X', {'inner': Union[TypedDict('__main__.A', {'key': Literal['A']}), TypedDict('__main__.B', {'key': Literal['B']})]}), TypedDict('__main__.Y', {'inner': Union[TypedDict('__main__.B', {'key': Literal['B']}), TypedDict('__main__.C', {'key': Literal['C']})]})]' + reveal_type(unknown['inner']) # N: Revealed type is 'TypedDict('__main__.B', {'key': Literal['B']})' +if unknown['inner']['key'] == 'C': + reveal_type(unknown) # N: Revealed type is 'TypedDict('__main__.Y', {'inner': Union[TypedDict('__main__.B', {'key': Literal['B']}), TypedDict('__main__.C', {'key': Literal['C']})]})' + reveal_type(unknown['inner']) # N: Revealed type is 'TypedDict('__main__.C', {'key': Literal['C']})' +[builtins fixtures/primitives.pyi] + [case testNarrowingParentWithMultipleParents] from enum import Enum from typing import Union @@ -445,3 +610,353 @@ if y["model"]["key"] is Key.C: else: reveal_type(y) # N: Revealed type is 'Union[TypedDict('__main__.Parent1', {'model': TypedDict('__main__.Model1', {'key': Literal[__main__.Key.A]}), 'foo': builtins.int}), TypedDict('__main__.Parent2', {'model': TypedDict('__main__.Model2', {'key': Literal[__main__.Key.B]}), 'bar': builtins.str})]' reveal_type(y["model"]) # N: Revealed type is 'Union[TypedDict('__main__.Model1', {'key': Literal[__main__.Key.A]}), TypedDict('__main__.Model2', {'key': Literal[__main__.Key.B]})]' + +[case testNarrowingParentsHierarchyTypedDictWithStr] +# flags: --warn-unreachable +from typing import Union +from typing_extensions import TypedDict, Literal + +class Parent1(TypedDict): + model: Model1 + foo: int + +class Parent2(TypedDict): + model: Model2 + bar: str + +class Model1(TypedDict): + key: Literal['A'] + +class Model2(TypedDict): + key: Literal['B'] + +x: Union[Parent1, Parent2] +if x["model"]["key"] == 'A': + reveal_type(x) # N: Revealed type is 'TypedDict('__main__.Parent1', {'model': TypedDict('__main__.Model1', {'key': Literal['A']}), 'foo': builtins.int})' + reveal_type(x["model"]) # N: Revealed type is 'TypedDict('__main__.Model1', {'key': Literal['A']})' +else: + reveal_type(x) # N: Revealed type is 'TypedDict('__main__.Parent2', {'model': TypedDict('__main__.Model2', {'key': Literal['B']}), 'bar': builtins.str})' + reveal_type(x["model"]) # N: Revealed type is 'TypedDict('__main__.Model2', {'key': Literal['B']})' + +y: Union[Parent1, Parent2] +if y["model"]["key"] == 'C': + reveal_type(y) # E: Statement is unreachable + reveal_type(y["model"]) +else: + reveal_type(y) # N: Revealed type is 'Union[TypedDict('__main__.Parent1', {'model': TypedDict('__main__.Model1', {'key': Literal['A']}), 'foo': builtins.int}), TypedDict('__main__.Parent2', {'model': TypedDict('__main__.Model2', {'key': Literal['B']}), 'bar': builtins.str})]' + reveal_type(y["model"]) # N: Revealed type is 'Union[TypedDict('__main__.Model1', {'key': Literal['A']}), TypedDict('__main__.Model2', {'key': Literal['B']})]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityFlipFlop] +# flags: --warn-unreachable --strict-equality +from typing_extensions import Literal, Final +from enum import Enum + +class State(Enum): + A = 1 + B = 2 + +class FlipFlopEnum: + def __init__(self) -> None: + self.state = State.A + + def mutate(self) -> None: + self.state = State.B if self.state == State.A else State.A + +class FlipFlopStr: + def __init__(self) -> None: + self.state = "state-1" + + def mutate(self) -> None: + self.state = "state-2" if self.state == "state-1" else "state-1" + +def test1(switch: FlipFlopEnum) -> None: + # Naively, we might assume the 'assert' here would narrow the type to + # Literal[State.A]. However, doing this ends up breaking a fair number of real-world + # code (usually test cases) that looks similar to this function: e.g. checks + # to make sure a field was mutated to some particular value. + # + # And since mypy can't really reason about state mutation, we take a conservative + # approach and avoid narrowing anything here. + + assert switch.state == State.A + reveal_type(switch.state) # N: Revealed type is '__main__.State' + + switch.mutate() + + assert switch.state == State.B + reveal_type(switch.state) # N: Revealed type is '__main__.State' + +def test2(switch: FlipFlopEnum) -> None: + # So strictly speaking, we ought to do the same thing with 'is' comparisons + # for the same reasons as above. But in practice, not too many people seem to + # know that doing 'some_enum is MyEnum.Value' is idiomatic. So in practice, + # this is probably good enough for now. + + assert switch.state is State.A + reveal_type(switch.state) # N: Revealed type is 'Literal[__main__.State.A]' + + switch.mutate() + + assert switch.state is State.B # E: Non-overlapping identity check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]") + reveal_type(switch.state) # E: Statement is unreachable + +def test3(switch: FlipFlopStr) -> None: + # This is the same thing as 'test1', except we try using str literals. + + assert switch.state == "state-1" + reveal_type(switch.state) # N: Revealed type is 'builtins.str' + + switch.mutate() + + assert switch.state == "state-2" + reveal_type(switch.state) # N: Revealed type is 'builtins.str' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityRequiresExplicitStrLiteral] +# flags: --strict-optional +from typing_extensions import Literal, Final + +A_final: Final = "A" +A_literal: Literal["A"] + +# Neither the LHS nor the RHS are explicit literals, so regrettably nothing +# is narrowed here -- see 'testNarrowingEqualityFlipFlop' for an example of +# why more precise inference here is problematic. +x_str: str +if x_str == "A": + reveal_type(x_str) # N: Revealed type is 'builtins.str' +else: + reveal_type(x_str) # N: Revealed type is 'builtins.str' +reveal_type(x_str) # N: Revealed type is 'builtins.str' + +if x_str == A_final: + reveal_type(x_str) # N: Revealed type is 'builtins.str' +else: + reveal_type(x_str) # N: Revealed type is 'builtins.str' +reveal_type(x_str) # N: Revealed type is 'builtins.str' + +# But the RHS is a literal, so we can at least narrow the 'if' case now. +if x_str == A_literal: + reveal_type(x_str) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x_str) # N: Revealed type is 'builtins.str' +reveal_type(x_str) # N: Revealed type is 'builtins.str' + +# But in these two cases, the LHS is a literal/literal-like type. So we +# assume the user *does* want literal-based narrowing and narrow accordingly +# regardless of whether the RHS is an explicit literal or not. +x_union: Literal["A", "B", None] +if x_union == A_final: + reveal_type(x_union) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x_union) # N: Revealed type is 'Union[Literal['B'], None]' +reveal_type(x_union) # N: Revealed type is 'Union[Literal['A'], Literal['B'], None]' + +if x_union == A_literal: + reveal_type(x_union) # N: Revealed type is 'Literal['A']' +else: + reveal_type(x_union) # N: Revealed type is 'Union[Literal['B'], None]' +reveal_type(x_union) # N: Revealed type is 'Union[Literal['A'], Literal['B'], None]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityRequiresExplicitEnumLiteral] +# flags: --strict-optional +from typing_extensions import Literal, Final +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + +A_final: Final = Foo.A +A_literal: Literal[Foo.A] + +# See comments in testNarrowingEqualityRequiresExplicitStrLiteral and +# testNarrowingEqualityFlipFlop for more on why we can't narrow here. +x1: Foo +if x1 == Foo.A: + reveal_type(x1) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x1) # N: Revealed type is '__main__.Foo' + +x2: Foo +if x2 == A_final: + reveal_type(x2) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x2) # N: Revealed type is '__main__.Foo' + +# But we let this narrow since there's an explicit literal in the RHS. +x3: Foo +if x3 == A_literal: + reveal_type(x3) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x3) # N: Revealed type is 'Literal[__main__.Foo.B]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityDisabledForCustomEquality] +from typing import Union +from typing_extensions import Literal +from enum import Enum + +class Custom: + def __eq__(self, other: object) -> bool: return True + +class Default: pass + +x1: Union[Custom, Literal[1], Literal[2]] +if x1 == 1: + reveal_type(x1) # N: Revealed type is 'Union[__main__.Custom, Literal[1], Literal[2]]' +else: + reveal_type(x1) # N: Revealed type is 'Union[__main__.Custom, Literal[1], Literal[2]]' + +x2: Union[Default, Literal[1], Literal[2]] +if x2 == 1: + reveal_type(x2) # N: Revealed type is 'Literal[1]' +else: + reveal_type(x2) # N: Revealed type is 'Union[__main__.Default, Literal[2]]' + +class CustomEnum(Enum): + A = 1 + B = 2 + + def __eq__(self, other: object) -> bool: return True + +x3: CustomEnum +key: Literal[CustomEnum.A] +if x3 == key: + reveal_type(x3) # N: Revealed type is '__main__.CustomEnum' +else: + reveal_type(x3) # N: Revealed type is '__main__.CustomEnum' + +# For comparison, this narrows since we bypass __eq__ +if x3 is key: + reveal_type(x3) # N: Revealed type is 'Literal[__main__.CustomEnum.A]' +else: + reveal_type(x3) # N: Revealed type is 'Literal[__main__.CustomEnum.B]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingEqualityDisabledForCustomEqualityChain] +# flags: --strict-optional --strict-equality --warn-unreachable +from typing import Union +from typing_extensions import Literal + +class Custom: + def __eq__(self, other: object) -> bool: return True + +class Default: pass + +x: Literal[1, 2, None] +y: Custom +z: Default + +# We could maybe try doing something clever, but for simplicity we +# treat the whole chain as contaminated and mostly disable narrowing. +# +# The only exception is that we do at least strip away the 'None'. We +# (perhaps optimistically) assume no custom class would be pathological +# enough to declare itself to be equal to None and so permit this narrowing, +# since it's often convenient in practice. +if 1 == x == y: + reveal_type(x) # N: Revealed type is 'Union[Literal[1], Literal[2]]' + reveal_type(y) # N: Revealed type is '__main__.Custom' +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[1], Literal[2], None]' + reveal_type(y) # N: Revealed type is '__main__.Custom' + +# No contamination here +if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Union[Literal[1], Literal[2], None]", right operand type: "Default") + reveal_type(x) # E: Statement is unreachable + reveal_type(z) +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[1], Literal[2], None]' + reveal_type(z) # N: Revealed type is '__main__.Default' +[builtins fixtures/primitives.pyi] + +[case testNarrowingUnreachableCases] +# flags: --strict-optional --strict-equality --warn-unreachable +from typing import Union +from typing_extensions import Literal + +a: Literal[1] +b: Literal[1, 2] +c: Literal[2, 3] + +if a == b == c: + reveal_type(a) # E: Statement is unreachable + reveal_type(b) + reveal_type(c) +else: + reveal_type(a) # N: Revealed type is 'Literal[1]' + reveal_type(b) # N: Revealed type is 'Union[Literal[1], Literal[2]]' + reveal_type(c) # N: Revealed type is 'Union[Literal[2], Literal[3]]' + +if a == a == a: + reveal_type(a) # N: Revealed type is 'Literal[1]' +else: + reveal_type(a) # E: Statement is unreachable + +if a == a == b: + reveal_type(a) # N: Revealed type is 'Literal[1]' + reveal_type(b) # N: Revealed type is 'Literal[1]' +else: + reveal_type(a) # N: Revealed type is 'Literal[1]' + reveal_type(b) # N: Revealed type is 'Literal[2]' + +# In this case, it's ok for 'b' to narrow down to Literal[1] in the else case +# since that's the only way 'b == 2' can be false +if b == 2: + reveal_type(b) # N: Revealed type is 'Literal[2]' +else: + reveal_type(b) # N: Revealed type is 'Literal[1]' + +# But in this case, we can't conclude anything about the else case. This expression +# could end up being either '2 == 2 == 3' or '1 == 2 == 2', which means we can't +# conclude anything. +if b == 2 == c: + reveal_type(b) # N: Revealed type is 'Literal[2]' + reveal_type(c) # N: Revealed type is 'Literal[2]' +else: + reveal_type(b) # N: Revealed type is 'Union[Literal[1], Literal[2]]' + reveal_type(c) # N: Revealed type is 'Union[Literal[2], Literal[3]]' +[builtins fixtures/primitives.pyi] + +[case testNarrowingUnreachableCases2] +# flags: --strict-optional --strict-equality --warn-unreachable +from typing import Union +from typing_extensions import Literal + +a: Literal[1, 2, 3, 4] +b: Literal[1, 2, 3, 4] + +if a == b == 1: + reveal_type(a) # N: Revealed type is 'Literal[1]' + reveal_type(b) # N: Revealed type is 'Literal[1]' +elif a == b == 2: + reveal_type(a) # N: Revealed type is 'Literal[2]' + reveal_type(b) # N: Revealed type is 'Literal[2]' +elif a == b == 3: + reveal_type(a) # N: Revealed type is 'Literal[3]' + reveal_type(b) # N: Revealed type is 'Literal[3]' +elif a == b == 4: + reveal_type(a) # N: Revealed type is 'Literal[4]' + reveal_type(b) # N: Revealed type is 'Literal[4]' +else: + # This branch is reachable if a == 1 and b == 2, for example. + reveal_type(a) # N: Revealed type is 'Union[Literal[1], Literal[2], Literal[3], Literal[4]]' + reveal_type(b) # N: Revealed type is 'Union[Literal[1], Literal[2], Literal[3], Literal[4]]' + +if a == a == 1: + reveal_type(a) # N: Revealed type is 'Literal[1]' +elif a == a == 2: + reveal_type(a) # N: Revealed type is 'Literal[2]' +elif a == a == 3: + reveal_type(a) # N: Revealed type is 'Literal[3]' +elif a == a == 4: + reveal_type(a) # N: Revealed type is 'Literal[4]' +else: + # In contrast, this branch must be unreachable: we assume (maybe naively) + # that 'a' won't be mutated in the middle of the expression. + reveal_type(a) # E: Statement is unreachable + reveal_type(b) +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 9c40d550699e..ab11c42173d4 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -489,6 +489,10 @@ if x == '': reveal_type(x) # N: Revealed type is 'builtins.str' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' +if x is '': + reveal_type(x) # N: Revealed type is 'builtins.str' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithUnion] @@ -498,6 +502,10 @@ if x == '': reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is '': + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithOverlap] @@ -507,6 +515,10 @@ if x == object(): reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is object(): + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsStillOptionalWithNoOverlap] @@ -516,6 +528,10 @@ if x == 0: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' +if x is 0: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' [builtins fixtures/ops.pyi] [case testInferEqualsStillOptionalWithBothOptional] @@ -526,6 +542,10 @@ if x == y: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is y: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithMultipleArgs]