diff --git a/libcst/codemod/commands/convert_type_comments.py b/libcst/codemod/commands/convert_type_comments.py index baccae520..5af002731 100644 --- a/libcst/codemod/commands/convert_type_comments.py +++ b/libcst/codemod/commands/convert_type_comments.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse import ast import builtins import dataclasses @@ -83,18 +84,27 @@ def _is_builtin(annotation: str) -> bool: return annotation in _builtins() -def _convert_annotation(raw: str) -> cst.Annotation: - # Convert annotation comments to string annotations to be safe, - # otherwise runtime errors would be common. - # - # Special-case builtins to reduce the amount of quoting noise. - # - # NOTE: we could potentially detect more cases for skipping quotes - # using ScopeProvider, which would make the output prettier. +def _convert_annotation( + raw: str, + quote_annotations: bool, +) -> cst.Annotation: + """ + Convert a raw annotation - which is a string coming from a type + comment - into a suitable libcst Annotation node. + + If `quote_annotations`, we'll always quote annotations unless they are builtin + types. The reason for this is to make the codemod safer to apply + on legacy code where type comments may well include invalid types + that would crash at runtime. + """ if _is_builtin(raw): return cst.Annotation(annotation=cst.Name(value=raw)) - else: - return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"')) + if not quote_annotations: + try: + return cst.Annotation(annotation=cst.parse_expression(raw)) + except cst.ParserSyntaxError: + pass + return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"')) def _is_type_comment(comment: Optional[cst.Comment]) -> bool: @@ -195,10 +205,14 @@ def annotated_bindings( def type_declaration( binding: cst.BaseAssignTargetExpression, raw_annotation: str, + quote_annotations: bool, ) -> cst.AnnAssign: return cst.AnnAssign( target=binding, - annotation=_convert_annotation(raw=raw_annotation), + annotation=_convert_annotation( + raw=raw_annotation, + quote_annotations=quote_annotations, + ), value=None, ) @@ -207,6 +221,7 @@ def type_declaration_statements( bindings: UnpackedBindings, annotations: UnpackedAnnotations, leading_lines: Sequence[cst.EmptyLine], + quote_annotations: bool, ) -> List[cst.SimpleStatementLine]: return [ cst.SimpleStatementLine( @@ -214,6 +229,7 @@ def type_declaration_statements( AnnotationSpreader.type_declaration( binding=binding, raw_annotation=raw_annotation, + quote_annotations=quote_annotations, ) ], leading_lines=leading_lines if i == 0 else [], @@ -230,6 +246,7 @@ def type_declaration_statements( def convert_Assign( node: cst.Assign, annotation: ast.expr, + quote_annotations: bool, ) -> Union[ _FailedToApplyAnnotation, cst.AnnAssign, @@ -255,7 +272,10 @@ def convert_Assign( binding, raw_annotation = annotated_targets[0][0] return cst.AnnAssign( target=binding, - annotation=_convert_annotation(raw=raw_annotation), + annotation=_convert_annotation( + raw=raw_annotation, + quote_annotations=quote_annotations, + ), value=node.value, semicolon=node.semicolon, ) @@ -264,7 +284,11 @@ def convert_Assign( # on the LHS or multiple `=` tokens or both), we need to add a type # declaration per individual LHS target. type_declarations = [ - AnnotationSpreader.type_declaration(binding, raw_annotation) + AnnotationSpreader.type_declaration( + binding, + raw_annotation, + quote_annotations=quote_annotations, + ) for annotated_bindings in annotated_targets for binding, raw_annotation in annotated_bindings ] @@ -388,7 +412,7 @@ class ConvertTypeComments(VisitorBasedCodemodCommand): - For parameters, we prefer inline type comments to function-level type comments if we find both. - We always apply the type comments as quoted annotations, unless + We always apply the type comments as quote_annotations annotations, unless we know that it refers to a builtin. We do not guarantee that the resulting string annotations would parse, but they should never cause failures at module import time. @@ -427,7 +451,22 @@ class ConvertTypeComments(VisitorBasedCodemodCommand): function_body_stack: List[cst.BaseSuite] aggressively_strip_type_comments: bool - def __init__(self, context: CodemodContext) -> None: + @staticmethod + def add_args(arg_parser: argparse.ArgumentParser) -> None: + arg_parser.add_argument( + "--no-quote-annotations", + action="store_true", + help=( + "Add unquoted annotations. This leads to prettier code " + + "but possibly more errors if type comments are invalid." + ), + ) + + def __init__( + self, + context: CodemodContext, + no_quote_annotations: bool = False, + ) -> None: if (sys.version_info.major, sys.version_info.minor) < (3, 9): # The ast module did not get `unparse` until Python 3.9, # or `type_comments` until Python 3.8 @@ -444,6 +483,9 @@ def __init__(self, context: CodemodContext) -> None: + "it is only libcst that needs a new Python version." ) super().__init__(context) + # flags used to control overall behavior + self.quote_annotations: bool = not no_quote_annotations + # state used to manage how we traverse nodes in various contexts self.function_type_info_stack = [] self.function_body_stack = [] self.aggressively_strip_type_comments = False @@ -480,6 +522,7 @@ def leave_SimpleStatementLine( converted = convert_Assign( node=assign, annotation=annotation, + quote_annotations=self.quote_annotations, ) if isinstance(converted, _FailedToApplyAnnotation): # We were unable to consume the type comment, so return the @@ -556,6 +599,7 @@ def leave_For( bindings=AnnotationSpreader.unpack_target(updated_node.target), annotations=AnnotationSpreader.unpack_annotation(annotation), leading_lines=updated_node.leading_lines, + quote_annotations=self.quote_annotations, ) except _ArityError: return updated_node @@ -606,6 +650,7 @@ def leave_With( bindings=AnnotationSpreader.unpack_target(target), annotations=AnnotationSpreader.unpack_annotation(annotation), leading_lines=updated_node.leading_lines, + quote_annotations=self.quote_annotations, ) except _ArityError: return updated_node @@ -773,7 +818,10 @@ def leave_Param( raw_annotation = function_type_info.arguments.get(updated_node.name.value) if raw_annotation is not None: return updated_node.with_changes( - annotation=_convert_annotation(raw=raw_annotation) + annotation=_convert_annotation( + raw=raw_annotation, + quote_annotations=self.quote_annotations, + ) ) else: return updated_node @@ -787,7 +835,10 @@ def leave_FunctionDef( function_type_info = self.function_type_info_stack.pop() if updated_node.returns is None and function_type_info.returns is not None: return updated_node.with_changes( - returns=_convert_annotation(raw=function_type_info.returns) + returns=_convert_annotation( + raw=function_type_info.returns, + quote_annotations=self.quote_annotations, + ) ) else: return updated_node diff --git a/libcst/codemod/commands/tests/test_convert_type_comments.py b/libcst/codemod/commands/tests/test_convert_type_comments.py index 6bd5a8a3e..98eaa7670 100644 --- a/libcst/codemod/commands/tests/test_convert_type_comments.py +++ b/libcst/codemod/commands/tests/test_convert_type_comments.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import sys +from typing import Any from libcst.codemod import CodemodTest from libcst.codemod.commands.convert_type_comments import ConvertTypeComments @@ -14,16 +15,16 @@ class TestConvertTypeCommentsBase(CodemodTest): maxDiff = 1500 TRANSFORM = ConvertTypeComments - def assertCodemod39Plus(self, before: str, after: str) -> None: + def assertCodemod39Plus(self, before: str, after: str, **kwargs: Any) -> None: """ Assert that the codemod works on Python 3.9+, and that we raise a NotImplementedError on other Python versions. """ if (sys.version_info.major, sys.version_info.minor) < (3, 9): with self.assertRaises(NotImplementedError): - super().assertCodemod(before, after) + super().assertCodemod(before, after, **kwargs) else: - super().assertCodemod(before, after) + super().assertCodemod(before, after, **kwargs) class TestConvertTypeComments_AssignForWith(TestConvertTypeCommentsBase): @@ -436,3 +437,37 @@ class WrapsAFunction: """ after = before self.assertCodemod39Plus(before, after) + + def test_no_quoting(self) -> None: + before = """ + def f(x): + # type: (Foo) -> Foo + pass + w = x # type: Foo + y, z = x, x # type: (Foo, Foo) + return w + + with get_context() as context: # type: Context + pass + + for loop_var in the_iterable: # type: LoopType + pass + """ + after = """ + def f(x: Foo) -> Foo: + pass + w: Foo = x + y: Foo + z: Foo + y, z = x, x + return w + + context: Context + with get_context() as context: + pass + + loop_var: LoopType + for loop_var in the_iterable: + pass + """ + self.assertCodemod39Plus(before, after, no_quote_annotations=True)