diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index b439da829..760369046 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -12,6 +12,7 @@ from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareTransformer from libcst.codemod.visitors._add_imports import AddImportsVisitor +from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor from libcst.codemod.visitors._gather_imports import GatherImportsVisitor from libcst.helpers import get_full_name_for_node from libcst.metadata import PositionProvider, QualifiedNameProvider @@ -595,6 +596,12 @@ def __init__( self.current_assign: Optional[cst.Assign] = None self.typevars: Dict[str, cst.Assign] = {} + # Global variables and classes defined on the toplevel of the target module. + # Used to help determine which names we need to check are in scope, and add + # quotations to avoid undefined forward references in type annotations. + self.global_names: Set[str] = set() + + @staticmethod def store_stub_in_context( context: CodemodContext, @@ -631,11 +638,17 @@ def transform_module_impl( Collect type annotations from all stubs and apply them to ``tree``. Gather existing imports from ``tree`` so that we don't add duplicate imports. + + Gather global names from ``tree`` so forward references are quoted. """ import_gatherer = GatherImportsVisitor(CodemodContext()) tree.visit(import_gatherer) existing_import_names = _get_imported_names(import_gatherer.all_imports) + global_names_gatherer = GatherGlobalNamesVisitor(CodemodContext()) + tree.visit(global_names_gatherer) + self.global_names = global_names_gatherer.global_names.union(global_names_gatherer.class_names) + context_contents = self.context.scratch.get( ApplyTypeAnnotationsVisitor.CONTEXT_KEY ) @@ -677,6 +690,26 @@ def transform_module_impl( else: return tree + # helpers for processing annotation nodes + def _quote_future_annotations(self, annotation: cst.Annotation) -> cst.Annotation: + # TODO: We probably want to make sure references to classes defined in the current + # module come to us fully qualified - so we can do the dequalification here and + # know to look for what is in-scope without also catching builtins like "None" in the + # quoting. This should probably also be extended to handle what imports are in scope, + # as well as subscriptable types. + # Note: We are collecting all imports and passing this to the type collector grabbing + # annotations from the stub file; should consolidate import handling somewhere too. + node = annotation.annotation + if ( + isinstance(node, cst.Name) + and (node.value in self.global_names) + and not (node.value in self.visited_classes) + ): + return annotation.with_changes( + annotation=cst.SimpleString(value=f'"{node.value}"') + ) + return annotation + # smart constructors: all applied annotations happen via one of these def _apply_annotation_to_attribute_or_global( @@ -691,7 +724,7 @@ def _apply_annotation_to_attribute_or_global( self.annotation_counts.attribute_annotations += 1 return cst.AnnAssign( cst.Name(name), - annotation, + self._quote_future_annotations(annotation), value, ) @@ -702,7 +735,7 @@ def _apply_annotation_to_parameter( ) -> cst.Param: self.annotation_counts.parameter_annotations += 1 return parameter.with_changes( - annotation=annotation, + annotation=self._quote_future_annotations(annotation), ) def _apply_annotation_to_return( @@ -711,7 +744,9 @@ def _apply_annotation_to_return( annotation: cst.Annotation, ) -> cst.FunctionDef: self.annotation_counts.return_annotations += 1 - return function_def.with_changes(returns=annotation) + return function_def.with_changes( + returns=self._quote_future_annotations(annotation), + ) # private methods used in the visit and leave methods @@ -948,13 +983,13 @@ def visit_ClassDef( node: cst.ClassDef, ) -> None: self.qualifier.append(node.name.value) - self.visited_classes.add(node.name.value) def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef, ) -> cst.ClassDef: + self.visited_classes.add(original_node.name.value) cls_name = ".".join(self.qualifier) self.qualifier.pop() definition = self.annotations.class_definitions.get(cls_name) diff --git a/libcst/codemod/visitors/_gather_global_names.py b/libcst/codemod/visitors/_gather_global_names.py index 10862b4ad..c4a5d57db 100644 --- a/libcst/codemod/visitors/_gather_global_names.py +++ b/libcst/codemod/visitors/_gather_global_names.py @@ -2,8 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -# -from typing import Dict, List, Sequence, Set, Tuple, Union + +from typing import Set import libcst from libcst.codemod._context import CodemodContext @@ -48,7 +48,7 @@ def visit_ClassDef(self, node: libcst.ClassDef) -> None: self.class_names.add(node.name.value) self.scope_depth += 1 - def leave_ClassDef(self, node: libcst.ClassDef) -> None: + def leave_ClassDef(self, original_node: libcst.ClassDef) -> None: self.scope_depth -= 1 def visit_FunctionDef(self, node: libcst.FunctionDef) -> None: @@ -56,7 +56,7 @@ def visit_FunctionDef(self, node: libcst.FunctionDef) -> None: self.function_names.add(node.name.value) self.scope_depth += 1 - def leave_FunctionDef(self, node: libcst.FunctionDef) -> None: + def leave_FunctionDef(self, original_node: libcst.FunctionDef) -> None: self.scope_depth -= 1 def visit_Assign(self, node: libcst.Assign) -> None: diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index c7e3695d5..9a3e59d79 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -1058,6 +1058,124 @@ def test_annotate_functions_with_existing_annotations( overwrite_existing_annotations=True, ) + @data_provider( + { + "return_self": ( + """ + class Foo: + def f(self) -> Foo: ... + """, + """ + class Foo: + def f(self): + return self + """, + """ + class Foo: + def f(self) -> "Foo": + return self + """, + ), + "return_forward_reference": ( + """ + class Foo: + def f(self) -> Bar: ... + + class Bar: + ... + """, + """ + class Foo: + def f(self): + return Bar() + + class Bar: + pass + """, + """ + class Foo: + def f(self) -> "Bar": + return Bar() + + class Bar: + pass + """, + ), + "return_backward_reference": ( + """ + class Bar: + ... + + class Foo: + def f(self) -> Bar: ... + """, + """ + class Bar: + pass + + class Foo: + def f(self): + return Bar() + """, + """ + class Bar: + pass + + class Foo: + def f(self) -> Bar: + return Bar() + """, + ), + "return_undefined_name": ( + """ + class Foo: + def f(self) -> Bar: ... + """, + """ + class Foo: + def f(self): + return self + """, + """ + class Foo: + def f(self) -> Bar: + return self + """, + ), + "parameter_forward_reference": ( + """ + def f(input: Bar) -> None: ... + + class Bar: + ... + """, + """ + def f(input): + pass + + class Bar: + pass + """, + """ + def f(input: "Bar") -> None: + pass + + class Bar: + pass + """, + ), + } + ) + def test_annotate_with_forward_references( + self, stub: str, before: str, after: str + ) -> None: + self.run_test_case_with_flags( + stub=stub, + before=before, + after=after, + overwrite_existing_annotations=True, + ) + @data_provider( { "fully_annotated_with_untyped_stub": (