From 9365fbfb8635bff93d749ed5321561e789c68cd3 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Fri, 25 Oct 2024 16:33:29 +0100 Subject: [PATCH] Refactor type narrowing further (#18043) Move a big chunk of code to a helper function. --- mypy/checker.py | 232 ++++++++++++++++++++++++------------------------ 1 file changed, 118 insertions(+), 114 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 070d695fb9c2..dbc997ba33c6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5983,121 +5983,10 @@ def find_isinstance_check_helper( ), ) elif isinstance(node, ComparisonExpr): - # Step 1: Obtain the types of each operand and whether or not we can - # narrow their types. (For example, we shouldn't try narrowing the - # types of literal string or enum expressions). - - operands = [collapse_walrus(x) for x in node.operands] - operand_types = [] - narrowable_operand_index_to_hash = {} - for i, expr in enumerate(operands): - if not self.has_type(expr): - return {}, {} - expr_type = self.lookup_type(expr) - operand_types.append(expr_type) - - if ( - literal(expr) == LITERAL_TYPE - and not is_literal_none(expr) - and not self.is_literal_enum(expr) - ): - h = literal_hash(expr) - if h is not None: - narrowable_operand_index_to_hash[i] = h - - # Step 2: Group operands chained by either the 'is' or '==' operands - # together. For all other operands, we keep them in groups of size 2. - # So the expression: - # - # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8 - # - # ...is converted into the simplified operator list: - # - # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]), - # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])] - # - # We group identity/equality expressions so we can propagate information - # we discover about one operand across the entire chain. We don't bother - # handling 'is not' and '!=' chains in a special way: those are very rare - # in practice. - - simplified_operator_list = group_comparison_operands( - node.pairwise(), narrowable_operand_index_to_hash, {"==", "is"} - ) - - # Step 3: Analyze each group and infer more precise type maps for each - # assignable operand, if possible. We combine these type maps together - # in the final step. - - partial_type_maps = [] - for operator, expr_indices in simplified_operator_list: - if operator in {"is", "is not", "==", "!="}: - if_map, else_map = self.equality_type_narrowing_helper( - node, - operator, - operands, - operand_types, - expr_indices, - narrowable_operand_index_to_hash, - ) - elif operator in {"in", "not in"}: - assert len(expr_indices) == 2 - left_index, right_index = expr_indices - item_type = operand_types[left_index] - iterable_type = operand_types[right_index] - - if_map, else_map = {}, {} - - if left_index in narrowable_operand_index_to_hash: - # We only try and narrow away 'None' for now - if is_overlapping_none(item_type): - collection_item_type = get_proper_type( - builtin_item_type(iterable_type) - ) - if ( - collection_item_type is not None - and not is_overlapping_none(collection_item_type) - and not ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ) - and is_overlapping_erased_types(item_type, collection_item_type) - ): - if_map[operands[left_index]] = remove_optional(item_type) - - if right_index in narrowable_operand_index_to_hash: - if_type, else_type = self.conditional_types_for_iterable( - item_type, iterable_type - ) - expr = operands[right_index] - if if_type is None: - if_map = None - else: - if_map[expr] = if_type - if else_type is None: - else_map = None - else: - else_map[expr] = else_type - - else: - if_map = {} - else_map = {} - - if operator in {"is not", "!=", "not in"}: - if_map, else_map = else_map, if_map - - partial_type_maps.append((if_map, else_map)) - - # If we have found non-trivial restrictions from the regular comparisons, - # then return soon. Otherwise try to infer restrictions involving `len(x)`. - # TODO: support regular and len() narrowing in the same chain. - if any(m != ({}, {}) for m in partial_type_maps): - return reduce_conditional_maps(partial_type_maps) - else: - # Use meet for `and` maps to get correct results for chained checks - # like `if 1 < len(x) < 4: ...` - return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True) + return self.comparison_type_narrowing_helper(node) elif isinstance(node, AssignmentExpr): + if_map: dict[Expression, Type] | None + else_map: dict[Expression, Type] | None if_map = {} else_map = {} @@ -6184,6 +6073,121 @@ def find_isinstance_check_helper( else_map = {node: else_type} if not isinstance(else_type, UninhabitedType) else None return if_map, else_map + def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMap, TypeMap]: + """Infer type narrowing from a comparison expression.""" + # Step 1: Obtain the types of each operand and whether or not we can + # narrow their types. (For example, we shouldn't try narrowing the + # types of literal string or enum expressions). + + operands = [collapse_walrus(x) for x in node.operands] + operand_types = [] + narrowable_operand_index_to_hash = {} + for i, expr in enumerate(operands): + if not self.has_type(expr): + return {}, {} + expr_type = self.lookup_type(expr) + operand_types.append(expr_type) + + if ( + literal(expr) == LITERAL_TYPE + and not is_literal_none(expr) + and not self.is_literal_enum(expr) + ): + h = literal_hash(expr) + if h is not None: + narrowable_operand_index_to_hash[i] = h + + # Step 2: Group operands chained by either the 'is' or '==' operands + # together. For all other operands, we keep them in groups of size 2. + # So the expression: + # + # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8 + # + # ...is converted into the simplified operator list: + # + # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]), + # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])] + # + # We group identity/equality expressions so we can propagate information + # we discover about one operand across the entire chain. We don't bother + # handling 'is not' and '!=' chains in a special way: those are very rare + # in practice. + + simplified_operator_list = group_comparison_operands( + node.pairwise(), narrowable_operand_index_to_hash, {"==", "is"} + ) + + # Step 3: Analyze each group and infer more precise type maps for each + # assignable operand, if possible. We combine these type maps together + # in the final step. + + partial_type_maps = [] + for operator, expr_indices in simplified_operator_list: + if operator in {"is", "is not", "==", "!="}: + if_map, else_map = self.equality_type_narrowing_helper( + node, + operator, + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash, + ) + elif operator in {"in", "not in"}: + assert len(expr_indices) == 2 + left_index, right_index = expr_indices + item_type = operand_types[left_index] + iterable_type = operand_types[right_index] + + if_map, else_map = {}, {} + + if left_index in narrowable_operand_index_to_hash: + # We only try and narrow away 'None' for now + if is_overlapping_none(item_type): + collection_item_type = get_proper_type(builtin_item_type(iterable_type)) + if ( + collection_item_type is not None + and not is_overlapping_none(collection_item_type) + and not ( + isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == "builtins.object" + ) + and is_overlapping_erased_types(item_type, collection_item_type) + ): + if_map[operands[left_index]] = remove_optional(item_type) + + if right_index in narrowable_operand_index_to_hash: + if_type, else_type = self.conditional_types_for_iterable( + item_type, iterable_type + ) + expr = operands[right_index] + if if_type is None: + if_map = None + else: + if_map[expr] = if_type + if else_type is None: + else_map = None + else: + else_map[expr] = else_type + + else: + if_map = {} + else_map = {} + + if operator in {"is not", "!=", "not in"}: + if_map, else_map = else_map, if_map + + partial_type_maps.append((if_map, else_map)) + + # If we have found non-trivial restrictions from the regular comparisons, + # then return soon. Otherwise try to infer restrictions involving `len(x)`. + # TODO: support regular and len() narrowing in the same chain. + if any(m != ({}, {}) for m in partial_type_maps): + return reduce_conditional_maps(partial_type_maps) + else: + # Use meet for `and` maps to get correct results for chained checks + # like `if 1 < len(x) < 4: ...` + return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True) + def equality_type_narrowing_helper( self, node: ComparisonExpr,