Skip to content

Commit

Permalink
Copy some typeshed changes to builtins.pytd and fix a load_pytd bug.
Browse files Browse the repository at this point in the history
We need to make classmethod and staticmethod generic to unblock
python/typeshed#5703.

Running pytype over that PR also exposes a circular dependency bug involving
star imports, which I've fixed.

PiperOrigin-RevId: 420213268
  • Loading branch information
rchen152 committed Jan 7, 2022
1 parent 4f2f2a0 commit 6aaf955
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
15 changes: 15 additions & 0 deletions pytype/load_pytd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,21 @@ def test_module_alias(self):
""").strip()
self.assertMultiLineEqual(pytd_utils.Print(ast), expected)

def test_star_import_in_circular_dep(self):
stub3_ast = self._import(stub1="""
from stub2 import Foo
from typing import Mapping as Mapping
""", stub2="""
from stub3 import Mapping
class Foo: ...
""", stub3="""
from stub1 import *
""")
self.assertEqual(stub3_ast.Lookup("stub3.Foo").type,
pytd.ClassType("stub2.Foo"))
self.assertEqual(stub3_ast.Lookup("stub3.Mapping").type,
pytd.ClassType("typing.Mapping"))


class ImportTypeMacroTest(_LoaderTest):

Expand Down
26 changes: 25 additions & 1 deletion pytype/pytd/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,25 @@ def _ResolveUsingGetattr(self, module_name, module):
assert len(g.signatures) == 1
return g.signatures[0].return_type

def _ResolveUsingStarImport(self, module, name):
"""Try to use any star imports in 'module' to resolve 'name'."""
wanted_name = self._ModulePrefix() + name
for alias in module.aliases:
type_name = alias.type.name
if not type_name or not type_name.endswith(".*"):
continue
imported_module = type_name[:-2]
# 'module' contains 'from imported_module import *'. If we can find an AST
# for imported_module, check whether any of the imported names match the
# one we want to resolve.
if imported_module not in self._module_map:
continue
imported_aliases, _ = self._ImportAll(imported_module)
for imported_alias in imported_aliases:
if imported_alias.name == wanted_name:
return imported_alias
return None

def EnterAlias(self, t):
super().EnterAlias(t)
assert not self._alias_name
Expand Down Expand Up @@ -431,7 +450,12 @@ def VisitNamedType(self, t):
except KeyError as e:
item = self._ResolveUsingGetattr(module_name, module)
if item is None:
raise KeyError("No %s in module %s" % (name, module_name)) from e
# If 'module' is involved in a circular dependency, it may contain a
# star import that has not yet been resolved via the usual mechanism, so
# we need to manually resolve it here.
item = self._ResolveUsingStarImport(module, name)
if item is None:
raise KeyError("No %s in module %s" % (name, module_name)) from e
if not self._in_generic_type and isinstance(item, pytd.Alias):
# If `item` contains type parameters and is not inside a GenericType, then
# we replace the parameters with Any.
Expand Down
12 changes: 6 additions & 6 deletions pytype/stubs/builtins/builtins.pytd
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,15 @@ class property(object):
def __delete__(self, *args, **kwargs) -> Any: ...

# staticmethod and classmethod are handled in special_builtins.py.
class staticmethod(typing.Callable):
class staticmethod(typing.Callable, Generic[_T]):
__slots__ = []
def __init__(self, func) -> NoneType: ...
def __get__(self, *args, **kwargs) -> Any: ...
def __init__(self: staticmethod[_T], __f: Callable[..., _T]) -> None: ...
def __get__(self, __obj: _T2, __type: Type[_T2] | None = ...) -> Callable[..., _T]: ...

class classmethod(typing.Callable):
class classmethod(typing.Callable, Generic[_T]):
__slots__ = []
def __init__(self, func) -> NoneType: ...
def __get__(self, *args, **kwargs) -> Any: ...
def __init__(self: classmethod[_T], __f: Callable[..., _T]) -> None: ...
def __get__(self, __obj: _T2, __type: Type[_T2] | None = ...) -> Callable[..., _T]: ...

_T_str = TypeVar('_T_str', bound=str)
class str(Sequence[str], Hashable):
Expand Down

0 comments on commit 6aaf955

Please sign in to comment.