Skip to content

Commit

Permalink
feat!: Make linear types @inout by default; add @owned annotation (#486)
Browse files Browse the repository at this point in the history
Resolves #413 and #414.

BREAKING CHANGE: Linear function arguments are now borrowed by default;
removed the now redundant `@inout` annotation

---------

Co-authored-by: Agustín Borgna <[email protected]>
Co-authored-by: Agustin Borgna <[email protected]>
Co-authored-by: Mark Koch <[email protected]>
  • Loading branch information
4 people authored Sep 13, 2024
1 parent 00b18c1 commit e900c96
Show file tree
Hide file tree
Showing 167 changed files with 834 additions and 646 deletions.
14 changes: 8 additions & 6 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@
"metadata": {},
"outputs": [],
"source": [
"from guppylang.prelude.quantum import qubit, h, cx, measure\n",
"from guppylang.prelude.builtins import owned\n",
"from guppylang.prelude.quantum import qubit, measure\n",
"from guppylang.prelude.quantum_functional import h, cx\n",
"\n",
"module.load(qubit, h, cx, measure)\n",
"\n",
Expand Down Expand Up @@ -296,7 +298,7 @@
"Guppy compilation failed. Error in file <In [9]>:6\n",
"\n",
"4: @guppy(bad_module)\n",
"5: def bad(q: qubit) -> tuple[qubit, qubit]:\n",
"5: def bad(q: qubit @owned) -> tuple[qubit, qubit]:\n",
"6: return cx(q, q)\n",
" ^\n",
"GuppyError: Variable `q` with linear type `qubit` was already used (at 6:14)\n"
Expand All @@ -308,7 +310,7 @@
"bad_module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit) -> tuple[qubit, qubit]:\n",
"def bad(q: qubit @owned) -> tuple[qubit, qubit]:\n",
" return cx(q, q)\n",
"\n",
"bad_module.compile() # Raises an error"
Expand Down Expand Up @@ -338,7 +340,7 @@
"text": [
"Guppy compilation failed. Error in file <In [10]>:7\n",
"\n",
"5: def bad(q: qubit) -> qubit:\n",
"5: def bad(q: qubit @owned) -> qubit:\n",
"6: tmp = h(qubit())\n",
"7: tmp, q = cx(tmp, q)\n",
" ^^^\n",
Expand All @@ -351,7 +353,7 @@
"bad_module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit) -> qubit:\n",
"def bad(q: qubit @owned) -> qubit:\n",
" tmp = h(qubit())\n",
" tmp, q = cx(tmp, q)\n",
" #discard(tmp) # Compiler complains if tmp is not explicitly discarded\n",
Expand Down Expand Up @@ -420,7 +422,7 @@
" q2: qubit\n",
"\n",
" @guppy(module)\n",
" def method(self: \"QubitPair\") -> \"QubitPair\":\n",
" def method(self: \"QubitPair @owned\") -> \"QubitPair\":\n",
" self.q1 = h(self.q1)\n",
" self.q1, self.q2 = cx(self.q1, self.q2)\n",
" return self\n",
Expand Down
10 changes: 6 additions & 4 deletions examples/random_walk_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from collections.abc import Callable

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.angles import angle
from guppylang.prelude.builtins import py, result
from guppylang.prelude.quantum import cx, discard, h, measure, qubit, rz, x
from guppylang.prelude.builtins import owned, py, result
from guppylang.prelude.quantum import discard, measure, qubit
from guppylang.prelude.quantum_functional import cx, h, rz, x


sqrt_e = math.sqrt(math.e)
Expand All @@ -21,7 +23,7 @@
@guppy
def random_walk_phase_estimation(
eigenstate: Callable[[], qubit],
controlled_oracle: Callable[[qubit, qubit, float], tuple[qubit, qubit]],
controlled_oracle: Callable[[qubit @owned, qubit @owned, float], tuple[qubit, qubit]],
num_iters: int,
reset_rate: int,
mu: float,
Expand Down Expand Up @@ -62,7 +64,7 @@ def random_walk_phase_estimation(


@guppy
def example_controlled_oracle(q1: qubit, q2: qubit, t: float) -> tuple[qubit, qubit]:
def example_controlled_oracle(q1: qubit @owned, q2: qubit @owned, t: float) -> tuple[qubit, qubit]:
"""A controlled e^itH gate for the example Hamiltonian H = -0.5 * Z"""
# This is just a controlled rz gate
a = angle(-0.5 * t)
Expand Down
15 changes: 8 additions & 7 deletions examples/t_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from guppylang.decorator import guppy
from guppylang.prelude.angles import angle, pi
from guppylang.prelude.builtins import linst, py
from guppylang.prelude.builtins import linst, owned, py
from guppylang.prelude.quantum import (
cz,
discard,
h,
measure,
qubit,
)
from guppylang.prelude.quantum_functional import (
cz,
h,
rx,
rz,
)
Expand All @@ -18,24 +20,23 @@


@guppy
def ry(q: qubit, theta: angle) -> qubit:
def ry(q: qubit @owned, theta: angle) -> qubit:
q = rx(q, pi / 2)
q = rz(q, theta + pi)
q = rx(q, pi / 2)
return rz(q, pi)


# Preparation of approximate T state, from https://arxiv.org/abs/2310.12106
@guppy
def prepare_approx(q: qubit) -> qubit:
def prepare_approx(q: qubit @owned) -> qubit:
q = ry(q, angle(py(phi)))
return rz(q, pi / 4)


# The inverse of the [[5,3,1]] encoder in figure 3 of https://arxiv.org/abs/2208.01863
@guppy
def distill(
target: qubit, q0: qubit, q1: qubit, q2: qubit, q3: qubit
target: qubit @owned, q0: qubit @owned, q1: qubit @owned, q2: qubit @owned, q3: qubit @owned
) -> tuple[qubit, bool]:
"""First argument is the target qubit which will be returned from the circuit.
Other arguments are ancillae, which should also be in an approximate T state.
Expand Down
7 changes: 5 additions & 2 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,12 @@ def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> BB | None:
def visit_FunctionDef(
self, node: ast.FunctionDef, bb: BB, jumps: Jumps
) -> BB | None:
from guppylang.checker.func_checker import check_signature, parse_docstring
from guppylang.checker.func_checker import (
check_signature,
parse_function_with_docstring,
)

node, docstring = parse_docstring(node)
node, docstring = parse_function_with_docstring(node)

func_ty = check_signature(node, self.globals)
returns_none = isinstance(func_ty.output, NoneType)
Expand Down
2 changes: 1 addition & 1 deletion guppylang/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def analyze(
inout_vars: list[str],
) -> dict[BB, VariableStats[str]]:
stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
# Mark all @inout variables as implicitly used in the exit BB
# Mark all borrowed variables as implicitly used in the exit BB
stats[self.exit_bb].used |= {x: InoutReturnSentinel(var=x) for x in inout_vars}
self.live_before = LivenessAnalysis(stats).run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
Expand Down
20 changes: 11 additions & 9 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,9 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:

def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
flags = InputFlags.Owned if ty.linear else InputFlags.NoFlags
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False)
[FuncInput(ty, flags)], ExistentialTypeVar.fresh("Iter", False)
)
expr, ty = self.synthesize_instance_func(
node.value, [], "__iter__", "not iterable", exp_sig
Expand All @@ -648,17 +649,17 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:

def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty])
)
flags = InputFlags.Owned if ty.linear else InputFlags.NoFlags
exp_sig = FunctionType([FuncInput(ty, flags)], TupleType([bool_type(), ty]))
return self.synthesize_instance_func(
node.value, [], "__hasnext__", "not an iterator", exp_sig, True
)

def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
flags = InputFlags.Owned if ty.linear else InputFlags.NoFlags
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)],
[FuncInput(ty, flags)],
TupleType([ExistentialTypeVar.fresh("T", False), ty]),
)
return self.synthesize_instance_func(
Expand All @@ -667,7 +668,8 @@ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:

def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType())
flags = InputFlags.Owned if ty.linear else InputFlags.NoFlags
exp_sig = FunctionType([FuncInput(ty, flags)], NoneType())
return self.synthesize_instance_func(
node.value, [], "__end__", "not an iterator", exp_sig, True
)
Expand Down Expand Up @@ -814,7 +816,7 @@ def type_check_args(


def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
"""Performs additional checks for place arguments in @inout position.
"""Performs additional checks for borrowed place arguments.
In particular, we need to check that places involving `place[item]` subscripts
implement the corresponding `__setitem__` method.
Expand All @@ -830,7 +832,7 @@ def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
[
FuncInput(parent.ty, InputFlags.Inout),
FuncInput(item.ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.Owned),
],
NoneType(),
)
Expand All @@ -843,7 +845,7 @@ def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
setitem_args[0],
setitem_args[1:],
"__setitem__",
"not allowed in a subscripted `@inout` position",
"unable to have subscripted elements borrowed",
exp_sig,
True,
)
Expand Down
6 changes: 4 additions & 2 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
)


def parse_docstring(func_ast: ast.FunctionDef) -> tuple[ast.FunctionDef, str | None]:
def parse_function_with_docstring(
func_ast: ast.FunctionDef,
) -> tuple[ast.FunctionDef, str | None]:
"""Check if the first line of a function is a docstring.
If it is, return the function with the docstring removed, plus the docstring.
Expand All @@ -185,7 +187,7 @@ def parse_docstring(func_ast: ast.FunctionDef) -> tuple[ast.FunctionDef, str | N


def inout_var_names(func_ty: FunctionType) -> list[str]:
"""Returns the names of all `@inout` arguments of a function type."""
"""Returns the names of all borrowed arguments in a function type."""
assert func_ty.input_names is not None
return [
x
Expand Down
37 changes: 19 additions & 18 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def new_scope(self) -> Generator[Scope, None, None]:
self.scope = scope

def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> None:
# Usage of @inout variables is generally forbidden. The only exception is using
# them in another @inout position of a function call. In that case, our
# Usage of borrowed variables is generally forbidden. The only exception is
# letting them be borrowed by another function call. In that case, our
# `_visit_call_args` helper will set `is_inout_arg=True`.
if is_inout_var(node.place) and not is_inout_arg:
raise GuppyError(
f"{node.place.describe} may only be used in an `@inout` position since "
"it is annotated as `@inout`. Consider removing the annotation to get "
f"{node.place.describe} may not be used in an `@owned` position since "
"it isn't owned. Consider adding a `@owned` annotation to get "
"ownership of the value.",
node,
)
Expand All @@ -150,7 +150,7 @@ def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> Non
if not is_inout_arg and subscript.parent.ty.linear:
raise GuppyError(
"Subscripting on expression with linear type "
f"`{subscript.parent.ty}` is only allowed in `@inout` position",
f"`{subscript.parent.ty}` is not allowed in `@owned` position",
node,
)
self.scope.assign(subscript.item)
Expand Down Expand Up @@ -187,23 +187,24 @@ def _visit_call_args(self, func_ty: FunctionType, args: list[ast.expr]) -> None:
self.visit(arg)

def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> None:
"""Helper function to reassign the @inout arguments after a function call."""
"""Helper function to reassign the borrowed arguments after a function call."""
for inp, arg in zip(func_ty.inputs, args, strict=True):
if InputFlags.Inout in inp.flags:
match arg:
case PlaceNode(place=place):
self._reassign_single_inout_arg(place, arg)
case arg if inp.ty.linear:
raise GuppyError(
f"Inout argument with linear type `{inp.ty}` would be "
f"Borrowed argument with linear type `{inp.ty}` would be "
"dropped after this function call. Consider assigning the "
"expression to a local variable before passing it to the "
"function.",
arg,
)

def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None:
"""Helper function to reassign a single inout argument after a function call."""
"""Helper function to reassign a single borrowed argument after a function
call."""
# Places involving subscripts are given back by visiting the `__setitem__` call
if subscript := contains_subscript(place):
assert subscript.setitem_call is not None
Expand Down Expand Up @@ -311,11 +312,11 @@ def _check_assign_targets(self, targets: list[ast.expr]) -> None:
[target] = targets
for tgt in find_nodes(lambda n: isinstance(n, PlaceNode), target):
assert isinstance(tgt, PlaceNode)
# Special error message for shadowing of @inout vars
# Special error message for shadowing of borrowed vars
x = tgt.place.id
if x in self.scope.vars and is_inout_var(self.scope[x]):
raise GuppyError(
f"Assignment shadows argument `{tgt.place}` annotated as `@inout`. "
f"Assignment shadows borrowed argument `{tgt.place}`. "
"Consider assigning to a different name.",
tgt,
)
Expand Down Expand Up @@ -432,7 +433,7 @@ def contains_subscript(place: Place) -> SubscriptAccess | None:


def is_inout_var(place: Place) -> TypeGuard[Variable]:
"""Checks whether a place is an @inout variable."""
"""Checks whether a place is a borrowed variable."""
return isinstance(place, Variable) and InputFlags.Inout in place.flags


Expand All @@ -452,7 +453,7 @@ def check_cfg_linearity(
for bb in cfg.bbs
}

# Check that @inout vars are not being shadowed. This would also be caught by
# Check that borrowed vars are not being shadowed. This would also be caught by
# the dataflow analysis below, however we can give nicer error messages here.
for bb, scope in scopes.items():
if bb == cfg.entry_bb:
Expand All @@ -466,12 +467,12 @@ def check_cfg_linearity(
entry_place = entry_scope[x]
if is_inout_var(entry_place):
raise GuppyError(
f"Assignment shadows argument `{entry_place}` annotated as "
"`@inout`. Consider assigning to a different name.",
f"Assignment shadows borrowed argument `{entry_place}`. "
"Consider assigning to a different name.",
place.defined_at,
)

# Mark the @inout variables as implicitly used in the exit BB
# Mark the borrowed variables as implicitly used in the exit BB
exit_scope = scopes[cfg.exit_bb]
for var in cfg.entry_bb.sig.input_row:
if InputFlags.Inout in var.flags:
Expand All @@ -498,13 +499,13 @@ def check_cfg_linearity(
if place.ty.linear and (prev_use := scope.used(x)):
use = use_scope.used_parent[x]
# Special case if this is a use arising from the implicit returning
# of an @inout argument
# of a borrowed argument
if isinstance(use, InoutReturnSentinel):
assert isinstance(use.var, Variable)
assert InputFlags.Inout in use.var.flags
raise GuppyError(
f"Argument `{use.var}` annotated as `@inout` cannot be "
f"returned to the caller since `{place}` is used at {{0}}. "
f"Borrowed argument `{use.var}` cannot be returned "
f"to the caller since `{place}` is used at {{0}}. "
f"Consider writing a value back into `{place}` before "
"returning.",
use.var.defined_at,
Expand Down
6 changes: 3 additions & 3 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,18 @@ def _update_inout_ports(
inout_ports: Iterator[Wire],
func_ty: FunctionType,
) -> None:
"""Helper method that updates the ports for @inout arguments after a call."""
"""Helper method that updates the ports for borrowed arguments after a call."""
for inp, arg in zip(func_ty.inputs, args, strict=True):
if InputFlags.Inout in inp.flags:
# Linearity checker ensures that inout arguments that are not places
# Linearity checker ensures that borrowed arguments that are not places
# can be safely dropped after the call returns
if not isinstance(arg, PlaceNode):
next(inout_ports)
continue
self.dfg[arg.place] = next(inout_ports)
# Places involving subscripts need to generate code for the appropriate
# `__setitem__` call. Nested subscripts are handled automatically since
# `arg.place.parent` occurs as an inout arg of this call, so will also
# `arg.place.parent` occurs as an arg of this call, so will also
# be recursively reassigned.
if subscript := contains_subscript(arg.place):
self.visit(subscript.setitem_call)
Expand Down
Loading

0 comments on commit e900c96

Please sign in to comment.