diff --git a/mypy/semanal.py b/mypy/semanal.py index 71a8323be292..5fa30fc55b1a 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -564,6 +564,8 @@ def check_function_signature(self, fdef: FuncItem) -> None: def visit_class_def(self, defn: ClassDef) -> None: self.clean_up_bases_and_infer_type_variables(defn) + if self.analyze_typeddict_classdef(defn): + return if self.analyze_namedtuple_classdef(defn): return self.setup_class_def_analysis(defn) @@ -944,6 +946,101 @@ def bind_class_type_variables_in_symbol_table( nodes.append(node) return nodes + def is_typeddict(self, expr: Expression) -> bool: + return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and + expr.node.typeddict_type is not None) + + def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: + # special case for TypedDict + possible = False + for base_expr in defn.base_type_exprs: + if isinstance(base_expr, RefExpr): + base_expr.accept(self) + if (base_expr.fullname == 'mypy_extensions.TypedDict' or + self.is_typeddict(base_expr)): + possible = True + if possible: + node = self.lookup(defn.name, defn) + if node is not None: + node.kind = GDEF # TODO in process_namedtuple_definition also applies here + if (len(defn.base_type_exprs) == 1 and + isinstance(defn.base_type_exprs[0], RefExpr) and + defn.base_type_exprs[0].fullname == 'mypy_extensions.TypedDict'): + # Building a new TypedDict + fields, types = self.check_typeddict_classdef(defn) + node.node = self.build_typeddict_typeinfo(defn.name, fields, types) + return True + # Extending/merging existing TypedDicts + if any(not isinstance(expr, RefExpr) or + expr.fullname != 'mypy_extensions.TypedDict' and + not self.is_typeddict(expr) for expr in defn.base_type_exprs): + self.fail("All bases of a new TypedDict must be TypedDict types", defn) + typeddict_bases = list(filter(self.is_typeddict, defn.base_type_exprs)) + newfields = [] # type: List[str] + newtypes = [] # type: List[Type] + tpdict = None # type: OrderedDict[str, Type] + for base in typeddict_bases: + assert isinstance(base, RefExpr) + assert isinstance(base.node, TypeInfo) + assert isinstance(base.node.typeddict_type, TypedDictType) + tpdict = base.node.typeddict_type.items + newdict = tpdict.copy() + for key in tpdict: + if key in newfields: + self.fail('Cannot overwrite TypedDict field "{}" while merging' + .format(key), defn) + newdict.pop(key) + newfields.extend(newdict.keys()) + newtypes.extend(newdict.values()) + fields, types = self.check_typeddict_classdef(defn, newfields) + newfields.extend(fields) + newtypes.extend(types) + node.node = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) + return True + return False + + def check_typeddict_classdef(self, defn: ClassDef, + oldfields: List[str] = None) -> Tuple[List[str], List[Type]]: + TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' + 'expected "field_name: field_type"') + if self.options.python_version < (3, 6): + self.fail('TypedDict class syntax is only supported in Python 3.6', defn) + return [], [] + fields = [] # type: List[str] + types = [] # type: List[Type] + for stmt in defn.defs.body: + if not isinstance(stmt, AssignmentStmt): + # Still allow pass or ... (for empty TypedDict's). + if (not isinstance(stmt, PassStmt) and + not (isinstance(stmt, ExpressionStmt) and + isinstance(stmt.expr, EllipsisExpr))): + self.fail(TPDICT_CLASS_ERROR, stmt) + elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): + # An assignment, but an invalid one. + self.fail(TPDICT_CLASS_ERROR, stmt) + else: + name = stmt.lvalues[0].name + if name in (oldfields or []): + self.fail('Cannot overwrite TypedDict field "{}" while extending' + .format(name), stmt) + continue + if name in fields: + self.fail('Duplicate TypedDict field "{}"'.format(name), stmt) + continue + # Append name and type in this case... + fields.append(name) + types.append(AnyType() if stmt.type is None else self.anal_type(stmt.type)) + # ...despite possible minor failures that allow further analyzis. + if name.startswith('_'): + self.fail('TypedDict field name cannot start with an underscore: {}' + .format(name), stmt) + if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: + self.fail(TPDICT_CLASS_ERROR, stmt) + elif not isinstance(stmt.rvalue, TempNode): + # x: int assigns rvalue to TempNode(AnyType()) + self.fail('Right hand side values are not supported in TypedDict', stmt) + return fields, types + def visit_import(self, i: Import) -> None: for id, as_id in i.ids: if as_id is not None: diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 424c8b2b84e0..4f9c66c5b4fc 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -63,6 +63,151 @@ p = Point(x='meaning_of_life', y=1337) # E: Incompatible types (expression has [builtins fixtures/dict.pyi] +-- Define TypedDict (Class syntax) + +[case testCanCreateTypedDictWithClass] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class Point(TypedDict): + x: int + y: int + +p = Point(x=42, y=1337) +reveal_type(p) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.int, _fallback=typing.Mapping[builtins.str, builtins.int])' +[builtins fixtures/dict.pyi] + +[case testCanCreateTypedDictWithSubclass] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class Point1D(TypedDict): + x: int +class Point2D(Point1D): + y: int +r: Point1D +p: Point2D +reveal_type(r) # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=__main__.Point1D)' +reveal_type(p) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.int, _fallback=__main__.Point2D)' +[builtins fixtures/dict.pyi] + +[case testCanCreateTypedDictWithSubclass2] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class Point1D(TypedDict): + x: int +class Point2D(TypedDict, Point1D): # We also allow to include TypedDict in bases, it is simply ignored at runtime + y: int + +p: Point2D +reveal_type(p) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.int, _fallback=__main__.Point2D)' +[builtins fixtures/dict.pyi] + +[case testCanCreateTypedDictClassEmpty] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class EmptyDict(TypedDict): + pass + +p = EmptyDict() +reveal_type(p) # E: Revealed type is 'TypedDict(_fallback=typing.Mapping[builtins.str, builtins.None])' +[builtins fixtures/dict.pyi] + + +-- Define TypedDict (Class syntax errors) + +[case testCanCreateTypedDictWithClassOldVersion] +# flags: --python-version 3.5 +from mypy_extensions import TypedDict + +class Point(TypedDict): # E: TypedDict class syntax is only supported in Python 3.6 + pass +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictWithClassOtherBases] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class A: pass + +class Point1D(TypedDict, A): # E: All bases of a new TypedDict must be TypedDict types + x: int +class Point2D(Point1D, A): # E: All bases of a new TypedDict must be TypedDict types + y: int + +p: Point2D +reveal_type(p) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.int, _fallback=__main__.Point2D)' +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictWithClassWithOtherStuff] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class Point(TypedDict): + x: int + y: int = 1 # E: Right hand side values are not supported in TypedDict + def f(): pass # E: Invalid statement in TypedDict definition; expected "field_name: field_type" + z = int # E: Invalid statement in TypedDict definition; expected "field_name: field_type" + +p = Point(x=42, y=1337, z='whatever') +reveal_type(p) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.int, z=builtins.str, _fallback=typing.Mapping[builtins.str, builtins.object])' +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictWithClassUnderscores] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class Point(TypedDict): + x: int + _y: int # E: TypedDict field name cannot start with an underscore: _y + +p: Point +reveal_type(p) # E: Revealed type is 'TypedDict(x=builtins.int, _y=builtins.int, _fallback=__main__.Point)' +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictWithClassOverwriting] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class Bad(TypedDict): + x: int + x: str # E: Duplicate TypedDict field "x" + +b: Bad +reveal_type(b) # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=__main__.Bad)' +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictWithClassOverwriting2] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class Point1(TypedDict): + x: int +class Point2(TypedDict): + x: float +class Bad(Point1, Point2): # E: Cannot overwrite TypedDict field "x" while merging + pass + +b: Bad +reveal_type(b) # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=__main__.Bad)' +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictWithClassOverwriting2] +# flags: --python-version 3.6 +from mypy_extensions import TypedDict + +class Point1(TypedDict): + x: int +class Point2(Point1): + x: float # E: Cannot overwrite TypedDict field "x" while extending + +p2: Point2 +reveal_type(p2) # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=__main__.Point2)' +[builtins fixtures/dict.pyi] + + -- Subtyping [case testCanConvertTypedDictToItself]