Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for narrowing Literals using equality #8151

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 122 additions & 92 deletions mypy/checker.py

Large diffs are not rendered by default.

47 changes: 3 additions & 44 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
SYMBOL_FUNCBASE_TYPES
)
from mypy.literals import literal
from mypy import nodes
Expand All @@ -51,15 +50,16 @@
from mypy import erasetype
from mypy.checkmember import analyze_member_access, type_object_type
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
from mypy.checkstrformat import StringFormatterChecker, custom_special_method
from mypy.checkstrformat import StringFormatterChecker
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
from mypy.util import split_module_names
from mypy.typevars import fill_typevars
from mypy.visitor import ExpressionVisitor
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
from mypy.typeops import (
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
function_type, callable_type, try_getting_str_literals
function_type, callable_type, try_getting_str_literals, custom_special_method,
is_literal_type_like,
)
import mypy.errorcodes as codes

Expand Down Expand Up @@ -4266,24 +4266,6 @@ def merge_typevars_in_callables_by_name(
return output, variables


def is_literal_type_like(t: Optional[Type]) -> bool:
"""Returns 'true' if the given type context is potentially either a LiteralType,
a Union of LiteralType, or something similar.
"""
t = get_proper_type(t)
if t is None:
return False
elif isinstance(t, LiteralType):
return True
elif isinstance(t, UnionType):
return any(is_literal_type_like(item) for item in t.items)
elif isinstance(t, TypeVarType):
return (is_literal_type_like(t.upper_bound)
or any(is_literal_type_like(item) for item in t.values))
else:
return False


def try_getting_literal(typ: Type) -> ProperType:
"""If possible, get a more precise literal type for a given type."""
typ = get_proper_type(typ)
Expand All @@ -4305,29 +4287,6 @@ def is_expr_literal_type(node: Expression) -> bool:
return False


def custom_equality_method(typ: Type) -> bool:
"""Does this type have a custom __eq__() method?"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
method = typ.type.get('__eq__')
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
if method.node.info:
return not method.node.info.fullname.startswith('builtins.')
return False
if isinstance(typ, UnionType):
return any(custom_equality_method(t) for t in typ.items)
if isinstance(typ, TupleType):
return custom_equality_method(tuple_fallback(typ))
if isinstance(typ, CallableType) and typ.is_type_obj():
# Look up __eq__ on the metaclass for class objects.
return custom_equality_method(typ.fallback)
if isinstance(typ, AnyType):
# Avoid false positives in uncertain cases.
return True
# TODO: support other types (see ExpressionChecker.has_member())?
return False


def has_bytes_component(typ: Type, py2: bool = False) -> bool:
"""Is this one of builtin byte types, or a union that contains it?"""
typ = get_proper_type(typ)
Expand Down
35 changes: 3 additions & 32 deletions mypy/checkstrformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

from mypy.types import (
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType,
CallableType, LiteralType, get_proper_types
LiteralType, get_proper_types
)
from mypy.nodes import (
StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr,
IndexExpr, MemberExpr, TempNode, ARG_POS, ARG_STAR, ARG_NAMED, ARG_STAR2,
SYMBOL_FUNCBASE_TYPES, Decorator, Var, Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
)
import mypy.errorcodes as codes

Expand All @@ -35,7 +35,7 @@
from mypy import message_registry
from mypy.messages import MessageBuilder
from mypy.maptype import map_instance_to_supertype
from mypy.typeops import tuple_fallback
from mypy.typeops import custom_special_method
from mypy.subtypes import is_subtype
from mypy.parse import parse

Expand Down Expand Up @@ -961,32 +961,3 @@ def has_type_component(typ: Type, fullname: str) -> bool:
elif isinstance(typ, UnionType):
return any(has_type_component(t, fullname) for t in typ.relevant_items())
return False


def custom_special_method(typ: Type, name: str,
check_all: bool = False) -> bool:
"""Does this type have a custom special method such as __format__() or __eq__()?

If check_all is True ensure all items of a union have a custom method, not just some.
"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
method = typ.type.get(name)
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
if method.node.info:
return not method.node.info.fullname.startswith('builtins.')
return False
if isinstance(typ, UnionType):
if check_all:
return all(custom_special_method(t, name, check_all) for t in typ.items)
return any(custom_special_method(t, name) for t in typ.items)
if isinstance(typ, TupleType):
return custom_special_method(tuple_fallback(typ), name)
if isinstance(typ, CallableType) and typ.is_type_obj():
# Look up __method__ on the metaclass for class objects.
return custom_special_method(typ.fallback, name)
if isinstance(typ, AnyType):
# Avoid false positives in uncertain cases.
return True
# TODO: support other types (see ExpressionChecker.has_member())?
return False
53 changes: 50 additions & 3 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from mypy.nodes import (
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS,
Expression, StrExpr, Var
Expression, StrExpr, Var, Decorator, SYMBOL_FUNCBASE_TYPES
)
from mypy.maptype import map_instance_to_supertype
from mypy.expandtype import expand_type_by_instance, expand_type
Expand Down Expand Up @@ -564,6 +564,24 @@ def try_getting_literals_from_type(typ: Type,
return literals


def is_literal_type_like(t: Optional[Type]) -> bool:
"""Returns 'true' if the given type context is potentially either a LiteralType,
a Union of LiteralType, or something similar.
"""
t = get_proper_type(t)
if t is None:
return False
elif isinstance(t, LiteralType):
return True
elif isinstance(t, UnionType):
return any(is_literal_type_like(item) for item in t.items)
elif isinstance(t, TypeVarType):
return (is_literal_type_like(t.upper_bound)
or any(is_literal_type_like(item) for item in t.values))
else:
return False


def get_enum_values(typ: Instance) -> List[str]:
"""Return the list of values for an Enum."""
return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)]
Expand Down Expand Up @@ -640,10 +658,11 @@ class Status(Enum):
return typ


