Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crashes with comments in parentheses #4453

Merged
merged 5 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

<!-- Changes that affect Black's stable style -->

- Fix crashes involving comments in parenthesised return types or `X | Y` style unions.
(#4453)

### Preview style

<!-- Changes that affect Black's preview style -->
Expand Down
82 changes: 49 additions & 33 deletions src/black/linegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,47 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
)


def _ensure_trailing_comma(
leaves: List[Leaf], original: Line, opening_bracket: Leaf
) -> bool:
if not leaves:
return False
# Ensure a trailing comma for imports
if original.is_import:
return True
# ...and standalone function arguments
if not original.is_def:
return False
if opening_bracket.value != "(":
return False
# Don't add commas if we already have any commas
if any(
leaf.type == token.COMMA
and (
Preview.typed_params_trailing_comma not in original.mode
or not is_part_of_annotation(leaf)
)
for leaf in leaves
):
return False

# Find a leaf with a parent (comments don't have parents)
leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None)
if leaf_with_parent is None:
return True
# Don't add commas inside parenthesized return annotations
if get_annotation_type(leaf_with_parent) == "return":
return False
# Don't add commas inside PEP 604 unions
if (
leaf_with_parent.parent
and leaf_with_parent.parent.next_sibling
and leaf_with_parent.parent.next_sibling.type == token.VBAR
):
return False
return True


def bracket_split_build_line(
leaves: List[Leaf],
original: Line,
Expand All @@ -1099,40 +1140,15 @@ def bracket_split_build_line(
if component is _BracketSplitComponent.body:
result.inside_brackets = True
result.depth += 1
if leaves:
no_commas = (
# Ensure a trailing comma for imports and standalone function arguments
original.is_def
# Don't add one after any comments or within type annotations
and opening_bracket.value == "("
# Don't add one if there's already one there
and not any(
leaf.type == token.COMMA
and (
Preview.typed_params_trailing_comma not in original.mode
or not is_part_of_annotation(leaf)
)
for leaf in leaves
)
# Don't add one inside parenthesized return annotations
and get_annotation_type(leaves[0]) != "return"
# Don't add one inside PEP 604 unions
and not (
leaves[0].parent
and leaves[0].parent.next_sibling
and leaves[0].parent.next_sibling.type == token.VBAR
)
)

if original.is_import or no_commas:
for i in range(len(leaves) - 1, -1, -1):
if leaves[i].type == STANDALONE_COMMENT:
continue
if _ensure_trailing_comma(leaves, original, opening_bracket):
for i in range(len(leaves) - 1, -1, -1):
if leaves[i].type == STANDALONE_COMMENT:
continue

if leaves[i].type != token.COMMA:
new_comma = Leaf(token.COMMA, ",")
leaves.insert(i + 1, new_comma)
break
if leaves[i].type != token.COMMA:
new_comma = Leaf(token.COMMA, ",")
leaves.insert(i + 1, new_comma)
break

leaves_to_track: Set[LeafID] = set()
if component is _BracketSplitComponent.head:
Expand Down
1 change: 1 addition & 0 deletions src/black/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]:

def is_part_of_annotation(leaf: Leaf) -> bool:
"""Returns whether this leaf is part of a type annotation."""
assert leaf.parent is not None
return get_annotation_type(leaf) is not None


Expand Down
2 changes: 1 addition & 1 deletion src/black/trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def do_match(self, line: Line) -> TMatchResult:
break
i += 1

if not is_part_of_annotation(leaf) and not contains_comment:
if not contains_comment and not is_part_of_annotation(leaf):
string_indices.append(idx)

# Advance to the next non-STRING leaf.
Expand Down
1 change: 1 addition & 0 deletions tests/data/cases/funcdef_return_type_trailing_comma.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def SimplePyFn(
Buffer[UInt8, 2],
Buffer[UInt8, 2],
]: ...

# output
# normal, short, function definition
def foo(a, b) -> tuple[int, float]: ...
Expand Down
130 changes: 130 additions & 0 deletions tests/data/cases/function_trailing_comma.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,64 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
argument1, (one, two,), argument4, argument5, argument6
)

def foo() -> (
# comment inside parenthesised return type
int
):
...

def foo() -> (
# comment inside parenthesised return type
# more
int
# another
):
...

def foo() -> (
# comment inside parenthesised new union return type
int | str | bytes
):
...

def foo() -> (
# comment inside plain tuple
):
pass

def foo(arg: (# comment with non-return annotation
int
# comment with non-return annotation
)):
pass

def foo(arg: (# comment with non-return annotation
int | range | memoryview
# comment with non-return annotation
)):
pass

def foo(arg: (# only before
int
)):
pass

def foo(arg: (
int
# only after
)):
pass

variable: ( # annotation
because
# why not
)

variable: (
because
# why not
)

# output

def f(
Expand Down Expand Up @@ -176,3 +234,75 @@ def func() -> (
argument5,
argument6,
)


def foo() -> (
# comment inside parenthesised return type
int
): ...


def foo() -> (
# comment inside parenthesised return type
# more
int
# another
): ...


def foo() -> (
# comment inside parenthesised new union return type
int
| str
| bytes
): ...


def foo() -> (
# comment inside plain tuple
):
pass


def foo(
arg: ( # comment with non-return annotation
int
# comment with non-return annotation
),
):
pass


def foo(
arg: ( # comment with non-return annotation
int
| range
| memoryview
# comment with non-return annotation
),
):
pass


def foo(arg: int): # only before
pass


def foo(
arg: (
int
# only after
),
):
pass


variable: ( # annotation
because
# why not
)

variable: (
because
# why not
)
Loading