Skip to content

Commit

Permalink
Make reachability code understand chained comparisons (v2) (python#8148)
Browse files Browse the repository at this point in the history
This pull request is v2 (well, more like v10...) of my attempts to
make our reachability code better understand chained comparisons.

Unlike python#7169, this diff focuses
exclusively on adding support for chained operation comparisons and
deliberately does not attempt to change any of the semantics of
how identity and equality operations are performed.

Specifically, mypy currently only examines the first two operands
within a comparison expression when refining types. That means
the following expressions all do not behave as expected:

```python
x: MyEnum
y: MyEnum
if x is y is MyEnum.A:
    # x and y are not narrowed at all

if x is MyEnum.A is y:
    # Only x is narrowed to Literal[MyEnum.A]
```

This pull request fixes this so we correctly infer the literal type
for x and y in both conditionals.

Some additional notes:

1. While analyzing our codebase, I found that while comparison
   expressions involving two or more `is` or `==` operators
   were somewhat common, there were almost no comparisons involving
   chains of `!=` or `is not` operators, and no comparisons
   involving "disjoint chains" -- e.g. expressions like
   `a == b < c == b` where there are multiple "disjoint"
   chains of equality comparisons.

   So, this diff is primarily designed to handle the case where
   a comparison expression has just one chain of `is` or `==`.
   For all other cases, I fall back to the more naive strategy
   of evaluating each comparison individually and and-ing the
   inferred types together without attempting to propagate
   any info.

2. I tested this code against one of our internal codebases. This
   ended up making mypy produce 3 or 4 new errors, but they all
   seemed legitimate, as far as I can tell.

3. I plan on submitting a follow-up diff that takes advantage of
   the work done in this diff to complete support for tagged unions
   using any Literal key, as previously promised.

   (I tried adding support for tagged unions in this diff, but
   attempting to simultaneously add support for chained comparisons
   while overhauling the semantics of `==` proved to be a little
   too overwhelming for me. So, baby steps.)
Michael0x2a authored Dec 25, 2019

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 5336f27 commit 9101707
Showing 5 changed files with 956 additions and 65 deletions.
603 changes: 541 additions & 62 deletions mypy/checker.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
@@ -1750,6 +1750,13 @@ def __init__(self, operators: List[str], operands: List[Expression]) -> None:
self.operands = operands
self.method_types = []

def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]:
"""If this comparison expr is "a < b is c == d", yields the sequence
("<", a, b), ("is", b, c), ("==", c, d)
"""
for i, operator in enumerate(self.operators):
yield operator, self.operands[i], self.operands[i + 1]

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_comparison_expr(self)

