Skip to content

Commit

Permalink
Add visitor to undo AddNamePrefix.
Browse files Browse the repository at this point in the history
Adds a new visitor, `visitor.RemoveNamePrefix`, that performs
the inverse operation of `visitor.AddNamePrefix` as much as is
possible.

Some conversions performed by `AddNamePrefix` are lossy; we choose
a consistent inverse operation in those cases.

Adds tests for the same.

#tftypes

PiperOrigin-RevId: 574974230
  • Loading branch information
Spyboticsguy authored and rchen152 committed Oct 20, 2023
1 parent 74f68cb commit 88d1ae5
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 1 deletion.
114 changes: 113 additions & 1 deletion pytype/pytd/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import logging
import re
from typing import Callable, List, Optional, Set, cast
from typing import Callable, List, Optional, Set, TypeVar, cast

from pytype import datatypes
from pytype import module_utils
Expand All @@ -19,6 +19,10 @@
from pytype.pytd.parse import parser_constants # pylint: disable=g-importing-member


_N = TypeVar("_N", bound=pytd.Node)
_T = TypeVar("_T", bound=pytd.Type)


class ContainerError(Exception):
pass

Expand Down Expand Up @@ -1332,6 +1336,114 @@ def VisitModule(self, node):
return self._VisitNamedNode(node)


class RemoveNamePrefix(Visitor):
"""Visitor which removes the fully-qualified-names added by AddNamePrefix."""

def __init__(self):
super().__init__()
self.cls_stack: list[pytd.Class] = []
self.classes: set[str] = set()
self.prefix = None
self.name = None

def _removeprefix(self, s: str, prefix: str) -> str:
"""Removes the given prefix from the string, if present."""
if not s.startswith(prefix):
return s
return s[len(prefix) :]

def _SuperClassString(self) -> str:
classes = ".".join(
cls.name.rsplit(".", 1)[-1] for cls in self.cls_stack[:-1]
)
if classes:
classes = classes + "."
return self.prefix + classes

def EnterTypeDeclUnit(self, node: pytd.TypeDeclUnit) -> None:
self.name = node.name
self.prefix = node.name + "."
self.classes = {
self._removeprefix(cls.name, self._SuperClassString())
for cls in node.classes
}

def EnterClass(self, cls: pytd.Class) -> None:
self.cls_stack.append(cls)

def LeaveClass(self, cls: pytd.Class) -> None:
assert self.cls_stack[-1] is cls
self.cls_stack.pop()

def VisitClassType(self, node: pytd.ClassType) -> pytd.ClassType:
if node.cls is not None:
raise ValueError("RemoveNamePrefix visitor called after resolving")
return self._VisitType(node)

def VisitLateType(self, node: pytd.LateType) -> pytd.LateType:
return self._VisitType(node)

def VisitNamedType(self, node: pytd.NamedType) -> pytd.NamedType:
return self._VisitType(node)

def _VisitType(self, node: _T) -> _T:
"""Unprefix a pytd.Type."""
if not node.name:
return node
name = self._removeprefix(node.name, self.prefix)
if name.split(".")[0] in self.classes:
# We need to check just the first part, in case we have a class constant
# like Foo.BAR, or some similarly nested name.
return node.Replace(name=name)
if self.cls_stack:
name = self._removeprefix(node.name, self._SuperClassString())
if name == self.cls_stack[-1].name:
# We're referencing a class from within itself.
return node.Replace(name=name)
elif "." in name:
prefix = name.rsplit(".", 1)[0]
if prefix == self.cls_stack[-1].name:
# The parser leaves aliases to nested classes as
# ImmediateOuter.Nested, so we need to preserve the outer class.
return node.Replace(name=name)
return node

def VisitClass(self, node: pytd.Class) -> pytd.Class:
name = self._removeprefix(node.name, self._SuperClassString())
return node.Replace(name=name)

def VisitTypeParameter(self, node: pytd.TypeParameter) -> pytd.TypeParameter:
if not node.scope:
return node
# If the type parameter's scope was the module name, set it back to its
# original value of None.
if node.scope == self.name:
return node.Replace(scope=None)
scope = self._removeprefix(node.scope, self.prefix)
return node.Replace(scope=scope)

def _VisitNamedNode(self, node: _N) -> _N:
if self.cls_stack:
return node
else:
# global constant. Rename to its relative name.
return node.Replace(
name=module_utils.get_relative_name(self.name, node.name)
)

def VisitFunction(self, node: pytd.Function) -> pytd.Function:
return self._VisitNamedNode(node)

def VisitConstant(self, node: pytd.Constant) -> pytd.Constant:
return self._VisitNamedNode(node)

def VisitAlias(self, node: pytd.Alias) -> pytd.Alias:
return self._VisitNamedNode(node)

def VisitModule(self, node: pytd.Module) -> pytd.Module:
return self._VisitNamedNode(node)


