Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
stubgen: multiple fixes to the generated imports
Browse files Browse the repository at this point in the history
* Fix handling of nested imports
  Instead of assuming that a name is imported from a top level package,
  look in the imports for this name starting from the parent submodule
  up until the import is found
* Fix "from imports" getting rexported unnecessarily
* Fix import sorting when having import aliases

Fixes python#13661
Fixes python#7006
hamdanal committed Jul 8, 2023
1 parent 6cd8c00 commit ad37a6d
Showing 2 changed files with 74 additions and 10 deletions.
24 changes: 17 additions & 7 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
@@ -492,7 +492,9 @@ def add_import(self, module: str, alias: str | None = None) -> None:
name = name.rpartition(".")[0]

def require_name(self, name: str) -> None:
self.required_names.add(name.split(".")[0])
while name not in self.direct_imports and "." in name:
name = name.rsplit(".", 1)[0]
self.required_names.add(name)

def reexport(self, name: str) -> None:
"""Mark a given non qualified name as needed in __all__.
@@ -512,7 +514,10 @@ def import_lines(self) -> list[str]:
# be imported from it. the names can also be alias in the form 'original as alias'
module_map: Mapping[str, list[str]] = defaultdict(list)

for name in sorted(self.required_names):
for name in sorted(
self.required_names,
key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""),
):
# If we haven't seen this name in an import statement, ignore it
if name not in self.module_for:
continue
@@ -536,7 +541,7 @@ def import_lines(self) -> list[str]:
assert "." not in name # Because reexports only has nonqualified names
result.append(f"import {name} as {name}\n")
else:
result.append(f"import {self.direct_imports[name]}\n")
result.append(f"import {name}\n")

# Now generate all the from ... import ... lines collected in module_map
for module, names in sorted(module_map.items()):
@@ -591,7 +596,7 @@ def visit_name_expr(self, e: NameExpr) -> None:
self.refs.add(e.name)

def visit_instance(self, t: Instance) -> None:
self.add_ref(t.type.fullname)
self.add_ref(t.type.name)
super().visit_instance(t)

def visit_unbound_type(self, t: UnboundType) -> None:
@@ -610,7 +615,10 @@ def visit_callable_type(self, t: CallableType) -> None:
t.ret_type.accept(self)

def add_ref(self, fullname: str) -> None:
self.refs.add(fullname.split(".")[-1])
self.refs.add(fullname)
while "." in fullname:
fullname = fullname.rsplit(".", 1)[0]
self.refs.add(fullname)


class StubGenerator(mypy.traverser.TraverserVisitor):
@@ -1250,6 +1258,7 @@ def visit_import_from(self, o: ImportFrom) -> None:
if (
as_name is None
and name not in self.referenced_names
and not any(n.startswith(name + ".") for n in self.referenced_names)
and (not self._all_ or name in IGNORED_DUNDERS)
and not is_private
and module not in ("abc", "asyncio") + TYPING_MODULE_NAMES
@@ -1258,14 +1267,15 @@ def visit_import_from(self, o: ImportFrom) -> None:
# exported, unless there is an explicit __all__. Note that we need to special
# case 'abc' since some references are deleted during semantic analysis.
exported = True
top_level = full_module.split(".")[0]
top_level = full_module.split(".", 1)[0]
self_top_level = self.module.split(".", 1)[0]
if (
as_name is None
and not self.export_less
and (not self._all_ or name in IGNORED_DUNDERS)
and self.module
and not is_private
and top_level in (self.module.split(".")[0], "_" + self.module.split(".")[0])
and top_level in (self_top_level, "_" + self_top_level)
):
# Export imports from the same package, since we can't reliably tell whether they
# are part of the public API.
60 changes: 57 additions & 3 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
@@ -2656,9 +2656,9 @@ y: b.Y
z: p.a.X

[out]
import p.a
import p.a as a
import p.b as b
import p.a

x: a.X
y: b.Y
@@ -2671,7 +2671,7 @@ from p import a
x: a.X

[out]
from p import a as a
from p import a

x: a.X

@@ -2693,7 +2693,7 @@ from p import a
x: a.X

[out]
from p import a as a
from p import a

x: a.X

@@ -2743,6 +2743,60 @@ import p.a
x: a.X
y: p.a.Y

[case testNestedImports]
import p
import p.m1
import p.m2

x: p.X
y: p.m1.Y
z: p.m2.Z

[out]
import p
import p.m1
import p.m2

x: p.X
y: p.m1.Y
z: p.m2.Z

[case testNestedImportsAliased]
import p as t
import p.m1 as pm1
import p.m2 as pm2

x: t.X
y: pm1.Y
z: pm2.Z

[out]
import p as t
import p.m1 as pm1
import p.m2 as pm2

x: t.X
y: pm1.Y
z: pm2.Z

[case testNestedFromImports]
from p import m1
from p.m1 import sm1
from p.m2 import sm2

x: m1.X
y: sm1.Y
z: sm2.Z

[out]
from p import m1
from p.m1 import sm1
from p.m2 import sm2

x: m1.X
y: sm1.Y
z: sm2.Z

[case testOverload_fromTypingImport]
from typing import Tuple, Union, overload

0 comments on commit ad37a6d

Please sign in to comment.