Skip to content

Commit

Permalink
feat: Update linearity checker to handle subscripts (#421)
Browse files Browse the repository at this point in the history
Closes #417 and closes #252. See #415 for context.
  • Loading branch information
mark-koch authored Aug 30, 2024
1 parent 61997cb commit ded9e1c
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 12 deletions.
76 changes: 64 additions & 12 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@
InoutReturnSentinel,
LocalCall,
PlaceNode,
SubscriptAccessAndDrop,
TensorCall,
)
from guppylang.tys.ty import FunctionType, InputFlags, StructType
from guppylang.tys.ty import (
FunctionType,
InputFlags,
StructType,
)


class Scope(Locals[PlaceId, Place]):
Expand Down Expand Up @@ -136,16 +141,31 @@ def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> Non
"ownership of the value.",
node,
)
for place in leaf_places(node.place):
x = place.id
if (use := self.scope.used(x)) and place.ty.linear:
# Places involving subscripts are handled differently since we ignore everything
# after the subscript for the purposes of linearity checking
if subscript := contains_subscript(node.place):
if not is_inout_arg and subscript.parent.ty.linear:
raise GuppyError(
f"{place.describe} with linear type `{place.ty}` was already "
"used (at {0})",
"Subscripting on expression with linear type "
f"`{subscript.parent.ty}` is only allowed in `@inout` position",
node,
[use],
)
self.scope.use(x, node)
self.scope.assign(subscript.item)
# Visiting the `__getitem__(place.parent, place.item)` call ensures that we
# linearity-check the parent and element.
self.visit(subscript.getitem_call)
# For all other places, we record uses of all leafs
else:
for place in leaf_places(node.place):
x = place.id
if (use := self.scope.used(x)) and place.ty.linear:
raise GuppyError(
f"{place.describe} with linear type `{place.ty}` was already "
"used (at {0})",
node,
[use],
)
self.scope.use(x, node)

def visit_Assign(self, node: ast.Assign) -> None:
self.visit(node.value)
Expand All @@ -169,10 +189,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
if InputFlags.Inout in inp.flags:
match arg:
case PlaceNode(place=place):
for leaf in leaf_places(place):
assert not isinstance(leaf, SubscriptAccess)
leaf = leaf.replace_defined_at(arg)
self.scope.assign(leaf)
self._reassign_single_inout_arg(place, arg)
case arg if inp.ty.linear:
raise GuppyError(
f"Inout argument with linear type `{inp.ty}` would be "
Expand All @@ -182,6 +199,19 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
arg,
)

def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None:
"""Helper function to reassign a single inout argument after a function call."""
# Places involving subscripts are given back by visiting the `__setitem__` call
if subscript := contains_subscript(place):
assert subscript.setitem_call is not None
self.visit(subscript.setitem_call)
self._reassign_single_inout_arg(subscript.parent, node)
else:
for leaf in leaf_places(place):
assert not isinstance(leaf, SubscriptAccess)
leaf = leaf.replace_defined_at(node)
self.scope.assign(leaf)

def visit_GlobalCall(self, node: GlobalCall) -> None:
func = self.globals[node.def_id]
assert isinstance(func, CallableDef)
Expand Down Expand Up @@ -214,6 +244,19 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None:
node.value,
)

def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> None:
# A subscript access on a value that is not a place. This means the value can no
# longer be accessed after the item has been projected out. Thus, this is only
# legal if the items in the container are not linear
elem_ty = get_type(node.getitem_expr)
if elem_ty.linear:
raise GuppyTypeError(
f"Remaining linear items with type `{elem_ty}` are not used", node
)
self.visit(node.item_expr)
self.scope.assign(node.item)
self.visit(node.getitem_expr)

def visit_Expr(self, node: ast.Expr) -> None:
# An expression statement where the return value is discarded
self.visit(node.value)
Expand Down Expand Up @@ -357,6 +400,15 @@ def leaf_places(place: Place) -> Iterator[Place]:
yield place


def contains_subscript(place: Place) -> SubscriptAccess | None:
"""Checks if a place contains a subscript access and returns the rightmost one."""
while not isinstance(place, Variable):
if isinstance(place, SubscriptAccess):
return place
place = place.parent
return None


def is_inout_var(place: Place) -> TypeGuard[Variable]:
"""Checks whether a place is an @inout variable."""
return isinstance(place, Variable) and InputFlags.Inout in place.flags
Expand Down
7 changes: 7 additions & 0 deletions tests/error/array_errors/subscript_after_use.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:18

16: @guppy(module)
17: def main(qs: array[qubit, 42]) -> array[qubit, 42]:
18: return foo(qs, qs[0])
^^
GuppyError: Variable `qs` with linear type `array[qubit, 42]` was already used (at 18:15)
21 changes: 21 additions & 0 deletions tests/error/array_errors/subscript_after_use.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array, inout
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy.declare(module)
def foo(qs: array[qubit, 42], q: qubit @inout) -> array[qubit, 42]: ...


@guppy(module)
def main(qs: array[qubit, 42]) -> array[qubit, 42]:
return foo(qs, qs[0])


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/subscript_drop.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:18

16: @guppy(module)
17: def main() -> qubit:
18: return foo()[0]
^^^^^^^^
GuppyTypeError: Remaining linear items with type `qubit` are not used
21 changes: 21 additions & 0 deletions tests/error/array_errors/subscript_drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy.declare(module)
def foo() -> array[qubit, 10]: ...


@guppy(module)
def main() -> qubit:
return foo()[0]


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/subscript_non_inout.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:14

12: @guppy(module)
13: def main(qs: array[qubit, 42]) -> tuple[qubit, array[qubit, 42]]:
14: q = qs[0]
^^^^^
GuppyError: Subscripting on expression with linear type `array[qubit, 42]` is only allowed in `@inout` position
18 changes: 18 additions & 0 deletions tests/error/array_errors/subscript_non_inout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main(qs: array[qubit, 42]) -> tuple[qubit, array[qubit, 42]]:
q = qs[0]
return q, qs


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/use_after_subscript.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:18

16: @guppy(module)
17: def main(qs: array[qubit, 42]) -> array[qubit, 42]:
18: return foo(qs[0], qs)
^^^^^
GuppyError: Variable `qs` with linear type `array[qubit, 42]` was already used (at 18:22)
21 changes: 21 additions & 0 deletions tests/error/array_errors/use_after_subscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array, inout
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy.declare(module)
def foo(q: qubit @inout, qs: array[qubit, 42]) -> array[qubit, 42]: ...


@guppy(module)
def main(qs: array[qubit, 42]) -> array[qubit, 42]:
return foo(qs[0], qs)


module.compile()

0 comments on commit ded9e1c

Please sign in to comment.