def coerce_to_literal(typ: Type) -> ProperType:
def coerce_to_literal(typ: Type) -> Type:
"""Recursively converts any Instances that have a last_known_value or are
instances of enum types with a single value into the corresponding LiteralType.
"""
original_type = typ
typ = get_proper_type(typ)
if isinstance(typ, UnionType):
new_items = [coerce_to_literal(item) for item in typ.items]
Expand All @@ -655,7 +674,7 @@ def coerce_to_literal(typ: Type) -> ProperType:
enum_values = get_enum_values(typ)
if len(enum_values) == 1:
return LiteralType(value=enum_values[0], fallback=typ)
return typ
return original_type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is right, thanks!



def get_type_vars(tp: Type) -> List[TypeVarType]:
Expand All @@ -674,3 +693,31 @@ def _merge(self, iter: Iterable[List[TypeVarType]]) -> List[TypeVarType]:

def visit_type_var(self, t: TypeVarType) -> List[TypeVarType]:
return [t]


def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool:
"""Does this type have a custom special method such as __format__() or __eq__()?

If check_all is True ensure all items of a union have a custom method, not just some.
"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
method = typ.type.get(name)
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
if method.node.info:
return not method.node.info.fullname.startswith('builtins.')
return False
if isinstance(typ, UnionType):
if check_all:
return all(custom_special_method(t, name, check_all) for t in typ.items)
return any(custom_special_method(t, name) for t in typ.items)
if isinstance(typ, TupleType):
return custom_special_method(tuple_fallback(typ), name, check_all)
if isinstance(typ, CallableType) and typ.is_type_obj():
# Look up __method__ on the metaclass for class objects.
return custom_special_method(typ.fallback, name, check_all)
if isinstance(typ, AnyType):
# Avoid false positives in uncertain cases.
return True
# TODO: support other types (see ExpressionChecker.has_member())?
return False
69 changes: 42 additions & 27 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -978,32 +978,43 @@ class Foo(Enum):
x: Foo
y: Foo

# We can't narrow anything in the else cases -- what if
# x is Foo.A and y is Foo.B or vice versa, for example?
if x is y is Foo.A:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif x is y is Foo.B:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

if x is Foo.A is y:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif x is Foo.B is y:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

if Foo.A is x is y:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif Foo.B is x is y:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

[builtins fixtures/primitives.pyi]

Expand All @@ -1026,8 +1037,10 @@ if x is Foo.A < y is Foo.B:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
# Note: we can't narrow in this case. What if both x and y
# are Foo.A, for example?
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

Expand Down Expand Up @@ -1109,11 +1122,13 @@ if x0 is x1 is Foo.A is x2 < x3 is Foo.B is x4 is x5:
reveal_type(x4) # N: Revealed type is 'Literal[__main__.Foo.B]'
reveal_type(x5) # N: Revealed type is 'Literal[__main__.Foo.B]'
else:
reveal_type(x0) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(x2) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'

reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
reveal_type(x4) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
reveal_type(x5) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
# We unfortunately can't narrow away anything. For example,
# what if x0 == Foo.A and x1 == Foo.B or vice versa?
reveal_type(x0) # N: Revealed type is '__main__.Foo'
reveal_type(x1) # N: Revealed type is '__main__.Foo'
reveal_type(x2) # N: Revealed type is '__main__.Foo'

reveal_type(x3) # N: Revealed type is '__main__.Foo'
reveal_type(x4) # N: Revealed type is '__main__.Foo'
reveal_type(x5) # N: Revealed type is '__main__.Foo'
[builtins fixtures/primitives.pyi]
Loading