Skip to content

Commit

Permalink
Add some clarifications in complex methods and data structures in Imp…
Browse files Browse the repository at this point in the history
…ortTracker
  • Loading branch information
dmoisset committed Sep 1, 2017
1 parent cb8d946 commit 7302599
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class AnnotationPrinter(TypeStrVisitor):
def __init__(self, stubgen: 'StubGenerator') -> None:
super().__init__()
self.stubgen = stubgen

def visit_unbound_type(self, t: UnboundType)-> str:
s = t.name
base = s.split('.')[0]
Expand Down Expand Up @@ -300,9 +301,25 @@ def visit_ellipsis(self, node: EllipsisExpr) -> str:
class ImportTracker:

def __init__(self) -> None:
# module_for['foo'] has the module name where 'foo' was imported from, or None if
# 'foo' is a module imported directly; examples
# 'from pkg.m import f as foo' ==> module_for['foo'] == 'pkg.m'
# 'from m import f' ==> module_for['f'] == 'm'
# 'import m' ==> module_for['m'] == None
self.module_for = {} # type: Dict[str, Optional[str]]

# direct_imports['foo'] is the module path used when the name 'foo' was added to the
# namespace.
# import foo.bar.baz ==> direct_imports['foo'] == 'foo.bar.baz'
self.direct_imports = {} # type: Dict[str, str]

# reverse_alias['foo'] is the name that 'foo' had originally when imported with an
# alias; examples
# 'import numpy as np' ==> reverse_alias['np'] == 'numpy'
# 'from decimal import Decimal as D' ==> reverse_alias['D'] == 'Decimal'
self.reverse_alias = {} # type: Dict[str, str]

# required_names is the set of names that are actually used in a type annotation
self.required_names = set() # type: Set[str]

def add_import_from(self, module: str, names: List[Tuple[str, Optional[str]]]) -> None:
Expand All @@ -322,22 +339,38 @@ def require_name(self, name: str) -> None:
self.required_names.add(name.split('.')[0])

def import_lines(self) -> List[str]:
"""
The list of required import lines (as strings with python code)
"""
result = []

# To summarize multiple names imported from a same module, we collect those
# in the `module_map` dictionary, mapping a module path to the list of names that should
# be imported from it. the names can also be alias in the form 'original as alias'
module_map = defaultdict(list) # type: Mapping[str, List[str]]

for name in sorted(self.required_names):
# If we haven't seen this name in an import statement, ignore it
if name not in self.module_for:
continue

m = self.module_for[name]
if m is not None:
# This name was found in a from ... import ...
# Collect the name in the module_map
if name in self.reverse_alias:
name = '{} as {}'.format(self.reverse_alias[name], name)
module_map[m].append(name)
else:
# This name was found in an import ...
# We can already generate the import line
if name in self.reverse_alias:
name, alias = self.reverse_alias[name], name
result.append("import {} as {}\n".format(self.direct_imports[name], alias))
else:
result.append("import {}\n".format(self.direct_imports[name]))

# Now generate all the from ... import ... lines collected in module_map
for module, names in sorted(module_map.items()):
result.append("from {} import {}\n".format(module, ', '.join(sorted(names))))
return result
Expand Down

0 comments on commit 7302599

Please sign in to comment.