From 62c6f6bd79646b13650aa088b2ed2cd4bb9ea08f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 3 Feb 2017 06:40:51 -0800 Subject: [PATCH] Support functional API for Enum. Fixes #2306. --- mypy/checker.py | 39 ++++---- mypy/checkexpr.py | 25 ++++- mypy/nodes.py | 19 ++++ mypy/semanal.py | 136 +++++++++++++++++++++++++++- mypy/strconv.py | 3 + mypy/treetransform.py | 6 +- mypy/visitor.py | 7 ++ test-data/unit/pythoneval-enum.test | 88 ++++++++++++++++++ 8 files changed, 299 insertions(+), 24 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 4fa03cdf0c7d..312c6082136f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -17,12 +17,17 @@ TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt, WhileStmt, OperatorAssignmentStmt, WithStmt, AssertStmt, RaiseStmt, TryStmt, ForStmt, DelStmt, CallExpr, IntExpr, StrExpr, - UnicodeExpr, OpExpr, UnaryExpr, LambdaExpr, TempNode, SymbolTableNode, - Context, Decorator, PrintStmt, LITERAL_TYPE, BreakStmt, PassStmt, ContinueStmt, - ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, ImportFrom, ImportAll, ImportBase, - ARG_POS, CONTRAVARIANT, COVARIANT, ExecStmt, GlobalDecl, Import, NonlocalDecl, - MDEF, Node -) + BytesExpr, UnicodeExpr, FloatExpr, OpExpr, UnaryExpr, CastExpr, RevealTypeExpr, SuperExpr, + TypeApplication, DictExpr, SliceExpr, LambdaExpr, TempNode, SymbolTableNode, + Context, ListComprehension, ConditionalExpr, GeneratorExpr, + Decorator, SetExpr, TypeVarExpr, NewTypeExpr, PrintStmt, + LITERAL_TYPE, BreakStmt, PassStmt, ContinueStmt, ComparisonExpr, StarExpr, + YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension, + DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, + RefExpr, YieldExpr, BackquoteExpr, Import, ImportFrom, ImportAll, ImportBase, + AwaitExpr, PromoteExpr, Node, EnumCallExpr, + ARG_POS, MDEF, + CONTRAVARIANT, COVARIANT) from mypy import nodes from mypy.types import ( Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType, @@ -45,7 +50,7 @@ from mypy.semanal import set_callable_name, refers_to_fullname from mypy.erasetype import erase_typevars from mypy.expandtype import expand_type, expand_type_by_instance -from mypy.visitor import StatementVisitor +from mypy.visitor import NodeVisitor from mypy.join import join_types from mypy.treetransform import TransformVisitor from mypy.binder import ConditionalTypeBinder, get_declaration @@ -70,7 +75,7 @@ ]) -class TypeChecker(StatementVisitor[None]): +class TypeChecker(NodeVisitor[None]): """Mypy type checker. Type check mypy source files that have been semantically analyzed. @@ -2259,21 +2264,13 @@ def visit_break_stmt(self, s: BreakStmt) -> None: def visit_continue_stmt(self, s: ContinueStmt) -> None: self.binder.handle_continue() + return None - def visit_exec_stmt(self, s: ExecStmt) -> None: - pass - - def visit_global_decl(self, s: GlobalDecl) -> None: - pass - - def visit_nonlocal_decl(self, s: NonlocalDecl) -> None: - pass + def visit_typeddict_expr(self, e: TypedDictExpr) -> Type: + return self.expr_checker.visit_typeddict_expr(e) - def visit_var(self, s: Var) -> None: - pass - - def visit_pass_stmt(self, s: PassStmt) -> None: - pass + def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> Type: + return self.expr_checker.visit_enum_call_expr(o) # # Helpers diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 51a6f92e9ce6..e9d2f159a53a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -20,7 +20,8 @@ ConditionalExpr, ComparisonExpr, TempNode, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr, YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, - TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, + TypeAliasExpr, BackquoteExpr, EnumCallExpr, + ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, UNBOUND_TVAR, BOUND_TVAR, LITERAL_TYPE ) from mypy import nodes @@ -349,6 +350,12 @@ def check_call(self, callee: Type, args: List[Expression], """ arg_messages = arg_messages or self.msg if isinstance(callee, CallableType): + if (isinstance(callable_node, RefExpr) + and callable_node.fullname in ('enum.Enum', 'enum.IntEnum', + 'enum.Flag', 'enum.IntFlag')): + # An Enum() call that failed SemanticAnalyzer.check_enum_call(). + return callee.ret_type, callee + if (callee.is_type_obj() and callee.type_object().is_abstract # Exceptions for Type[...] and classmethod first argument and not callee.from_type_type and not callee.is_classmethod_class): @@ -2199,6 +2206,22 @@ def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type: # TODO: Perhaps return a type object type? return AnyType() + def visit_enum_call_expr(self, e: EnumCallExpr) -> Type: + for name, value in zip(e.items, e.values): + if value is not None: + typ = self.accept(value) + if not isinstance(typ, AnyType): + var = e.info.names[name].node + if isinstance(var, Var): + # Inline TypeCheker.set_inferred_type(), + # without the lvalue. (This doesn't really do + # much, since the value attribute is defined + # to have type Any in the typeshed stub.) + var.type = typ + var.is_inferred = True + # TODO: Perhaps return a type object type? + return AnyType() + def visit_typeddict_expr(self, e: TypedDictExpr) -> Type: # TODO: Perhaps return a type object type? return AnyType() diff --git a/mypy/nodes.py b/mypy/nodes.py index e5a26d638518..4584245b9904 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1830,6 +1830,25 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_typeddict_expr(self) +class EnumCallExpr(Expression): + """Named tuple expression Enum('name', 'val1 val2 ...').""" + + # The class representation of this enumerated type + info = None # type: TypeInfo + # The item names (for debugging) + items = None # type: List[str] + values = None # type: List[Optional[Expression]] + + def __init__(self, info: 'TypeInfo', items: List[str], + values: List[Optional[Expression]]) -> None: + self.info = info + self.items = items + self.values = values + + def accept(self, visitor: ExpressionVisitor[T]) -> T: + return visitor.visit_enum_call_expr(self) + + class PromoteExpr(Expression): """Ducktype class decorator expression _promote(...).""" diff --git a/mypy/semanal.py b/mypy/semanal.py index ea5a350cdfd5..0ff9d427c62c 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -64,7 +64,7 @@ YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SymbolNode, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, - IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, + IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES, ARG_OPT, nongen_builtins, collections_type_aliases, get_member_expr_fullname, ) @@ -1498,6 +1498,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: self.process_typevar_declaration(s) self.process_namedtuple_definition(s) self.process_typeddict_definition(s) + self.process_enum_call(s) if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and @@ -2327,6 +2328,139 @@ def is_classvar(self, typ: Type) -> bool: def fail_invalid_classvar(self, context: Context) -> None: self.fail('ClassVar can only be used for assignments in class body', context) + def process_enum_call(self, s: AssignmentStmt) -> None: + """Check if s defines an Enum; if yes, store the definition in symbol table.""" + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + lvalue = s.lvalues[0] + name = lvalue.name + enum_call = self.check_enum_call(s.rvalue, name) + if enum_call is None: + return + # Yes, it's a valid Enum definition. Add it to the symbol table. + node = self.lookup(name, s) + if node: + node.kind = GDEF # TODO locally defined Enum + node.node = enum_call + + def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines an Enum. + + Example: + + A = enum.Enum('A', 'foo bar') + + is equivalent to: + + class A(enum.Enum): + foo = 1 + bar = 2 + """ + if not isinstance(node, CallExpr): + return None + call = node + callee = call.callee + if not isinstance(callee, RefExpr): + return None + fullname = callee.fullname + if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): + return None + items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1]) + if not ok: + # Error. Construct dummy return value. + return self.build_enum_call_typeinfo('Enum', [], fullname) + name = cast(StrExpr, call.args[0]).value + if name != var_name or self.is_func_scope(): + # Give it a unique name derived from the line number. + name += '@' + str(call.line) + info = self.build_enum_call_typeinfo(name, items, fullname) + # Store it as a global just in case it would remain anonymous. + # (Or in the nearest class if there is one.) + stnode = SymbolTableNode(GDEF, info, self.cur_mod_id) + if self.type: + self.type.names[name] = stnode + else: + self.globals[name] = stnode + call.analyzed = EnumCallExpr(info, items, values) + call.analyzed.set_line(call.line, call.column) + return info + + def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo: + base = self.named_type_or_none(fullname) + assert base is not None + info = self.basic_new_typeinfo(name, base) + info.is_enum = True + for item in items: + var = Var(item) + var.info = info + var.is_property = True + info.names[item] = SymbolTableNode(MDEF, var) + return info + + def parse_enum_call_args(self, call: CallExpr, + class_name: str) -> Tuple[List[str], + List[Optional[Expression]], bool]: + args = call.args + if len(args) < 2: + return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call) + if len(args) > 2: + return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call) + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call) + if not isinstance(args[0], (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() expects a string literal as the first argument" % class_name, call) + items = [] + values = [] # type: List[Optional[Expression]] + if isinstance(args[1], (StrExpr, UnicodeExpr)): + fields = args[1].value + for field in fields.replace(',', ' ').split(): + items.append(field) + elif isinstance(args[1], (TupleExpr, ListExpr)): + seq_items = args[1].items + if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): + items = [cast(StrExpr, seq_item).value for seq_item in seq_items] + elif all(isinstance(seq_item, (TupleExpr, ListExpr)) + and len(seq_item.items) == 2 + and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) + for seq_item in seq_items): + for seq_item in seq_items: + assert isinstance(seq_item, (TupleExpr, ListExpr)) + name, value = seq_item.items + assert isinstance(name, (StrExpr, UnicodeExpr)) + items.append(name.value) + values.append(value) + else: + return self.fail_enum_call_arg( + "%s() with tuple or list expects strings or (name, value) pairs" % + class_name, + call) + elif isinstance(args[1], DictExpr): + for key, value in args[1].items: + if not isinstance(key, (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() with dict literal requires string literals" % class_name, call) + items.append(key.value) + values.append(value) + else: + # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}? + return self.fail_enum_call_arg( + "%s() expects a string, tuple, list or dict literal as the second argument" % + class_name, + call) + if len(items) == 0: + return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call) + if not values: + values = [None] * len(items) + assert len(items) == len(values) + return items, values, True + + def fail_enum_call_arg(self, message: str, + context: Context) -> Tuple[List[str], + List[Optional[Expression]], bool]: + self.fail(message, context) + return [], [], False + def visit_decorator(self, dec: Decorator) -> None: for d in dec.decorators: d.accept(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index 61411f9b13d8..b8bda6d0224c 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -431,6 +431,9 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> str: o.info.name(), o.info.tuple_type) + def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> str: + return 'EnumCallExpr:{}({}, {})'.format(o.line, o.info.name(), o.items) + def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> str: return 'TypedDictExpr:{}({})'.format(o.line, o.info.name()) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index dde17f2ab2f6..f0a3bcba8caa 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -19,7 +19,8 @@ ComparisonExpr, TempNode, StarExpr, Statement, Expression, YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, OverloadPart + YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, + OverloadPart, EnumCallExpr, ) from mypy.types import Type, FunctionLike from mypy.traverser import TraverserVisitor @@ -486,6 +487,9 @@ def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr: def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr: return NamedTupleExpr(node.info) + def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr: + return EnumCallExpr(node.info, node.items, node.values) + def visit_typeddict_expr(self, node: TypedDictExpr) -> Node: return TypedDictExpr(node.info) diff --git a/mypy/visitor.py b/mypy/visitor.py index ab691b04be34..6bd7520f4fb6 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -156,6 +156,10 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: pass + @abstractmethod + def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T: + pass + @abstractmethod def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T: pass @@ -514,6 +518,9 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: pass + def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T: + pass + def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T: pass diff --git a/test-data/unit/pythoneval-enum.test b/test-data/unit/pythoneval-enum.test index dfbd8f73cef0..1f4a34f91f1b 100644 --- a/test-data/unit/pythoneval-enum.test +++ b/test-data/unit/pythoneval-enum.test @@ -157,3 +157,91 @@ class E(IntEnum): E[1] [out] _program.py:4: error: Enum index should be a string (actual index type "int") + +[case testFunctionalEnumString] +from enum import Enum, IntEnum +E = Enum('E', 'foo bar') +I = IntEnum('I', ' bar, baz ') +reveal_type(E.foo) +reveal_type(E.bar.value) +reveal_type(I.bar) +reveal_type(I.baz.value) +[out] +_program.py:4: error: Revealed type is '_testFunctionalEnumString.E' +_program.py:5: error: Revealed type is 'Any' +_program.py:6: error: Revealed type is '_testFunctionalEnumString.I' +_program.py:7: error: Revealed type is 'builtins.int' + +[case testFunctionalEnumListOfStrings] +from enum import Enum, IntEnum +E = Enum('E', ('foo', 'bar')) +F = IntEnum('F', ['bar', 'baz']) +reveal_type(E.foo) +reveal_type(F.baz) +[out] +_program.py:4: error: Revealed type is '_testFunctionalEnumListOfStrings.E' +_program.py:5: error: Revealed type is '_testFunctionalEnumListOfStrings.F' + +[case testFunctionalEnumListOfPairs] +from enum import Enum, IntEnum +E = Enum('E', [('foo', 1), ['bar', 2]]) +F = IntEnum('F', (['bar', 1], ('baz', 2))) +reveal_type(E.foo) +reveal_type(F.baz) +reveal_type(E.foo.value) +reveal_type(F.bar.name) +[out] +_program.py:4: error: Revealed type is '_testFunctionalEnumListOfPairs.E' +_program.py:5: error: Revealed type is '_testFunctionalEnumListOfPairs.F' +_program.py:6: error: Revealed type is 'Any' +_program.py:7: error: Revealed type is 'builtins.str' + +[case testFunctionalEnumDict] +from enum import Enum, IntEnum +E = Enum('E', {'foo': 1, 'bar': 2}) +F = IntEnum('F', {'bar': 1, 'baz': 2}) +reveal_type(E.foo) +reveal_type(F.baz) +reveal_type(E.foo.value) +reveal_type(F.bar.name) +[out] +_program.py:4: error: Revealed type is '_testFunctionalEnumDict.E' +_program.py:5: error: Revealed type is '_testFunctionalEnumDict.F' +_program.py:6: error: Revealed type is 'Any' +_program.py:7: error: Revealed type is 'builtins.str' + +[case testFunctionalEnumErrors] +from enum import Enum, IntEnum +A = Enum('A') +B = Enum('B', 42) +C = Enum('C', 'a b', 'x') +D = Enum('D', foo) +bar = 'x y z' +E = Enum('E', bar) +I = IntEnum('I') +J = IntEnum('I', 42) +K = IntEnum('I', 'p q', 'z') +L = Enum('L', ' ') +M = Enum('M', ()) +N = IntEnum('M', []) +P = Enum('P', [42]) +Q = Enum('Q', [('a', 42, 0)]) +R = IntEnum('R', [[0, 42]]) +S = Enum('S', {1: 1}) +[out] +_program.py:2: error: Too few arguments for Enum() +_program.py:3: error: Enum() expects a string, tuple, list or dict literal as the second argument +_program.py:4: error: Too many arguments for Enum() +_program.py:5: error: Enum() expects a string, tuple, list or dict literal as the second argument +_program.py:5: error: Name 'foo' is not defined +_program.py:7: error: Enum() expects a string, tuple, list or dict literal as the second argument +_program.py:8: error: Too few arguments for IntEnum() +_program.py:9: error: IntEnum() expects a string, tuple, list or dict literal as the second argument +_program.py:10: error: Too many arguments for IntEnum() +_program.py:11: error: Enum() needs at least one item +_program.py:12: error: Enum() needs at least one item +_program.py:13: error: IntEnum() needs at least one item +_program.py:14: error: Enum() with tuple or list expects strings or (name, value) pairs +_program.py:15: error: Enum() with tuple or list expects strings or (name, value) pairs +_program.py:16: error: IntEnum() with tuple or list expects strings or (name, value) pairs +_program.py:17: error: Enum() with dict literal requires string literals