Skip to content

Commit

Permalink
Support functional API for Enum.
Browse files Browse the repository at this point in the history
Fixes #2306.
  • Loading branch information
Guido van Rossum committed Feb 3, 2017
1 parent ed7d0c0 commit d6c543b
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 5 deletions.
5 changes: 4 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension,
DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr,
RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase,
AwaitExpr, PromoteExpr, Node,
AwaitExpr, PromoteExpr, Node, EnumCallExpr,
ARG_POS, MDEF,
CONTRAVARIANT, COVARIANT)
from mypy import nodes
Expand Down Expand Up @@ -2244,6 +2244,9 @@ def visit_newtype_expr(self, e: NewTypeExpr) -> Type:
def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type:
return self.expr_checker.visit_namedtuple_expr(e)

def visit_enum_call_expr(self, e: EnumCallExpr) -> Type:
return self.expr_checker.visit_enum_call_expr(e)

def visit_typeddict_expr(self, e: TypedDictExpr) -> Type:
return self.expr_checker.visit_typeddict_expr(e)

Expand Down
26 changes: 24 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF,
UNBOUND_TVAR, BOUND_TVAR,
TypeAliasExpr, BackquoteExpr, EnumCallExpr,
ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, UNBOUND_TVAR, BOUND_TVAR,
)
from mypy import nodes
import mypy.checker
Expand Down Expand Up @@ -343,6 +343,12 @@ def check_call(self, callee: Type, args: List[Expression],
"""
arg_messages = arg_messages or self.msg
if isinstance(callee, CallableType):
if (isinstance(callable_node, RefExpr)
and callable_node.fullname in ('enum.Enum', 'enum.IntEnum',
'enum.Flag', 'enum.IntFlag')):
# An Enum() call that failed SemanticAnalyzer.check_enum_call().
return callee.ret_type, callee

if callee.is_concrete_type_obj() and callee.type_object().is_abstract:
type = callee.type_object()
self.msg.cannot_instantiate_abstract_class(
Expand Down Expand Up @@ -2156,6 +2162,22 @@ def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type:
# TODO: Perhaps return a type object type?
return AnyType()

def visit_enum_call_expr(self, e: EnumCallExpr) -> Type:
for name, value in zip(e.items, e.values):
if value is not None:
typ = self.accept(value)
if not isinstance(typ, AnyType):
var = e.info.names[name].node
if isinstance(var, Var):
# Inline TypeCheker.set_inferred_type(),
# without the lvalue. (This doesn't really do
# much, since the value attribute is defined
# to have type Any in the typeshed stub.)
var.type = typ
var.is_inferred = True
# TODO: Perhaps return a type object type?
return AnyType()

def visit_typeddict_expr(self, e: TypedDictExpr) -> Type:
# TODO: Perhaps return a type object type?
return AnyType()
Expand Down
19 changes: 19 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,25 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_typeddict_expr(self)


class EnumCallExpr(Expression):
"""Named tuple expression Enum('name', 'val1 val2 ...')."""

# The class representation of this enumerated type
info = None # type: TypeInfo
# The item names (for debugging)
items = None # type: List[str]
values = None # type: List[Optional[Expression]]

def __init__(self, info: 'TypeInfo', items: List[str],
values: List[Optional[Expression]]) -> None:
self.info = info
self.items = items
self.values = values

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_enum_call_expr(self)


class PromoteExpr(Expression):
"""Ducktype class decorator expression _promote(...)."""

Expand Down
136 changes: 135 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SymbolNode,
SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr,
YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr,
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode,
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr,
COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES,
)
from mypy.typevars import has_no_typevars, fill_typevars
Expand Down Expand Up @@ -1195,6 +1195,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
self.process_typevar_declaration(s)
self.process_namedtuple_definition(s)
self.process_typeddict_definition(s)
self.process_enum_call(s)

if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and
s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and
Expand Down Expand Up @@ -1991,6 +1992,139 @@ def build_typeddict_typeinfo(self, name: str, items: List[str],

return info

def process_enum_call(self, s: AssignmentStmt) -> None:
"""Check if s defines an Enum; if yes, store the definition in symbol table."""
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr):
return
lvalue = s.lvalues[0]
name = lvalue.name
enum_call = self.check_enum_call(s.rvalue, name)
if enum_call is None:
return
# Yes, it's a valid Enum definition. Add it to the symbol table.
node = self.lookup(name, s)
if node:
node.kind = GDEF # TODO locally defined Enum
node.node = enum_call

def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]:
"""Check if a call defines an Enum.
Example:
A = enum.Enum('A', 'foo bar')
is equivalent to:
class A(enum.Enum):
foo = 1
bar = 2
"""
if not isinstance(node, CallExpr):
return None
call = node
callee = call.callee
if not isinstance(callee, RefExpr):
return None
fullname = callee.fullname
if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'):
return None
items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1])
if not ok:
# Error. Construct dummy return value.
return self.build_enum_call_typeinfo('Enum', [], fullname)
name = cast(StrExpr, call.args[0]).value
if name != var_name or self.is_func_scope():
# Give it a unique name derived from the line number.
name += '@' + str(call.line)
info = self.build_enum_call_typeinfo(name, items, fullname)
# Store it as a global just in case it would remain anonymous.
# (Or in the nearest class if there is one.)
stnode = SymbolTableNode(GDEF, info, self.cur_mod_id)
if self.type:
self.type.names[name] = stnode
else:
self.globals[name] = stnode
call.analyzed = EnumCallExpr(info, items, values)
call.analyzed.set_line(call.line, call.column)
return info

def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo:
base = self.named_type_or_none(fullname)
assert base is not None
info = self.basic_new_typeinfo(name, base)
info.is_enum = True
for item in items:
var = Var(item)
var.info = info
var.is_property = True
info.names[item] = SymbolTableNode(MDEF, var)
return info

def parse_enum_call_args(self, call: CallExpr,
class_name: str) -> Tuple[List[str],
List[Optional[Expression]], bool]:
args = call.args
if len(args) < 2:
return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call)
if len(args) > 2:
return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call)
if call.arg_kinds != [ARG_POS, ARG_POS]:
return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call)
if not isinstance(args[0], (StrExpr, UnicodeExpr)):
return self.fail_enum_call_arg(
"%s() expects a string literal as the first argument" % class_name, call)
items = []
values = [] # type: List[Optional[Expression]]
if isinstance(args[1], (StrExpr, UnicodeExpr)):
fields = args[1].value
for field in fields.replace(',', ' ').split():
items.append(field)
elif isinstance(args[1], (TupleExpr, ListExpr)):
seq_items = args[1].items
if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items):
items = [cast(StrExpr, seq_item).value for seq_item in seq_items]
elif all(isinstance(seq_item, (TupleExpr, ListExpr))
and len(seq_item.items) == 2
and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr))
for seq_item in seq_items):
for seq_item in seq_items:
assert isinstance(seq_item, (TupleExpr, ListExpr))
name, value = seq_item.items
assert isinstance(name, (StrExpr, UnicodeExpr))
items.append(name.value)
values.append(value)
else:
return self.fail_enum_call_arg(
"%s() with tuple or list of (name, value) pairs not yet supported" %
class_name,
call)
elif isinstance(args[1], DictExpr):
for key, value in args[1].items:
if not isinstance(key, (StrExpr, UnicodeExpr)):
return self.fail_enum_call_arg(
"%s() with dict literal requires string literals" % class_name, call)
items.append(key.value)
values.append(value)
else:
# TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}?
return self.fail_enum_call_arg(
"%s() expects a string, tuple, list or dict literal as the second argument" %
class_name,
call)
if len(items) == 0:
return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call)
if not values:
values = [None] * len(items)
assert len(items) == len(values)
return items, values, True

def fail_enum_call_arg(self, message: str,
context: Context) -> Tuple[List[str],
List[Optional[Expression]], bool]:
self.fail(message, context)
return [], [], False

def visit_decorator(self, dec: Decorator) -> None:
for d in dec.decorators:
d.accept(self)
Expand Down
3 changes: 3 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> str:
o.info.name(),
o.info.tuple_type)

def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> str:
return 'EnumCallExpr:{}({}, {})'.format(o.line, o.info.name(), o.items)

def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> str:
return 'TypedDictExpr:{}({})'.format(o.line,
o.info.name())
Expand Down
5 changes: 4 additions & 1 deletion mypy/treetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ComparisonExpr, TempNode, StarExpr, Statement, Expression,
YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension,
DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr,
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr,
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, EnumCallExpr,
)
from mypy.types import Type, FunctionLike
from mypy.traverser import TraverserVisitor
Expand Down Expand Up @@ -483,6 +483,9 @@ def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr:
def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr:
return NamedTupleExpr(node.info)

def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr:
return EnumCallExpr(node.info, node.items, node.values)

def visit_typeddict_expr(self, node: TypedDictExpr) -> Node:
return TypedDictExpr(node.info)

Expand Down
7 changes: 7 additions & 0 deletions mypy/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
pass

@abstractmethod
def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
pass

@abstractmethod
def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
pass
Expand Down Expand Up @@ -392,6 +396,9 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
pass

def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
pass

def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
pass

Expand Down
86 changes: 86 additions & 0 deletions test-data/unit/pythoneval-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,89 @@ class E(N, Enum):
def f(x: E) -> None: pass

f(E.X)

[case testFunctionalEnumString]
from enum import Enum, IntEnum
E = Enum('E', 'foo bar')
I = IntEnum('I', ' bar, baz ')
reveal_type(E.foo)
reveal_type(E.bar.value)
reveal_type(I.bar)
reveal_type(I.baz.value)
[out]
_program.py:4: error: Revealed type is '_testFunctionalEnumString.E'
_program.py:5: error: Revealed type is 'Any'
_program.py:6: error: Revealed type is '_testFunctionalEnumString.I'
_program.py:7: error: Revealed type is 'builtins.int'

[case testFunctionalEnumListOfStrings]
from enum import Enum, IntEnum
E = Enum('E', ('foo', 'bar'))
F = Enum('F', ['bar', 'baz'])
reveal_type(E.foo)
reveal_type(F.baz)
[out]
_program.py:4: error: Revealed type is '_testFunctionalEnumListOfStrings.E'
_program.py:5: error: Revealed type is '_testFunctionalEnumListOfStrings.F'

[case testFunctionalEnumListOfPairs]
from enum import Enum, IntEnum
E = Enum('E', [('foo', 1), ['bar', 2]])
F = Enum('F', (['bar', 1], ('baz', 2)))
reveal_type(E.foo)
reveal_type(F.baz)
reveal_type(E.foo.value)
reveal_type(F.bar.name)
[out]
_program.py:4: error: Revealed type is '_testFunctionalEnumListOfPairs.E'
_program.py:5: error: Revealed type is '_testFunctionalEnumListOfPairs.F'
_program.py:6: error: Revealed type is 'Any'
_program.py:7: error: Revealed type is 'builtins.str'

[case testFunctionalEnumDict]
from enum import Enum, IntEnum
E = Enum('E', {'foo': 1, 'bar': 2})
F = IntEnum('F', {'bar': 1, 'baz': 2})
reveal_type(E.foo)
reveal_type(F.baz)
reveal_type(E.foo.value)
reveal_type(F.bar.name)
[out]
_program.py:4: error: Revealed type is '_testFunctionalEnumDict.E'
_program.py:5: error: Revealed type is '_testFunctionalEnumDict.F'
_program.py:6: error: Revealed type is 'Any'
_program.py:7: error: Revealed type is 'builtins.str'

[case testFunctionalEnumErrors]
from enum import Enum, IntEnum
A = Enum('A')
B = Enum('B', 42)
C = Enum('C', 'a b', 'x')
D = Enum('D', foo)
bar = 'x y z'
E = Enum('E', bar)
I = IntEnum('I')
J = IntEnum('I', 42)
K = IntEnum('I', 'p q', 'z')
L = Enum('L', ' ')
M = Enum('M', ())
N = IntEnum('M', [])
P = Enum('P', [42])
Q = Enum('Q', [('a', 42, 0)])
R = IntEnum('R', [[0, 42]])
[out]
_program.py:2: error: Too few arguments for Enum()
_program.py:3: error: Enum() expects a string, tuple, list or dict literal as the second argument
_program.py:4: error: Too many arguments for Enum()
_program.py:5: error: Enum() expects a string, tuple, list or dict literal as the second argument
_program.py:5: error: Name 'foo' is not defined
_program.py:7: error: Enum() expects a string, tuple, list or dict literal as the second argument
_program.py:8: error: Too few arguments for IntEnum()
_program.py:9: error: IntEnum() expects a string, tuple, list or dict literal as the second argument
_program.py:10: error: Too many arguments for IntEnum()
_program.py:11: error: Enum() needs at least one item
_program.py:12: error: Enum() needs at least one item
_program.py:13: error: IntEnum() needs at least one item
_program.py:14: error: Enum() with tuple or list of (name, value) pairs not yet supported
_program.py:15: error: Enum() with tuple or list of (name, value) pairs not yet supported
_program.py:16: error: IntEnum() with tuple or list of (name, value) pairs not yet supported

0 comments on commit d6c543b

Please sign in to comment.