239 changes: 236 additions & 3 deletions mypy/test/testinfer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""Test cases for type inference helper functions."""

from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Dict, Set

from mypy.test.helpers import Suite, assert_equal
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED
from mypy.checker import group_comparison_operands, DisjointDict
from mypy.literals import Key
from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, NameExpr
from mypy.types import AnyType, TupleType, Type, TypeOfAny
from mypy.test.typefixture import TypeFixture


class MapActualsToFormalsSuite(Suite):
"""Test cases for checkexpr.map_actuals_to_formals."""
"""Test cases for argmap.map_actuals_to_formals."""

def test_basic(self) -> None:
self.assert_map([], [], [])
@@ -223,3 +225,234 @@ def expand_callee_kinds(kinds_and_names: List[Union[int, Tuple[int, str]]]
kinds.append(v)
names.append(None)
return kinds, names


class OperandDisjointDictSuite(Suite):
"""Test cases for checker.DisjointDict, which is used for type inference with operands."""
def new(self) -> DisjointDict[int, str]:
return DisjointDict()

def test_independent_maps(self) -> None:
d = self.new()
d.add_mapping({0, 1}, {"group1"})
d.add_mapping({2, 3, 4}, {"group2"})
d.add_mapping({5, 6, 7}, {"group3"})

self.assertEqual(d.items(), [
({0, 1}, {"group1"}),
({2, 3, 4}, {"group2"}),
({5, 6, 7}, {"group3"}),
])

def test_partial_merging(self) -> None:
d = self.new()
d.add_mapping({0, 1}, {"group1"})
d.add_mapping({1, 2}, {"group2"})
d.add_mapping({3, 4}, {"group3"})
d.add_mapping({5, 0}, {"group4"})
d.add_mapping({5, 6}, {"group5"})
d.add_mapping({4, 7}, {"group6"})

self.assertEqual(d.items(), [
({0, 1, 2, 5, 6}, {"group1", "group2", "group4", "group5"}),
({3, 4, 7}, {"group3", "group6"}),
])

def test_full_merging(self) -> None:
d = self.new()
d.add_mapping({0, 1, 2}, {"a"})
d.add_mapping({3, 4, 2}, {"b"})
d.add_mapping({10, 11, 12}, {"c"})
d.add_mapping({13, 14, 15}, {"d"})
d.add_mapping({14, 10, 16}, {"e"})
d.add_mapping({0, 10}, {"f"})

self.assertEqual(d.items(), [
({0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16}, {"a", "b", "c", "d", "e", "f"}),
])

def test_merge_with_multiple_overlaps(self) -> None:
d = self.new()
d.add_mapping({0, 1, 2}, {"a"})
d.add_mapping({3, 4, 5}, {"b"})
d.add_mapping({1, 2, 4, 5}, {"c"})
d.add_mapping({6, 1, 2, 4, 5}, {"d"})
d.add_mapping({6, 1, 2, 4, 5}, {"e"})

self.assertEqual(d.items(), [
({0, 1, 2, 3, 4, 5, 6}, {"a", "b", "c", "d", "e"}),
])


class OperandComparisonGroupingSuite(Suite):
"""Test cases for checker.group_comparison_operands."""
def literal_keymap(self, assignable_operands: Dict[int, NameExpr]) -> Dict[int, Key]:
output = {} # type: Dict[int, Key]
for index, expr in assignable_operands.items():
output[index] = ('FakeExpr', expr.name)
return output

def test_basic_cases(self) -> None:
# Note: the grouping function doesn't actually inspect the input exprs, so we
# just default to using NameExprs for simplicity.
x0 = NameExpr('x0')
x1 = NameExpr('x1')
x2 = NameExpr('x2')
x3 = NameExpr('x3')
x4 = NameExpr('x4')

basic_input = [('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4)]

none_assignable = self.literal_keymap({})
all_assignable = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4})

for assignable in [none_assignable, all_assignable]:
self.assertEqual(
group_comparison_operands(basic_input, assignable, set()),
[('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])],
)
self.assertEqual(
group_comparison_operands(basic_input, assignable, {'=='}),
[('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])],
)
self.assertEqual(
group_comparison_operands(basic_input, assignable, {'<'}),
[('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])],
)
self.assertEqual(
group_comparison_operands(basic_input, assignable, {'==', '<'}),
[('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])],
)

def test_multiple_groups(self) -> None:
x0 = NameExpr('x0')
x1 = NameExpr('x1')
x2 = NameExpr('x2')
x3 = NameExpr('x3')
x4 = NameExpr('x4')
x5 = NameExpr('x5')

self.assertEqual(
group_comparison_operands(
[('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x4)],
self.literal_keymap({}),
{'==', 'is'},
),
[('==', [0, 1, 2]), ('is', [2, 3, 4])],
)
self.assertEqual(
group_comparison_operands(
[('==', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)],
self.literal_keymap({}),
{'==', 'is'},
),
[('==', [0, 1, 2, 3, 4])],
)
self.assertEqual(
group_comparison_operands(
[('is', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)],
self.literal_keymap({}),
{'==', 'is'},
),
[('is', [0, 1]), ('==', [1, 2, 3, 4])],
)
self.assertEqual(
group_comparison_operands(
[('is', x0, x1), ('is', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x5)],
self.literal_keymap({}),
{'==', 'is'},
),
[('is', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])],
)

def test_multiple_groups_coalescing(self) -> None:
x0 = NameExpr('x0')
x1 = NameExpr('x1')
x2 = NameExpr('x2')
x3 = NameExpr('x3')
x4 = NameExpr('x4')

nothing_combined = [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])]
everything_combined = [('==', [0, 1, 2, 3, 4, 5]), ('<', [2, 3])]

# Note: We do 'x4 == x0' at the very end!
two_groups = [
('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x0),
]
self.assertEqual(
group_comparison_operands(
two_groups,
self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4, 5: x0}),
{'=='},
),
everything_combined,
"All vars are assignable, everything is combined"
)
self.assertEqual(
group_comparison_operands(
two_groups,
self.literal_keymap({1: x1, 2: x2, 3: x3, 4: x4}),
{'=='},
),
nothing_combined,
"x0 is unassignable, so no combining"
)
self.assertEqual(
group_comparison_operands(
two_groups,
self.literal_keymap({0: x0, 1: x1, 3: x3, 5: x0}),
{'=='},
),
everything_combined,
"Some vars are unassignable but x0 is, so we combine"
)
self.assertEqual(
group_comparison_operands(
two_groups,
self.literal_keymap({0: x0, 5: x0}),
{'=='},
),
everything_combined,
"All vars are unassignable but x0 is, so we combine"
)

def test_multiple_groups_different_operators(self) -> None:
x0 = NameExpr('x0')
x1 = NameExpr('x1')
x2 = NameExpr('x2')
x3 = NameExpr('x3')

groups = [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x0)]
keymap = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x0})
self.assertEqual(
group_comparison_operands(groups, keymap, {'==', 'is'}),
[('==', [0, 1, 2]), ('is', [2, 3, 4])],
"Different operators can never be combined"
)

def test_single_pair(self) -> None:
x0 = NameExpr('x0')
x1 = NameExpr('x1')

single_comparison = [('==', x0, x1)]
expected_output = [('==', [0, 1])]

assignable_combinations = [
{}, {0: x0}, {1: x1}, {0: x0, 1: x1},
] # type: List[Dict[int, NameExpr]]
to_group_by = [set(), {'=='}, {'is'}] # type: List[Set[str]]

for combo in assignable_combinations:
for operators in to_group_by:
keymap = self.literal_keymap(combo)
self.assertEqual(
group_comparison_operands(single_comparison, keymap, operators),
expected_output,
)

def test_empty_pair_list(self) -> None:
# This case should never occur in practice -- ComparisionExprs
# always contain at least one comparision. But in case it does...

self.assertEqual(group_comparison_operands([], {}, set()), [])
self.assertEqual(group_comparison_operands([], {}, {'=='}), [])
150 changes: 150 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
@@ -967,3 +967,153 @@ class A:
self.b = Enum("x", [("foo", "bar")]) # E: Enum type as attribute is not supported

reveal_type(A().b) # N: Revealed type is 'Any'

[case testEnumReachabilityWithChaining]
from enum import Enum

class Foo(Enum):
A = 1
B = 2

x: Foo
y: Foo

if x is y is Foo.A:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

if x is Foo.A is y:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

if Foo.A is x is y:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

[builtins fixtures/primitives.pyi]

[case testEnumReachabilityWithChainingDisjoint]
# flags: --warn-unreachable
from enum import Enum

class Foo(Enum):
A = 1
B = 2

# Used to divide up a chained comparison into multiple identity groups
def __lt__(self, other: object) -> bool: return True

x: Foo
y: Foo

# No conflict
if x is Foo.A < y is Foo.B:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

# The standard output when we end up inferring two disjoint facts about the same expr
if x is Foo.A and x is Foo.B:
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'

# ..and we get the same result if we have two disjoint groups within the same comp expr
if x is Foo.A < x is Foo.B:
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'
[builtins fixtures/primitives.pyi]

[case testEnumReachabilityWithChainingDirectConflict]
# flags: --warn-unreachable
from enum import Enum
from typing_extensions import Literal, Final

class Foo(Enum):
A = 1
B = 2
C = 3

x: Foo
if x is Foo.A is Foo.B:
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'

literal_a: Literal[Foo.A]
literal_b: Literal[Foo.B]
if x is literal_a is literal_b:
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'

final_a: Final = Foo.A
final_b: Final = Foo.B
if x is final_a is final_b:
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'

[builtins fixtures/primitives.pyi]

[case testEnumReachabilityWithChainingBigDisjoints]
# flags: --warn-unreachable
from enum import Enum
from typing_extensions import Literal, Final

class Foo(Enum):
A = 1
B = 2
C = 3

def __lt__(self, other: object) -> bool: return True

x0: Foo
x1: Foo
x2: Foo
x3: Foo
x4: Foo
x5: Foo

if x0 is x1 is Foo.A is x2 < x3 is Foo.B is x4 is x5:
reveal_type(x0) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Foo.A]'

reveal_type(x3) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x4) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x5) # N: Revealed type is 'Literal[__main__.Foo.B]'
else:
reveal_type(x0) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(x2) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'

reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
reveal_type(x4) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
reveal_type(x5) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
[builtins fixtures/primitives.pyi]
22 changes: 22 additions & 0 deletions test-data/unit/check-optional.test
Original file line number Diff line number Diff line change
@@ -528,6 +528,28 @@ else:
reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]'
[builtins fixtures/ops.pyi]

[case testInferEqualsNotOptionalWithMultipleArgs]
from typing import Optional
x: Optional[int]
y: Optional[int]
if x == y == 1:
reveal_type(x) # N: Revealed type is 'builtins.int'
reveal_type(y) # N: Revealed type is 'builtins.int'
else:
reveal_type(x) # N: Revealed type is 'Union[builtins.int, None]'
reveal_type(y) # N: Revealed type is 'Union[builtins.int, None]'

class A: pass
a: Optional[A]
b: Optional[A]
if a == b == object():
reveal_type(a) # N: Revealed type is '__main__.A'
reveal_type(b) # N: Revealed type is '__main__.A'
else:
reveal_type(a) # N: Revealed type is 'Union[__main__.A, None]'
reveal_type(b) # N: Revealed type is 'Union[__main__.A, None]'
[builtins fixtures/ops.pyi]

[case testWarnNoReturnWorksWithStrictOptional]
# flags: --warn-no-return
def f() -> None:

0 comments on commit 9101707

Please sign in to comment.