Skip to content

Commit

Permalink
Quote forward references when applying annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
shannonzhu committed Mar 8, 2022
1 parent 5294942 commit da2dac2
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 4 deletions.
43 changes: 39 additions & 4 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._add_imports import AddImportsVisitor
from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.helpers import get_full_name_for_node
from libcst.metadata import PositionProvider, QualifiedNameProvider
Expand Down Expand Up @@ -595,6 +596,12 @@ def __init__(
self.current_assign: Optional[cst.Assign] = None
self.typevars: Dict[str, cst.Assign] = {}

# Global variables and classes defined on the toplevel of the target module.
# Used to help determine which names we need to check are in scope, and add
# quotations to avoid undefined forward references in type annotations.
self.global_names: Set[str] = set()


@staticmethod
def store_stub_in_context(
context: CodemodContext,
Expand Down Expand Up @@ -631,11 +638,17 @@ def transform_module_impl(
Collect type annotations from all stubs and apply them to ``tree``.
Gather existing imports from ``tree`` so that we don't add duplicate imports.
Gather global names from ``tree`` so forward references are quoted.
"""
import_gatherer = GatherImportsVisitor(CodemodContext())
tree.visit(import_gatherer)
existing_import_names = _get_imported_names(import_gatherer.all_imports)

global_names_gatherer = GatherGlobalNamesVisitor(CodemodContext())
tree.visit(global_names_gatherer)
self.global_names = global_names_gatherer.global_names.union(global_names_gatherer.class_names)

context_contents = self.context.scratch.get(
ApplyTypeAnnotationsVisitor.CONTEXT_KEY
)
Expand Down Expand Up @@ -677,6 +690,26 @@ def transform_module_impl(
else:
return tree

# helpers for processing annotation nodes
def _quote_future_annotations(self, annotation: cst.Annotation) -> cst.Annotation:
# TODO: We probably want to make sure references to classes defined in the current
# module come to us fully qualified - so we can do the dequalification here and
# know to look for what is in-scope without also catching builtins like "None" in the
# quoting. This should probably also be extended to handle what imports are in scope,
# as well as subscriptable types.
# Note: We are collecting all imports and passing this to the type collector grabbing
# annotations from the stub file; should consolidate import handling somewhere too.
node = annotation.annotation
if (
isinstance(node, cst.Name)
and (node.value in self.global_names)
and not (node.value in self.visited_classes)
):
return annotation.with_changes(
annotation=cst.SimpleString(value=f'"{node.value}"')
)
return annotation

# smart constructors: all applied annotations happen via one of these

def _apply_annotation_to_attribute_or_global(
Expand All @@ -691,7 +724,7 @@ def _apply_annotation_to_attribute_or_global(
self.annotation_counts.attribute_annotations += 1
return cst.AnnAssign(
cst.Name(name),
annotation,
self._quote_future_annotations(annotation),
value,
)

Expand All @@ -702,7 +735,7 @@ def _apply_annotation_to_parameter(
) -> cst.Param:
self.annotation_counts.parameter_annotations += 1
return parameter.with_changes(
annotation=annotation,
annotation=self._quote_future_annotations(annotation),
)

def _apply_annotation_to_return(
Expand All @@ -711,7 +744,9 @@ def _apply_annotation_to_return(
annotation: cst.Annotation,
) -> cst.FunctionDef:
self.annotation_counts.return_annotations += 1
return function_def.with_changes(returns=annotation)
return function_def.with_changes(
returns=self._quote_future_annotations(annotation),
)

# private methods used in the visit and leave methods

Expand Down Expand Up @@ -948,13 +983,13 @@ def visit_ClassDef(
node: cst.ClassDef,
) -> None:
self.qualifier.append(node.name.value)
self.visited_classes.add(node.name.value)

def leave_ClassDef(
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
self.visited_classes.add(original_node.name.value)
cls_name = ".".join(self.qualifier)
self.qualifier.pop()
definition = self.annotations.class_definitions.get(cls_name)
Expand Down
118 changes: 118 additions & 0 deletions libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,124 @@ def test_annotate_functions_with_existing_annotations(
overwrite_existing_annotations=True,
)

@data_provider(
{
"return_self": (
"""
class Foo:
def f(self) -> Foo: ...
""",
"""
class Foo:
def f(self):
return self
""",
"""
class Foo:
def f(self) -> "Foo":
return self
""",
),
"return_forward_reference": (
"""
class Foo:
def f(self) -> Bar: ...
class Bar:
...
""",
"""
class Foo:
def f(self):
return Bar()
class Bar:
pass
""",
"""
class Foo:
def f(self) -> "Bar":
return Bar()
class Bar:
pass
""",
),
"return_backward_reference": (
"""
class Bar:
...
class Foo:
def f(self) -> Bar: ...
""",
"""
class Bar:
pass
class Foo:
def f(self):
return Bar()
""",
"""
class Bar:
pass
class Foo:
def f(self) -> Bar:
return Bar()
""",
),
"return_undefined_name": (
"""
class Foo:
def f(self) -> Bar: ...
""",
"""
class Foo:
def f(self):
return self
""",
"""
class Foo:
def f(self) -> Bar:
return self
""",
),
"parameter_forward_reference": (
"""
def f(input: Bar) -> None: ...
class Bar:
...
""",
"""
def f(input):
pass
class Bar:
pass
""",
"""
def f(input: "Bar") -> None:
pass
class Bar:
pass
""",
),
}
)
def test_annotate_with_forward_references(
self, stub: str, before: str, after: str
) -> None:
self.run_test_case_with_flags(
stub=stub,
before=before,
after=after,
overwrite_existing_annotations=True,
)

@data_provider(
{
"fully_annotated_with_untyped_stub": (
Expand Down

0 comments on commit da2dac2

Please sign in to comment.