Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow borrowing inside comprehensions #723

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,12 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Error: Linearity violation (at <In [10]>:7:7)\n",
"Error: Linearity violation (at <In [10]>:6:4)\n",
" | \n",
"4 | @guppy(bad_module)\n",
"5 | def bad(q: qubit @owned) -> qubit:\n",
"6 | tmp = qubit()\n",
"7 | cx(tmp, q)\n",
" | ^^^ Variable `tmp` with linear type `qubit` is leaked\n",
" | ^^^ Variable `tmp` with linear type `qubit` is leaked\n",
"\n",
"Help: Make sure that `tmp` is consumed or returned to avoid the leak\n",
"\n",
Expand Down
45 changes: 34 additions & 11 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,13 @@ def _reassign_inout_args(self, func_ty: FunctionType, call: AnyCall) -> None:
if InputFlags.Inout in inp.flags:
match arg:
case PlaceNode(place=place):
self._reassign_single_inout_arg(place, arg)
self._reassign_single_inout_arg(place, place.defined_at or arg)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by: Improve leak errors by keeping track of the original place where a variable was defined if possible

case arg if inp.ty.linear:
err = DropAfterCallError(arg, inp.ty, self._call_name(call))
err.add_sub_diagnostic(DropAfterCallError.Assign(None))
raise GuppyError(err)

