Skip to content

Commit

Permalink
Refactor type narrowing further (#18043)
Browse files Browse the repository at this point in the history
Move a big chunk of code to a helper function.
  • Loading branch information
JukkaL authored Oct 25, 2024
1 parent 4b8e7df commit 9365fbf
Showing 1 changed file with 118 additions and 114 deletions.
232 changes: 118 additions & 114 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5983,121 +5983,10 @@ def find_isinstance_check_helper(
),
)
elif isinstance(node, ComparisonExpr):
# Step 1: Obtain the types of each operand and whether or not we can
# narrow their types. (For example, we shouldn't try narrowing the
# types of literal string or enum expressions).

operands = [collapse_walrus(x) for x in node.operands]
operand_types = []
narrowable_operand_index_to_hash = {}
for i, expr in enumerate(operands):
if not self.has_type(expr):
return {}, {}
expr_type = self.lookup_type(expr)
operand_types.append(expr_type)

if (
literal(expr) == LITERAL_TYPE
and not is_literal_none(expr)
and not self.is_literal_enum(expr)
):
h = literal_hash(expr)
if h is not None:
narrowable_operand_index_to_hash[i] = h

# Step 2: Group operands chained by either the 'is' or '==' operands
# together. For all other operands, we keep them in groups of size 2.
# So the expression:
#
# x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
#
# ...is converted into the simplified operator list:
#
# [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
# ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
#
# We group identity/equality expressions so we can propagate information
# we discover about one operand across the entire chain. We don't bother
# handling 'is not' and '!=' chains in a special way: those are very rare
# in practice.

simplified_operator_list = group_comparison_operands(
node.pairwise(), narrowable_operand_index_to_hash, {"==", "is"}
)

# Step 3: Analyze each group and infer more precise type maps for each
# assignable operand, if possible. We combine these type maps together
# in the final step.

partial_type_maps = []
for operator, expr_indices in simplified_operator_list:
if operator in {"is", "is not", "==", "!="}:
if_map, else_map = self.equality_type_narrowing_helper(
node,
operator,
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash,
)
elif operator in {"in", "not in"}:
assert len(expr_indices) == 2
left_index, right_index = expr_indices
item_type = operand_types[left_index]
iterable_type = operand_types[right_index]

if_map, else_map = {}, {}

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_overlapping_none(item_type):
collection_item_type = get_proper_type(
builtin_item_type(iterable_type)
)
if (
collection_item_type is not None
and not is_overlapping_none(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
)
and is_overlapping_erased_types(item_type, collection_item_type)
):
if_map[operands[left_index]] = remove_optional(item_type)

if right_index in narrowable_operand_index_to_hash:
if_type, else_type = self.conditional_types_for_iterable(
item_type, iterable_type
)
expr = operands[right_index]
if if_type is None:
if_map = None
else:
if_map[expr] = if_type
if else_type is None:
else_map = None
else:
else_map[expr] = else_type

else:
if_map = {}
else_map = {}

if operator in {"is not", "!=", "not in"}:
if_map, else_map = else_map, if_map

partial_type_maps.append((if_map, else_map))

# If we have found non-trivial restrictions from the regular comparisons,
# then return soon. Otherwise try to infer restrictions involving `len(x)`.
# TODO: support regular and len() narrowing in the same chain.
if any(m != ({}, {}) for m in partial_type_maps):
return reduce_conditional_maps(partial_type_maps)
else:
# Use meet for `and` maps to get correct results for chained checks
# like `if 1 < len(x) < 4: ...`
return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True)
return self.comparison_type_narrowing_helper(node)
elif isinstance(node, AssignmentExpr):
if_map: dict[Expression, Type] | None
else_map: dict[Expression, Type] | None
if_map = {}
else_map = {}

Expand Down Expand Up @@ -6184,6 +6073,121 @@ def find_isinstance_check_helper(
else_map = {node: else_type} if not isinstance(else_type, UninhabitedType) else None
return if_map, else_map

def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMap, TypeMap]:
"""Infer type narrowing from a comparison expression."""
# Step 1: Obtain the types of each operand and whether or not we can
# narrow their types. (For example, we shouldn't try narrowing the
# types of literal string or enum expressions).

operands = [collapse_walrus(x) for x in node.operands]
operand_types = []
narrowable_operand_index_to_hash = {}
for i, expr in enumerate(operands):
if not self.has_type(expr):
return {}, {}
expr_type = self.lookup_type(expr)
operand_types.append(expr_type)

if (
literal(expr) == LITERAL_TYPE
and not is_literal_none(expr)
and not self.is_literal_enum(expr)
):
h = literal_hash(expr)
if h is not None:
narrowable_operand_index_to_hash[i] = h

# Step 2: Group operands chained by either the 'is' or '==' operands
# together. For all other operands, we keep them in groups of size 2.
# So the expression:
#
# x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
#
# ...is converted into the simplified operator list:
#
# [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
# ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
#
# We group identity/equality expressions so we can propagate information
# we discover about one operand across the entire chain. We don't bother
# handling 'is not' and '!=' chains in a special way: those are very rare
# in practice.

simplified_operator_list = group_comparison_operands(
node.pairwise(), narrowable_operand_index_to_hash, {"==", "is"}
)

# Step 3: Analyze each group and infer more precise type maps for each
# assignable operand, if possible. We combine these type maps together
# in the final step.

partial_type_maps = []
for operator, expr_indices in simplified_operator_list:
if operator in {"is", "is not", "==", "!="}:
if_map, else_map = self.equality_type_narrowing_helper(
node,
operator,
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash,
)
elif operator in {"in", "not in"}:
assert len(expr_indices) == 2
left_index, right_index = expr_indices
item_type = operand_types[left_index]
iterable_type = operand_types[right_index]

if_map, else_map = {}, {}

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_overlapping_none(item_type):
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
if (
collection_item_type is not None
and not is_overlapping_none(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
)
and is_overlapping_erased_types(item_type, collection_item_type)
):
if_map[operands[left_index]] = remove_optional(item_type)

if right_index in narrowable_operand_index_to_hash:
if_type, else_type = self.conditional_types_for_iterable(
item_type, iterable_type
)
expr = operands[right_index]
if if_type is None:
if_map = None
else:
if_map[expr] = if_type
if else_type is None:
else_map = None
else:
else_map[expr] = else_type

else:
if_map = {}
else_map = {}

if operator in {"is not", "!=", "not in"}:
if_map, else_map = else_map, if_map

partial_type_maps.append((if_map, else_map))

# If we have found non-trivial restrictions from the regular comparisons,
# then return soon. Otherwise try to infer restrictions involving `len(x)`.
# TODO: support regular and len() narrowing in the same chain.
if any(m != ({}, {}) for m in partial_type_maps):
return reduce_conditional_maps(partial_type_maps)
else:
# Use meet for `and` maps to get correct results for chained checks
# like `if 1 < len(x) < 4: ...`
return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True)

def equality_type_narrowing_helper(
self,
node: ComparisonExpr,
Expand Down

0 comments on commit 9365fbf

Please sign in to comment.