Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stubgen generates invalid stubs for code with import chain #6831

Open
matangover opened this issue May 15, 2019 · 2 comments
Open

Stubgen generates invalid stubs for code with import chain #6831

matangover opened this issue May 15, 2019 · 2 comments

Comments

@matangover
Copy link
Contributor

matangover commented May 15, 2019

Encountered many issues while trying to make stubs for TensorFlow (lost cause, I guess, but I had to try). One of the issues is reproduced by the following toy example.

m1.py:

x = 123

m2.py:

from m1 import x
__all__ = []

m3.py:

from m2 import x
__all__ = ['x']

Generated stubs:

m1.pyi

x: int

m2.pyi

m3.pyi

from m2 import x as x

Probably the most 'right' thing to do would be to have m3.pyi say:

from m1 import x as x

But this doesn't seem easy with current stubgen architecture.

@matangover
Copy link
Contributor Author

Maybe something like this could work (fixes toy example -- not sure about real world).

diff --git a/mypy/stubgen.py b/mypy/stubgen.py
index bfe4dea0..a659a1dd 100755
--- a/mypy/stubgen.py
+++ b/mypy/stubgen.py
@@ -334,7 +334,8 @@ class ImportTracker:
 
 class StubGenerator(mypy.traverser.TraverserVisitor):
     def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int],
-                 include_private: bool = False, analyzed: bool = False) -> None:
+                 include_private: bool = False, analyzed: bool = False,
+                 all_modules: Dict[str, MypyFile] = None) -> None:
         # Best known value of __all__.
         self._all_ = _all_
         self._output = []  # type: List[str]
@@ -358,6 +359,7 @@ class StubGenerator(mypy.traverser.TraverserVisitor):
         # Names in __all__ are required
         for name in _all_ or ():
             self.import_tracker.reexport(name)
+        self._all_modules = all_modules or {}
 
     def visit_mypy_file(self, o: MypyFile) -> None:
         super().visit_mypy_file(o)
@@ -670,7 +672,19 @@ class StubGenerator(mypy.traverser.TraverserVisitor):
 
     def visit_import_from(self, o: ImportFrom) -> None:
         exported_names = set()  # type: Set[str]
-        self.import_tracker.add_import_from('.' * o.relative + o.id, o.names)
+        module = self._all_modules.get(o.id) # TODO: Fix relative import.
+        if module:
+            for imported_name in o.names:
+                imported = module.names.get(imported_name[0])
+                if imported and imported.fullname:
+                    name_parts = imported.fullname.split('.')
+                    containing_module = '.'.join(name_parts[:-1])
+                    original_name = name_parts[-1]
+                    self.import_tracker.add_import_from(containing_module, [(original_name, imported_name[1])])
+                else:
+                    self.import_tracker.add_import_from('.' * o.relative + o.id, [imported_name])
+        else:
+            self.import_tracker.add_import_from('.' * o.relative + o.id, o.names)
         self._vars[-1].extend(alias or name for name, alias in o.names)
         for name, alias in o.names:
             self.record_name(alias or name)
@@ -1026,7 +1040,8 @@ def generate_stub_from_ast(mod: StubSource,
                            parse_only: bool = False,
                            pyversion: Tuple[int, int] = defaults.PYTHON3_VERSION,
                            include_private: bool = False,
-                           add_header: bool = True) -> None:
+                           add_header: bool = True,
+                           all_modules = None) -> None:
     """Use analysed (or just parsed) AST to generate type stub for single file.
 
     If directory for target doesn't exist it will created. Existing stub
@@ -1035,7 +1050,8 @@ def generate_stub_from_ast(mod: StubSource,
     gen = StubGenerator(mod.runtime_all,
                         pyversion=pyversion,
                         include_private=include_private,
-                        analyzed=not parse_only)
+                        analyzed=not parse_only,
+                        all_modules=all_modules)
     assert mod.ast is not None, "This function must be used only with analyzed modules"
     mod.ast.accept(gen)
 
@@ -1081,6 +1097,7 @@ def generate_stubs(options: Options,
 
     # Use parsed sources to generate stubs for Python modules.
     generate_asts_for_modules(py_modules, options.parse_only, mypy_opts)
+    py_modules_by_fullname = {mod.ast.fullname(): mod.ast for mod in py_modules if mod.ast}
     for mod in py_modules:
         assert mod.path is not None, "Not found module was not skipped"
         target = mod.module.replace('.', '/')
@@ -1092,7 +1109,8 @@ def generate_stubs(options: Options,
         with generate_guarded(mod.module, target, options.ignore_errors, quiet):
             generate_stub_from_ast(mod, target,
                                    options.parse_only, options.pyversion,
-                                   options.include_private, add_header)
+                                   options.include_private, add_header,
+                                   py_modules_by_fullname)
 
     # Separately analyse C modules using different logic.
     for mod in c_modules:

@ilevkivskyi
Copy link
Member

TBH, I am not sure what is the best fix here. Also is it a common pattern?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants