Skip to content

Commit

Permalink
fix: Allow borrowing inside comprehensions
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Dec 16, 2024
1 parent d52a00a commit 81f6226
Show file tree
Hide file tree
Showing 18 changed files with 409 additions and 21 deletions.
6 changes: 3 additions & 3 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,12 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Error: Linearity violation (at <In [10]>:7:7)\n",
"Error: Linearity violation (at <In [10]>:6:4)\n",
" | \n",
"4 | @guppy(bad_module)\n",
"5 | def bad(q: qubit @owned) -> qubit:\n",
"6 | tmp = qubit()\n",
"7 | cx(tmp, q)\n",
" | ^^^ Variable `tmp` with linear type `qubit` is leaked\n",
" | ^^^ Variable `tmp` with linear type `qubit` is leaked\n",
"\n",
"Help: Make sure that `tmp` is consumed or returned to avoid the leak\n",
"\n",
Expand Down
45 changes: 34 additions & 11 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,13 @@ def _reassign_inout_args(self, func_ty: FunctionType, call: AnyCall) -> None:
if InputFlags.Inout in inp.flags:
match arg:
case PlaceNode(place=place):
self._reassign_single_inout_arg(place, arg)
self._reassign_single_inout_arg(place, place.defined_at or arg)
case arg if inp.ty.linear:
err = DropAfterCallError(arg, inp.ty, self._call_name(call))
err.add_sub_diagnostic(DropAfterCallError.Assign(None))
raise GuppyError(err)

