Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define gather global names visitor #657

Merged
merged 2 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libcst/codemod/visitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from libcst.codemod.visitors._apply_type_annotations import ApplyTypeAnnotationsVisitor
from libcst.codemod.visitors._gather_comments import GatherCommentsVisitor
from libcst.codemod.visitors._gather_exports import GatherExportsVisitor
from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._gather_string_annotation_names import (
GatherNamesFromStringAnnotationsVisitor,
Expand All @@ -20,6 +21,7 @@
"ApplyTypeAnnotationsVisitor",
"GatherCommentsVisitor",
"GatherExportsVisitor",
"GatherGlobalNamesVisitor",
"GatherImportsVisitor",
"GatherNamesFromStringAnnotationsVisitor",
"GatherUnusedImportsVisitor",
Expand Down
44 changes: 40 additions & 4 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._add_imports import AddImportsVisitor
from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.helpers import get_full_name_for_node
from libcst.metadata import PositionProvider, QualifiedNameProvider
Expand Down Expand Up @@ -595,6 +596,11 @@ def __init__(
self.current_assign: Optional[cst.Assign] = None
self.typevars: Dict[str, cst.Assign] = {}

# Global variables and classes defined on the toplevel of the target module.
# Used to help determine which names we need to check are in scope, and add
# quotations to avoid undefined forward references in type annotations.
self.global_names: Set[str] = set()

@staticmethod
def store_stub_in_context(
context: CodemodContext,
Expand Down Expand Up @@ -631,11 +637,19 @@ def transform_module_impl(
Collect type annotations from all stubs and apply them to ``tree``.

Gather existing imports from ``tree`` so that we don't add duplicate imports.

Gather global names from ``tree`` so forward references are quoted.
"""
import_gatherer = GatherImportsVisitor(CodemodContext())
tree.visit(import_gatherer)
existing_import_names = _get_imported_names(import_gatherer.all_imports)

global_names_gatherer = GatherGlobalNamesVisitor(CodemodContext())
tree.visit(global_names_gatherer)
self.global_names = global_names_gatherer.global_names.union(
global_names_gatherer.class_names
)

context_contents = self.context.scratch.get(
ApplyTypeAnnotationsVisitor.CONTEXT_KEY
)
Expand Down Expand Up @@ -677,6 +691,26 @@ def transform_module_impl(
else:
return tree

# helpers for processing annotation nodes
def _quote_future_annotations(self, annotation: cst.Annotation) -> cst.Annotation:
# TODO: We probably want to make sure references to classes defined in the current
# module come to us fully qualified - so we can do the dequalification here and
# know to look for what is in-scope without also catching builtins like "None" in the
# quoting. This should probably also be extended to handle what imports are in scope,
# as well as subscriptable types.
# Note: We are collecting all imports and passing this to the type collector grabbing
# annotations from the stub file; should consolidate import handling somewhere too.
node = annotation.annotation
if (
isinstance(node, cst.Name)
and (node.value in self.global_names)
and not (node.value in self.visited_classes)
):
return annotation.with_changes(
annotation=cst.SimpleString(value=f'"{node.value}"')
)
return annotation

# smart constructors: all applied annotations happen via one of these

def _apply_annotation_to_attribute_or_global(
Expand All @@ -691,7 +725,7 @@ def _apply_annotation_to_attribute_or_global(
self.annotation_counts.attribute_annotations += 1
return cst.AnnAssign(
cst.Name(name),
annotation,
self._quote_future_annotations(annotation),
value,
)

Expand All @@ -702,7 +736,7 @@ def _apply_annotation_to_parameter(
) -> cst.Param:
self.annotation_counts.parameter_annotations += 1
return parameter.with_changes(
annotation=annotation,
annotation=self._quote_future_annotations(annotation),
)

def _apply_annotation_to_return(
Expand All @@ -711,7 +745,9 @@ def _apply_annotation_to_return(
annotation: cst.Annotation,
) -> cst.FunctionDef:
self.annotation_counts.return_annotations += 1
return function_def.with_changes(returns=annotation)
return function_def.with_changes(
returns=self._quote_future_annotations(annotation),
)

# private methods used in the visit and leave methods

Expand Down Expand Up @@ -948,13 +984,13 @@ def visit_ClassDef(
node: cst.ClassDef,
) -> None:
self.qualifier.append(node.name.value)
self.visited_classes.add(node.name.value)

def leave_ClassDef(
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
self.visited_classes.add(original_node.name.value)
cls_name = ".".join(self.qualifier)
self.qualifier.pop()
definition = self.annotations.class_definitions.get(cls_name)
Expand Down
75 changes: 75 additions & 0 deletions libcst/codemod/visitors/_gather_global_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set

import libcst
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor


class GatherGlobalNamesVisitor(ContextAwareVisitor):
"""
Gathers all globally accessible names defined in a module and stores them as
attributes on the instance.
Intended to be instantiated and passed to a :class:`~libcst.Module`
:meth:`~libcst.CSTNode.visit` method in order to gather up information about
names defined on a module. Note that this is not a substitute for scope
analysis or qualified name support. Please see :ref:`libcst-scope-tutorial`
for a more robust way of determining the qualified name and definition for
an arbitrary node.
Names that are globally accessible through imports are currently not included
but can be retrieved with GatherImportsVisitor.

After visiting a module the following attributes will be populated:

global_names
A sequence of strings representing global variables defined in the module
toplevel.
class_names
A sequence of strings representing classes defined in the module toplevel.
function_names
A sequence of strings representing functions defined in the module toplevel.

"""

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
self.global_names: Set[str] = set()
self.class_names: Set[str] = set()
self.function_names: Set[str] = set()
# Track scope nesting
self.scope_depth: int = 0

def visit_ClassDef(self, node: libcst.ClassDef) -> None:
if self.scope_depth == 0:
self.class_names.add(node.name.value)
self.scope_depth += 1

def leave_ClassDef(self, original_node: libcst.ClassDef) -> None:
self.scope_depth -= 1

def visit_FunctionDef(self, node: libcst.FunctionDef) -> None:
if self.scope_depth == 0:
self.function_names.add(node.name.value)
self.scope_depth += 1

def leave_FunctionDef(self, original_node: libcst.FunctionDef) -> None:
self.scope_depth -= 1

def visit_Assign(self, node: libcst.Assign) -> None:
if self.scope_depth != 0:
return
for assign_target in node.targets:
target = assign_target.target
if isinstance(target, libcst.Name):
self.global_names.add(target.value)

def visit_AnnAssign(self, node: libcst.AnnAssign) -> None:
if self.scope_depth != 0:
return
target = node.target
if isinstance(target, libcst.Name):
self.global_names.add(target.value)
118 changes: 118 additions & 0 deletions libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,124 @@ def test_annotate_functions_with_existing_annotations(
overwrite_existing_annotations=True,
)

@data_provider(
{
"return_self": (
"""
class Foo:
def f(self) -> Foo: ...
""",
"""
class Foo:
def f(self):
return self
""",
"""
class Foo:
def f(self) -> "Foo":
return self
""",
),
"return_forward_reference": (
"""
class Foo:
def f(self) -> Bar: ...

class Bar:
...
""",
"""
class Foo:
def f(self):
return Bar()

class Bar:
pass
""",
"""
class Foo:
def f(self) -> "Bar":
return Bar()

class Bar:
pass
""",
),
"return_backward_reference": (
"""
class Bar:
...

class Foo:
def f(self) -> Bar: ...
""",
"""
class Bar:
pass

class Foo:
def f(self):
return Bar()
""",
"""
class Bar:
pass

class Foo:
def f(self) -> Bar:
return Bar()
""",
),
"return_undefined_name": (
"""
class Foo:
def f(self) -> Bar: ...
""",
"""
class Foo:
def f(self):
return self
""",
"""
class Foo:
def f(self) -> Bar:
return self
""",
),
"parameter_forward_reference": (
"""
def f(input: Bar) -> None: ...

class Bar:
...
""",
"""
def f(input):
pass

class Bar:
pass
""",
"""
def f(input: "Bar") -> None:
pass

class Bar:
pass
""",
),
}
)
def test_annotate_with_forward_references(
self, stub: str, before: str, after: str
) -> None:
self.run_test_case_with_flags(
stub=stub,
before=before,
after=after,
overwrite_existing_annotations=True,
)

@data_provider(
{
"fully_annotated_with_untyped_stub": (
Expand Down
54 changes: 54 additions & 0 deletions libcst/codemod/visitors/tests/test_gather_global_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from libcst import parse_module
from libcst.codemod import CodemodContext, CodemodTest
from libcst.codemod.visitors import GatherGlobalNamesVisitor
from libcst.testing.utils import UnitTest


class TestGatherGlobalNamesVisitor(UnitTest):
def gather_global_names(self, code: str) -> GatherGlobalNamesVisitor:
transform_instance = GatherGlobalNamesVisitor(
CodemodContext(full_module_name="a.b.foobar")
)
input_tree = parse_module(CodemodTest.make_fixture_data(code))
input_tree.visit(transform_instance)
return transform_instance

def test_gather_nothing(self) -> None:
code = """
from a import b
b()
"""
gatherer = self.gather_global_names(code)
self.assertEqual(gatherer.global_names, set())
self.assertEqual(gatherer.class_names, set())
self.assertEqual(gatherer.function_names, set())

def test_globals(self) -> None:
code = """
x = 1
y = 2
def foo(): pass
class Foo: pass
"""
gatherer = self.gather_global_names(code)
self.assertEqual(gatherer.global_names, {"x", "y"})
self.assertEqual(gatherer.class_names, {"Foo"})
self.assertEqual(gatherer.function_names, {"foo"})

def test_omit_nested(self) -> None:
code = """
def foo():
x = 1

class Foo:
def method(self): pass
"""
gatherer = self.gather_global_names(code)
self.assertEqual(gatherer.global_names, set())
self.assertEqual(gatherer.class_names, {"Foo"})
self.assertEqual(gatherer.function_names, {"foo"})