Skip to content

Commit

Permalink
Apply Type Comments: Allow for skipping quotes when applying type com…
Browse files Browse the repository at this point in the history
…ments (#644)

* Allow for skipping quotes when applying type comments

* Fix bad flag (tests don't check argparse, I ran it on pytorch)

* Run ufmt
  • Loading branch information
stroxler authored Feb 10, 2022
1 parent 3af6820 commit 0f42a78
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 20 deletions.
85 changes: 68 additions & 17 deletions libcst/codemod/commands/convert_type_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import ast
import builtins
import dataclasses
Expand Down Expand Up @@ -83,18 +84,27 @@ def _is_builtin(annotation: str) -> bool:
return annotation in _builtins()


def _convert_annotation(raw: str) -> cst.Annotation:
# Convert annotation comments to string annotations to be safe,
# otherwise runtime errors would be common.
#
# Special-case builtins to reduce the amount of quoting noise.
#
# NOTE: we could potentially detect more cases for skipping quotes
# using ScopeProvider, which would make the output prettier.
def _convert_annotation(
raw: str,
quote_annotations: bool,
) -> cst.Annotation:
"""
Convert a raw annotation - which is a string coming from a type
comment - into a suitable libcst Annotation node.
If `quote_annotations`, we'll always quote annotations unless they are builtin
types. The reason for this is to make the codemod safer to apply
on legacy code where type comments may well include invalid types
that would crash at runtime.
"""
if _is_builtin(raw):
return cst.Annotation(annotation=cst.Name(value=raw))
else:
return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"'))
if not quote_annotations:
try:
return cst.Annotation(annotation=cst.parse_expression(raw))
except cst.ParserSyntaxError:
pass
return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"'))


def _is_type_comment(comment: Optional[cst.Comment]) -> bool:
Expand Down Expand Up @@ -195,10 +205,14 @@ def annotated_bindings(
def type_declaration(
binding: cst.BaseAssignTargetExpression,
raw_annotation: str,
quote_annotations: bool,
) -> cst.AnnAssign:
return cst.AnnAssign(
target=binding,
annotation=_convert_annotation(raw=raw_annotation),
annotation=_convert_annotation(
raw=raw_annotation,
quote_annotations=quote_annotations,
),
value=None,
)

Expand All @@ -207,13 +221,15 @@ def type_declaration_statements(
bindings: UnpackedBindings,
annotations: UnpackedAnnotations,
leading_lines: Sequence[cst.EmptyLine],
quote_annotations: bool,
) -> List[cst.SimpleStatementLine]:
return [
cst.SimpleStatementLine(
body=[
AnnotationSpreader.type_declaration(
binding=binding,
raw_annotation=raw_annotation,
quote_annotations=quote_annotations,
)
],
leading_lines=leading_lines if i == 0 else [],
Expand All @@ -230,6 +246,7 @@ def type_declaration_statements(
def convert_Assign(
node: cst.Assign,
annotation: ast.expr,
quote_annotations: bool,
) -> Union[
_FailedToApplyAnnotation,
cst.AnnAssign,
Expand All @@ -255,7 +272,10 @@ def convert_Assign(
binding, raw_annotation = annotated_targets[0][0]
return cst.AnnAssign(
target=binding,
annotation=_convert_annotation(raw=raw_annotation),
annotation=_convert_annotation(
raw=raw_annotation,
quote_annotations=quote_annotations,
),
value=node.value,
semicolon=node.semicolon,
)
Expand All @@ -264,7 +284,11 @@ def convert_Assign(
# on the LHS or multiple `=` tokens or both), we need to add a type
# declaration per individual LHS target.
type_declarations = [
AnnotationSpreader.type_declaration(binding, raw_annotation)
AnnotationSpreader.type_declaration(
binding,
raw_annotation,
quote_annotations=quote_annotations,
)
for annotated_bindings in annotated_targets
for binding, raw_annotation in annotated_bindings
]
Expand Down Expand Up @@ -388,7 +412,7 @@ class ConvertTypeComments(VisitorBasedCodemodCommand):
- For parameters, we prefer inline type comments to
function-level type comments if we find both.
We always apply the type comments as quoted annotations, unless
We always apply the type comments as quote_annotations annotations, unless
we know that it refers to a builtin. We do not guarantee that
the resulting string annotations would parse, but they should
never cause failures at module import time.
Expand Down Expand Up @@ -427,7 +451,22 @@ class ConvertTypeComments(VisitorBasedCodemodCommand):
function_body_stack: List[cst.BaseSuite]
aggressively_strip_type_comments: bool

def __init__(self, context: CodemodContext) -> None:
@staticmethod
def add_args(arg_parser: argparse.ArgumentParser) -> None:
arg_parser.add_argument(
"--no-quote-annotations",
action="store_true",
help=(
"Add unquoted annotations. This leads to prettier code "
+ "but possibly more errors if type comments are invalid."
),
)

def __init__(
self,
context: CodemodContext,
no_quote_annotations: bool = False,
) -> None:
if (sys.version_info.major, sys.version_info.minor) < (3, 9):
# The ast module did not get `unparse` until Python 3.9,
# or `type_comments` until Python 3.8
Expand All @@ -444,6 +483,9 @@ def __init__(self, context: CodemodContext) -> None:
+ "it is only libcst that needs a new Python version."
)
super().__init__(context)
# flags used to control overall behavior
self.quote_annotations: bool = not no_quote_annotations
# state used to manage how we traverse nodes in various contexts
self.function_type_info_stack = []
self.function_body_stack = []
self.aggressively_strip_type_comments = False
Expand Down Expand Up @@ -480,6 +522,7 @@ def leave_SimpleStatementLine(
converted = convert_Assign(
node=assign,
annotation=annotation,
quote_annotations=self.quote_annotations,
)
if isinstance(converted, _FailedToApplyAnnotation):
# We were unable to consume the type comment, so return the
Expand Down Expand Up @@ -556,6 +599,7 @@ def leave_For(
bindings=AnnotationSpreader.unpack_target(updated_node.target),
annotations=AnnotationSpreader.unpack_annotation(annotation),
leading_lines=updated_node.leading_lines,
quote_annotations=self.quote_annotations,
)
except _ArityError:
return updated_node
Expand Down Expand Up @@ -606,6 +650,7 @@ def leave_With(
bindings=AnnotationSpreader.unpack_target(target),
annotations=AnnotationSpreader.unpack_annotation(annotation),
leading_lines=updated_node.leading_lines,
quote_annotations=self.quote_annotations,
)
except _ArityError:
return updated_node
Expand Down Expand Up @@ -773,7 +818,10 @@ def leave_Param(
raw_annotation = function_type_info.arguments.get(updated_node.name.value)
if raw_annotation is not None:
return updated_node.with_changes(
annotation=_convert_annotation(raw=raw_annotation)
annotation=_convert_annotation(
raw=raw_annotation,
quote_annotations=self.quote_annotations,
)
)
else:
return updated_node
Expand All @@ -787,7 +835,10 @@ def leave_FunctionDef(
function_type_info = self.function_type_info_stack.pop()
if updated_node.returns is None and function_type_info.returns is not None:
return updated_node.with_changes(
returns=_convert_annotation(raw=function_type_info.returns)
returns=_convert_annotation(
raw=function_type_info.returns,
quote_annotations=self.quote_annotations,
)
)
else:
return updated_node
Expand Down
41 changes: 38 additions & 3 deletions libcst/codemod/commands/tests/test_convert_type_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import sys
from typing import Any

from libcst.codemod import CodemodTest
from libcst.codemod.commands.convert_type_comments import ConvertTypeComments
Expand All @@ -14,16 +15,16 @@ class TestConvertTypeCommentsBase(CodemodTest):
maxDiff = 1500
TRANSFORM = ConvertTypeComments

def assertCodemod39Plus(self, before: str, after: str) -> None:
def assertCodemod39Plus(self, before: str, after: str, **kwargs: Any) -> None:
"""
Assert that the codemod works on Python 3.9+, and that we raise
a NotImplementedError on other Python versions.
"""
if (sys.version_info.major, sys.version_info.minor) < (3, 9):
with self.assertRaises(NotImplementedError):
super().assertCodemod(before, after)
super().assertCodemod(before, after, **kwargs)
else:
super().assertCodemod(before, after)
super().assertCodemod(before, after, **kwargs)


class TestConvertTypeComments_AssignForWith(TestConvertTypeCommentsBase):
Expand Down Expand Up @@ -436,3 +437,37 @@ class WrapsAFunction:
"""
after = before
self.assertCodemod39Plus(before, after)

def test_no_quoting(self) -> None:
before = """
def f(x):
# type: (Foo) -> Foo
pass
w = x # type: Foo
y, z = x, x # type: (Foo, Foo)
return w
with get_context() as context: # type: Context
pass
for loop_var in the_iterable: # type: LoopType
pass
"""
after = """
def f(x: Foo) -> Foo:
pass
w: Foo = x
y: Foo
z: Foo
y, z = x, x
return w
context: Context
with get_context() as context:
pass
loop_var: LoopType
for loop_var in the_iterable:
pass
"""
self.assertCodemod39Plus(before, after, no_quote_annotations=True)

0 comments on commit 0f42a78

Please sign in to comment.