diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 3123b18331eb..a2f895414502 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -44,9 +44,10 @@ import sys import textwrap import traceback +from collections import defaultdict from typing import ( - Any, List, Dict, Tuple, Iterable, Iterator, Optional, NamedTuple, Set, Union, cast + Any, List, Dict, Tuple, Iterable, Iterator, Mapping, Optional, NamedTuple, Set, Union, cast ) import mypy.build @@ -56,14 +57,16 @@ from mypy import defaults from mypy.nodes import ( Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, TupleExpr, - ListExpr, ComparisonExpr, CallExpr, ClassDef, MypyFile, Decorator, AssignmentStmt, - IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, - ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT, + ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, + ClassDef, MypyFile, Decorator, AssignmentStmt, + IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, TempNode, + ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT, ) from mypy.stubgenc import parse_all_signatures, find_unique_signatures, generate_stub_for_c_module from mypy.stubutil import is_c_module, write_header from mypy.options import Options as MypyOptions - +from mypy.types import Type, TypeStrVisitor, AnyType, CallableType, UnboundType, NoneTyp, TupleType +from mypy.visitor import NodeVisitor Options = NamedTuple('Options', [('pyversion', Tuple[int, int]), ('no_import', bool), @@ -118,6 +121,7 @@ def generate_stub_for_module(module: str, output_dir: str, quiet: bool = False, else: target += '.pyi' target = os.path.join(output_dir, target) + generate_stub(module_path, output_dir, module_all, target=target, add_header=add_header, module=module, pyversion=pyversion, include_private=include_private) @@ -230,27 +234,188 @@ def generate_stub(path: str, NOT_IN_ALL = 'NOT_IN_ALL' +class AnnotationPrinter(TypeStrVisitor): + + def __init__(self, stubgen: 'StubGenerator') -> None: + super().__init__() + self.stubgen = stubgen + + def visit_unbound_type(self, t: UnboundType)-> str: + s = t.name + base = s.split('.')[0] + self.stubgen.import_tracker.require_name(base) + if t.args != []: + s += '[{}]'.format(self.list_str(t.args)) + return s + + def visit_none_type(self, t: NoneTyp) -> str: + return "None" + + +class AliasPrinter(NodeVisitor[str]): + + def __init__(self, stubgen: 'StubGenerator') -> None: + self.stubgen = stubgen + super().__init__() + + def visit_call_expr(self, node: CallExpr) -> str: + # Call expressions are not usually types, but we also treat `X = TypeVar(...)` as a + # type alias that has to be preserved (even if TypeVar is not the same as an alias) + callee = node.callee.accept(self) + args = [] + for name, arg, kind in zip(node.arg_names, node.args, node.arg_kinds): + if kind == ARG_POS: + args.append(arg.accept(self)) + elif kind == ARG_STAR: + args.append('*' + arg.accept(self)) + elif kind == ARG_STAR2: + args.append('**' + arg.accept(self)) + elif kind == ARG_NAMED: + args.append('{}={}'.format(name, arg.accept(self))) + else: + raise ValueError("Unknown argument kind %d in call" % kind) + return "{}({})".format(callee, ", ".join(args)) + + def visit_name_expr(self, node: NameExpr) -> str: + self.stubgen.import_tracker.require_name(node.name) + return node.name + + def visit_str_expr(self, node: StrExpr) -> str: + return repr(node.value) + + def visit_index_expr(self, node: IndexExpr) -> str: + base = node.base.accept(self) + index = node.index.accept(self) + return "{}[{}]".format(base, index) + + def visit_tuple_expr(self, node: TupleExpr) -> str: + return ", ".join(n.accept(self) for n in node.items) + + def visit_list_expr(self, node: ListExpr) -> str: + return "[{}]".format(", ".join(n.accept(self) for n in node.items)) + + def visit_ellipsis(self, node: EllipsisExpr) -> str: + return "..." + + +class ImportTracker: + + def __init__(self) -> None: + # module_for['foo'] has the module name where 'foo' was imported from, or None if + # 'foo' is a module imported directly; examples + # 'from pkg.m import f as foo' ==> module_for['foo'] == 'pkg.m' + # 'from m import f' ==> module_for['f'] == 'm' + # 'import m' ==> module_for['m'] == None + self.module_for = {} # type: Dict[str, Optional[str]] + + # direct_imports['foo'] is the module path used when the name 'foo' was added to the + # namespace. + # import foo.bar.baz ==> direct_imports['foo'] == 'foo.bar.baz' + self.direct_imports = {} # type: Dict[str, str] + + # reverse_alias['foo'] is the name that 'foo' had originally when imported with an + # alias; examples + # 'import numpy as np' ==> reverse_alias['np'] == 'numpy' + # 'from decimal import Decimal as D' ==> reverse_alias['D'] == 'Decimal' + self.reverse_alias = {} # type: Dict[str, str] + + # required_names is the set of names that are actually used in a type annotation + self.required_names = set() # type: Set[str] + + # Names that should be reexported if they come from another module + self.reexports = set() # type: Set[str] + + def add_import_from(self, module: str, names: List[Tuple[str, Optional[str]]]) -> None: + for name, alias in names: + self.module_for[alias or name] = module + if alias: + self.reverse_alias[alias] = name + + def add_import(self, module: str, alias: Optional[str]=None) -> None: + name = module.split('.')[0] + self.module_for[alias or name] = None + self.direct_imports[name] = module + if alias: + self.reverse_alias[alias] = name + + def require_name(self, name: str) -> None: + self.required_names.add(name.split('.')[0]) + + def reexport(self, name: str) -> None: + """ + Mark a given non qualified name as needed in __all__. This means that in case it + comes from a module, it should be imported with an alias even is the alias is the same + as the name. + + """ + self.require_name(name) + self.reexports.add(name) + + def import_lines(self) -> List[str]: + """ + The list of required import lines (as strings with python code) + """ + result = [] + + # To summarize multiple names imported from a same module, we collect those + # in the `module_map` dictionary, mapping a module path to the list of names that should + # be imported from it. the names can also be alias in the form 'original as alias' + module_map = defaultdict(list) # type: Mapping[str, List[str]] + + for name in sorted(self.required_names): + # If we haven't seen this name in an import statement, ignore it + if name not in self.module_for: + continue + + m = self.module_for[name] + if m is not None: + # This name was found in a from ... import ... + # Collect the name in the module_map + if name in self.reverse_alias: + name = '{} as {}'.format(self.reverse_alias[name], name) + elif name in self.reexports: + name = '{} as {}'.format(name, name) + module_map[m].append(name) + else: + # This name was found in an import ... + # We can already generate the import line + if name in self.reverse_alias: + name, alias = self.reverse_alias[name], name + result.append("import {} as {}\n".format(self.direct_imports[name], alias)) + elif name in self.reexports: + assert '.' not in name # Because reexports only has nonqualified names + result.append("import {} as {}\n".format(name, name)) + else: + result.append("import {}\n".format(self.direct_imports[name])) + + # Now generate all the from ... import ... lines collected in module_map + for module, names in sorted(module_map.items()): + result.append("from {} import {}\n".format(module, ', '.join(sorted(names)))) + return result + + class StubGenerator(mypy.traverser.TraverserVisitor): def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int], include_private: bool = False) -> None: self._all_ = _all_ self._output = [] # type: List[str] self._import_lines = [] # type: List[str] - self._imports = [] # type: List[str] self._indent = '' self._vars = [[]] # type: List[List[str]] self._state = EMPTY self._toplevel_names = [] # type: List[str] - self._classes = set() # type: Set[str] - self._base_classes = [] # type: List[str] self._pyversion = pyversion self._include_private = include_private + self.import_tracker = ImportTracker() + # Add imports that could be implicitly generated + self.import_tracker.add_import_from("collections", [("namedtuple", None)]) + typing_imports = "Any Optional TypeVar".split() + self.import_tracker.add_import_from("typing", [(t, None) for t in typing_imports]) + # Names in __all__ are required + for name in _all_ or (): + self.import_tracker.reexport(name) def visit_mypy_file(self, o: MypyFile) -> None: - self._classes = find_classes(o) - for node in o.defs: - if isinstance(node, ClassDef): - self._base_classes.extend(self.get_base_types(node)) super().visit_mypy_file(o) undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names] @@ -283,21 +448,34 @@ def visit_func_def(self, o: FuncDef) -> None: var = arg_.variable kind = arg_.kind name = var.name() + annotated_type = o.type.arg_types[i] if isinstance(o.type, CallableType) else None + if annotated_type and not ( + i == 0 and name == 'self' and isinstance(annotated_type, AnyType)): + annotation = ": {}".format(self.print_annotation(annotated_type)) + else: + annotation = "" init_stmt = arg_.initialization_statement if init_stmt: + initializer = '...' if kind in (ARG_NAMED, ARG_NAMED_OPT) and '*' not in args: args.append('*') - typename = self.get_str_type_of_node(init_stmt.rvalue, True) - arg = '{}: {} = ...'.format(name, typename) + if not annotation: + typename = self.get_str_type_of_node(init_stmt.rvalue, True) + annotation = ': {} = ...'.format(typename) + else: + annotation += '={}'.format(initializer) + arg = name + annotation elif kind == ARG_STAR: - arg = '*%s' % name + arg = '*%s%s' % (name, annotation) elif kind == ARG_STAR2: - arg = '**%s' % name + arg = '**%s%s' % (name, annotation) else: - arg = name + arg = name + annotation args.append(arg) retname = None - if o.name() == '__init__': + if isinstance(o.type, CallableType): + retname = self.print_annotation(o.type.ret_type) + elif o.name() == '__init__': retname = 'None' retfield = '' if retname is not None: @@ -330,6 +508,8 @@ def visit_class_def(self, o: ClassDef) -> None: base_types = self.get_base_types(o) if base_types: self.add('(%s)' % ', '.join(base_types)) + for base in base_types: + self.import_tracker.require_name(base) self.add(':\n') n = len(self._output) self._indent += ' ' @@ -337,6 +517,7 @@ def visit_class_def(self, o: ClassDef) -> None: super().visit_class_def(o) self._indent = self._indent[:-4] self._vars.pop() + self._vars[-1].append(o.name) if len(self._output) == n: if self._state == EMPTY_CLASS and sep is not None: self._output[sep] = '' @@ -354,7 +535,9 @@ def get_base_types(self, cdef: ClassDef) -> List[str]: elif isinstance(base, MemberExpr): modname = get_qualified_name(base.expr) base_types.append('%s.%s' % (modname, base.name)) - self.add_import_line('import %s\n' % modname) + elif isinstance(base, IndexExpr): + p = AliasPrinter(self) + base_types.append(base.accept(p)) return base_types def visit_assignment_stmt(self, o: AssignmentStmt) -> None: @@ -365,17 +548,24 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: assert isinstance(o.rvalue, CallExpr) self.process_namedtuple(lvalue, o.rvalue) continue - if isinstance(lvalue, TupleExpr): - items = lvalue.items - elif isinstance(lvalue, ListExpr): + if (self.is_top_level() and + isinstance(lvalue, NameExpr) and self.is_type_expression(o.rvalue)): + self.process_typealias(lvalue, o.rvalue) + continue + if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): items = lvalue.items + if isinstance(o.type, TupleType): + annotations = o.type.items # type: Iterable[Optional[Type]] + else: + annotations = [None] * len(items) else: items = [lvalue] + annotations = [o.type] sep = False found = False - for item in items: + for item, annotation in zip(items, annotations): if isinstance(item, NameExpr): - init = self.get_init(item.name, o.rvalue) + init = self.get_init(item.name, o.rvalue, annotation) if init: found = True if not sep and not self._indent and \ @@ -397,7 +587,7 @@ def is_namedtuple(self, expr: Expression) -> bool: (isinstance(callee, MemberExpr) and callee.name == 'namedtuple')) def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: - self.add_import_line('from collections import namedtuple\n') + self.import_tracker.require_name('namedtuple') if self._state != EMPTY: self.add('\n') name = repr(getattr(rvalue.args[0], 'value', '')) @@ -409,9 +599,50 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: else: items = '' self.add('%s = namedtuple(%s, %s)\n' % (lvalue.name, name, items)) - self._classes.add(lvalue.name) self._state = CLASS + def is_type_expression(self, expr: Expression, top_level: bool=True) -> bool: + """Return True for things that look like type expressions + + Used to know if assignments look like typealiases + """ + # Assignment of TypeVar(...) are passed through + if (isinstance(expr, CallExpr) and + isinstance(expr.callee, NameExpr) and + expr.callee.name == 'TypeVar'): + return True + elif isinstance(expr, EllipsisExpr): + return not top_level + elif isinstance(expr, NameExpr): + if expr.name in ('True', 'False'): + return False + elif expr.name == 'None': + return not top_level + else: + return True + elif isinstance(expr, IndexExpr) and isinstance(expr.base, NameExpr): + if isinstance(expr.index, TupleExpr): + indices = expr.index.items + else: + indices = [expr.index] + if expr.base.name == 'Callable' and len(indices) == 2: + args, ret = indices + if isinstance(args, EllipsisExpr): + indices = [ret] + elif isinstance(args, ListExpr): + indices = args.items + [ret] + else: + return False + return all(self.is_type_expression(i, top_level=False) for i in indices) + else: + return False + + def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None: + p = AliasPrinter(self) + self.add("{} = {}\n".format(lvalue.name, rvalue.accept(p))) + self.record_name(lvalue.name) + self._vars[-1].append(lvalue.name) + def visit_if_stmt(self, o: IfStmt) -> None: # Ignore if __name__ == '__main__'. expr = o.expr[0] @@ -428,53 +659,38 @@ def visit_import_all(self, o: ImportAll) -> None: def visit_import_from(self, o: ImportFrom) -> None: exported_names = set() # type: Set[str] + self.import_tracker.add_import_from('.' * o.relative + o.id, o.names) + self._vars[-1].extend(alias or name for name, alias in o.names) + for name, alias in o.names: + self.record_name(alias or name) + if self._all_: # Include import froms that import names defined in __all__. names = [name for name, alias in o.names if name in self._all_ and alias is None] exported_names.update(names) - self.import_and_export_names(o.id, o.relative, names) else: # Include import from targets that import from a submodule of a package. if o.relative: sub_names = [name for name, alias in o.names if alias is None] exported_names.update(sub_names) - self.import_and_export_names(o.id, o.relative, sub_names) - # Import names used as base classes. - base_names = [(name, alias) for name, alias in o.names - if alias or name in self._base_classes and name not in exported_names] - if base_names: - imp_names = [] # type: List[str] - for name, alias in base_names: - if alias is not None and alias != name: - imp_names.append('%s as %s' % (name, alias)) - else: - imp_names.append(name) - self.add_import_line('from %s%s import %s\n' % ( - '.' * o.relative, o.id, ', '.join(imp_names))) - - def import_and_export_names(self, module_id: str, relative: int, names: Iterable[str]) -> None: - """Import names from a module and export them (via from ... import x as x).""" - if names and module_id: - full_module_name = '%s%s' % ('.' * relative, module_id) - imported_names = ', '.join(['%s as %s' % (name, name) for name in names]) - self.add_import_line('from %s import %s\n' % (full_module_name, imported_names)) - for name in names: - self.record_name(name) + if o.id: + for name in sub_names: + self.import_tracker.require_name(name) def visit_import(self, o: Import) -> None: for id, as_id in o.ids: + self.import_tracker.add_import(id, as_id) if as_id is None: target_name = id.split('.')[0] else: target_name = as_id - if self._all_ and target_name in self._all_ and (as_id is not None or - '.' not in id): - self.add_import_line('import %s as %s\n' % (id, target_name)) - self.record_name(target_name) + self._vars[-1].append(target_name) + self.record_name(target_name) - def get_init(self, lvalue: str, rvalue: Expression) -> Optional[str]: + def get_init(self, lvalue: str, rvalue: Expression, + annotation: Optional[Type] = None) -> Optional[str]: """Return initializer for a variable. Return None if we've generated one already or if the variable is internal. @@ -486,8 +702,13 @@ def get_init(self, lvalue: str, rvalue: Expression) -> Optional[str]: if self.is_private_name(lvalue) or self.is_not_in_all(lvalue): return None self._vars[-1].append(lvalue) - typename = self.get_str_type_of_node(rvalue) - return '%s%s = ... # type: %s\n' % (self._indent, lvalue, typename) + if annotation is not None: + typename = self.print_annotation(annotation) + else: + typename = self.get_str_type_of_node(rvalue) + has_rhs = not (isinstance(rvalue, TempNode) and rvalue.no_rhs) + initializer = " = ..." if has_rhs and not self.is_top_level() else "" + return '%s%s: %s%s\n' % (self._indent, lvalue, typename, initializer) def add(self, string: str) -> None: """Add text to generated stub.""" @@ -498,8 +719,7 @@ def add_typing_import(self, name: str) -> None: The import will be internal to the stub. """ - if name not in self._imports: - self._imports.append(name) + self.import_tracker.require_name(name) def add_import_line(self, line: str) -> None: """Add a line of text to the import section, unless it's already there.""" @@ -509,10 +729,9 @@ def add_import_line(self, line: str) -> None: def output(self) -> str: """Return the text for the stub.""" imports = '' - if self._imports: - imports += 'from typing import %s\n' % ", ".join(sorted(self._imports)) if self._import_lines: imports += ''.join(self._import_lines) + imports += ''.join(self.import_tracker.import_lines()) if imports and self._output: imports += '\n' return imports + ''.join(self._output) @@ -559,6 +778,10 @@ def get_str_type_of_node(self, rvalue: Expression, self.add_typing_import('Any') return 'Any' + def print_annotation(self, t: Type) -> str: + printer = AnnotationPrinter(self) + return t.accept(printer) + def is_top_level(self) -> bool: """Are we processing the top level of a file?""" return self._indent == '' @@ -591,17 +814,6 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: return results -def find_classes(node: MypyFile) -> Set[str]: - results = set() # type: Set[str] - - class ClassTraverser(mypy.traverser.TraverserVisitor): - def visit_class_def(self, o: ClassDef) -> None: - results.add(o.name) - - node.accept(ClassTraverser()) - return results - - def get_qualified_name(o: Expression) -> str: if isinstance(o, NameExpr): return o.name diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 2795bb34538a..6bb07249271e 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -57,6 +57,21 @@ from typing import Any def f(x: Any = ...): ... +[case testPreserveFunctionAnnotation] +def f(x: Foo) -> Bar: ... +[out] +def f(x: Foo) -> Bar: ... + +[case testPreserveVarAnnotation] +x: Foo +[out] +x: Foo + +[case testPreserveVarAnnotationWithoutQuotes] +x: 'Foo' +[out] +x: Foo + [case testVarArgs] def f(x, *y): ... [out] @@ -81,20 +96,45 @@ def g(): ... [case testVariable] x = 1 [out] -x = ... # type: int +x: int + +[case testAnnotatedVariable] +x: int = 1 +[out] +x: int + +[case testAnnotatedVariableGeneric] +x: Foo[int, str] = ... +[out] +x: Foo[int, str] + +[case testAnnotatedVariableOldSyntax] +x = 1 # type: int +[out] +x: int + +[case testAnnotatedVariableNone] +x: None +[out] +x: None + +[case testAnnotatedVariableNoneOldSyntax] +x = None # type: None +[out] +x: None [case testMultipleVariable] x = y = 1 [out] -x = ... # type: int -y = ... # type: int +x: int +y: int [case testClassVariable] class C: x = 1 [out] class C: - x = ... # type: int + x: int = ... [case testSelfAssignment] class C: @@ -103,7 +143,7 @@ class C: x.y = 2 [out] class C: - x = ... # type: int + x: int = ... def __init__(self) -> None: ... [case testSelfAndClassBodyAssignment] @@ -114,10 +154,10 @@ class C: self.x = 1 self.x = 1 [out] -x = ... # type: int +x: int class C: - x = ... # type: int + x: int = ... def __init__(self) -> None: ... [case testEmptyClass] @@ -166,10 +206,10 @@ _x = 1 class A: _y = 1 [out] -_x = ... # type: int +_x: int class A: - _y = ... # type: int + _y: int = ... [case testSpecialInternalVar] __all__ = [] @@ -195,16 +235,22 @@ x, y = 1, 2 [out] from typing import Any -x = ... # type: Any -y = ... # type: Any +x: Any +y: Any + +[case testMultipleAssignmentAnnotated] +x, y = 1, "2" # type: int, str +[out] +x: int +y: str [case testMultipleAssignment2] [x, y] = 1, 2 [out] from typing import Any -x = ... # type: Any -y = ... # type: Any +x: Any +y: Any [case testKeywordOnlyArg] def f(x, *, y=1): ... @@ -333,7 +379,7 @@ class A: def f(self): ... [out] class A: - x = ... # type: int + x: int = ... def f(self): ... [case testSkipMultiplePrivateDefs] @@ -356,9 +402,9 @@ class C: ... [out] class A: ... -_x = ... # type: int -_y = ... # type: int -_z = ... # type: int +_x: int +_y: int +_z: int class C: ... @@ -369,7 +415,17 @@ x = 1 [out] from re import match as match, sub as sub -x = ... # type: int +x: int + +[case testExportModule_import] +import re +__all__ = ['re', 'x'] +x = 1 +y = 2 +[out] +import re as re + +x: int [case testExportModule_import] import re @@ -379,7 +435,7 @@ y = 2 [out] import re as re -x = ... # type: int +x: int [case testExportModuleAs_import] import re as rex @@ -389,7 +445,7 @@ y = 2 [out] import re as rex -x = ... # type: int +x: int [case testExportModuleInPackage_import] import urllib.parse as p @@ -397,12 +453,12 @@ __all__ = ['p'] [out] import urllib.parse as p -[case testExportModuleInPackageUnsupported_import] +[case testExportPackageOfAModule_import] import urllib.parse __all__ = ['urllib'] + [out] -# Names in __all__ with no definition: -# urllib +import urllib as urllib [case testRelativeImportAll] from .x import * @@ -418,7 +474,7 @@ class C: [out] def f(): ... -x = ... # type: int +x: int class C: def g(self): ... @@ -494,7 +550,6 @@ def f(): ... X = _namedtuple('X', 'a b') def g(): ... [out] -from collections import namedtuple as _namedtuple from collections import namedtuple def f(): ... @@ -551,6 +606,11 @@ from x import X as _X class A(_X): ... +[case testGenericClass] +class D(Generic[T]): ... +[out] +class D(Generic[T]): ... + [case testObjectBaseClass] class A(object): ... [out] @@ -577,19 +637,19 @@ class A: [out] class A: class B: - x = ... # type: int + x: int = ... def f(self): ... def g(self): ... [case testExportViaRelativeImport] from .api import get [out] -from .api import get as get +from .api import get [case testExportViaRelativePackageImport] from .packages.urllib3.contrib import parse [out] -from .packages.urllib3.contrib import parse as parse +from .packages.urllib3.contrib import parse [case testNoExportViaRelativeImport] from . import get @@ -600,7 +660,7 @@ from .x import X class A(X): pass [out] -from .x import X as X +from .x import X class A(X): ... @@ -619,14 +679,163 @@ def f(a): ... [case testInferOptionalOnlyFunc] class A: x = None - def __init__(self, a=None) -> None: + def __init__(self, a=None): + self.x = [] + def method(self, a=None): self.x = [] [out] from typing import Any, Optional class A: - x = ... # type: Any + x: Any = ... def __init__(self, a: Optional[Any] = ...) -> None: ... + def method(self, a: Optional[Any] = ...): ... + +[case testAnnotationImportsFrom] +import foo +from collection import defaultdict +x: defaultdict + +[out] +from collection import defaultdict + +x: defaultdict + +[case testAnnotationImports] +import foo +import collection +x: collection.defaultdict + +[out] +import collection + +x: collection.defaultdict + + +[case testAnnotationImports] +from typing import List +import collection +x: List[collection.defaultdict] + +[out] +import collection +from typing import List + +x: List[collection.defaultdict] + + +[case testAnnotationFwRefs] +x: C + +class C: + attr: C + +y: C +[out] +x: C + +class C: + attr: C + +y: C + +[case testTypeVarPreserved] +tv = TypeVar('tv') + +[out] +from typing import TypeVar + +tv = TypeVar('tv') + +[case testTypeVarArgsPreserved] +tv = TypeVar('tv', int, str) + +[out] +from typing import TypeVar + +tv = TypeVar('tv', int, str) + +[case testTypeVarNamedArgsPreserved] +tv = TypeVar('tv', bound=bool, covariant=True) + +[out] +from typing import TypeVar + +tv = TypeVar('tv', bound=bool, covariant=True) + +[case testTypeAliasPreserved] +alias = str + +[out] +alias = str + +[case testDeepTypeAliasPreserved] + +alias = Dict[str, List[str]] + +[out] +alias = Dict[str, List[str]] + +[case testDeepGenericTypeAliasPreserved] +from typing import TypeVar + +T = TypeVar('T') +alias = Union[T, List[T]] + +[out] +from typing import TypeVar + +T = TypeVar('T') +alias = Union[T, List[T]] + +[case testEllipsisAliasPreserved] + +alias = Tuple[int, ...] + +[out] +alias = Tuple[int, ...] + +[case testCallableAliasPreserved] + +alias1 = Callable[..., int] +alias2 = Callable[[str, bool], None] + +[out] +alias1 = Callable[..., int] +alias2 = Callable[[str, bool], None] + +[case testAliasPullsImport] +from module import Container + +alias = Container[Any] + +[out] +from module import Container +from typing import Any + +alias = Container[Any] + +[case testAliasOnlyToplevel] +class Foo: + alias = str + +[out] +from typing import Any + +class Foo: + alias: Any = ... + +[case testAliasExceptions] +noalias1 = None +noalias2 = ... +noalias3 = True + +[out] +from typing import Any + +noalias1: Any +noalias2: Any +noalias3: bool -- More features/fixes: -- do not export deleted names