From df22528709c08ef4a3b54260d5d8adc1a64a8a57 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 7 Apr 2017 12:20:43 +0200 Subject: [PATCH 01/22] Stubgen using python3.6 variable syntax for stubs --- mypy/stubgen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 3123b18331eb..dc9d4998e0a8 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -487,7 +487,7 @@ def get_init(self, lvalue: str, rvalue: Expression) -> Optional[str]: return None self._vars[-1].append(lvalue) typename = self.get_str_type_of_node(rvalue) - return '%s%s = ... # type: %s\n' % (self._indent, lvalue, typename) + return '%s%s: %s\n' % (self._indent, lvalue, typename) def add(self, string: str) -> None: """Add text to generated stub.""" From cb5296defee0d834d03b0b409540659e52040cac Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 7 Apr 2017 13:21:24 +0200 Subject: [PATCH 02/22] Adjust stubgen tests to new output --- test-data/unit/stubgen.test | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 2795bb34538a..d69888b48dd9 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -81,20 +81,20 @@ def g(): ... [case testVariable] x = 1 [out] -x = ... # type: int +x: int [case testMultipleVariable] x = y = 1 [out] -x = ... # type: int -y = ... # type: int +x: int +y: int [case testClassVariable] class C: x = 1 [out] class C: - x = ... # type: int + x: int [case testSelfAssignment] class C: @@ -103,7 +103,7 @@ class C: x.y = 2 [out] class C: - x = ... # type: int + x: int def __init__(self) -> None: ... [case testSelfAndClassBodyAssignment] @@ -114,10 +114,10 @@ class C: self.x = 1 self.x = 1 [out] -x = ... # type: int +x: int class C: - x = ... # type: int + x: int def __init__(self) -> None: ... [case testEmptyClass] @@ -195,16 +195,16 @@ x, y = 1, 2 [out] from typing import Any -x = ... # type: Any -y = ... # type: Any +x: Any +y: Any [case testMultipleAssignment2] [x, y] = 1, 2 [out] from typing import Any -x = ... # type: Any -y = ... # type: Any +x: Any +y: Any [case testKeywordOnlyArg] def f(x, *, y=1): ... @@ -333,7 +333,7 @@ class A: def f(self): ... [out] class A: - x = ... # type: int + x: int def f(self): ... [case testSkipMultiplePrivateDefs] @@ -369,7 +369,7 @@ x = 1 [out] from re import match as match, sub as sub -x = ... # type: int +x: int [case testExportModule_import] import re @@ -379,7 +379,7 @@ y = 2 [out] import re as re -x = ... # type: int +x: int [case testExportModuleAs_import] import re as rex @@ -389,7 +389,7 @@ y = 2 [out] import re as rex -x = ... # type: int +x: int [case testExportModuleInPackage_import] import urllib.parse as p @@ -418,7 +418,7 @@ class C: [out] def f(): ... -x = ... # type: int +x: int class C: def g(self): ... @@ -577,7 +577,7 @@ class A: [out] class A: class B: - x = ... # type: int + x: int def f(self): ... def g(self): ... @@ -625,7 +625,7 @@ class A: from typing import Any, Optional class A: - x = ... # type: Any + x: Any def __init__(self, a: Optional[Any] = ...) -> None: ... -- More features/fixes: From da25e1c674f8c16883e435057272dbad420a5788 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 7 Apr 2017 17:30:40 +0200 Subject: [PATCH 03/22] Added annotation pass-through for variable annotations --- mypy/stubgen.py | 49 +++++++++++++++++++++++++++++++------ test-data/unit/stubgen.test | 31 +++++++++++++++++++++++ 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index dc9d4998e0a8..fda7aec7b63d 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -35,6 +35,7 @@ - we don't seem to always detect properties ('closed' in 'io', for example) """ +import builtins import glob import importlib import json @@ -63,6 +64,7 @@ from mypy.stubgenc import parse_all_signatures, find_unique_signatures, generate_stub_for_c_module from mypy.stubutil import is_c_module, write_header from mypy.options import Options as MypyOptions +from mypy.types import Type, TypeStrVisitor, UnboundType, NoneTyp, TupleType Options = NamedTuple('Options', [('pyversion', Tuple[int, int]), @@ -230,6 +232,30 @@ def generate_stub(path: str, NOT_IN_ALL = 'NOT_IN_ALL' +class AnnotationPrinter(TypeStrVisitor): + + def __init__(self, stubgen: 'StubGenerator') -> None: + super().__init__() + self.stubgen = stubgen + self.requires_quotes = False + + def visit_unbound_type(self, t: UnboundType)-> str: + s = t.name + base = s.split('.')[0] + # if the name is not defined, assume a forward reference + if (not self.requires_quotes and + base not in dir(builtins) and + not any(base in vs for vs in self.stubgen._vars)): + self.requires_quotes = True + + if t.args != []: + s += '[{}]'.format(self.list_str(t.args)) + return s + + def visit_none_type(self, t: NoneTyp) -> str: + return "None" + + class StubGenerator(mypy.traverser.TraverserVisitor): def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int], include_private: bool = False) -> None: @@ -365,17 +391,20 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: assert isinstance(o.rvalue, CallExpr) self.process_namedtuple(lvalue, o.rvalue) continue - if isinstance(lvalue, TupleExpr): - items = lvalue.items - elif isinstance(lvalue, ListExpr): + if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): items = lvalue.items + if isinstance(o.type, TupleType): + annotations = o.type.items + else: + annotations = [None] * len(items) else: items = [lvalue] + annotations = [o.type] sep = False found = False - for item in items: + for item, annotation in zip(items, annotations): if isinstance(item, NameExpr): - init = self.get_init(item.name, o.rvalue) + init = self.get_init(item.name, o.rvalue, annotation) if init: found = True if not sep and not self._indent and \ @@ -474,7 +503,7 @@ def visit_import(self, o: Import) -> None: self.add_import_line('import %s as %s\n' % (id, target_name)) self.record_name(target_name) - def get_init(self, lvalue: str, rvalue: Expression) -> Optional[str]: + def get_init(self, lvalue: str, rvalue: Expression, annotation: Type=None) -> Optional[str]: """Return initializer for a variable. Return None if we've generated one already or if the variable is internal. @@ -486,7 +515,13 @@ def get_init(self, lvalue: str, rvalue: Expression) -> Optional[str]: if self.is_private_name(lvalue) or self.is_not_in_all(lvalue): return None self._vars[-1].append(lvalue) - typename = self.get_str_type_of_node(rvalue) + if annotation is not None: + printer = AnnotationPrinter(self) + typename = annotation.accept(printer) + if printer.requires_quotes: + typename = "'{}'".format(typename) + else: + typename = self.get_str_type_of_node(rvalue) return '%s%s: %s\n' % (self._indent, lvalue, typename) def add(self, string: str) -> None: diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index d69888b48dd9..6ce165079df4 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -83,6 +83,31 @@ x = 1 [out] x: int +[case testAnnotatedVariable] +x: int = 1 +[out] +x: int + +[case testAnnotatedVariableGeneric] +x: Foo[int, str] = ... +[out] +x: 'Foo[int, str]' + +[case testAnnotatedVariableOldSyntax] +x = 1 # type: int +[out] +x: int + +[case testAnnotatedVariableNone] +x: None +[out] +x: None + +[case testAnnotatedVariableNoneOldSyntax] +x = None # type: None +[out] +x: None + [case testMultipleVariable] x = y = 1 [out] @@ -198,6 +223,12 @@ from typing import Any x: Any y: Any +[case testMultipleAssignmentAnnotated] +x, y = 1, "2" # type: int, str +[out] +x: int +y: str + [case testMultipleAssignment2] [x, y] = 1, 2 [out] From 96de85ce83750650c07866962ff592c64042d763 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 7 Apr 2017 21:29:28 +0200 Subject: [PATCH 04/22] Refactor handling on import/names and add imports for annotations --- mypy/stubgen.py | 126 +++++++++++++++++++++--------------- test-data/unit/stubgen.test | 66 +++++++++++++++++-- 2 files changed, 132 insertions(+), 60 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index fda7aec7b63d..be3925ec1719 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -45,9 +45,10 @@ import sys import textwrap import traceback +from collections import defaultdict from typing import ( - Any, List, Dict, Tuple, Iterable, Iterator, Optional, NamedTuple, Set, Union, cast + Any, List, Dict, Tuple, Iterable, Iterator, Mapping, Optional, NamedTuple, Set, Union, cast ) import mypy.build @@ -242,6 +243,7 @@ def __init__(self, stubgen: 'StubGenerator') -> None: def visit_unbound_type(self, t: UnboundType)-> str: s = t.name base = s.split('.')[0] + self.stubgen.import_tracker.require_name(base) # if the name is not defined, assume a forward reference if (not self.requires_quotes and base not in dir(builtins) and @@ -256,27 +258,73 @@ def visit_none_type(self, t: NoneTyp) -> str: return "None" +class ImportTracker: + + def __init__(self) -> None: + self.module_for = {} # type: Dict[str, Optional[str]] + self.direct_imports = {} # type: Dict[str, str] + self.reverse_alias = {} # type: Dict[str, str] + self.required_names = set() # type: Set[str] + + def add_import_from(self, module: str, names: List[Tuple[str, Optional[str]]]) -> None: + for name, alias in names: + self.module_for[alias or name] = module + if alias: + self.reverse_alias[alias] = name + + def add_import(self, module: str, alias: str=None) -> None: + name = module.split('.')[0] + self.module_for[alias or name] = None + self.direct_imports[name] = module + if alias: + self.reverse_alias[alias] = name + + def require_name(self, name: str) -> None: + self.required_names.add(name.split('.')[0]) + + def import_lines(self) -> List[str]: + result = [] + module_map = defaultdict(list) # type: Mapping[str, List[str]] + for name in self.required_names: + if name not in self.module_for: + continue + m = self.module_for[name] + if m is not None: + if name in self.reverse_alias: + name = '{} as {}'.format(self.reverse_alias[name], name) + module_map[m].append(name) + else: + if name in self.reverse_alias: + name, alias = self.reverse_alias[name], name + result.append("import {} as {}\n".format(self.direct_imports[name], alias)) + else: + result.append("import {}\n".format(self.direct_imports[name])) + for module, names in module_map.items(): + result.append("from {} import {}\n".format(module, ', '.join(sorted(names)))) + return result + + class StubGenerator(mypy.traverser.TraverserVisitor): def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int], include_private: bool = False) -> None: self._all_ = _all_ self._output = [] # type: List[str] self._import_lines = [] # type: List[str] - self._imports = [] # type: List[str] self._indent = '' self._vars = [[]] # type: List[List[str]] self._state = EMPTY self._toplevel_names = [] # type: List[str] - self._classes = set() # type: Set[str] - self._base_classes = [] # type: List[str] self._pyversion = pyversion self._include_private = include_private + self.import_tracker = ImportTracker() + # Add imports that could be implicitly generated + self.import_tracker.add_import_from("collections", [("namedtuple", None)]) + self.import_tracker.add_import_from("typing", [("Any", None), ("Optional", None)]) + # Names in __all__ are required + for name in _all_ or (): + self.import_tracker.require_name(name) def visit_mypy_file(self, o: MypyFile) -> None: - self._classes = find_classes(o) - for node in o.defs: - if isinstance(node, ClassDef): - self._base_classes.extend(self.get_base_types(node)) super().visit_mypy_file(o) undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names] @@ -356,6 +404,8 @@ def visit_class_def(self, o: ClassDef) -> None: base_types = self.get_base_types(o) if base_types: self.add('(%s)' % ', '.join(base_types)) + for base in base_types: + self.import_tracker.require_name(base) self.add(':\n') n = len(self._output) self._indent += ' ' @@ -363,6 +413,7 @@ def visit_class_def(self, o: ClassDef) -> None: super().visit_class_def(o) self._indent = self._indent[:-4] self._vars.pop() + self._vars[-1].append(o.name) if len(self._output) == n: if self._state == EMPTY_CLASS and sep is not None: self._output[sep] = '' @@ -380,7 +431,6 @@ def get_base_types(self, cdef: ClassDef) -> List[str]: elif isinstance(base, MemberExpr): modname = get_qualified_name(base.expr) base_types.append('%s.%s' % (modname, base.name)) - self.add_import_line('import %s\n' % modname) return base_types def visit_assignment_stmt(self, o: AssignmentStmt) -> None: @@ -426,7 +476,7 @@ def is_namedtuple(self, expr: Expression) -> bool: (isinstance(callee, MemberExpr) and callee.name == 'namedtuple')) def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: - self.add_import_line('from collections import namedtuple\n') + self.import_tracker.require_name('namedtuple') if self._state != EMPTY: self.add('\n') name = repr(getattr(rvalue.args[0], 'value', '')) @@ -438,7 +488,6 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: else: items = '' self.add('%s = namedtuple(%s, %s)\n' % (lvalue.name, name, items)) - self._classes.add(lvalue.name) self._state = CLASS def visit_if_stmt(self, o: IfStmt) -> None: @@ -457,51 +506,35 @@ def visit_import_all(self, o: ImportAll) -> None: def visit_import_from(self, o: ImportFrom) -> None: exported_names = set() # type: Set[str] + self.import_tracker.add_import_from('.' * o.relative + o.id, o.names) + self._vars[-1].extend(alias or name for name, alias in o.names) + for name, alias in o.names: + self.record_name(alias or name) + if self._all_: # Include import froms that import names defined in __all__. names = [name for name, alias in o.names if name in self._all_ and alias is None] exported_names.update(names) - self.import_and_export_names(o.id, o.relative, names) else: # Include import from targets that import from a submodule of a package. if o.relative: sub_names = [name for name, alias in o.names if alias is None] exported_names.update(sub_names) - self.import_and_export_names(o.id, o.relative, sub_names) - # Import names used as base classes. - base_names = [(name, alias) for name, alias in o.names - if alias or name in self._base_classes and name not in exported_names] - if base_names: - imp_names = [] # type: List[str] - for name, alias in base_names: - if alias is not None and alias != name: - imp_names.append('%s as %s' % (name, alias)) - else: - imp_names.append(name) - self.add_import_line('from %s%s import %s\n' % ( - '.' * o.relative, o.id, ', '.join(imp_names))) - - def import_and_export_names(self, module_id: str, relative: int, names: Iterable[str]) -> None: - """Import names from a module and export them (via from ... import x as x).""" - if names and module_id: - full_module_name = '%s%s' % ('.' * relative, module_id) - imported_names = ', '.join(['%s as %s' % (name, name) for name in names]) - self.add_import_line('from %s import %s\n' % (full_module_name, imported_names)) - for name in names: - self.record_name(name) + if o.id: + for name in sub_names: + self.import_tracker.require_name(name) def visit_import(self, o: Import) -> None: for id, as_id in o.ids: + self.import_tracker.add_import(id, as_id) if as_id is None: target_name = id.split('.')[0] else: target_name = as_id - if self._all_ and target_name in self._all_ and (as_id is not None or - '.' not in id): - self.add_import_line('import %s as %s\n' % (id, target_name)) - self.record_name(target_name) + self._vars[-1].append(target_name) + self.record_name(target_name) def get_init(self, lvalue: str, rvalue: Expression, annotation: Type=None) -> Optional[str]: """Return initializer for a variable. @@ -533,8 +566,7 @@ def add_typing_import(self, name: str) -> None: The import will be internal to the stub. """ - if name not in self._imports: - self._imports.append(name) + self.import_tracker.require_name(name) def add_import_line(self, line: str) -> None: """Add a line of text to the import section, unless it's already there.""" @@ -544,10 +576,9 @@ def add_import_line(self, line: str) -> None: def output(self) -> str: """Return the text for the stub.""" imports = '' - if self._imports: - imports += 'from typing import %s\n' % ", ".join(sorted(self._imports)) if self._import_lines: imports += ''.join(self._import_lines) + imports += ''.join(self.import_tracker.import_lines()) if imports and self._output: imports += '\n' return imports + ''.join(self._output) @@ -626,17 +657,6 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: return results -def find_classes(node: MypyFile) -> Set[str]: - results = set() # type: Set[str] - - class ClassTraverser(mypy.traverser.TraverserVisitor): - def visit_class_def(self, o: ClassDef) -> None: - results.add(o.name) - - node.accept(ClassTraverser()) - return results - - def get_qualified_name(o: Expression) -> str: if isinstance(o, NameExpr): return o.name diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 6ce165079df4..f11041284c03 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -398,7 +398,7 @@ from re import match, search, sub __all__ = ['match', 'sub', 'x'] x = 1 [out] -from re import match as match, sub as sub +from re import match, sub x: int @@ -408,7 +408,7 @@ __all__ = ['re', 'x'] x = 1 y = 2 [out] -import re as re +import re x: int @@ -429,12 +429,17 @@ __all__ = ['p'] import urllib.parse as p [case testExportModuleInPackageUnsupported_import] -import urllib.parse __all__ = ['urllib'] [out] # Names in __all__ with no definition: # urllib +[case testExportModuleInPackage_import] +import urllib.parse +__all__ = ['urllib'] +[out] +import urllib.parse + [case testRelativeImportAll] from .x import * [out] @@ -525,7 +530,6 @@ def f(): ... X = _namedtuple('X', 'a b') def g(): ... [out] -from collections import namedtuple as _namedtuple from collections import namedtuple def f(): ... @@ -615,12 +619,12 @@ class A: [case testExportViaRelativeImport] from .api import get [out] -from .api import get as get +from .api import get [case testExportViaRelativePackageImport] from .packages.urllib3.contrib import parse [out] -from .packages.urllib3.contrib import parse as parse +from .packages.urllib3.contrib import parse [case testNoExportViaRelativeImport] from . import get @@ -631,7 +635,7 @@ from .x import X class A(X): pass [out] -from .x import X as X +from .x import X class A(X): ... @@ -659,5 +663,53 @@ class A: x: Any def __init__(self, a: Optional[Any] = ...) -> None: ... +[case testAnnotationImportsFrom] +import foo +from collection import defaultdict +x: defaultdict + +[out] +from collection import defaultdict + +x: defaultdict + +[case testAnnotationImports] +import foo +import collection +x: collection.defaultdict + +[out] +import collection + +x: collection.defaultdict + + +[case testAnnotationImports] +from typing import List +import collection +x: List[collection.defaultdict] + +[out] +import collection +from typing import List + +x: List[collection.defaultdict] + + +[case testAnnotationFwRefs] +x: C + +class C: + attr: C + +y: C +[out] +x: 'C' + +class C: + attr: 'C' + +y: C + -- More features/fixes: -- do not export deleted names From fb422581468edbf9f683ea4351e63926d0e33381 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Sat, 8 Apr 2017 12:02:20 +0200 Subject: [PATCH 05/22] Preserve function annotations This also changes the style of initializer in generated stubs to be closer to PEP8 style, this implied a few test changes --- mypy/stubgen.py | 41 +++++++++++++++++++++++++++---------- test-data/unit/stubgen.test | 24 +++++++++++----------- 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index be3925ec1719..7d418b317acf 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -65,7 +65,7 @@ from mypy.stubgenc import parse_all_signatures, find_unique_signatures, generate_stub_for_c_module from mypy.stubutil import is_c_module, write_header from mypy.options import Options as MypyOptions -from mypy.types import Type, TypeStrVisitor, UnboundType, NoneTyp, TupleType +from mypy.types import Type, TypeStrVisitor, AnyType, UnboundType, NoneTyp, TupleType Options = NamedTuple('Options', [('pyversion', Tuple[int, int]), @@ -357,21 +357,36 @@ def visit_func_def(self, o: FuncDef) -> None: var = arg_.variable kind = arg_.kind name = var.name() + annotated_type = o.type.arg_types[i] if o.type else None + if annotated_type and not (i == 0 and name == 'self' and isinstance(annotated_type, AnyType)): + annotation = ": {}".format(self.print_annotation(annotated_type)) + else: + annotation = "" init_stmt = arg_.initialization_statement if init_stmt: + if isinstance(init_stmt.rvalue, NameExpr) and init_stmt.rvalue.name == 'None': + initializer = 'None' + else: + initializer = '...' if kind in (ARG_NAMED, ARG_NAMED_OPT) and '*' not in args: args.append('*') - typename = self.get_str_type_of_node(init_stmt.rvalue, True) - arg = '{}: {} = ...'.format(name, typename) + if not annotation: + typename = self.get_str_type_of_node(init_stmt.rvalue, True) + annotation = ': {}=...'.format(typename) + else: + annotation += '={}'.format(initializer) + arg = name + annotation elif kind == ARG_STAR: - arg = '*%s' % name + arg = '*%s%s' % (name, annotation) elif kind == ARG_STAR2: - arg = '**%s' % name + arg = '**%s%s' % (name, annotation) else: - arg = name + arg = name + annotation args.append(arg) retname = None - if o.name() == '__init__': + if o.type: + retname = self.print_annotation(o.type.ret_type) + elif o.name() == '__init__': retname = 'None' retfield = '' if retname is not None: @@ -549,10 +564,7 @@ def get_init(self, lvalue: str, rvalue: Expression, annotation: Type=None) -> Op return None self._vars[-1].append(lvalue) if annotation is not None: - printer = AnnotationPrinter(self) - typename = annotation.accept(printer) - if printer.requires_quotes: - typename = "'{}'".format(typename) + typename = self.print_annotation(annotation) else: typename = self.get_str_type_of_node(rvalue) return '%s%s: %s\n' % (self._indent, lvalue, typename) @@ -625,6 +637,13 @@ def get_str_type_of_node(self, rvalue: Expression, self.add_typing_import('Any') return 'Any' + def print_annotation(self, t: Type) -> str: + printer = AnnotationPrinter(self) + typename = t.accept(printer) + if printer.requires_quotes: + typename = "'{}'".format(typename) + return typename + def is_top_level(self) -> bool: """Are we processing the top level of a file?""" return self._indent == '' diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index f11041284c03..525ed5d4c6d2 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -20,42 +20,42 @@ def g(arg): ... def f(a, b=2): ... def g(b=-1, c=0): ... [out] -def f(a, b: int = ...): ... -def g(b: int = ..., c: int = ...): ... +def f(a, b: int=...): ... +def g(b: int=..., c: int=...): ... [case testDefaultArgNone] def f(x=None): ... [out] from typing import Any, Optional -def f(x: Optional[Any] = ...): ... +def f(x: Optional[Any]=...): ... [case testDefaultArgBool] def f(x=True, y=False): ... [out] -def f(x: bool = ..., y: bool = ...): ... +def f(x: bool=..., y: bool=...): ... [case testDefaultArgStr] def f(x='foo'): ... [out] -def f(x: str = ...): ... +def f(x: str=...): ... [case testDefaultArgBytes] def f(x=b'foo'): ... [out] -def f(x: bytes = ...): ... +def f(x: bytes=...): ... [case testDefaultArgFloat] def f(x=1.2): ... [out] -def f(x: float = ...): ... +def f(x: float=...): ... [case testDefaultArgOther] def f(x=ord): ... [out] from typing import Any -def f(x: Any = ...): ... +def f(x: Any=...): ... [case testVarArgs] def f(x, *y): ... @@ -241,8 +241,8 @@ y: Any def f(x, *, y=1): ... def g(x, *, y=1, z=2): ... [out] -def f(x, *, y: int = ...): ... -def g(x, *, y: int = ..., z: int = ...): ... +def f(x, *, y: int=...): ... +def g(x, *, y: int=..., z: int=...): ... [case testProperty] class A: @@ -654,14 +654,14 @@ def f(a): ... [case testInferOptionalOnlyFunc] class A: x = None - def __init__(self, a=None) -> None: + def method(self, a=None): self.x = [] [out] from typing import Any, Optional class A: x: Any - def __init__(self, a: Optional[Any] = ...) -> None: ... + def method(self, a: Optional[Any]=...): ... [case testAnnotationImportsFrom] import foo From 749b1acbf4a9f51facc0dfee4ce39ff6f8291273 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Thu, 13 Apr 2017 15:23:43 +0100 Subject: [PATCH 06/22] Add tests for preserving annotations --- test-data/unit/stubgen.test | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 525ed5d4c6d2..bdc45bbe269f 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -57,6 +57,16 @@ from typing import Any def f(x: Any=...): ... +[case testPreserveFunctionAnnotation] +def f(x: Foo) -> Bar: ... +[out] +def f(x: 'Foo') -> 'Bar': ... + +[case testPreserveVarAnnotation] +x: Foo +[out] +x: 'Foo' + [case testVarArgs] def f(x, *y): ... [out] From 6fbbf602a42b4b0d6fc26297a8a560c133c956e1 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Thu, 13 Apr 2017 15:39:59 +0100 Subject: [PATCH 07/22] Fix linter checks --- mypy/stubgen.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 7d418b317acf..45d15e18d039 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -65,7 +65,7 @@ from mypy.stubgenc import parse_all_signatures, find_unique_signatures, generate_stub_for_c_module from mypy.stubutil import is_c_module, write_header from mypy.options import Options as MypyOptions -from mypy.types import Type, TypeStrVisitor, AnyType, UnboundType, NoneTyp, TupleType +from mypy.types import Type, TypeStrVisitor, AnyType, CallableType, UnboundType, NoneTyp, TupleType Options = NamedTuple('Options', [('pyversion', Tuple[int, int]), @@ -357,8 +357,9 @@ def visit_func_def(self, o: FuncDef) -> None: var = arg_.variable kind = arg_.kind name = var.name() - annotated_type = o.type.arg_types[i] if o.type else None - if annotated_type and not (i == 0 and name == 'self' and isinstance(annotated_type, AnyType)): + annotated_type = o.type.arg_types[i] if isinstance(o.type, CallableType) else None + if annotated_type and not ( + i == 0 and name == 'self' and isinstance(annotated_type, AnyType)): annotation = ": {}".format(self.print_annotation(annotated_type)) else: annotation = "" @@ -384,7 +385,7 @@ def visit_func_def(self, o: FuncDef) -> None: arg = name + annotation args.append(arg) retname = None - if o.type: + if isinstance(o.type, CallableType): retname = self.print_annotation(o.type.ret_type) elif o.name() == '__init__': retname = 'None' From 4b98aa053612dcfab05e1f8060adc1f434308907 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Thu, 13 Apr 2017 18:33:19 +0100 Subject: [PATCH 08/22] Try to preserve typevars and type alias --- mypy/stubgen.py | 104 +++++++++++++++++++++++++++++++++--- test-data/unit/stubgen.test | 86 +++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 6 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 45d15e18d039..70371e1b6d8e 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -58,15 +58,16 @@ from mypy import defaults from mypy.nodes import ( Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, TupleExpr, - ListExpr, ComparisonExpr, CallExpr, ClassDef, MypyFile, Decorator, AssignmentStmt, + ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, + ClassDef, MypyFile, Decorator, AssignmentStmt, IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, - ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT, + ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT, ) from mypy.stubgenc import parse_all_signatures, find_unique_signatures, generate_stub_for_c_module from mypy.stubutil import is_c_module, write_header from mypy.options import Options as MypyOptions from mypy.types import Type, TypeStrVisitor, AnyType, CallableType, UnboundType, NoneTyp, TupleType - +from mypy.visitor import NodeVisitor Options = NamedTuple('Options', [('pyversion', Tuple[int, int]), ('no_import', bool), @@ -258,6 +259,50 @@ def visit_none_type(self, t: NoneTyp) -> str: return "None" +class AliasPrinter(NodeVisitor[str]): + + def __init__(self, stubgen: 'StubGenerator') -> None: + self.stubgen = stubgen + super().__init__() + + def visit_call_expr(self, node: CallExpr) -> str: + callee = node.callee.accept(self) + args = [] + for name, arg, kind in zip(node.arg_names, node.args, node.arg_kinds): + if kind == ARG_POS: + args.append(arg.accept(self)) + elif kind == ARG_STAR: + args.append('*' + arg.accept(self)) + elif kind == ARG_STAR2: + args.append('**' + arg.accept(self)) + elif kind == ARG_NAMED: + args.append('{}={}'.format(name, arg.accept(self))) + else: + raise ValueError("Unknown argument kind %d in call" % kind) + return "{}({})".format(callee, ", ".join(args)) + + def visit_name_expr(self, node: NameExpr) -> str: + self.stubgen.import_tracker.require_name(node.name) + return node.name + + def visit_str_expr(self, node: StrExpr) -> str: + return repr(node.value) + + def visit_index_expr(self, node: IndexExpr) -> str: + base = node.base.accept(self) + index = node.index.accept(self) + return "{}[{}]".format(base, index) + + def visit_tuple_expr(self, node: TupleExpr) -> str: + return ", ".join(n.accept(self) for n in node.items) + + def visit_list_expr(self, node: ListExpr) -> str: + return "[{}]".format(", ".join(n.accept(self) for n in node.items)) + + def visit_ellipsis(self, node: EllipsisExpr) -> str: + return "..." + + class ImportTracker: def __init__(self) -> None: @@ -285,7 +330,7 @@ def require_name(self, name: str) -> None: def import_lines(self) -> List[str]: result = [] module_map = defaultdict(list) # type: Mapping[str, List[str]] - for name in self.required_names: + for name in sorted(self.required_names): if name not in self.module_for: continue m = self.module_for[name] @@ -299,7 +344,7 @@ def import_lines(self) -> List[str]: result.append("import {} as {}\n".format(self.direct_imports[name], alias)) else: result.append("import {}\n".format(self.direct_imports[name])) - for module, names in module_map.items(): + for module, names in sorted(module_map.items()): result.append("from {} import {}\n".format(module, ', '.join(sorted(names)))) return result @@ -319,7 +364,8 @@ def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int], self.import_tracker = ImportTracker() # Add imports that could be implicitly generated self.import_tracker.add_import_from("collections", [("namedtuple", None)]) - self.import_tracker.add_import_from("typing", [("Any", None), ("Optional", None)]) + typing_imports = "Any Optional TypeVar".split() + self.import_tracker.add_import_from("typing", [(t, None) for t in typing_imports]) # Names in __all__ are required for name in _all_ or (): self.import_tracker.require_name(name) @@ -457,6 +503,10 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: assert isinstance(o.rvalue, CallExpr) self.process_namedtuple(lvalue, o.rvalue) continue + if (self.is_top_level() and + isinstance(lvalue, NameExpr) and self.is_type_expression(o.rvalue)): + self.process_typealias(lvalue, o.rvalue) + continue if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): items = lvalue.items if isinstance(o.type, TupleType): @@ -506,6 +556,48 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: self.add('%s = namedtuple(%s, %s)\n' % (lvalue.name, name, items)) self._state = CLASS + def is_type_expression(self, expr: Expression, top_level: bool=True) -> bool: + """Return True for things that look like type expressions + + Used to know if assignments look like typealiases + """ + # Assignment of TypeVar(...) are passed through + if (isinstance(expr, CallExpr) and + isinstance(expr.callee, NameExpr) and + expr.callee.name == 'TypeVar'): + return True + elif isinstance(expr, EllipsisExpr): + return not top_level + elif isinstance(expr, NameExpr): + if expr.name in ('True', 'False'): + return False + elif expr.name == 'None': + return not top_level + else: + return True + elif isinstance(expr, IndexExpr) and isinstance(expr.base, NameExpr): + if isinstance(expr.index, TupleExpr): + indices = expr.index.items + else: + indices = [expr.index] + if expr.base.name == 'Callable' and len(indices) == 2: + args, ret = indices + if isinstance(args, EllipsisExpr): + indices = [ret] + elif isinstance(args, ListExpr): + indices = args.items + [ret] + else: + return False + return all(self.is_type_expression(i, top_level=False) for i in indices) + else: + return False + + def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None: + p = AliasPrinter(self) + self.add("{} = {}\n".format(lvalue.name, rvalue.accept(p))) + self.record_name(lvalue.name) + self._vars[-1].append(lvalue.name) + def visit_if_stmt(self, o: IfStmt) -> None: # Ignore if __name__ == '__main__'. expr = o.expr[0] diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index bdc45bbe269f..5cc67991c2b0 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -721,5 +721,91 @@ class C: y: C +[case testTypeVarPreserved] +tv = TypeVar('tv') + +[out] +from typing import TypeVar + +tv = TypeVar('tv') + +[case testTypeVarArgsPreserved] +tv = TypeVar('tv', int, str) + +[out] +from typing import TypeVar + +tv = TypeVar('tv', int, str) + +[case testTypeVarNamedArgsPreserved] +tv = TypeVar('tv', bound=bool, covariant=True) + +[out] +from typing import TypeVar + +tv = TypeVar('tv', bound=bool, covariant=True) + +[case testTypeAliasPreserved] +alias = str + +[out] +alias = str + +[case testDeepTypeAliasPreserved] + +alias = Dict[str, List[str]] + +[out] +alias = Dict[str, List[str]] + +[case testEllipsisAliasPreserved] + +alias = Tuple[int, ...] + +[out] +alias = Tuple[int, ...] + +[case testCallableAliasPreserved] + +alias1 = Callable[..., int] +alias2 = Callable[[str, bool], None] + +[out] +alias1 = Callable[..., int] +alias2 = Callable[[str, bool], None] + +[case testAliasPullsImport] +from module import Container + +alias = Container[Any] + +[out] +from module import Container +from typing import Any + +alias = Container[Any] + +[case testAliasOnlyToplevel] +class Foo: + alias = str + +[out] +from typing import Any + +class Foo: + alias: Any + +[case testAliasExceptions] +noalias1 = None +noalias2 = ... +noalias3 = True + +[out] +from typing import Any + +noalias1: Any +noalias2: Any +noalias3: bool + -- More features/fixes: -- do not export deleted names From a85b3c2be4c7c6ba3dfbfe1ecc279aff2034fd5b Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Thu, 13 Apr 2017 19:17:19 +0100 Subject: [PATCH 09/22] Pass through generic base classes --- mypy/stubgen.py | 3 +++ test-data/unit/stubgen.test | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 70371e1b6d8e..b09067ed4983 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -493,6 +493,9 @@ def get_base_types(self, cdef: ClassDef) -> List[str]: elif isinstance(base, MemberExpr): modname = get_qualified_name(base.expr) base_types.append('%s.%s' % (modname, base.name)) + elif isinstance(base, IndexExpr): + p = AliasPrinter(self) + base_types.append(base.accept(p)) return base_types def visit_assignment_stmt(self, o: AssignmentStmt) -> None: diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 5cc67991c2b0..3e8c6dc3f4c2 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -596,6 +596,11 @@ from x import X as _X class A(_X): ... +[case testGenericClass] +class D(Generic[T]): ... +[out] +class D(Generic[T]): ... + [case testObjectBaseClass] class A(object): ... [out] From a4ec68622ab81de22ef252fd63e412bb9ed0bf80 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Thu, 20 Apr 2017 16:08:25 +0100 Subject: [PATCH 10/22] Use space around '=' on annotated function defaults PEP8 actually suggest this eception to the normal (no spaces) rule for annotated arguments --- mypy/stubgen.py | 2 +- test-data/unit/stubgen.test | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index b09067ed4983..6c174308c66d 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -419,7 +419,7 @@ def visit_func_def(self, o: FuncDef) -> None: args.append('*') if not annotation: typename = self.get_str_type_of_node(init_stmt.rvalue, True) - annotation = ': {}=...'.format(typename) + annotation = ': {} = ...'.format(typename) else: annotation += '={}'.format(initializer) arg = name + annotation diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 3e8c6dc3f4c2..0582980c8519 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -20,42 +20,42 @@ def g(arg): ... def f(a, b=2): ... def g(b=-1, c=0): ... [out] -def f(a, b: int=...): ... -def g(b: int=..., c: int=...): ... +def f(a, b: int = ...): ... +def g(b: int = ..., c: int = ...): ... [case testDefaultArgNone] def f(x=None): ... [out] from typing import Any, Optional -def f(x: Optional[Any]=...): ... +def f(x: Optional[Any] = ...): ... [case testDefaultArgBool] def f(x=True, y=False): ... [out] -def f(x: bool=..., y: bool=...): ... +def f(x: bool = ..., y: bool = ...): ... [case testDefaultArgStr] def f(x='foo'): ... [out] -def f(x: str=...): ... +def f(x: str = ...): ... [case testDefaultArgBytes] def f(x=b'foo'): ... [out] -def f(x: bytes=...): ... +def f(x: bytes = ...): ... [case testDefaultArgFloat] def f(x=1.2): ... [out] -def f(x: float=...): ... +def f(x: float = ...): ... [case testDefaultArgOther] def f(x=ord): ... [out] from typing import Any -def f(x: Any=...): ... +def f(x: Any = ...): ... [case testPreserveFunctionAnnotation] def f(x: Foo) -> Bar: ... @@ -251,8 +251,8 @@ y: Any def f(x, *, y=1): ... def g(x, *, y=1, z=2): ... [out] -def f(x, *, y: int=...): ... -def g(x, *, y: int=..., z: int=...): ... +def f(x, *, y: int = ...): ... +def g(x, *, y: int = ..., z: int = ...): ... [case testProperty] class A: @@ -676,7 +676,7 @@ from typing import Any, Optional class A: x: Any - def method(self, a: Optional[Any]=...): ... + def method(self, a: Optional[Any] = ...): ... [case testAnnotationImportsFrom] import foo From 530792e9d0146fd0c637649706393018a96232f8 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 1 Sep 2017 01:40:18 +0100 Subject: [PATCH 11/22] Fix broken tests after rebase --- mypy/stubgen.py | 7 ++++--- test-data/unit/stubgen.test | 10 +++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 6c174308c66d..06f2daf71eef 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -317,7 +317,7 @@ def add_import_from(self, module: str, names: List[Tuple[str, Optional[str]]]) - if alias: self.reverse_alias[alias] = name - def add_import(self, module: str, alias: str=None) -> None: + def add_import(self, module: str, alias: Optional[str]=None) -> None: name = module.split('.')[0] self.module_for[alias or name] = None self.direct_imports[name] = module @@ -513,7 +513,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): items = lvalue.items if isinstance(o.type, TupleType): - annotations = o.type.items + annotations = o.type.items # type: List[Optional[Type]] else: annotations = [None] * len(items) else: @@ -647,7 +647,8 @@ def visit_import(self, o: Import) -> None: self._vars[-1].append(target_name) self.record_name(target_name) - def get_init(self, lvalue: str, rvalue: Expression, annotation: Type=None) -> Optional[str]: + def get_init(self, lvalue: str, rvalue: Expression, + annotation: Optional[Type]=None) -> Optional[str]: """Return initializer for a variable. Return None if we've generated one already or if the variable is internal. diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 0582980c8519..b1ecacb9e80c 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -201,10 +201,10 @@ _x = 1 class A: _y = 1 [out] -_x = ... # type: int +_x: int class A: - _y = ... # type: int + _y: int [case testSpecialInternalVar] __all__ = [] @@ -397,9 +397,9 @@ class C: ... [out] class A: ... -_x = ... # type: int -_y = ... # type: int -_z = ... # type: int +_x: int +_y: int +_z: int class C: ... From dad8deb114474d9b59bd2e2b675a6acfd5221023 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 1 Sep 2017 10:19:05 +0100 Subject: [PATCH 12/22] Clarify on the use of visit_call_Expr to handl TypeVar --- mypy/stubgen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 06f2daf71eef..825ec4442881 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -266,6 +266,8 @@ def __init__(self, stubgen: 'StubGenerator') -> None: super().__init__() def visit_call_expr(self, node: CallExpr) -> str: + # Call expressions are not usually types, but we also treat `X = TypeVar(...)` as a + # type alias that has to be preserved (even if TypeVar is not the same as an alias) callee = node.callee.accept(self) args = [] for name, arg, kind in zip(node.arg_names, node.args, node.arg_kinds): From e50e744cc80f27fe91626efbf5b13c1aeedc53d8 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 1 Sep 2017 10:27:51 +0100 Subject: [PATCH 13/22] Add test for infer Optional on __init__ arguments --- test-data/unit/stubgen.test | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index b1ecacb9e80c..f3be32ec1e75 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -669,6 +669,8 @@ def f(a): ... [case testInferOptionalOnlyFunc] class A: x = None + def __init__(self, a=None): + self.x = [] def method(self, a=None): self.x = [] [out] @@ -676,6 +678,7 @@ from typing import Any, Optional class A: x: Any + def __init__(self, a: Optional[Any] = ...) -> None: ... def method(self, a: Optional[Any] = ...): ... [case testAnnotationImportsFrom] From cb8d9469f110b067ad76a17be988837415fcb66e Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 1 Sep 2017 10:34:43 +0100 Subject: [PATCH 14/22] remove uoting in stub files, not needed --- mypy/stubgen.py | 13 +------------ test-data/unit/stubgen.test | 13 +++++++++---- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 825ec4442881..f957a220c23c 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -239,18 +239,10 @@ class AnnotationPrinter(TypeStrVisitor): def __init__(self, stubgen: 'StubGenerator') -> None: super().__init__() self.stubgen = stubgen - self.requires_quotes = False - def visit_unbound_type(self, t: UnboundType)-> str: s = t.name base = s.split('.')[0] self.stubgen.import_tracker.require_name(base) - # if the name is not defined, assume a forward reference - if (not self.requires_quotes and - base not in dir(builtins) and - not any(base in vs for vs in self.stubgen._vars)): - self.requires_quotes = True - if t.args != []: s += '[{}]'.format(self.list_str(t.args)) return s @@ -738,10 +730,7 @@ def get_str_type_of_node(self, rvalue: Expression, def print_annotation(self, t: Type) -> str: printer = AnnotationPrinter(self) - typename = t.accept(printer) - if printer.requires_quotes: - typename = "'{}'".format(typename) - return typename + return t.accept(printer) def is_top_level(self) -> bool: """Are we processing the top level of a file?""" diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index f3be32ec1e75..6565e51f9b02 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -60,12 +60,17 @@ def f(x: Any = ...): ... [case testPreserveFunctionAnnotation] def f(x: Foo) -> Bar: ... [out] -def f(x: 'Foo') -> 'Bar': ... +def f(x: Foo) -> Bar: ... [case testPreserveVarAnnotation] x: Foo [out] +x: Foo + +[case testPreserveVarAnnotationWithoutQuotes] x: 'Foo' +[out] +x: Foo [case testVarArgs] def f(x, *y): ... @@ -101,7 +106,7 @@ x: int [case testAnnotatedVariableGeneric] x: Foo[int, str] = ... [out] -x: 'Foo[int, str]' +x: Foo[int, str] [case testAnnotatedVariableOldSyntax] x = 1 # type: int @@ -722,10 +727,10 @@ class C: y: C [out] -x: 'C' +x: C class C: - attr: 'C' + attr: C y: C From 7302599237c2018ca33aae222b966564a3fef726 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 1 Sep 2017 11:04:58 +0100 Subject: [PATCH 15/22] Add some clarifications in complex methods and data structures in ImportTracker --- mypy/stubgen.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index f957a220c23c..79a07596c343 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -239,6 +239,7 @@ class AnnotationPrinter(TypeStrVisitor): def __init__(self, stubgen: 'StubGenerator') -> None: super().__init__() self.stubgen = stubgen + def visit_unbound_type(self, t: UnboundType)-> str: s = t.name base = s.split('.')[0] @@ -300,9 +301,25 @@ def visit_ellipsis(self, node: EllipsisExpr) -> str: class ImportTracker: def __init__(self) -> None: + # module_for['foo'] has the module name where 'foo' was imported from, or None if + # 'foo' is a module imported directly; examples + # 'from pkg.m import f as foo' ==> module_for['foo'] == 'pkg.m' + # 'from m import f' ==> module_for['f'] == 'm' + # 'import m' ==> module_for['m'] == None self.module_for = {} # type: Dict[str, Optional[str]] + + # direct_imports['foo'] is the module path used when the name 'foo' was added to the + # namespace. + # import foo.bar.baz ==> direct_imports['foo'] == 'foo.bar.baz' self.direct_imports = {} # type: Dict[str, str] + + # reverse_alias['foo'] is the name that 'foo' had originally when imported with an + # alias; examples + # 'import numpy as np' ==> reverse_alias['np'] == 'numpy' + # 'from decimal import Decimal as D' ==> reverse_alias['D'] == 'Decimal' self.reverse_alias = {} # type: Dict[str, str] + + # required_names is the set of names that are actually used in a type annotation self.required_names = set() # type: Set[str] def add_import_from(self, module: str, names: List[Tuple[str, Optional[str]]]) -> None: @@ -322,22 +339,38 @@ def require_name(self, name: str) -> None: self.required_names.add(name.split('.')[0]) def import_lines(self) -> List[str]: + """ + The list of required import lines (as strings with python code) + """ result = [] + + # To summarize multiple names imported from a same module, we collect those + # in the `module_map` dictionary, mapping a module path to the list of names that should + # be imported from it. the names can also be alias in the form 'original as alias' module_map = defaultdict(list) # type: Mapping[str, List[str]] + for name in sorted(self.required_names): + # If we haven't seen this name in an import statement, ignore it if name not in self.module_for: continue + m = self.module_for[name] if m is not None: + # This name was found in a from ... import ... + # Collect the name in the module_map if name in self.reverse_alias: name = '{} as {}'.format(self.reverse_alias[name], name) module_map[m].append(name) else: + # This name was found in an import ... + # We can already generate the import line if name in self.reverse_alias: name, alias = self.reverse_alias[name], name result.append("import {} as {}\n".format(self.direct_imports[name], alias)) else: result.append("import {}\n".format(self.direct_imports[name])) + + # Now generate all the from ... import ... lines collected in module_map for module, names in sorted(module_map.items()): result.append("from {} import {}\n".format(module, ', '.join(sorted(names)))) return result From df0d5cf09db2263bfb852d768f78b12253c012a4 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 1 Sep 2017 11:51:45 +0100 Subject: [PATCH 16/22] Add test for deep, generic type alias --- test-data/unit/stubgen.test | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 6565e51f9b02..a66126633f53 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -771,6 +771,18 @@ alias = Dict[str, List[str]] [out] alias = Dict[str, List[str]] +[case testDeepGenericTypeAliasPreserved] +from typing import TypeVar + +T = TypeVar('T') +alias = Union[T, List[T]] + +[out] +from typing import TypeVar + +T = TypeVar('T') +alias = Union[T, List[T]] + [case testEllipsisAliasPreserved] alias = Tuple[int, ...] From 7a7c7a9f7ba0695d52228dcc826d928d421d39cc Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 22 Sep 2017 12:43:34 +0100 Subject: [PATCH 17/22] remove unused import --- mypy/stubgen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 79a07596c343..2a8c746aa891 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -35,7 +35,6 @@ - we don't seem to always detect properties ('closed' in 'io', for example) """ -import builtins import glob import importlib import json From 83006baea40001d161bc6681b2d437455054afe5 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 22 Sep 2017 12:47:54 +0100 Subject: [PATCH 18/22] Remove code no longer required now that implicit optional is discouraged --- mypy/stubgen.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 2a8c746aa891..167a82461007 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -437,10 +437,7 @@ def visit_func_def(self, o: FuncDef) -> None: annotation = "" init_stmt = arg_.initialization_statement if init_stmt: - if isinstance(init_stmt.rvalue, NameExpr) and init_stmt.rvalue.name == 'None': - initializer = 'None' - else: - initializer = '...' + initializer = '...' if kind in (ARG_NAMED, ARG_NAMED_OPT) and '*' not in args: args.append('*') if not annotation: From a95e7a41ed436418b8e6072c76e3786324a675df Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 22 Sep 2017 13:25:58 +0100 Subject: [PATCH 19/22] Reexport (with alias) names imported in __all__ --- mypy/stubgen.py | 19 +++++++++++++++++-- test-data/unit/stubgen.test | 2 +- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 167a82461007..8af4683e0919 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -321,6 +321,9 @@ def __init__(self) -> None: # required_names is the set of names that are actually used in a type annotation self.required_names = set() # type: Set[str] + # Names that should be reexported if they come from another module + self.reexports = set() # type: Set[str] + def add_import_from(self, module: str, names: List[Tuple[str, Optional[str]]]) -> None: for name, alias in names: self.module_for[alias or name] = module @@ -337,6 +340,16 @@ def add_import(self, module: str, alias: Optional[str]=None) -> None: def require_name(self, name: str) -> None: self.required_names.add(name.split('.')[0]) + def reexport(self, name: str) -> None: + """ + Mark a given non qualified name as needed in __all__. This means that in case it + comes from a module, it should be imported with an alias even is the alias is the same + as the name. + + """ + self.require_name(name) + self.reexports.add(name) + def import_lines(self) -> List[str]: """ The list of required import lines (as strings with python code) @@ -359,6 +372,8 @@ def import_lines(self) -> List[str]: # Collect the name in the module_map if name in self.reverse_alias: name = '{} as {}'.format(self.reverse_alias[name], name) + elif name in self.reexports: + name = '{} as {}'.format(name, name) module_map[m].append(name) else: # This name was found in an import ... @@ -394,7 +409,7 @@ def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int], self.import_tracker.add_import_from("typing", [(t, None) for t in typing_imports]) # Names in __all__ are required for name in _all_ or (): - self.import_tracker.require_name(name) + self.import_tracker.reexport(name) def visit_mypy_file(self, o: MypyFile) -> None: super().visit_mypy_file(o) @@ -671,7 +686,7 @@ def visit_import(self, o: Import) -> None: self.record_name(target_name) def get_init(self, lvalue: str, rvalue: Expression, - annotation: Optional[Type]=None) -> Optional[str]: + annotation: Optional[Type] = None) -> Optional[str]: """Return initializer for a variable. Return None if we've generated one already or if the variable is internal. diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index a66126633f53..9bb9c70700fb 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -413,7 +413,7 @@ from re import match, search, sub __all__ = ['match', 'sub', 'x'] x = 1 [out] -from re import match, sub +from re import match as match, sub as sub x: int From dd09d38fcd54ca56cbad751ac9a38e7a10570037 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 22 Sep 2017 15:21:11 +0100 Subject: [PATCH 20/22] Reexport of modules, not just module attributes --- mypy/stubgen.py | 4 ++++ test-data/unit/stubgen.test | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 8af4683e0919..18a1c09dec28 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -121,6 +121,7 @@ def generate_stub_for_module(module: str, output_dir: str, quiet: bool = False, else: target += '.pyi' target = os.path.join(output_dir, target) + generate_stub(module_path, output_dir, module_all, target=target, add_header=add_header, module=module, pyversion=pyversion, include_private=include_private) @@ -381,6 +382,9 @@ def import_lines(self) -> List[str]: if name in self.reverse_alias: name, alias = self.reverse_alias[name], name result.append("import {} as {}\n".format(self.direct_imports[name], alias)) + elif name in self.reexports: + assert '.' not in name # Because reexports only has nonqualified names + result.append("import {} as {}\n".format(name, name)) else: result.append("import {}\n".format(self.direct_imports[name])) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 9bb9c70700fb..da7942480208 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -423,7 +423,17 @@ __all__ = ['re', 'x'] x = 1 y = 2 [out] +import re as re + +x: int + +[case testExportModule_import] import re +__all__ = ['re', 'x'] +x = 1 +y = 2 +[out] +import re as re x: int @@ -443,17 +453,12 @@ __all__ = ['p'] [out] import urllib.parse as p -[case testExportModuleInPackageUnsupported_import] -__all__ = ['urllib'] -[out] -# Names in __all__ with no definition: -# urllib - -[case testExportModuleInPackage_import] +[case testExportPackageOfAModule_import] import urllib.parse __all__ = ['urllib'] + [out] -import urllib.parse +import urllib as urllib [case testRelativeImportAll] from .x import * From 6d383ac125dd06404e7b0581c852749f99dfb8f5 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 22 Sep 2017 15:54:30 +0100 Subject: [PATCH 21/22] Preserve initializer (with ...) in classes --- mypy/stubgen.py | 6 ++++-- test-data/unit/stubgen.test | 16 ++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 18a1c09dec28..729592e82289 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -59,7 +59,7 @@ Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, ClassDef, MypyFile, Decorator, AssignmentStmt, - IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, + IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, TempNode, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT, ) from mypy.stubgenc import parse_all_signatures, find_unique_signatures, generate_stub_for_c_module @@ -706,7 +706,9 @@ def get_init(self, lvalue: str, rvalue: Expression, typename = self.print_annotation(annotation) else: typename = self.get_str_type_of_node(rvalue) - return '%s%s: %s\n' % (self._indent, lvalue, typename) + has_rhs = not (isinstance(rvalue, TempNode) and rvalue.no_rhs) + initializer = " = ..." if has_rhs and not self.is_top_level() else "" + return '%s%s: %s%s\n' % (self._indent, lvalue, typename, initializer) def add(self, string: str) -> None: """Add text to generated stub.""" diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index da7942480208..6bb07249271e 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -134,7 +134,7 @@ class C: x = 1 [out] class C: - x: int + x: int = ... [case testSelfAssignment] class C: @@ -143,7 +143,7 @@ class C: x.y = 2 [out] class C: - x: int + x: int = ... def __init__(self) -> None: ... [case testSelfAndClassBodyAssignment] @@ -157,7 +157,7 @@ class C: x: int class C: - x: int + x: int = ... def __init__(self) -> None: ... [case testEmptyClass] @@ -209,7 +209,7 @@ class A: _x: int class A: - _y: int + _y: int = ... [case testSpecialInternalVar] __all__ = [] @@ -379,7 +379,7 @@ class A: def f(self): ... [out] class A: - x: int + x: int = ... def f(self): ... [case testSkipMultiplePrivateDefs] @@ -637,7 +637,7 @@ class A: [out] class A: class B: - x: int + x: int = ... def f(self): ... def g(self): ... @@ -687,7 +687,7 @@ class A: from typing import Any, Optional class A: - x: Any + x: Any = ... def __init__(self, a: Optional[Any] = ...) -> None: ... def method(self, a: Optional[Any] = ...): ... @@ -823,7 +823,7 @@ class Foo: from typing import Any class Foo: - alias: Any + alias: Any = ... [case testAliasExceptions] noalias1 = None From 93d300ae3da77079584cfe35868ef0c497b43971 Mon Sep 17 00:00:00 2001 From: Daniel F Moisset Date: Fri, 22 Sep 2017 16:54:44 +0100 Subject: [PATCH 22/22] Fix in annotation to avoid covariance issue --- mypy/stubgen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 729592e82289..a2f895414502 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -555,7 +555,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): items = lvalue.items if isinstance(o.type, TupleType): - annotations = o.type.items # type: List[Optional[Type]] + annotations = o.type.items # type: Iterable[Optional[Type]] else: annotations = [None] * len(items) else: