Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply Type Comments: Allow for skipping quotes when applying type comments #644

Merged
merged 3 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 68 additions & 17 deletions libcst/codemod/commands/convert_type_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import ast
import builtins
import dataclasses
Expand Down Expand Up @@ -83,18 +84,27 @@ def _is_builtin(annotation: str) -> bool:
return annotation in _builtins()


def _convert_annotation(raw: str) -> cst.Annotation:
# Convert annotation comments to string annotations to be safe,
# otherwise runtime errors would be common.
#
# Special-case builtins to reduce the amount of quoting noise.
#
# NOTE: we could potentially detect more cases for skipping quotes
# using ScopeProvider, which would make the output prettier.
def _convert_annotation(
raw: str,
quote_annotations: bool,
) -> cst.Annotation:
"""
Convert a raw annotation - which is a string coming from a type
comment - into a suitable libcst Annotation node.

If `quote_annotations`, we'll always quote annotations unless they are builtin
types. The reason for this is to make the codemod safer to apply
on legacy code where type comments may well include invalid types
that would crash at runtime.
"""
if _is_builtin(raw):
return cst.Annotation(annotation=cst.Name(value=raw))
else:
return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"'))
if not quote_annotations:
try:
return cst.Annotation(annotation=cst.parse_expression(raw))
except cst.ParserSyntaxError:
pass
return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"'))


