diff --git a/libcst/codemod/visitors/__init__.py b/libcst/codemod/visitors/__init__.py index 1cbbd2c8c..632d6fa64 100644 --- a/libcst/codemod/visitors/__init__.py +++ b/libcst/codemod/visitors/__init__.py @@ -7,6 +7,7 @@ from libcst.codemod.visitors._apply_type_annotations import ApplyTypeAnnotationsVisitor from libcst.codemod.visitors._gather_comments import GatherCommentsVisitor from libcst.codemod.visitors._gather_exports import GatherExportsVisitor +from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor from libcst.codemod.visitors._gather_imports import GatherImportsVisitor from libcst.codemod.visitors._gather_string_annotation_names import ( GatherNamesFromStringAnnotationsVisitor, @@ -20,6 +21,7 @@ "ApplyTypeAnnotationsVisitor", "GatherCommentsVisitor", "GatherExportsVisitor", + "GatherGlobalNamesVisitor", "GatherImportsVisitor", "GatherNamesFromStringAnnotationsVisitor", "GatherUnusedImportsVisitor", diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index b439da829..fe74b391c 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -12,6 +12,7 @@ from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareTransformer from libcst.codemod.visitors._add_imports import AddImportsVisitor +from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor from libcst.codemod.visitors._gather_imports import GatherImportsVisitor from libcst.helpers import get_full_name_for_node from libcst.metadata import PositionProvider, QualifiedNameProvider @@ -595,6 +596,11 @@ def __init__( self.current_assign: Optional[cst.Assign] = None self.typevars: Dict[str, cst.Assign] = {} + # Global variables and classes defined on the toplevel of the target module. + # Used to help determine which names we need to check are in scope, and add + # quotations to avoid undefined forward references in type annotations. + self.global_names: Set[str] = set() + @staticmethod def store_stub_in_context( context: CodemodContext, @@ -631,11 +637,19 @@ def transform_module_impl( Collect type annotations from all stubs and apply them to ``tree``. Gather existing imports from ``tree`` so that we don't add duplicate imports. + + Gather global names from ``tree`` so forward references are quoted. """ import_gatherer = GatherImportsVisitor(CodemodContext()) tree.visit(import_gatherer) existing_import_names = _get_imported_names(import_gatherer.all_imports) + global_names_gatherer = GatherGlobalNamesVisitor(CodemodContext()) + tree.visit(global_names_gatherer) + self.global_names = global_names_gatherer.global_names.union( + global_names_gatherer.class_names + ) + context_contents = self.context.scratch.get( ApplyTypeAnnotationsVisitor.CONTEXT_KEY ) @@ -677,6 +691,26 @@ def transform_module_impl( else: return tree + # helpers for processing annotation nodes + def _quote_future_annotations(self, annotation: cst.Annotation) -> cst.Annotation: + # TODO: We probably want to make sure references to classes defined in the current + # module come to us fully qualified - so we can do the dequalification here and + # know to look for what is in-scope without also catching builtins like "None" in the + # quoting. This should probably also be extended to handle what imports are in scope, + # as well as subscriptable types. + # Note: We are collecting all imports and passing this to the type collector grabbing + # annotations from the stub file; should consolidate import handling somewhere too. + node = annotation.annotation + if ( + isinstance(node, cst.Name) + and (node.value in self.global_names) + and not (node.value in self.visited_classes) + ): + return annotation.with_changes( + annotation=cst.SimpleString(value=f'"{node.value}"') + ) + return annotation + # smart constructors: all applied annotations happen via one of these def _apply_annotation_to_attribute_or_global( @@ -691,7 +725,7 @@ def _apply_annotation_to_attribute_or_global( self.annotation_counts.attribute_annotations += 1 return cst.AnnAssign( cst.Name(name), - annotation, + self._quote_future_annotations(annotation), value, ) @@ -702,7 +736,7 @@ def _apply_annotation_to_parameter( ) -> cst.Param: self.annotation_counts.parameter_annotations += 1 return parameter.with_changes( - annotation=annotation, + annotation=self._quote_future_annotations(annotation), ) def _apply_annotation_to_return( @@ -711,7 +745,9 @@ def _apply_annotation_to_return( annotation: cst.Annotation, ) -> cst.FunctionDef: self.annotation_counts.return_annotations += 1 - return function_def.with_changes(returns=annotation) + return function_def.with_changes( + returns=self._quote_future_annotations(annotation), + ) # private methods used in the visit and leave methods @@ -948,13 +984,13 @@ def visit_ClassDef( node: cst.ClassDef, ) -> None: self.qualifier.append(node.name.value) - self.visited_classes.add(node.name.value) def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef, ) -> cst.ClassDef: + self.visited_classes.add(original_node.name.value) cls_name = ".".join(self.qualifier) self.qualifier.pop() definition = self.annotations.class_definitions.get(cls_name) diff --git a/libcst/codemod/visitors/_gather_global_names.py b/libcst/codemod/visitors/_gather_global_names.py new file mode 100644 index 000000000..c4a5d57db --- /dev/null +++ b/libcst/codemod/visitors/_gather_global_names.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set + +import libcst +from libcst.codemod._context import CodemodContext +from libcst.codemod._visitor import ContextAwareVisitor + + +class GatherGlobalNamesVisitor(ContextAwareVisitor): + """ + Gathers all globally accessible names defined in a module and stores them as + attributes on the instance. + Intended to be instantiated and passed to a :class:`~libcst.Module` + :meth:`~libcst.CSTNode.visit` method in order to gather up information about + names defined on a module. Note that this is not a substitute for scope + analysis or qualified name support. Please see :ref:`libcst-scope-tutorial` + for a more robust way of determining the qualified name and definition for + an arbitrary node. + Names that are globally accessible through imports are currently not included + but can be retrieved with GatherImportsVisitor. + + After visiting a module the following attributes will be populated: + + global_names + A sequence of strings representing global variables defined in the module + toplevel. + class_names + A sequence of strings representing classes defined in the module toplevel. + function_names + A sequence of strings representing functions defined in the module toplevel. + + """ + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + self.global_names: Set[str] = set() + self.class_names: Set[str] = set() + self.function_names: Set[str] = set() + # Track scope nesting + self.scope_depth: int = 0 + + def visit_ClassDef(self, node: libcst.ClassDef) -> None: + if self.scope_depth == 0: + self.class_names.add(node.name.value) + self.scope_depth += 1 + + def leave_ClassDef(self, original_node: libcst.ClassDef) -> None: + self.scope_depth -= 1 + + def visit_FunctionDef(self, node: libcst.FunctionDef) -> None: + if self.scope_depth == 0: + self.function_names.add(node.name.value) + self.scope_depth += 1 + + def leave_FunctionDef(self, original_node: libcst.FunctionDef) -> None: + self.scope_depth -= 1 + + def visit_Assign(self, node: libcst.Assign) -> None: + if self.scope_depth != 0: + return + for assign_target in node.targets: + target = assign_target.target + if isinstance(target, libcst.Name): + self.global_names.add(target.value) + + def visit_AnnAssign(self, node: libcst.AnnAssign) -> None: + if self.scope_depth != 0: + return + target = node.target + if isinstance(target, libcst.Name): + self.global_names.add(target.value) diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index c7e3695d5..9a3e59d79 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -1058,6 +1058,124 @@ def test_annotate_functions_with_existing_annotations( overwrite_existing_annotations=True, ) + @data_provider( + { + "return_self": ( + """ + class Foo: + def f(self) -> Foo: ... + """, + """ + class Foo: + def f(self): + return self + """, + """ + class Foo: + def f(self) -> "Foo": + return self + """, + ), + "return_forward_reference": ( + """ + class Foo: + def f(self) -> Bar: ... + + class Bar: + ... + """, + """ + class Foo: + def f(self): + return Bar() + + class Bar: + pass + """, + """ + class Foo: + def f(self) -> "Bar": + return Bar() + + class Bar: + pass + """, + ), + "return_backward_reference": ( + """ + class Bar: + ... + + class Foo: + def f(self) -> Bar: ... + """, + """ + class Bar: + pass + + class Foo: + def f(self): + return Bar() + """, + """ + class Bar: + pass + + class Foo: + def f(self) -> Bar: + return Bar() + """, + ), + "return_undefined_name": ( + """ + class Foo: + def f(self) -> Bar: ... + """, + """ + class Foo: + def f(self): + return self + """, + """ + class Foo: + def f(self) -> Bar: + return self + """, + ), + "parameter_forward_reference": ( + """ + def f(input: Bar) -> None: ... + + class Bar: + ... + """, + """ + def f(input): + pass + + class Bar: + pass + """, + """ + def f(input: "Bar") -> None: + pass + + class Bar: + pass + """, + ), + } + ) + def test_annotate_with_forward_references( + self, stub: str, before: str, after: str + ) -> None: + self.run_test_case_with_flags( + stub=stub, + before=before, + after=after, + overwrite_existing_annotations=True, + ) + @data_provider( { "fully_annotated_with_untyped_stub": ( diff --git a/libcst/codemod/visitors/tests/test_gather_global_names.py b/libcst/codemod/visitors/tests/test_gather_global_names.py new file mode 100644 index 000000000..8a7a7b8ba --- /dev/null +++ b/libcst/codemod/visitors/tests/test_gather_global_names.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +from libcst import parse_module +from libcst.codemod import CodemodContext, CodemodTest +from libcst.codemod.visitors import GatherGlobalNamesVisitor +from libcst.testing.utils import UnitTest + + +class TestGatherGlobalNamesVisitor(UnitTest): + def gather_global_names(self, code: str) -> GatherGlobalNamesVisitor: + transform_instance = GatherGlobalNamesVisitor( + CodemodContext(full_module_name="a.b.foobar") + ) + input_tree = parse_module(CodemodTest.make_fixture_data(code)) + input_tree.visit(transform_instance) + return transform_instance + + def test_gather_nothing(self) -> None: + code = """ + from a import b + b() + """ + gatherer = self.gather_global_names(code) + self.assertEqual(gatherer.global_names, set()) + self.assertEqual(gatherer.class_names, set()) + self.assertEqual(gatherer.function_names, set()) + + def test_globals(self) -> None: + code = """ + x = 1 + y = 2 + def foo(): pass + class Foo: pass + """ + gatherer = self.gather_global_names(code) + self.assertEqual(gatherer.global_names, {"x", "y"}) + self.assertEqual(gatherer.class_names, {"Foo"}) + self.assertEqual(gatherer.function_names, {"foo"}) + + def test_omit_nested(self) -> None: + code = """ + def foo(): + x = 1 + + class Foo: + def method(self): pass + """ + gatherer = self.gather_global_names(code) + self.assertEqual(gatherer.global_names, set()) + self.assertEqual(gatherer.class_names, {"Foo"}) + self.assertEqual(gatherer.function_names, {"foo"})