class CollectDependencies(Visitor):
"""Visitor for retrieving module names from external types.
Expand Down
154 changes: 154 additions & 0 deletions pytype/pytd/visitors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,160 @@ def f(
""").strip())


class RemoveNamePrefixTest(parser_test_base.ParserTest):
"""Tests for RemoveNamePrefix."""

def test_remove_name_prefix(self):
src = textwrap.dedent("""
from typing import TypeVar
def f(a: T) -> T: ...
T = TypeVar("T")
class X(Generic[T]):
pass
""")
expected = textwrap.dedent("""
from typing import TypeVar
T = TypeVar('T')
class X(Generic[T]): ...
def f(a: T) -> T: ...
""").strip()
tree = self.Parse(src)

# type parameters
t = tree.Lookup("T").Replace(scope="foo")

# classes
x = tree.Lookup("X")
x_template = x.template[0]
x_type_param = x_template.type_param.Replace(scope="foo.X")
x_template = x_template.Replace(type_param=x_type_param)
x = x.Replace(name="foo.X", template=(x_template,))

# functions
f = tree.Lookup("f")
f_sig = f.signatures[0]
f_param = f_sig.params[0]
f_type_param = f_param.type.Replace(scope="foo.f")
f_param = f_param.Replace(type=f_type_param)
f_template = f_sig.template[0].Replace(type_param=f_type_param)
f_sig = f_sig.Replace(
params=(f_param,), return_type=f_type_param, template=(f_template,)
)
f = f.Replace(name="foo.f", signatures=(f_sig,))

tree = tree.Replace(
classes=(x,), functions=(f,), type_params=(t,), name="foo"
)
tree = tree.Visit(visitors.RemoveNamePrefix())
self.assertMultiLineEqual(expected, pytd_utils.Print(tree))

def test_remove_name_prefix_twice(self):
src = textwrap.dedent("""
from typing import Any, TypeVar
x = ... # type: Any
T = TypeVar("T")
class X(Generic[T]): ...
""")
expected_one = textwrap.dedent("""
from typing import Any, TypeVar
foo.x: Any
T = TypeVar('T')
class foo.X(Generic[T]): ...
""").strip()
expected_two = textwrap.dedent("""
from typing import Any, TypeVar
x: Any
T = TypeVar('T')
class X(Generic[T]): ...
""").strip()
tree = self.Parse(src)

# constants
x = tree.Lookup("x").Replace(name="foo.foo.x")

# type parameters
t = tree.Lookup("T").Replace(scope="foo.foo")

# classes
x_cls = tree.Lookup("X")
x_template = x_cls.template[0]
x_type_param = x_template.type_param.Replace(scope="foo.foo.X")
x_template = x_template.Replace(type_param=x_type_param)
x_cls = x_cls.Replace(name="foo.foo.X", template=(x_template,))

tree = tree.Replace(
classes=(x_cls,), constants=(x,), type_params=(t,), name="foo"
)
tree = tree.Visit(visitors.RemoveNamePrefix())
self.assertMultiLineEqual(expected_one, pytd_utils.Print(tree))
tree = tree.Visit(visitors.RemoveNamePrefix())
self.assertMultiLineEqual(expected_two, pytd_utils.Print(tree))

def test_remove_name_prefix_on_class_type(self):
src = textwrap.dedent("""
x = ... # type: y
class Y: ...
""")
expected = textwrap.dedent("""
x: Y
class Y: ...
""").strip()
tree = self.Parse(src)

# constants
x = tree.Lookup("x").Replace(name="foo.x", type=pytd.ClassType("foo.Y"))

# classes
y = tree.Lookup("Y").Replace(name="foo.Y")

tree = tree.Replace(classes=(y,), constants=(x,), name="foo")
tree = tree.Visit(visitors.RemoveNamePrefix())
self.assertMultiLineEqual(expected, pytd_utils.Print(tree))

def test_remove_name_prefix_on_nested_class(self):
src = textwrap.dedent("""
class A:
class B:
class C: ...
D = A.B.C
""")
expected = textwrap.dedent("""
from typing import Type
class A:
class B:
class C: ...
D: Type[A.B.C]
""").strip()
tree = self.Parse(src)

# classes
a = tree.Lookup("A")
b = a.Lookup("B")
c = b.Lookup("C").Replace(name="foo.A.B.C")
d = b.Lookup("D")
d_type = d.type
d_generic = d.type.parameters[0].Replace(name="foo.A.B.C")
d_type = d_type.Replace(parameters=(d_generic,))
d = d.Replace(type=d_type)
b = b.Replace(classes=(c,), constants=(d,), name="foo.A.B")
a = a.Replace(classes=(b,), name="foo.A")

tree = tree.Replace(classes=(a,), name="foo")
tree = tree.Visit(visitors.RemoveNamePrefix())
self.assertMultiLineEqual(expected, pytd_utils.Print(tree))


class ReplaceModulesWithAnyTest(unittest.TestCase):

def test_any_replacement(self):
Expand Down

0 comments on commit 88d1ae5

Please sign in to comment.