def _is_type_comment(comment: Optional[cst.Comment]) -> bool:
Expand Down Expand Up @@ -195,10 +205,14 @@ def annotated_bindings(
def type_declaration(
binding: cst.BaseAssignTargetExpression,
raw_annotation: str,
quote_annotations: bool,
) -> cst.AnnAssign:
return cst.AnnAssign(
target=binding,
annotation=_convert_annotation(raw=raw_annotation),
annotation=_convert_annotation(
raw=raw_annotation,
quote_annotations=quote_annotations,
),
value=None,
)

Expand All @@ -207,13 +221,15 @@ def type_declaration_statements(
bindings: UnpackedBindings,
annotations: UnpackedAnnotations,
leading_lines: Sequence[cst.EmptyLine],
quote_annotations: bool,
) -> List[cst.SimpleStatementLine]:
return [
cst.SimpleStatementLine(
body=[
AnnotationSpreader.type_declaration(
binding=binding,
raw_annotation=raw_annotation,
quote_annotations=quote_annotations,
)
],
leading_lines=leading_lines if i == 0 else [],
Expand All @@ -230,6 +246,7 @@ def type_declaration_statements(
def convert_Assign(
node: cst.Assign,
annotation: ast.expr,
quote_annotations: bool,
) -> Union[
_FailedToApplyAnnotation,
cst.AnnAssign,
Expand All @@ -255,7 +272,10 @@ def convert_Assign(
binding, raw_annotation = annotated_targets[0][0]
return cst.AnnAssign(
target=binding,
annotation=_convert_annotation(raw=raw_annotation),
annotation=_convert_annotation(
raw=raw_annotation,
quote_annotations=quote_annotations,
),
value=node.value,
semicolon=node.semicolon,
)
Expand All @@ -264,7 +284,11 @@ def convert_Assign(
# on the LHS or multiple `=` tokens or both), we need to add a type
# declaration per individual LHS target.
type_declarations = [
AnnotationSpreader.type_declaration(binding, raw_annotation)
AnnotationSpreader.type_declaration(
binding,
raw_annotation,
quote_annotations=quote_annotations,
)
for annotated_bindings in annotated_targets
for binding, raw_annotation in annotated_bindings
]
Expand Down Expand Up @@ -388,7 +412,7 @@ class ConvertTypeComments(VisitorBasedCodemodCommand):
- For parameters, we prefer inline type comments to
function-level type comments if we find both.

We always apply the type comments as quoted annotations, unless
We always apply the type comments as quote_annotations annotations, unless
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one is intended - in the docstring I do want to talk about "quoted annotations" when discussing the impact of the flag on output, quote_annotations is just the name of the python variable.

we know that it refers to a builtin. We do not guarantee that
the resulting string annotations would parse, but they should
never cause failures at module import time.
Expand Down Expand Up @@ -427,7 +451,22 @@ class ConvertTypeComments(VisitorBasedCodemodCommand):
function_body_stack: List[cst.BaseSuite]
aggressively_strip_type_comments: bool

def __init__(self, context: CodemodContext) -> None:
@staticmethod
def add_args(arg_parser: argparse.ArgumentParser) -> None:
arg_parser.add_argument(
"--no-quote-annotations",
action="store_true",
help=(
"Add unquoted annotations. This leads to prettier code "
+ "but possibly more errors if type comments are invalid."
),
)

def __init__(
self,
context: CodemodContext,
no_quote_annotations: bool = False,
) -> None:
if (sys.version_info.major, sys.version_info.minor) < (3, 9):
# The ast module did not get `unparse` until Python 3.9,
# or `type_comments` until Python 3.8
Expand All @@ -444,6 +483,9 @@ def __init__(self, context: CodemodContext) -> None:
+ "it is only libcst that needs a new Python version."
)
super().__init__(context)
# flags used to control overall behavior
self.quote_annotations: bool = not no_quote_annotations
# state used to manage how we traverse nodes in various contexts
self.function_type_info_stack = []
self.function_body_stack = []
self.aggressively_strip_type_comments = False
Expand Down Expand Up @@ -480,6 +522,7 @@ def leave_SimpleStatementLine(
converted = convert_Assign(
node=assign,
annotation=annotation,
quote_annotations=self.quote_annotations,
)
if isinstance(converted, _FailedToApplyAnnotation):
# We were unable to consume the type comment, so return the
Expand Down Expand Up @@ -556,6 +599,7 @@ def leave_For(
bindings=AnnotationSpreader.unpack_target(updated_node.target),
annotations=AnnotationSpreader.unpack_annotation(annotation),
leading_lines=updated_node.leading_lines,
quote_annotations=self.quote_annotations,
)
except _ArityError:
return updated_node
Expand Down Expand Up @@ -606,6 +650,7 @@ def leave_With(
bindings=AnnotationSpreader.unpack_target(target),
annotations=AnnotationSpreader.unpack_annotation(annotation),
leading_lines=updated_node.leading_lines,
quote_annotations=self.quote_annotations,
)
except _ArityError:
return updated_node
Expand Down Expand Up @@ -773,7 +818,10 @@ def leave_Param(
raw_annotation = function_type_info.arguments.get(updated_node.name.value)
if raw_annotation is not None:
return updated_node.with_changes(
annotation=_convert_annotation(raw=raw_annotation)
annotation=_convert_annotation(
raw=raw_annotation,
quote_annotations=self.quote_annotations,
)
)
else:
return updated_node
Expand All @@ -787,7 +835,10 @@ def leave_FunctionDef(
function_type_info = self.function_type_info_stack.pop()
if updated_node.returns is None and function_type_info.returns is not None:
return updated_node.with_changes(
returns=_convert_annotation(raw=function_type_info.returns)
returns=_convert_annotation(
raw=function_type_info.returns,
quote_annotations=self.quote_annotations,
)
)
else:
return updated_node
Expand Down
41 changes: 38 additions & 3 deletions libcst/codemod/commands/tests/test_convert_type_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import sys
from typing import Any

from libcst.codemod import CodemodTest
from libcst.codemod.commands.convert_type_comments import ConvertTypeComments
Expand All @@ -14,16 +15,16 @@ class TestConvertTypeCommentsBase(CodemodTest):
maxDiff = 1500
TRANSFORM = ConvertTypeComments

def assertCodemod39Plus(self, before: str, after: str) -> None:
def assertCodemod39Plus(self, before: str, after: str, **kwargs: Any) -> None:
"""
Assert that the codemod works on Python 3.9+, and that we raise
a NotImplementedError on other Python versions.
"""
if (sys.version_info.major, sys.version_info.minor) < (3, 9):
with self.assertRaises(NotImplementedError):
super().assertCodemod(before, after)
super().assertCodemod(before, after, **kwargs)
else:
super().assertCodemod(before, after)
super().assertCodemod(before, after, **kwargs)


class TestConvertTypeComments_AssignForWith(TestConvertTypeCommentsBase):
Expand Down Expand Up @@ -436,3 +437,37 @@ class WrapsAFunction:
"""
after = before
self.assertCodemod39Plus(before, after)

def test_no_quoting(self) -> None:
before = """
def f(x):
# type: (Foo) -> Foo
pass
w = x # type: Foo
y, z = x, x # type: (Foo, Foo)
return w

with get_context() as context: # type: Context
pass

for loop_var in the_iterable: # type: LoopType
pass
"""
after = """
def f(x: Foo) -> Foo:
pass
w: Foo = x
y: Foo
z: Foo
y, z = x, x
return w

context: Context
with get_context() as context:
pass

loop_var: LoopType
for loop_var in the_iterable:
pass
"""
self.assertCodemod39Plus(before, after, no_quote_annotations=True)