From e907a3d32e1bcf72692625c7b1acdf983ab13e53 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 3 Feb 2017 06:40:51 -0800 Subject: [PATCH] Support functional API for Enum. Fixes #2306. --- mypy/checker.py | 5 +- mypy/checkexpr.py | 28 +++++- mypy/nodes.py | 19 ++++ mypy/semanal.py | 136 +++++++++++++++++++++++++++- mypy/strconv.py | 3 + mypy/treetransform.py | 5 +- mypy/visitor.py | 7 ++ test-data/unit/pythoneval-enum.test | 72 +++++++++++++++ 8 files changed, 270 insertions(+), 5 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 77c8c4b3d2f36..121e6fa94a9fd 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -24,7 +24,7 @@ YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase, - AwaitExpr, PromoteExpr, Node, + AwaitExpr, PromoteExpr, Node, EnumCallExpr, ARG_POS, MDEF, CONTRAVARIANT, COVARIANT) from mypy import nodes @@ -2244,6 +2244,9 @@ def visit_newtype_expr(self, e: NewTypeExpr) -> Type: def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type: return self.expr_checker.visit_namedtuple_expr(e) + def visit_enum_call_expr(self, e: EnumCallExpr) -> Type: + return self.expr_checker.visit_enum_call_expr(e) + def visit_typeddict_expr(self, e: TypedDictExpr) -> Type: return self.expr_checker.visit_typeddict_expr(e) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 79454d029ee2b..5780362d1e984 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -19,8 +19,8 @@ ConditionalExpr, ComparisonExpr, TempNode, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr, YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, - TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, - UNBOUND_TVAR, BOUND_TVAR, + TypeAliasExpr, BackquoteExpr, EnumCallExpr, + ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, UNBOUND_TVAR, BOUND_TVAR, ) from mypy import nodes import mypy.checker @@ -343,6 +343,14 @@ def check_call(self, callee: Type, args: List[Expression], """ arg_messages = arg_messages or self.msg if isinstance(callee, CallableType): + if (callee.is_type_obj() + and isinstance(callee.ret_type, Instance) + and callee.ret_type.type.fullname() in ('enum.Enum', 'enum.IntEnum', + 'enum.Flag', 'enum.IntFlag')): + # Enum() calls are checked elsewhere + # XXX TODO This feels wrong + return callee.ret_type, callee + if callee.is_concrete_type_obj() and callee.type_object().is_abstract: type = callee.type_object() self.msg.cannot_instantiate_abstract_class( @@ -2156,6 +2164,22 @@ def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type: # TODO: Perhaps return a type object type? return AnyType() + def visit_enum_call_expr(self, e: EnumCallExpr) -> Type: + for name, value in zip(e.items, e.values): + if value is not None: + typ = self.accept(value) + if not isinstance(typ, AnyType): + var = e.info.names[name].node + if isinstance(var, Var): + # Inline TypeCheker.set_inferred_type(), + # without the lvalue. (This doesn't really do + # much, since the value attribute is defined + # to have type Any in the typeshed stub.) + var.type = typ + var.is_inferred = True + # TODO: Perhaps return a type object type? + return AnyType() + def visit_typeddict_expr(self, e: TypedDictExpr) -> Type: # TODO: Perhaps return a type object type? return AnyType() diff --git a/mypy/nodes.py b/mypy/nodes.py index c9485fe66d566..8442577345102 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1792,6 +1792,25 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_typeddict_expr(self) +class EnumCallExpr(Expression): + """Named tuple expression Enum('name', 'val1 val2 ...').""" + + # The class representation of this enumerated type + info = None # type: TypeInfo + # The item names (for debugging) + items = None # type: List[str] + values = None # type: List[Optional[Expression]] + + def __init__(self, info: 'TypeInfo', items: List[str], + values: List[Optional[Expression]]) -> None: + self.info = info + self.items = items + self.values = values + + def accept(self, visitor: ExpressionVisitor[T]) -> T: + return visitor.visit_enum_call_expr(self) + + class PromoteExpr(Expression): """Ducktype class decorator expression _promote(...).""" diff --git a/mypy/semanal.py b/mypy/semanal.py index 71a8323be292d..23c46a25084a8 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -64,7 +64,7 @@ YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SymbolNode, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, - IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, + IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES, ) from mypy.typevars import has_no_typevars, fill_typevars @@ -1195,6 +1195,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: self.process_typevar_declaration(s) self.process_namedtuple_definition(s) self.process_typeddict_definition(s) + self.process_enum_call(s) if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and @@ -1991,6 +1992,139 @@ def build_typeddict_typeinfo(self, name: str, items: List[str], return info + def process_enum_call(self, s: AssignmentStmt) -> None: + """Check if s defines an Enum; if yes, store the definition in symbol table.""" + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + lvalue = s.lvalues[0] + name = lvalue.name + enum_call = self.check_enum_call(s.rvalue, name) + if enum_call is None: + return + # Yes, it's a valid Enum definition. Add it to the symbol table. + node = self.lookup(name, s) + if node: + node.kind = GDEF # TODO locally defined Enum + node.node = enum_call + + def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines an Enum. + + Example: + + A = enum.Enum('A', 'foo bar') + + is equivalent to: + + class A(enum.Enum): + foo = 1 + bar = 2 + """ + if not isinstance(node, CallExpr): + return None + call = node + callee = call.callee + if not isinstance(callee, RefExpr): + return None + fullname = callee.fullname + if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): + return None + items, types, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1]) + if not ok: + # Error. Construct dummy return value. + return self.build_enum_call_typeinfo('Enum', [], [], fullname) + name = cast(StrExpr, call.args[0]).value + if name != var_name or self.is_func_scope(): + # Give it a unique name derived from the line number. + name += '@' + str(call.line) + info = self.build_enum_call_typeinfo(name, items, types, fullname) + # Store it as a global just in case it would remain anonymous. + # (Or in the nearest class if there is one.) + stnode = SymbolTableNode(GDEF, info, self.cur_mod_id) + if self.type: + self.type.names[name] = stnode + else: + self.globals[name] = stnode + call.analyzed = EnumCallExpr(info, items, values) + call.analyzed.set_line(call.line, call.column) + return info + + def build_enum_call_typeinfo(self, name: str, items: List[str], + types: List[Type], fullname: str) -> TypeInfo: + base = self.named_type_or_none(fullname) + assert base is not None + info = self.basic_new_typeinfo(name, base) + info.is_enum = True + for item, typ in zip(items, types): + var = Var(item, typ) + var.info = info + var.is_property = True + info.names[item] = SymbolTableNode(MDEF, var) + return info + + def parse_enum_call_args(self, call: CallExpr, + class_name: str) -> Tuple[List[str], List[Type], + List[Optional[Expression]], bool]: + args = call.args + if len(args) < 2: + return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call) + if len(args) > 2: + return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call) + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call) + if not isinstance(args[0], (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() expects a string literal as the first argument" % class_name, call) + items = [] + types = [] # type: List[Type] + values = [] # type: List[Optional[Expression]] + if isinstance(args[1], (StrExpr, UnicodeExpr)): + fields = args[1].value + for field in fields.replace(',', ' ').split(): + items.append(field) + elif isinstance(args[1], (TupleExpr, ListExpr)): + seq_items = args[1].items + if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): + items = [cast(StrExpr, seq_item).value for seq_item in seq_items] + elif all(isinstance(seq_item, (TupleExpr, ListExpr)) + and len(seq_item.items) == 2 + and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) + for seq_item in seq_items): + for seq_item in seq_items: + assert isinstance(seq_item, (TupleExpr, ListExpr)) + name, value = seq_item.items + assert isinstance(name, (StrExpr, UnicodeExpr)) + items.append(name.value) + values.append(value) + else: + return self.fail_enum_call_arg( + "%s() with tuple or list of (name, value) pairs not yet supported" % + class_name, + call) + elif isinstance(args[1], DictExpr): + # XXX TODO: {'p': 1, 'q': 2} + return self.fail_enum_call_arg( + "%s() with dict of name:value items not yet supported" % class_name, call) + else: + return self.fail_enum_call_arg( + "%s() expects a string, tuple, list or dict literal as the second argument" % + class_name, + call) + if len(items) == 0: + return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call) + if not types: + types = [AnyType()] * len(items) + if not values: + values = [None] * len(items) + assert len(items) == len(types) == len(values) + return items, types, values, True + + def fail_enum_call_arg(self, message: str, + context: Context) -> Tuple[List[str], List[Type], + List[Optional[Expression]], bool]: + self.fail(message, context) + return [], [], [], False + def visit_decorator(self, dec: Decorator) -> None: for d in dec.decorators: d.accept(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index d7c1e48a813c7..b280f61691b5b 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -429,6 +429,9 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> str: o.info.name(), o.info.tuple_type) + def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> str: + return 'EnumCallExpr:{}({}, {})'.format(o.line, o.info.name(), o.items) + def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> str: return 'TypedDictExpr:{}({})'.format(o.line, o.info.name()) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index e6e4678097ca2..3640f65d84c90 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -19,7 +19,7 @@ ComparisonExpr, TempNode, StarExpr, Statement, Expression, YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, + YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, EnumCallExpr, ) from mypy.types import Type, FunctionLike from mypy.traverser import TraverserVisitor @@ -483,6 +483,9 @@ def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr: def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr: return NamedTupleExpr(node.info) + def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr: + return EnumCallExpr(node.info, node.items, node.values) + def visit_typeddict_expr(self, node: TypedDictExpr) -> Node: return TypedDictExpr(node.info) diff --git a/mypy/visitor.py b/mypy/visitor.py index df04ebde82139..bf46701a61fea 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -156,6 +156,10 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: pass + @abstractmethod + def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T: + pass + @abstractmethod def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T: pass @@ -392,6 +396,9 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: pass + def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T: + pass + def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T: pass diff --git a/test-data/unit/pythoneval-enum.test b/test-data/unit/pythoneval-enum.test index 3ae2df55f120d..d8cca9ac856d2 100644 --- a/test-data/unit/pythoneval-enum.test +++ b/test-data/unit/pythoneval-enum.test @@ -132,3 +132,75 @@ class E(N, Enum): def f(x: E) -> None: pass f(E.X) + +[case testFunctionalEnumString] +from enum import Enum, IntEnum +E = Enum('E', 'foo bar') +I = IntEnum('I', ' bar, baz ') +reveal_type(E.foo) +reveal_type(E.bar.value) +reveal_type(I.bar) +reveal_type(I.baz.value) +[out] +_program.py:4: error: Revealed type is '_testFunctionalEnumString.E' +_program.py:5: error: Revealed type is 'Any' +_program.py:6: error: Revealed type is '_testFunctionalEnumString.I' +_program.py:7: error: Revealed type is 'builtins.int' + +[case testFunctionalEnumListOfStrings] +from enum import Enum, IntEnum +E = Enum('E', ('foo', 'bar')) +F = Enum('F', ['bar', 'baz']) +reveal_type(E.foo) +reveal_type(F.baz) +[out] +_program.py:4: error: Revealed type is '_testFunctionalEnumListOfStrings.E' +_program.py:5: error: Revealed type is '_testFunctionalEnumListOfStrings.F' + +[case testFunctionalEnumListOfPairs] +from enum import Enum, IntEnum +E = Enum('E', [('foo', 1), ['bar', 2]]) +F = Enum('F', (['bar', 1], ('baz', 2))) +reveal_type(E.foo) +reveal_type(F.baz) +reveal_type(E.foo.value) +reveal_type(F.bar.name) +[out] +_program.py:4: error: Revealed type is '_testFunctionalEnumListOfPairs.E' +_program.py:5: error: Revealed type is '_testFunctionalEnumListOfPairs.F' +_program.py:6: error: Revealed type is 'Any' +_program.py:7: error: Revealed type is 'builtins.str' + +[case testFunctionalEnumErrors] +from enum import Enum, IntEnum +A = Enum('A') +B = Enum('B', 42) +C = Enum('C', 'a b', 'x') +D = Enum('D', foo) +bar = 'x y z' +E = Enum('E', bar) +I = IntEnum('I') +J = IntEnum('I', 42) +K = IntEnum('I', 'p q', 'z') +L = Enum('L', ' ') +M = Enum('M', ()) +N = IntEnum('M', []) +P = Enum('P', [42]) +Q = Enum('Q', [('a', 42, 0)]) +R = IntEnum('R', [[0, 42]]) +[out] +_program.py:2: error: Too few arguments for Enum() +_program.py:3: error: Enum() expects a string, tuple, list or dict literal as the second argument +_program.py:4: error: Too many arguments for Enum() +_program.py:5: error: Enum() expects a string, tuple, list or dict literal as the second argument +_program.py:5: error: Name 'foo' is not defined +_program.py:7: error: Enum() expects a string, tuple, list or dict literal as the second argument +_program.py:8: error: Too few arguments for IntEnum() +_program.py:9: error: IntEnum() expects a string, tuple, list or dict literal as the second argument +_program.py:10: error: Too many arguments for IntEnum() +_program.py:11: error: Enum() needs at least one item +_program.py:12: error: Enum() needs at least one item +_program.py:13: error: IntEnum() needs at least one item +_program.py:14: error: Enum() with tuple or list of (name, value) pairs not yet supported +_program.py:15: error: Enum() with tuple or list of (name, value) pairs not yet supported +_program.py:16: error: IntEnum() with tuple or list of (name, value) pairs not yet supported