Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'in' can narrow TypedDict unions #13838

Merged
merged 23 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ca87360
'in' can narrow TypedDict unions
ikonst Oct 8, 2022
2a73325
take care of TypeVar
ikonst Oct 15, 2022
9671a69
take care of totality
ikonst Oct 15, 2022
0b3701b
add missing assert_type fixture
ikonst Oct 15, 2022
1f4224a
rename tests slightly
ikonst Oct 15, 2022
e67c6f8
fix unintentional changes to typing-typeddict
ikonst Oct 15, 2022
88e2c9f
add testOperatorContainsNarrowsTypedDicts_unionWithList
ikonst Oct 15, 2022
f799344
respect final-ity of TypedDict
ikonst Oct 15, 2022
20a2e66
remove bogus addition
ikonst Oct 16, 2022
b4dc248
Merge branch 'master' into typed-dict-in-type-narrowing
ikonst Oct 17, 2022
88f03f6
use less risky pattern
ikonst Oct 19, 2022
123656c
Merge branch 'typed-dict-in-type-narrowing' of https://github.com/iko…
ikonst Oct 19, 2022
a47bc04
'D1 | None' -> 'D1 | list [str]'
ikonst Oct 19, 2022
f72a634
remove unused stuff in tests
ikonst Oct 19, 2022
bf4364f
test totality through both total= and (Not)Required
ikonst Oct 19, 2022
062fcb1
add spam in d_final test
ikonst Oct 19, 2022
a57f580
Merge branch 'master' into typed-dict-in-type-narrowing
ikonst Oct 24, 2022
2807feb
Merge branch 'master' into typed-dict-in-type-narrowing
ikonst Nov 2, 2022
520df60
update tests per hauntsaninja's code review
ikonst Nov 3, 2022
620da98
Merge branch 'typed-dict-in-type-narrowing' of https://github.com/iko…
ikonst Nov 3, 2022
fb564eb
clarify we don't narrow the left operand
ikonst Nov 4, 2022
178483e
no more recursion + similar naming to conditional_types...
ikonst Nov 4, 2022
26c4d04
Merge remote-tracking branch 'origin/master' into typed-dict-in-type-…
ikonst Nov 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5017,6 +5017,45 @@ def conditional_callable_type_map(

return None, {}

def conditional_types_for_iterable(
self, item_type: Type, iterable_type: Type
) -> tuple[Type | None, Type | None]:
"""
Narrows the type of `iterable_type` based on the type of `item_type`.
For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s).
"""
if_types: list[Type] = []
else_types: list[Type] = []

iterable_type = get_proper_type(iterable_type)
if isinstance(iterable_type, UnionType):
possible_iterable_types = get_proper_types(iterable_type.relevant_items())
else:
possible_iterable_types = [iterable_type]

item_str_literals = try_getting_str_literals_from_type(item_type)

for possible_iterable_type in possible_iterable_types:
if item_str_literals and isinstance(possible_iterable_type, TypedDictType):
for key in item_str_literals:
if key in possible_iterable_type.required_keys:
if_types.append(possible_iterable_type)
elif (
key in possible_iterable_type.items or not possible_iterable_type.is_final
):
if_types.append(possible_iterable_type)
else_types.append(possible_iterable_type)
else:
else_types.append(possible_iterable_type)
else:
if_types.append(possible_iterable_type)
else_types.append(possible_iterable_type)

return (
UnionType.make_union(if_types) if if_types else None,
UnionType.make_union(else_types) if else_types else None,
)

def _is_truthy_type(self, t: ProperType) -> bool:
return (
(
Expand Down Expand Up @@ -5324,28 +5363,42 @@ def has_no_custom_eq_checks(t: Type) -> bool:
elif operator in {"in", "not in"}:
assert len(expr_indices) == 2
left_index, right_index = expr_indices
if left_index not in narrowable_operand_index_to_hash:
continue

item_type = operand_types[left_index]
collection_type = operand_types[right_index]
iterable_type = operand_types[right_index]

# We only try and narrow away 'None' for now
if not is_optional(item_type):
continue
if_map, else_map = {}, {}

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_optional(item_type):
collection_item_type = get_proper_type(
builtin_item_type(iterable_type)
)
if (
collection_item_type is not None
and not is_optional(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
)
and is_overlapping_erased_types(item_type, collection_item_type)
):
if_map[operands[left_index]] = remove_optional(item_type)

if right_index in narrowable_operand_index_to_hash:
if_type, else_type = self.conditional_types_for_iterable(
item_type, iterable_type
)
expr = operands[right_index]
if if_type is None:
if_map = None
else:
if_map[expr] = if_type
if else_type is None:
else_map = None
else:
else_map[expr] = else_type

collection_item_type = get_proper_type(builtin_item_type(collection_type))
if collection_item_type is None or is_optional(collection_item_type):
continue
if (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
):
continue
if is_overlapping_erased_types(item_type, collection_item_type):
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
else:
continue
else:
if_map = {}
else_map = {}
Expand Down
4 changes: 4 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2290,6 +2290,10 @@ def deserialize(cls, data: JsonDict) -> TypedDictType:
Instance.deserialize(data["fallback"]),
)

@property
def is_final(self) -> bool:
return self.fallback.type.is_final

def is_anonymous(self) -> bool:
return self.fallback.type.fullname in TPDICT_FB_NAMES

Expand Down
185 changes: 185 additions & 0 deletions test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,191 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value"
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testOperatorContainsNarrowsTypedDicts_unionWithList]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test case seems completely duplicated in testOperatorContainsNarrowsTypedDicts_total, let's remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the relevant part from testOperatorContainsNarrowsTypedDicts_total, as to keep test names descriptive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from __future__ import annotations
from typing import assert_type, TypedDict, Union
from typing_extensions import final

@final
class D(TypedDict):
foo: int


d_or_list: D | list[str]

if 'foo' in d_or_list:
assert_type(d_or_list, Union[D, list[str]])
elif 'bar' in d_or_list:
assert_type(d_or_list, list[str])
else:
assert_type(d_or_list, list[str])

[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testOperatorContainsNarrowsTypedDicts_total]
from __future__ import annotations
from typing import assert_type, Literal, TypedDict, TypeVar, Union
from typing_extensions import final

@final
class D1(TypedDict):
foo: int


@final
class D2(TypedDict):
bar: int


d: D1 | D2

if 'foo' in d:
assert_type(d, D1)
else:
assert_type(d, D2)

foo_or_bar: Literal['foo', 'bar']
if foo_or_bar in d:
assert_type(d, Union[D1, D2])
else:
assert_type(d, Union[D1, D2])

foo_or_invalid: Literal['foo', 'invalid']
if foo_or_invalid in d:
assert_type(d, D1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory this could narrow foo_or_invalid as well, want to add an assert_type for that behaviour too?

Copy link
Contributor Author

@ikonst ikonst Nov 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'd have to be implemented. And I think it could be pretty neat, but would require a rework: I'd have to pass in the left expression, and return type maps (not "if_type" and "else_type"). Maybe in a follow-up?

Copy link
Contributor Author

@ikonst ikonst Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My current implementation is to make a form of "tagged TypeDicts" work, but I suspect a more generalized form of type narrowing should be possible. However, before we tackle the more contrived x in y case, we should probably do it for x == y, so that this example would see narrowing applied.

(I might be naive, though, and this might've been tried before and proven impossible.)

P.S. PyRight is similarly limited in this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, to be clear my suggestion wasn't to implement this / I don't think it's a terribly important feature. I was just saying it's worth adding an assert_type in this test to make this (lack of) behaviour clear to the reader.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, done: fb564eb

# won't narrow 'foo_or_invalid'
assert_type(foo_or_invalid, Literal['foo', 'invalid'])
else:
assert_type(d, Union[D1, D2])
# won't narrow 'foo_or_invalid'
assert_type(foo_or_invalid, Literal['foo', 'invalid'])

TD = TypeVar('TD', D1, D2)

def f(arg: TD) -> None:
value: int
if 'foo' in arg:
assert_type(arg['foo'], int)
else:
assert_type(arg['bar'], int)


[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testOperatorContainsNarrowsTypedDicts_final]
# flags: --warn-unreachable
from __future__ import annotations
from typing import assert_type, TypedDict, Union
from typing_extensions import final

@final
class DFinal(TypedDict):
foo: int


class DNotFinal(TypedDict):
bar: int


d_not_final: DNotFinal

if 'bar' in d_not_final:
assert_type(d_not_final, DNotFinal)
else:
spam = 'ham' # E: Statement is unreachable

if 'spam' in d_not_final:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen for if 'spam' in d_final? I assume it should be marked as unreachable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test, thanks. (Yes, it does that.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert_type(d_not_final, DNotFinal)
else:
assert_type(d_not_final, DNotFinal)

d_final: DFinal

if 'spam' in d_final:
spam = 'ham' # E: Statement is unreachable
else:
assert_type(d_final, DFinal)

d_union: DFinal | DNotFinal

if 'foo' in d_union:
assert_type(d_union, Union[DFinal, DNotFinal])
else:
assert_type(d_union, DNotFinal)

[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testOperatorContainsNarrowsTypedDicts_partialThroughTotalFalse]
from __future__ import annotations
from typing import assert_type, Literal, TypedDict, Union
from typing_extensions import final

@final
class DTotal(TypedDict):
required_key: int


@final
class DNotTotal(TypedDict, total=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also test NotRequired and Required types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do it. I assumed another test assured it's equivalent, but I can explicitly test it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional_key: int


d: DTotal | DNotTotal

if 'required_key' in d:
assert_type(d, DTotal)
else:
assert_type(d, DNotTotal)

if 'optional_key' in d:
assert_type(d, DNotTotal)
else:
assert_type(d, Union[DTotal, DNotTotal])

key: Literal['optional_key', 'required_key']
if key in d:
assert_type(d, Union[DTotal, DNotTotal])
else:
assert_type(d, Union[DTotal, DNotTotal])

[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testOperatorContainsNarrowsTypedDicts_partialThroughNotRequired]
from __future__ import annotations
from typing import assert_type, Required, NotRequired, TypedDict, Union
from typing_extensions import final

@final
class D1(TypedDict):
required_key: Required[int]
optional_key: NotRequired[int]


@final
class D2(TypedDict):
abc: int
xyz: int


d: D1 | D2

if 'required_key' in d:
assert_type(d, D1)
else:
assert_type(d, D2)

if 'optional_key' in d:
assert_type(d, D1)
else:
assert_type(d, Union[D1, D2])

[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testCannotSubclassFinalTypedDict]
from typing import TypedDict
from typing_extensions import final
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/typing-typeddict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from abc import ABCMeta

cast = 0
assert_type = 0
overload = 0
Any = 0
Union = 0
Expand Down