From ded9e1c9332e8d6ae08c9535f16ac2c5c3cac66c Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:59:16 +0100 Subject: [PATCH] feat: Update linearity checker to handle subscripts (#421) Closes #417 and closes #252. See #415 for context. --- guppylang/checker/linearity_checker.py | 76 ++++++++++++++++--- .../array_errors/subscript_after_use.err | 7 ++ .../error/array_errors/subscript_after_use.py | 21 +++++ tests/error/array_errors/subscript_drop.err | 7 ++ tests/error/array_errors/subscript_drop.py | 21 +++++ .../array_errors/subscript_non_inout.err | 7 ++ .../error/array_errors/subscript_non_inout.py | 18 +++++ .../array_errors/use_after_subscript.err | 7 ++ .../error/array_errors/use_after_subscript.py | 21 +++++ 9 files changed, 173 insertions(+), 12 deletions(-) create mode 100644 tests/error/array_errors/subscript_after_use.err create mode 100644 tests/error/array_errors/subscript_after_use.py create mode 100644 tests/error/array_errors/subscript_drop.err create mode 100644 tests/error/array_errors/subscript_drop.py create mode 100644 tests/error/array_errors/subscript_non_inout.err create mode 100644 tests/error/array_errors/subscript_non_inout.py create mode 100644 tests/error/array_errors/use_after_subscript.err create mode 100644 tests/error/array_errors/use_after_subscript.py diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index 7a518646..21cb1b6a 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -32,9 +32,14 @@ InoutReturnSentinel, LocalCall, PlaceNode, + SubscriptAccessAndDrop, TensorCall, ) -from guppylang.tys.ty import FunctionType, InputFlags, StructType +from guppylang.tys.ty import ( + FunctionType, + InputFlags, + StructType, +) class Scope(Locals[PlaceId, Place]): @@ -136,16 +141,31 @@ def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> Non "ownership of the value.", node, ) - for place in leaf_places(node.place): - x = place.id - if (use := self.scope.used(x)) and place.ty.linear: + # Places involving subscripts are handled differently since we ignore everything + # after the subscript for the purposes of linearity checking + if subscript := contains_subscript(node.place): + if not is_inout_arg and subscript.parent.ty.linear: raise GuppyError( - f"{place.describe} with linear type `{place.ty}` was already " - "used (at {0})", + "Subscripting on expression with linear type " + f"`{subscript.parent.ty}` is only allowed in `@inout` position", node, - [use], ) - self.scope.use(x, node) + self.scope.assign(subscript.item) + # Visiting the `__getitem__(place.parent, place.item)` call ensures that we + # linearity-check the parent and element. + self.visit(subscript.getitem_call) + # For all other places, we record uses of all leafs + else: + for place in leaf_places(node.place): + x = place.id + if (use := self.scope.used(x)) and place.ty.linear: + raise GuppyError( + f"{place.describe} with linear type `{place.ty}` was already " + "used (at {0})", + node, + [use], + ) + self.scope.use(x, node) def visit_Assign(self, node: ast.Assign) -> None: self.visit(node.value) @@ -169,10 +189,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N if InputFlags.Inout in inp.flags: match arg: case PlaceNode(place=place): - for leaf in leaf_places(place): - assert not isinstance(leaf, SubscriptAccess) - leaf = leaf.replace_defined_at(arg) - self.scope.assign(leaf) + self._reassign_single_inout_arg(place, arg) case arg if inp.ty.linear: raise GuppyError( f"Inout argument with linear type `{inp.ty}` would be " @@ -182,6 +199,19 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N arg, ) + def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None: + """Helper function to reassign a single inout argument after a function call.""" + # Places involving subscripts are given back by visiting the `__setitem__` call + if subscript := contains_subscript(place): + assert subscript.setitem_call is not None + self.visit(subscript.setitem_call) + self._reassign_single_inout_arg(subscript.parent, node) + else: + for leaf in leaf_places(place): + assert not isinstance(leaf, SubscriptAccess) + leaf = leaf.replace_defined_at(node) + self.scope.assign(leaf) + def visit_GlobalCall(self, node: GlobalCall) -> None: func = self.globals[node.def_id] assert isinstance(func, CallableDef) @@ -214,6 +244,19 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None: node.value, ) + def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> None: + # A subscript access on a value that is not a place. This means the value can no + # longer be accessed after the item has been projected out. Thus, this is only + # legal if the items in the container are not linear + elem_ty = get_type(node.getitem_expr) + if elem_ty.linear: + raise GuppyTypeError( + f"Remaining linear items with type `{elem_ty}` are not used", node + ) + self.visit(node.item_expr) + self.scope.assign(node.item) + self.visit(node.getitem_expr) + def visit_Expr(self, node: ast.Expr) -> None: # An expression statement where the return value is discarded self.visit(node.value) @@ -357,6 +400,15 @@ def leaf_places(place: Place) -> Iterator[Place]: yield place +def contains_subscript(place: Place) -> SubscriptAccess | None: + """Checks if a place contains a subscript access and returns the rightmost one.""" + while not isinstance(place, Variable): + if isinstance(place, SubscriptAccess): + return place + place = place.parent + return None + + def is_inout_var(place: Place) -> TypeGuard[Variable]: """Checks whether a place is an @inout variable.""" return isinstance(place, Variable) and InputFlags.Inout in place.flags diff --git a/tests/error/array_errors/subscript_after_use.err b/tests/error/array_errors/subscript_after_use.err new file mode 100644 index 00000000..098b28ed --- /dev/null +++ b/tests/error/array_errors/subscript_after_use.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main(qs: array[qubit, 42]) -> array[qubit, 42]: +18: return foo(qs, qs[0]) + ^^ +GuppyError: Variable `qs` with linear type `array[qubit, 42]` was already used (at 18:15) diff --git a/tests/error/array_errors/subscript_after_use.py b/tests/error/array_errors/subscript_after_use.py new file mode 100644 index 00000000..37c9e6d2 --- /dev/null +++ b/tests/error/array_errors/subscript_after_use.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array, inout +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo(qs: array[qubit, 42], q: qubit @inout) -> array[qubit, 42]: ... + + +@guppy(module) +def main(qs: array[qubit, 42]) -> array[qubit, 42]: + return foo(qs, qs[0]) + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/subscript_drop.err b/tests/error/array_errors/subscript_drop.err new file mode 100644 index 00000000..448abbbc --- /dev/null +++ b/tests/error/array_errors/subscript_drop.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main() -> qubit: +18: return foo()[0] + ^^^^^^^^ +GuppyTypeError: Remaining linear items with type `qubit` are not used diff --git a/tests/error/array_errors/subscript_drop.py b/tests/error/array_errors/subscript_drop.py new file mode 100644 index 00000000..efd91b92 --- /dev/null +++ b/tests/error/array_errors/subscript_drop.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo() -> array[qubit, 10]: ... + + +@guppy(module) +def main() -> qubit: + return foo()[0] + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/subscript_non_inout.err b/tests/error/array_errors/subscript_non_inout.err new file mode 100644 index 00000000..55df1950 --- /dev/null +++ b/tests/error/array_errors/subscript_non_inout.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:14 + +12: @guppy(module) +13: def main(qs: array[qubit, 42]) -> tuple[qubit, array[qubit, 42]]: +14: q = qs[0] + ^^^^^ +GuppyError: Subscripting on expression with linear type `array[qubit, 42]` is only allowed in `@inout` position diff --git a/tests/error/array_errors/subscript_non_inout.py b/tests/error/array_errors/subscript_non_inout.py new file mode 100644 index 00000000..70756ead --- /dev/null +++ b/tests/error/array_errors/subscript_non_inout.py @@ -0,0 +1,18 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main(qs: array[qubit, 42]) -> tuple[qubit, array[qubit, 42]]: + q = qs[0] + return q, qs + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/use_after_subscript.err b/tests/error/array_errors/use_after_subscript.err new file mode 100644 index 00000000..566f65de --- /dev/null +++ b/tests/error/array_errors/use_after_subscript.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main(qs: array[qubit, 42]) -> array[qubit, 42]: +18: return foo(qs[0], qs) + ^^^^^ +GuppyError: Variable `qs` with linear type `array[qubit, 42]` was already used (at 18:22) diff --git a/tests/error/array_errors/use_after_subscript.py b/tests/error/array_errors/use_after_subscript.py new file mode 100644 index 00000000..49b5e808 --- /dev/null +++ b/tests/error/array_errors/use_after_subscript.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array, inout +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo(q: qubit @inout, qs: array[qubit, 42]) -> array[qubit, 42]: ... + + +@guppy(module) +def main(qs: array[qubit, 42]) -> array[qubit, 42]: + return foo(qs[0], qs) + + +module.compile() \ No newline at end of file