def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None:
def _reassign_single_inout_arg(self, place: Place, node: AstNode) -> None:
"""Helper function to reassign a single borrowed argument after a function
call."""
# Places involving subscripts are given back by visiting the `__setitem__` call
Expand Down Expand Up @@ -496,6 +496,11 @@ def _check_comprehension(
continue
for leaf in leaf_places(place):
x = leaf.id
# Also ignore borrowed variables
if x in inner_scope.used_parent and (
inner_scope.used_parent[x].kind == UseKind.BORROW
):
continue
if not self.scope.used(x) and place.ty.linear:
err = PlaceNotUsedError(place.defined_at, place)
err.add_sub_diagnostic(
Expand All @@ -508,6 +513,22 @@ def _check_comprehension(
# Recursively check the remaining generators
self._check_comprehension(gens, elt)

# Look for any linear variables that were borrowed from the outer scope
gen.borrowed_outer_places = []
for x, use in inner_scope.used_parent.items():
if use.kind == UseKind.BORROW:
# Since `x` was borrowed, we know that is now also assigned in the
# inner scope since it gets reassigned in the local scope after the
# borrow expires
place = inner_scope[x]
gen.borrowed_outer_places.append(place)
# Also mark this place as implicitly used so we don't complain about
# it later.
for leaf in leaf_places(place):
inner_scope.use(
leaf.id, InoutReturnSentinel(leaf), UseKind.RETURN
)

# Check the iter finalizer so we record a final use of the iterator
self.visit(gen.iterend)

Expand All @@ -519,15 +540,17 @@ def _check_comprehension(
if leaf.ty.linear and not inner_scope.used(x):
raise GuppyTypeError(PlaceNotUsedError(leaf.defined_at, leaf))

# On the other hand, we have to ensure that no linear places from the
# outer scope have been used inside the comprehension (they would be used
# multiple times since the comprehension body is executed repeatedly)
for x, use in inner_scope.used_parent.items():
place = inner_scope[x]
if place.ty.linear:
raise GuppyTypeError(
ComprAlreadyUsedError(use.node, place, use.kind)
)
# On the other hand, we have to ensure that no linear places from the
# outer scope have been used inside the comprehension (they would be used
# multiple times since the comprehension body is executed repeatedly)
for x, use in inner_scope.used_parent.items():
place = inner_scope[x]
# The only exception are values that are only borrowed from the outer
# scope. These can be safely reassigned.
if use.kind == UseKind.BORROW:
self._reassign_single_inout_arg(place, use.node)
elif place.ty.linear:
raise GuppyTypeError(ComprAlreadyUsedError(use.node, place, use.kind))


def leaf_places(place: Place) -> Iterator[Place]:
Expand Down
3 changes: 3 additions & 0 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ def _build_generators(
assert isinstance(gen.iter, PlaceNode)
assert isinstance(gen.hasnext, PlaceNode)
inputs = [gen.iter] + [PlaceNode(place=var) for var in loop_vars]
inputs += [
PlaceNode(place=place) for place in gen.borrowed_outer_places
]
# Remember to finalize the iterator once we are done with it. Note that
# we need to use partial in the callback, so that we bind the *current*
# value of `gen` instead of only last
Expand Down
3 changes: 3 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class DesugaredGenerator(ast.expr):
hasnext: ast.expr
ifs: list[ast.expr]

borrowed_outer_places: "list[Place]"

_fields = (
"iter_assign",
"hasnext_assign",
Expand All @@ -215,6 +217,7 @@ class DesugaredGenerator(ast.expr):
"iter",
"hasnext",
"ifs",
"borrowed_outer_places",
)


Expand Down
12 changes: 12 additions & 0 deletions tests/error/comprehension_errors/borrow_after_consume.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Error: Linearity violation (at $FILE:21:16)
|
19 | @guppy(module)
20 | def foo(qs: list[qubit] @owned) -> list[int]:
21 | return [baz(q) for q in qs if bar(q)]
| ^ Variable `q` with linear type `qubit` cannot be borrowed
| ...
|
21 | return [baz(q) for q in qs if bar(q)]
| - since it was already consumed here

Guppy compilation failed due to 1 previous error
24 changes: 24 additions & 0 deletions tests/error/comprehension_errors/borrow_after_consume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit @owned) -> int: ...


@guppy.declare(module)
def baz(q: qubit) -> int: ...


@guppy(module)
def foo(qs: list[qubit] @owned) -> list[int]:
return [baz(q) for q in qs if bar(q)]


module.compile()
10 changes: 10 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked1.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Error: Linearity violation (at $FILE:17:16)
|
15 | @guppy(module)
16 | def foo(n: int, q: qubit @owned) -> list[int]:
17 | return [bar(q) for _ in range(n)]
| ^ Variable `q` with linear type `qubit` is leaked

Help: Make sure that `q` is consumed or returned to avoid the leak

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit) -> int: ...


@guppy(module)
def foo(n: int, q: qubit @owned) -> list[int]:
return [bar(q) for _ in range(n)]


module.compile()
8 changes: 8 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked2.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Linearity violation (at $FILE:17:23)
|
15 | @guppy(module)
16 | def foo(qs: list[qubit] @owned) -> list[int]:
17 | return [bar(q) for q in qs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obvs not for this PR but it would be super sweet if a hint said "try removing owned"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #726

| ^ Variable `q` with linear type `qubit` is leaked

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit) -> int: ...


@guppy(module)
def foo(qs: list[qubit] @owned) -> list[int]:
return [bar(q) for q in qs]


module.compile()
11 changes: 11 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked3.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Error: Linearity violation (at $FILE:17:18)
|
15 | @guppy(module)
16 | def foo(qs: list[qubit] @owned) -> list[int]:
17 | return [0 for q in qs if bar(q)]
| ^ Variable `q` with linear type `qubit` may be leaked ...
|
17 | return [0 for q in qs if bar(q)]
| ------ if this expression is `False`

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit) -> bool: ...


@guppy(module)
def foo(qs: list[qubit] @owned) -> list[int]:
return [0 for q in qs if bar(q)]


module.compile()
8 changes: 8 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked4.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Linearity violation (at $FILE:17:18)
|
15 | @guppy(module)
16 | def foo(qs: list[qubit] @owned) -> list[qubit]:
17 | return [r for q in qs for r in bar(q)]
| ^ Variable `q` with linear type `qubit` is leaked

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit) -> list[qubit]: ...


@guppy(module)
def foo(qs: list[qubit] @owned) -> list[qubit]:
return [r for q in qs for r in bar(q)]


module.compile()
6 changes: 3 additions & 3 deletions tests/error/inout_errors/override_after_call.err
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Error: Linearity violation (at $FILE:16:13)
Error: Linearity violation (at $FILE:15:9)
|
13 |
14 | @guppy(module)
15 | def test(q1: qubit @owned, q2: qubit @owned) -> tuple[qubit, qubit]:
16 | q1 = foo(q1, q2)
| ^^ Variable `q1` with linear type `qubit` is leaked
| ^^^^^^^^^^^^^^^^ Variable `q1` with linear type `qubit` is leaked

Help: Make sure that `q1` is consumed or returned to avoid the leak

Expand Down
6 changes: 3 additions & 3 deletions tests/error/inout_errors/unused_after_call.err
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Error: Linearity violation (at $FILE:16:7)
Error: Linearity violation (at $FILE:15:9)
|
13 |
14 | @guppy(module)
15 | def test(q: qubit @owned) -> None:
16 | foo(q)
| ^ Variable `q` with linear type `qubit` is leaked
| ^^^^^^^^^^^^^^^ Variable `q` with linear type `qubit` is leaked

Help: Make sure that `q` is consumed or returned to avoid the leak

Expand Down
75 changes: 74 additions & 1 deletion tests/integration/test_array_comprehension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import array
from guppylang.std.builtins import array, owned
from guppylang.std.quantum import qubit

import guppylang.std.quantum_functional as quantum
Expand Down Expand Up @@ -88,3 +88,76 @@ def test(xs: array[int, n]) -> array[int, n]:
return array(x + 1 for x in xs)

validate(module.compile())


def test_borrow(validate):
module = GuppyModule("test")
module.load_all(quantum)
module.load(qubit)
n = guppy.nat_var("n", module)

@guppy.declare(module)
def foo(q: qubit) -> int: ...

@guppy(module)
def test(q: qubit) -> array[int, n]:
return array(foo(q) for _ in range(n))

validate(module.compile())


def test_borrow_twice(validate):
module = GuppyModule("test")
module.load_all(quantum)
module.load(qubit)
n = guppy.nat_var("n", module)

@guppy.declare(module)
def foo(q: qubit) -> int: ...

@guppy(module)
def test(q: qubit) -> array[int, n]:
return array(foo(q) + foo(q) for _ in range(n))

validate(module.compile())


def test_borrow_struct(validate):
module = GuppyModule("test")
module.load_all(quantum)
module.load(qubit)
n = guppy.nat_var("n", module)

@guppy.struct(module)
class MyStruct:
q1: qubit
q2: qubit

@guppy.declare(module)
def foo(s: MyStruct) -> int: ...

@guppy(module)
def test(s: MyStruct) -> array[int, n]:
return array(foo(s) for _ in range(n))

validate(module.compile())


def test_borrow_and_consume(validate):
module = GuppyModule("test")
module.load_all(quantum)
module.load(qubit)
n = guppy.nat_var("n", module)

@guppy.declare(module)
def foo(q: qubit) -> int: ...

@guppy.declare(module)
def bar(q: qubit @ owned) -> int: ...

@guppy(module)
def test(qs: array[qubit, n] @ owned) -> array[int, n]:
return array(foo(q) + bar(q) for q in qs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flex


validate(module.compile())

Loading
Loading