From 81f62264a1e8d0e969908ba4b0440a901fe5425d Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 16 Dec 2024 15:13:32 +0000 Subject: [PATCH] fix: Allow borrowing inside comprehensions --- examples/demo.ipynb | 6 +- guppylang/checker/linearity_checker.py | 45 ++++-- guppylang/compiler/expr_compiler.py | 3 + guppylang/nodes.py | 3 + .../borrow_after_consume.err | 12 ++ .../borrow_after_consume.py | 24 ++++ .../comprehension_errors/borrow_leaked1.err | 10 ++ .../comprehension_errors/borrow_leaked1.py | 20 +++ .../comprehension_errors/borrow_leaked2.err | 8 ++ .../comprehension_errors/borrow_leaked2.py | 20 +++ .../comprehension_errors/borrow_leaked3.err | 11 ++ .../comprehension_errors/borrow_leaked3.py | 20 +++ .../comprehension_errors/borrow_leaked4.err | 8 ++ .../comprehension_errors/borrow_leaked4.py | 20 +++ .../inout_errors/override_after_call.err | 6 +- .../error/inout_errors/unused_after_call.err | 6 +- tests/integration/test_array_comprehension.py | 75 +++++++++- tests/integration/test_comprehension.py | 133 ++++++++++++++++++ 18 files changed, 409 insertions(+), 21 deletions(-) create mode 100644 tests/error/comprehension_errors/borrow_after_consume.err create mode 100644 tests/error/comprehension_errors/borrow_after_consume.py create mode 100644 tests/error/comprehension_errors/borrow_leaked1.err create mode 100644 tests/error/comprehension_errors/borrow_leaked1.py create mode 100644 tests/error/comprehension_errors/borrow_leaked2.err create mode 100644 tests/error/comprehension_errors/borrow_leaked2.py create mode 100644 tests/error/comprehension_errors/borrow_leaked3.err create mode 100644 tests/error/comprehension_errors/borrow_leaked3.py create mode 100644 tests/error/comprehension_errors/borrow_leaked4.err create mode 100644 tests/error/comprehension_errors/borrow_leaked4.py diff --git a/examples/demo.ipynb b/examples/demo.ipynb index 331c370a..608a42e4 100644 --- a/examples/demo.ipynb +++ b/examples/demo.ipynb @@ -356,12 +356,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "Error: Linearity violation (at :7:7)\n", + "Error: Linearity violation (at :6:4)\n", " | \n", + "4 | @guppy(bad_module)\n", "5 | def bad(q: qubit @owned) -> qubit:\n", "6 | tmp = qubit()\n", - "7 | cx(tmp, q)\n", - " | ^^^ Variable `tmp` with linear type `qubit` is leaked\n", + " | ^^^ Variable `tmp` with linear type `qubit` is leaked\n", "\n", "Help: Make sure that `tmp` is consumed or returned to avoid the leak\n", "\n", diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index 5ff7534e..d2d6d323 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -317,13 +317,13 @@ def _reassign_inout_args(self, func_ty: FunctionType, call: AnyCall) -> None: if InputFlags.Inout in inp.flags: match arg: case PlaceNode(place=place): - self._reassign_single_inout_arg(place, arg) + self._reassign_single_inout_arg(place, place.defined_at or arg) case arg if inp.ty.linear: err = DropAfterCallError(arg, inp.ty, self._call_name(call)) err.add_sub_diagnostic(DropAfterCallError.Assign(None)) raise GuppyError(err) - def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None: + def _reassign_single_inout_arg(self, place: Place, node: AstNode) -> None: """Helper function to reassign a single borrowed argument after a function call.""" # Places involving subscripts are given back by visiting the `__setitem__` call @@ -496,6 +496,11 @@ def _check_comprehension( continue for leaf in leaf_places(place): x = leaf.id + # Also ignore borrowed variables + if x in inner_scope.used_parent and ( + inner_scope.used_parent[x].kind == UseKind.BORROW + ): + continue if not self.scope.used(x) and place.ty.linear: err = PlaceNotUsedError(place.defined_at, place) err.add_sub_diagnostic( @@ -508,6 +513,22 @@ def _check_comprehension( # Recursively check the remaining generators self._check_comprehension(gens, elt) + # Look for any linear variables that were borrowed from the outer scope + gen.borrowed_outer_places = [] + for x, use in inner_scope.used_parent.items(): + if use.kind == UseKind.BORROW: + # Since `x` was borrowed, we know that is now also assigned in the + # inner scope since it gets reassigned in the local scope after the + # borrow expires + place = inner_scope[x] + gen.borrowed_outer_places.append(place) + # Also mark this place as implicitly used so we don't complain about + # it later. + for leaf in leaf_places(place): + inner_scope.use( + leaf.id, InoutReturnSentinel(leaf), UseKind.RETURN + ) + # Check the iter finalizer so we record a final use of the iterator self.visit(gen.iterend) @@ -519,15 +540,17 @@ def _check_comprehension( if leaf.ty.linear and not inner_scope.used(x): raise GuppyTypeError(PlaceNotUsedError(leaf.defined_at, leaf)) - # On the other hand, we have to ensure that no linear places from the - # outer scope have been used inside the comprehension (they would be used - # multiple times since the comprehension body is executed repeatedly) - for x, use in inner_scope.used_parent.items(): - place = inner_scope[x] - if place.ty.linear: - raise GuppyTypeError( - ComprAlreadyUsedError(use.node, place, use.kind) - ) + # On the other hand, we have to ensure that no linear places from the + # outer scope have been used inside the comprehension (they would be used + # multiple times since the comprehension body is executed repeatedly) + for x, use in inner_scope.used_parent.items(): + place = inner_scope[x] + # The only exception are values that are only borrowed from the outer + # scope. These can be safely reassigned. + if use.kind == UseKind.BORROW: + self._reassign_single_inout_arg(place, use.node) + elif place.ty.linear: + raise GuppyTypeError(ComprAlreadyUsedError(use.node, place, use.kind)) def leaf_places(place: Place) -> Iterator[Place]: diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index d71b1c23..b3d8184f 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -543,6 +543,9 @@ def _build_generators( assert isinstance(gen.iter, PlaceNode) assert isinstance(gen.hasnext, PlaceNode) inputs = [gen.iter] + [PlaceNode(place=var) for var in loop_vars] + inputs += [ + PlaceNode(place=place) for place in gen.borrowed_outer_places + ] # Remember to finalize the iterator once we are done with it. Note that # we need to use partial in the callback, so that we bind the *current* # value of `gen` instead of only last diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 101be0ab..2eca1ea4 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -207,6 +207,8 @@ class DesugaredGenerator(ast.expr): hasnext: ast.expr ifs: list[ast.expr] + borrowed_outer_places: "list[Place]" + _fields = ( "iter_assign", "hasnext_assign", @@ -215,6 +217,7 @@ class DesugaredGenerator(ast.expr): "iter", "hasnext", "ifs", + "borrowed_outer_places", ) diff --git a/tests/error/comprehension_errors/borrow_after_consume.err b/tests/error/comprehension_errors/borrow_after_consume.err new file mode 100644 index 00000000..d925da2d --- /dev/null +++ b/tests/error/comprehension_errors/borrow_after_consume.err @@ -0,0 +1,12 @@ +Error: Linearity violation (at $FILE:21:16) + | +19 | @guppy(module) +20 | def foo(qs: list[qubit] @owned) -> list[int]: +21 | return [baz(q) for q in qs if bar(q)] + | ^ Variable `q` with linear type `qubit` cannot be borrowed + | ... + | +21 | return [baz(q) for q in qs if bar(q)] + | - since it was already consumed here + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/comprehension_errors/borrow_after_consume.py b/tests/error/comprehension_errors/borrow_after_consume.py new file mode 100644 index 00000000..7ea87a4d --- /dev/null +++ b/tests/error/comprehension_errors/borrow_after_consume.py @@ -0,0 +1,24 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.quantum import qubit +from guppylang.std.builtins import owned + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def bar(q: qubit @owned) -> int: ... + + +@guppy.declare(module) +def baz(q: qubit) -> int: ... + + +@guppy(module) +def foo(qs: list[qubit] @owned) -> list[int]: + return [baz(q) for q in qs if bar(q)] + + +module.compile() diff --git a/tests/error/comprehension_errors/borrow_leaked1.err b/tests/error/comprehension_errors/borrow_leaked1.err new file mode 100644 index 00000000..37128bf3 --- /dev/null +++ b/tests/error/comprehension_errors/borrow_leaked1.err @@ -0,0 +1,10 @@ +Error: Linearity violation (at $FILE:17:16) + | +15 | @guppy(module) +16 | def foo(n: int, q: qubit @owned) -> list[int]: +17 | return [bar(q) for _ in range(n)] + | ^ Variable `q` with linear type `qubit` is leaked + +Help: Make sure that `q` is consumed or returned to avoid the leak + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/comprehension_errors/borrow_leaked1.py b/tests/error/comprehension_errors/borrow_leaked1.py new file mode 100644 index 00000000..3ff39850 --- /dev/null +++ b/tests/error/comprehension_errors/borrow_leaked1.py @@ -0,0 +1,20 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.quantum import qubit +from guppylang.std.builtins import owned + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def bar(q: qubit) -> int: ... + + +@guppy(module) +def foo(n: int, q: qubit @owned) -> list[int]: + return [bar(q) for _ in range(n)] + + +module.compile() diff --git a/tests/error/comprehension_errors/borrow_leaked2.err b/tests/error/comprehension_errors/borrow_leaked2.err new file mode 100644 index 00000000..35887987 --- /dev/null +++ b/tests/error/comprehension_errors/borrow_leaked2.err @@ -0,0 +1,8 @@ +Error: Linearity violation (at $FILE:17:23) + | +15 | @guppy(module) +16 | def foo(qs: list[qubit] @owned) -> list[int]: +17 | return [bar(q) for q in qs] + | ^ Variable `q` with linear type `qubit` is leaked + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/comprehension_errors/borrow_leaked2.py b/tests/error/comprehension_errors/borrow_leaked2.py new file mode 100644 index 00000000..00fdb613 --- /dev/null +++ b/tests/error/comprehension_errors/borrow_leaked2.py @@ -0,0 +1,20 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.quantum import qubit +from guppylang.std.builtins import owned + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def bar(q: qubit) -> int: ... + + +@guppy(module) +def foo(qs: list[qubit] @owned) -> list[int]: + return [bar(q) for q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/borrow_leaked3.err b/tests/error/comprehension_errors/borrow_leaked3.err new file mode 100644 index 00000000..9b7efded --- /dev/null +++ b/tests/error/comprehension_errors/borrow_leaked3.err @@ -0,0 +1,11 @@ +Error: Linearity violation (at $FILE:17:18) + | +15 | @guppy(module) +16 | def foo(qs: list[qubit] @owned) -> list[int]: +17 | return [0 for q in qs if bar(q)] + | ^ Variable `q` with linear type `qubit` may be leaked ... + | +17 | return [0 for q in qs if bar(q)] + | ------ if this expression is `False` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/comprehension_errors/borrow_leaked3.py b/tests/error/comprehension_errors/borrow_leaked3.py new file mode 100644 index 00000000..194816f2 --- /dev/null +++ b/tests/error/comprehension_errors/borrow_leaked3.py @@ -0,0 +1,20 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.quantum import qubit +from guppylang.std.builtins import owned + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def bar(q: qubit) -> bool: ... + + +@guppy(module) +def foo(qs: list[qubit] @owned) -> list[int]: + return [0 for q in qs if bar(q)] + + +module.compile() diff --git a/tests/error/comprehension_errors/borrow_leaked4.err b/tests/error/comprehension_errors/borrow_leaked4.err new file mode 100644 index 00000000..3ea63d8d --- /dev/null +++ b/tests/error/comprehension_errors/borrow_leaked4.err @@ -0,0 +1,8 @@ +Error: Linearity violation (at $FILE:17:18) + | +15 | @guppy(module) +16 | def foo(qs: list[qubit] @owned) -> list[qubit]: +17 | return [r for q in qs for r in bar(q)] + | ^ Variable `q` with linear type `qubit` is leaked + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/comprehension_errors/borrow_leaked4.py b/tests/error/comprehension_errors/borrow_leaked4.py new file mode 100644 index 00000000..ea60c4f7 --- /dev/null +++ b/tests/error/comprehension_errors/borrow_leaked4.py @@ -0,0 +1,20 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.quantum import qubit +from guppylang.std.builtins import owned + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def bar(q: qubit) -> list[qubit]: ... + + +@guppy(module) +def foo(qs: list[qubit] @owned) -> list[qubit]: + return [r for q in qs for r in bar(q)] + + +module.compile() diff --git a/tests/error/inout_errors/override_after_call.err b/tests/error/inout_errors/override_after_call.err index 3b7b5290..2e686d7f 100644 --- a/tests/error/inout_errors/override_after_call.err +++ b/tests/error/inout_errors/override_after_call.err @@ -1,9 +1,9 @@ -Error: Linearity violation (at $FILE:16:13) +Error: Linearity violation (at $FILE:15:9) | +13 | 14 | @guppy(module) 15 | def test(q1: qubit @owned, q2: qubit @owned) -> tuple[qubit, qubit]: -16 | q1 = foo(q1, q2) - | ^^ Variable `q1` with linear type `qubit` is leaked + | ^^^^^^^^^^^^^^^^ Variable `q1` with linear type `qubit` is leaked Help: Make sure that `q1` is consumed or returned to avoid the leak diff --git a/tests/error/inout_errors/unused_after_call.err b/tests/error/inout_errors/unused_after_call.err index e2bc014c..17635bba 100644 --- a/tests/error/inout_errors/unused_after_call.err +++ b/tests/error/inout_errors/unused_after_call.err @@ -1,9 +1,9 @@ -Error: Linearity violation (at $FILE:16:7) +Error: Linearity violation (at $FILE:15:9) | +13 | 14 | @guppy(module) 15 | def test(q: qubit @owned) -> None: -16 | foo(q) - | ^ Variable `q` with linear type `qubit` is leaked + | ^^^^^^^^^^^^^^^ Variable `q` with linear type `qubit` is leaked Help: Make sure that `q` is consumed or returned to avoid the leak diff --git a/tests/integration/test_array_comprehension.py b/tests/integration/test_array_comprehension.py index bdd4a05a..e012b3a5 100644 --- a/tests/integration/test_array_comprehension.py +++ b/tests/integration/test_array_comprehension.py @@ -2,7 +2,7 @@ from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.std.builtins import array +from guppylang.std.builtins import array, owned from guppylang.std.quantum import qubit import guppylang.std.quantum_functional as quantum @@ -88,3 +88,76 @@ def test(xs: array[int, n]) -> array[int, n]: return array(x + 1 for x in xs) validate(module.compile()) + + +def test_borrow(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + n = guppy.nat_var("n", module) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy(module) + def test(q: qubit) -> array[int, n]: + return array(foo(q) for _ in range(n)) + + validate(module.compile()) + + +def test_borrow_twice(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + n = guppy.nat_var("n", module) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy(module) + def test(q: qubit) -> array[int, n]: + return array(foo(q) + foo(q) for _ in range(n)) + + validate(module.compile()) + + +def test_borrow_struct(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + n = guppy.nat_var("n", module) + + @guppy.struct(module) + class MyStruct: + q1: qubit + q2: qubit + + @guppy.declare(module) + def foo(s: MyStruct) -> int: ... + + @guppy(module) + def test(s: MyStruct) -> array[int, n]: + return array(foo(s) for _ in range(n)) + + validate(module.compile()) + + +def test_borrow_and_consume(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + n = guppy.nat_var("n", module) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy.declare(module) + def bar(q: qubit @ owned) -> int: ... + + @guppy(module) + def test(qs: array[qubit, n] @ owned) -> array[int, n]: + return array(foo(q) + bar(q) for q in qs) + + validate(module.compile()) + diff --git a/tests/integration/test_comprehension.py b/tests/integration/test_comprehension.py index 22967d74..bec9fc72 100644 --- a/tests/integration/test_comprehension.py +++ b/tests/integration/test_comprehension.py @@ -310,3 +310,136 @@ def test(mt: MyType, xs: list[int]) -> list[tuple[int, int]]: return [(x, x + y) for x in mt for y in xs] validate(module.compile()) + + +def test_borrow(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy(module) + def test(q: qubit, n: int) -> list[int]: + return [foo(q) for _ in range(n)] + + validate(module.compile()) + + +def test_borrow_nested(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy(module) + def test(q: qubit, n: int, m: int) -> list[int]: + return [foo(q) for _ in range(n) for _ in range(m)] + + validate(module.compile()) + + +def test_borrow_guarded(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy(module) + def test(q: qubit, n: int) -> list[int]: + return [foo(q) for i in range(n) if i % 2 == 0] + + validate(module.compile()) + + +def test_borrow_twice(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy(module) + def test(q: qubit, n: int) -> list[int]: + return [foo(q) + foo(q) for _ in range(n)] + + validate(module.compile()) + + +def test_borrow_in_guard(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy.declare(module) + def bar(q: qubit) -> bool: ... + + @guppy(module) + def test(q: qubit, n: int) -> list[int]: + return [foo(q) for _ in range(n) if bar(q)] + + validate(module.compile()) + + +def test_borrow_in_iter(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy(module) + def test(q: qubit @ owned) -> tuple[list[int], qubit]: + return [foo(q) for _ in range(foo(q))], q + + validate(module.compile()) + + +def test_borrow_struct(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy.struct(module) + class MyStruct: + q1: qubit + q2: qubit + + @guppy.declare(module) + def foo(s: MyStruct) -> int: ... + + @guppy(module) + def test(s: MyStruct, n: int) -> list[int]: + return [foo(s) for _ in range(n)] + + validate(module.compile()) + + +def test_borrow_and_consume(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy.declare(module) + def foo(q: qubit) -> int: ... + + @guppy.declare(module) + def bar(q: qubit @ owned) -> int: ... + + @guppy(module) + def test(qs: list[qubit] @ owned) -> list[int]: + return [foo(q) + bar(q) for q in qs] + + validate(module.compile()) + +