def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None:
def _reassign_single_inout_arg(self, place: Place, node: AstNode) -> None:
"""Helper function to reassign a single borrowed argument after a function
call."""
# Places involving subscripts are given back by visiting the `__setitem__` call
Expand Down Expand Up @@ -496,6 +496,11 @@ def _check_comprehension(
continue
for leaf in leaf_places(place):
x = leaf.id
# Also ignore borrowed variables
if x in inner_scope.used_parent and (
inner_scope.used_parent[x].kind == UseKind.BORROW
):
continue
if not self.scope.used(x) and place.ty.linear:
err = PlaceNotUsedError(place.defined_at, place)
err.add_sub_diagnostic(
Expand All @@ -508,6 +513,22 @@ def _check_comprehension(
# Recursively check the remaining generators
self._check_comprehension(gens, elt)

# Look for any linear variables that were borrowed from the outer scope
gen.borrowed_outer_places = []
for x, use in inner_scope.used_parent.items():
if use.kind == UseKind.BORROW:
# Since `x` was borrowed, we know that is now also assigned in the
# inner scope since it gets reassigned in the local scope after the
# borrow expires
place = inner_scope[x]
gen.borrowed_outer_places.append(place)
# Also mark this place as implicitly used so we don't complain about
# it later.
for leaf in leaf_places(place):
inner_scope.use(
leaf.id, InoutReturnSentinel(leaf), UseKind.RETURN
)

# Check the iter finalizer so we record a final use of the iterator
self.visit(gen.iterend)

Expand All @@ -519,15 +540,17 @@ def _check_comprehension(
if leaf.ty.linear and not inner_scope.used(x):
raise GuppyTypeError(PlaceNotUsedError(leaf.defined_at, leaf))

# On the other hand, we have to ensure that no linear places from the
# outer scope have been used inside the comprehension (they would be used
# multiple times since the comprehension body is executed repeatedly)
for x, use in inner_scope.used_parent.items():
place = inner_scope[x]
if place.ty.linear:
raise GuppyTypeError(
ComprAlreadyUsedError(use.node, place, use.kind)
)
# On the other hand, we have to ensure that no linear places from the
# outer scope have been used inside the comprehension (they would be used
# multiple times since the comprehension body is executed repeatedly)
for x, use in inner_scope.used_parent.items():
place = inner_scope[x]
# The only exception are values that are only borrowed from the outer
# scope. These can be safely reassigned.
if use.kind == UseKind.BORROW:
self._reassign_single_inout_arg(place, use.node)
elif place.ty.linear:
raise GuppyTypeError(ComprAlreadyUsedError(use.node, place, use.kind))


def leaf_places(place: Place) -> Iterator[Place]:
Expand Down
3 changes: 3 additions & 0 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ def _build_generators(
assert isinstance(gen.iter, PlaceNode)
assert isinstance(gen.hasnext, PlaceNode)
inputs = [gen.iter] + [PlaceNode(place=var) for var in loop_vars]
inputs += [
PlaceNode(place=place) for place in gen.borrowed_outer_places
]
# Remember to finalize the iterator once we are done with it. Note that
# we need to use partial in the callback, so that we bind the *current*
# value of `gen` instead of only last
Expand Down
3 changes: 3 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class DesugaredGenerator(ast.expr):
hasnext: ast.expr
ifs: list[ast.expr]

borrowed_outer_places: "list[Place]"

_fields = (
"iter_assign",
"hasnext_assign",
Expand All @@ -215,6 +217,7 @@ class DesugaredGenerator(ast.expr):
"iter",
"hasnext",
"ifs",
"borrowed_outer_places",
)


Expand Down
12 changes: 12 additions & 0 deletions tests/error/comprehension_errors/borrow_after_consume.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Error: Linearity violation (at $FILE:21:16)
|
19 | @guppy(module)
20 | def foo(qs: list[qubit] @owned) -> list[int]:
21 | return [baz(q) for q in qs if bar(q)]
| ^ Variable `q` with linear type `qubit` cannot be borrowed
| ...
|
21 | return [baz(q) for q in qs if bar(q)]
| - since it was already consumed here

Guppy compilation failed due to 1 previous error
24 changes: 24 additions & 0 deletions tests/error/comprehension_errors/borrow_after_consume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit @owned) -> int: ...


@guppy.declare(module)
def baz(q: qubit) -> int: ...


@guppy(module)
def foo(qs: list[qubit] @owned) -> list[int]:
return [baz(q) for q in qs if bar(q)]


module.compile()
10 changes: 10 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked1.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Error: Linearity violation (at $FILE:17:16)
|
15 | @guppy(module)
16 | def foo(n: int, q: qubit @owned) -> list[int]:
17 | return [bar(q) for _ in range(n)]
| ^ Variable `q` with linear type `qubit` is leaked

Help: Make sure that `q` is consumed or returned to avoid the leak

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit) -> int: ...


@guppy(module)
def foo(n: int, q: qubit @owned) -> list[int]:
return [bar(q) for _ in range(n)]


module.compile()
8 changes: 8 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked2.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Linearity violation (at $FILE:17:23)
|
15 | @guppy(module)
16 | def foo(qs: list[qubit] @owned) -> list[int]:
17 | return [bar(q) for q in qs]
| ^ Variable `q` with linear type `qubit` is leaked

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit) -> int: ...


@guppy(module)
def foo(qs: list[qubit] @owned) -> list[int]:
return [bar(q) for q in qs]


module.compile()
11 changes: 11 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked3.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Error: Linearity violation (at $FILE:17:18)
|
15 | @guppy(module)
16 | def foo(qs: list[qubit] @owned) -> list[int]:
17 | return [0 for q in qs if bar(q)]
| ^ Variable `q` with linear type `qubit` may be leaked ...
|
17 | return [0 for q in qs if bar(q)]
| ------ if this expression is `False`

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit) -> bool: ...


@guppy(module)
def foo(qs: list[qubit] @owned) -> list[int]:
return [0 for q in qs if bar(q)]


module.compile()
8 changes: 8 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked4.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Linearity violation (at $FILE:17:18)
|
15 | @guppy(module)
16 | def foo(qs: list[qubit] @owned) -> list[qubit]:
17 | return [r for q in qs for r in bar(q)]
| ^ Variable `q` with linear type `qubit` is leaked

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/comprehension_errors/borrow_leaked4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit
from guppylang.std.builtins import owned

module = GuppyModule("test")
module.load_all(quantum)


@guppy.declare(module)
def bar(q: qubit) -> list[qubit]: ...


@guppy(module)
def foo(qs: list[qubit] @owned) -> list[qubit]:
return [r for q in qs for r in bar(q)]


module.compile()
6 changes: 3 additions & 3 deletions tests/error/inout_errors/override_after_call.err
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Error: Linearity violation (at $FILE:16:13)
Error: Linearity violation (at $FILE:15:9)
|
13 |
14 | @guppy(module)
15 | def test(q1: qubit @owned, q2: qubit @owned) -> tuple[qubit, qubit]:
16 | q1 = foo(q1, q2)
| ^^ Variable `q1` with linear type `qubit` is leaked
| ^^^^^^^^^^^^^^^^ Variable `q1` with linear type `qubit` is leaked

Help: Make sure that `q1` is consumed or returned to avoid the leak

Expand Down
6 changes: 3 additions & 3 deletions tests/error/inout_errors/unused_after_call.err
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Error: Linearity violation (at $FILE:16:7)
Error: Linearity violation (at $FILE:15:9)
|
13 |
14 | @guppy(module)
15 | def test(q: qubit @owned) -> None:
16 | foo(q)
| ^ Variable `q` with linear type `qubit` is leaked
| ^^^^^^^^^^^^^^^ Variable `q` with linear type `qubit` is leaked

Help: Make sure that `q` is consumed or returned to avoid the leak

Expand Down
75 changes: 74 additions & 1 deletion tests/integration/test_array_comprehension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import array
from guppylang.std.builtins import array, owned
from guppylang.std.quantum import qubit

import guppylang.std.quantum_functional as quantum
Expand Down Expand Up @@ -88,3 +88,76 @@ def test(xs: array[int, n]) -> array[int, n]:
return array(x + 1 for x in xs)

validate(module.compile())


def test_borrow(validate):
module = GuppyModule("test")
module.load_all(quantum)
module.load(qubit)
n = guppy.nat_var("n", module)

@guppy.declare(module)
def foo(q: qubit) -> int: ...

@guppy(module)
def test(q: qubit) -> array[int, n]:
return array(foo(q) for _ in range(n))

validate(module.compile())


def test_borrow_twice(validate):
module = GuppyModule("test")
module.load_all(quantum)
module.load(qubit)
n = guppy.nat_var("n", module)

@guppy.declare(module)
def foo(q: qubit) -> int: ...

@guppy(module)
def test(q: qubit) -> array[int, n]:
return array(foo(q) + foo(q) for _ in range(n))

validate(module.compile())


def test_borrow_struct(validate):
module = GuppyModule("test")
module.load_all(quantum)
module.load(qubit)
n = guppy.nat_var("n", module)

@guppy.struct(module)
class MyStruct:
q1: qubit
q2: qubit

@guppy.declare(module)
def foo(s: MyStruct) -> int: ...

@guppy(module)
def test(s: MyStruct) -> array[int, n]:
return array(foo(s) for _ in range(n))

validate(module.compile())


def test_borrow_and_consume(validate):
module = GuppyModule("test")
module.load_all(quantum)
module.load(qubit)
n = guppy.nat_var("n", module)

@guppy.declare(module)
def foo(q: qubit) -> int: ...

@guppy.declare(module)
def bar(q: qubit @ owned) -> int: ...

@guppy(module)
def test(qs: array[qubit, n] @ owned) -> array[int, n]:
return array(foo(q) + bar(q) for q in qs)

validate(module.compile())

Loading

0 comments on commit 81f6226

Please sign in to comment.