Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Make linear types @inout by default; add @owned annotation #486

Merged
merged 28 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
40e9701
Drive-by: change name of parse_docstring
croyzor Sep 10, 2024
ec3072e
Drive-by: Say which identifiers are unknown
croyzor Sep 10, 2024
091802e
checkpoint: integration tests
croyzor Sep 11, 2024
1edb8f0
Update comprehension error tests
croyzor Sep 11, 2024
24e7d04
Update inout error files
croyzor Sep 11, 2024
d5db2e7
Update linear error files
croyzor Sep 11, 2024
e64155f
Update misc error tests
croyzor Sep 12, 2024
8a0976d
Only print the missing module for attributes
croyzor Sep 12, 2024
b982bc4
Update more error files
croyzor Sep 12, 2024
ea04238
Merge remote-tracking branch 'origin/main' into feat/owned
croyzor Sep 12, 2024
217d471
Fix printing of flags
croyzor Sep 12, 2024
718c2c7
Review comments
croyzor Sep 12, 2024
f07175a
Add new test case
croyzor Sep 12, 2024
6fa60fd
Update error files
croyzor Sep 12, 2024
8f9b7de
Review comment: update used_twice
croyzor Sep 12, 2024
8c911ee
revert pyproject change
croyzor Sep 12, 2024
45be26d
Shut up mypy
croyzor Sep 12, 2024
11d9642
Update references to @inout
croyzor Sep 12, 2024
fe69a36
Merge branch 'main' into feat/owned
croyzor Sep 12, 2024
df989bb
Take ownership of linear args to struct constructors
croyzor Sep 13, 2024
fc75c81
Add tests for linear struct initialisation
croyzor Sep 13, 2024
b8fe162
Make check an assert
croyzor Sep 13, 2024
b13e547
chore: Unskip now-passing tests (#475)
aborgna-q Sep 12, 2024
f32cd61
Merge remote-tracking branch 'origin/main' into feat/owned
croyzor Sep 13, 2024
a34d824
Update guppylang/tys/ty.py
croyzor Sep 13, 2024
46d32a1
Update guppylang/checker/linearity_checker.py
croyzor Sep 13, 2024
7c7e2d0
Update guppylang/tys/parsing.py
croyzor Sep 13, 2024
fdf10f2
Update tests; mypy annotations
croyzor Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@
"metadata": {},
"outputs": [],
"source": [
"from guppylang.prelude.quantum import qubit, h, cx, measure\n",
"from guppylang.prelude.builtins import owned\n",
"from guppylang.prelude.quantum import qubit, measure\n",
"from guppylang.prelude.quantum_functional import h, cx\n",
"\n",
"module.load(qubit, h, cx, measure)\n",
"\n",
Expand Down Expand Up @@ -296,7 +298,7 @@
"Guppy compilation failed. Error in file <In [9]>:6\n",
"\n",
"4: @guppy(bad_module)\n",
"5: def bad(q: qubit) -> tuple[qubit, qubit]:\n",
"5: def bad(q: qubit @owned) -> tuple[qubit, qubit]:\n",
"6: return cx(q, q)\n",
" ^\n",
"GuppyError: Variable `q` with linear type `qubit` was already used (at 6:14)\n"
Expand All @@ -308,7 +310,7 @@
"bad_module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit) -> tuple[qubit, qubit]:\n",
"def bad(q: qubit @owned) -> tuple[qubit, qubit]:\n",
" return cx(q, q)\n",
"\n",
"bad_module.compile() # Raises an error"
Expand Down Expand Up @@ -338,7 +340,7 @@
"text": [
"Guppy compilation failed. Error in file <In [10]>:7\n",
"\n",
"5: def bad(q: qubit) -> qubit:\n",
"5: def bad(q: qubit @owned) -> qubit:\n",
"6: tmp = h(qubit())\n",
"7: tmp, q = cx(tmp, q)\n",
" ^^^\n",
Expand All @@ -351,7 +353,7 @@
"bad_module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit) -> qubit:\n",
"def bad(q: qubit @owned) -> qubit:\n",
" tmp = h(qubit())\n",
" tmp, q = cx(tmp, q)\n",
" #discard(tmp) # Compiler complains if tmp is not explicitly discarded\n",
Expand Down Expand Up @@ -420,7 +422,7 @@
" q2: qubit\n",
"\n",
" @guppy(module)\n",
" def method(self: \"QubitPair\") -> \"QubitPair\":\n",
" def method(self: \"QubitPair @owned\") -> \"QubitPair\":\n",
" self.q1 = h(self.q1)\n",
" self.q1, self.q2 = cx(self.q1, self.q2)\n",
" return self\n",
Expand Down
10 changes: 6 additions & 4 deletions examples/random_walk_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from collections.abc import Callable

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.angles import angle
from guppylang.prelude.builtins import py, result
from guppylang.prelude.quantum import cx, discard, h, measure, qubit, rz, x
from guppylang.prelude.builtins import owned, py, result
from guppylang.prelude.quantum import discard, measure, qubit
from guppylang.prelude.quantum_functional import cx, h, rz, x


sqrt_e = math.sqrt(math.e)
Expand All @@ -21,7 +23,7 @@
@guppy
def random_walk_phase_estimation(
eigenstate: Callable[[], qubit],
controlled_oracle: Callable[[qubit, qubit, float], tuple[qubit, qubit]],
controlled_oracle: Callable[[qubit @owned, qubit @owned, float], tuple[qubit, qubit]],
num_iters: int,
reset_rate: int,
mu: float,
Expand Down Expand Up @@ -62,7 +64,7 @@ def random_walk_phase_estimation(


@guppy
def example_controlled_oracle(q1: qubit, q2: qubit, t: float) -> tuple[qubit, qubit]:
def example_controlled_oracle(q1: qubit @owned, q2: qubit @owned, t: float) -> tuple[qubit, qubit]:
"""A controlled e^itH gate for the example Hamiltonian H = -0.5 * Z"""
# This is just a controlled rz gate
a = angle(-0.5 * t)
Expand Down
15 changes: 8 additions & 7 deletions examples/t_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from guppylang.decorator import guppy
from guppylang.prelude.angles import angle, pi
from guppylang.prelude.builtins import linst, py
from guppylang.prelude.builtins import linst, owned, py
from guppylang.prelude.quantum import (
cz,
discard,
h,
measure,
qubit,
)
from guppylang.prelude.quantum_functional import (
cz,
h,
rx,
rz,
)
Expand All @@ -18,24 +20,23 @@


@guppy
def ry(q: qubit, theta: angle) -> qubit:
def ry(q: qubit @owned, theta: angle) -> qubit:
q = rx(q, pi / 2)
q = rz(q, theta + pi)
q = rx(q, pi / 2)
return rz(q, pi)


# Preparation of approximate T state, from https://arxiv.org/abs/2310.12106
@guppy
def prepare_approx(q: qubit) -> qubit:
def prepare_approx(q: qubit @owned) -> qubit:
q = ry(q, angle(py(phi)))
return rz(q, pi / 4)


# The inverse of the [[5,3,1]] encoder in figure 3 of https://arxiv.org/abs/2208.01863
@guppy
def distill(
target: qubit, q0: qubit, q1: qubit, q2: qubit, q3: qubit
target: qubit @owned, q0: qubit @owned, q1: qubit @owned, q2: qubit @owned, q3: qubit @owned
) -> tuple[qubit, bool]:
"""First argument is the target qubit which will be returned from the circuit.
Other arguments are ancillae, which should also be in an approximate T state.
Expand Down
7 changes: 5 additions & 2 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,12 @@ def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> BB | None:
def visit_FunctionDef(
self, node: ast.FunctionDef, bb: BB, jumps: Jumps
) -> BB | None:
from guppylang.checker.func_checker import check_signature, parse_docstring
from guppylang.checker.func_checker import (
check_signature,
parse_function_with_docstring,
)

node, docstring = parse_docstring(node)
node, docstring = parse_function_with_docstring(node)

func_ty = check_signature(node, self.globals)
returns_none = isinstance(func_ty.output, NoneType)
Expand Down
2 changes: 1 addition & 1 deletion guppylang/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def analyze(
inout_vars: list[str],
) -> dict[BB, VariableStats[str]]:
stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
# Mark all @inout variables as implicitly used in the exit BB
# Mark all borrowed variables as implicitly used in the exit BB
stats[self.exit_bb].used |= {x: InoutReturnSentinel(var=x) for x in inout_vars}
self.live_before = LivenessAnalysis(stats).run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
Expand Down
20 changes: 11 additions & 9 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,9 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:

def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
flags = InputFlags.Owned if ty.linear else InputFlags.NoFlags
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False)
[FuncInput(ty, flags)], ExistentialTypeVar.fresh("Iter", False)
)
expr, ty = self.synthesize_instance_func(
node.value, [], "__iter__", "not iterable", exp_sig
Expand All @@ -648,17 +649,17 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:

def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty])
)
flags = InputFlags.Owned if ty.linear else InputFlags.NoFlags
exp_sig = FunctionType([FuncInput(ty, flags)], TupleType([bool_type(), ty]))
return self.synthesize_instance_func(
node.value, [], "__hasnext__", "not an iterator", exp_sig, True
)

def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
flags = InputFlags.Owned if ty.linear else InputFlags.NoFlags
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)],
[FuncInput(ty, flags)],
TupleType([ExistentialTypeVar.fresh("T", False), ty]),
)
return self.synthesize_instance_func(
Expand All @@ -667,7 +668,8 @@ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:

def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType())
flags = InputFlags.Owned if ty.linear else InputFlags.NoFlags
exp_sig = FunctionType([FuncInput(ty, flags)], NoneType())
return self.synthesize_instance_func(
node.value, [], "__end__", "not an iterator", exp_sig, True
)
Expand Down Expand Up @@ -814,7 +816,7 @@ def type_check_args(


def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
"""Performs additional checks for place arguments in @inout position.
"""Performs additional checks for borrowed place arguments.

In particular, we need to check that places involving `place[item]` subscripts
implement the corresponding `__setitem__` method.
Expand All @@ -830,7 +832,7 @@ def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
[
FuncInput(parent.ty, InputFlags.Inout),
FuncInput(item.ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.Owned),
],
NoneType(),
)
Expand All @@ -843,7 +845,7 @@ def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
setitem_args[0],
setitem_args[1:],
"__setitem__",
"not allowed in a subscripted `@inout` position",
"unable to have subscripted elements borrowed",
exp_sig,
True,
)
Expand Down
6 changes: 4 additions & 2 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
)


def parse_docstring(func_ast: ast.FunctionDef) -> tuple[ast.FunctionDef, str | None]:
def parse_function_with_docstring(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by: I updated the name of this function because I found it too confusing

func_ast: ast.FunctionDef,
) -> tuple[ast.FunctionDef, str | None]:
"""Check if the first line of a function is a docstring.

If it is, return the function with the docstring removed, plus the docstring.
Expand All @@ -185,7 +187,7 @@ def parse_docstring(func_ast: ast.FunctionDef) -> tuple[ast.FunctionDef, str | N


def inout_var_names(func_ty: FunctionType) -> list[str]:
"""Returns the names of all `@inout` arguments of a function type."""
"""Returns the names of all borrowed arguments in a function type."""
assert func_ty.input_names is not None
return [
x
Expand Down
37 changes: 19 additions & 18 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def new_scope(self) -> Generator[Scope, None, None]:
self.scope = scope

def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> None:
# Usage of @inout variables is generally forbidden. The only exception is using
# them in another @inout position of a function call. In that case, our
# Usage of borrowed variables is generally forbidden. The only exception is
# letting them be borrowed by another function call. In that case, our
# `_visit_call_args` helper will set `is_inout_arg=True`.
if is_inout_var(node.place) and not is_inout_arg:
raise GuppyError(
f"{node.place.describe} may only be used in an `@inout` position since "
"it is annotated as `@inout`. Consider removing the annotation to get "
f"{node.place.describe} may not be used in an `@owned` position since "
"it lacks an `@owned` annotation. Consider adding `@owned` to get "
croyzor marked this conversation as resolved.
Show resolved Hide resolved
"ownership of the value.",
node,
)
Expand All @@ -150,7 +150,7 @@ def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> Non
if not is_inout_arg and subscript.parent.ty.linear:
raise GuppyError(
"Subscripting on expression with linear type "
f"`{subscript.parent.ty}` is only allowed in `@inout` position",
f"`{subscript.parent.ty}` is not allowed in `@owned` position",
node,
)
self.scope.assign(subscript.item)
Expand Down Expand Up @@ -187,23 +187,24 @@ def _visit_call_args(self, func_ty: FunctionType, args: list[ast.expr]) -> None:
self.visit(arg)

def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> None:
"""Helper function to reassign the @inout arguments after a function call."""
"""Helper function to reassign the borrowed arguments after a function call."""
for inp, arg in zip(func_ty.inputs, args, strict=True):
if InputFlags.Inout in inp.flags:
match arg:
case PlaceNode(place=place):
self._reassign_single_inout_arg(place, arg)
case arg if inp.ty.linear:
raise GuppyError(
f"Inout argument with linear type `{inp.ty}` would be "
f"Borrowed argument with linear type `{inp.ty}` would be "
"dropped after this function call. Consider assigning the "
"expression to a local variable before passing it to the "
"function.",
arg,
)

def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None:
"""Helper function to reassign a single inout argument after a function call."""
"""Helper function to reassign a single borrowed argument after a function
call."""
# Places involving subscripts are given back by visiting the `__setitem__` call
if subscript := contains_subscript(place):
assert subscript.setitem_call is not None
Expand Down Expand Up @@ -311,11 +312,11 @@ def _check_assign_targets(self, targets: list[ast.expr]) -> None:
[target] = targets
for tgt in find_nodes(lambda n: isinstance(n, PlaceNode), target):
assert isinstance(tgt, PlaceNode)
# Special error message for shadowing of @inout vars
# Special error message for shadowing of borrowed vars
x = tgt.place.id
if x in self.scope.vars and is_inout_var(self.scope[x]):
raise GuppyError(
f"Assignment shadows argument `{tgt.place}` annotated as `@inout`. "
f"Assignment shadows borrowed argument `{tgt.place}`. "
"Consider assigning to a different name.",
tgt,
)
Expand Down Expand Up @@ -432,7 +433,7 @@ def contains_subscript(place: Place) -> SubscriptAccess | None:


def is_inout_var(place: Place) -> TypeGuard[Variable]:
"""Checks whether a place is an @inout variable."""
"""Checks whether a place is a borrowed variable."""
return isinstance(place, Variable) and InputFlags.Inout in place.flags


Expand All @@ -452,7 +453,7 @@ def check_cfg_linearity(
for bb in cfg.bbs
}

# Check that @inout vars are not being shadowed. This would also be caught by
# Check that borrowed vars are not being shadowed. This would also be caught by
# the dataflow analysis below, however we can give nicer error messages here.
for bb, scope in scopes.items():
if bb == cfg.entry_bb:
Expand All @@ -466,12 +467,12 @@ def check_cfg_linearity(
entry_place = entry_scope[x]
if is_inout_var(entry_place):
raise GuppyError(
f"Assignment shadows argument `{entry_place}` annotated as "
"`@inout`. Consider assigning to a different name.",
f"Assignment shadows borrowed argument `{entry_place}`. "
"Consider assigning to a different name.",
place.defined_at,
)

# Mark the @inout variables as implicitly used in the exit BB
# Mark the borrowed variables as implicitly used in the exit BB
exit_scope = scopes[cfg.exit_bb]
for var in cfg.entry_bb.sig.input_row:
if InputFlags.Inout in var.flags:
Expand All @@ -498,13 +499,13 @@ def check_cfg_linearity(
if place.ty.linear and (prev_use := scope.used(x)):
use = use_scope.used_parent[x]
# Special case if this is a use arising from the implicit returning
# of an @inout argument
# of a borrowed argument
if isinstance(use, InoutReturnSentinel):
assert isinstance(use.var, Variable)
assert InputFlags.Inout in use.var.flags
raise GuppyError(
f"Argument `{use.var}` annotated as `@inout` cannot be "
f"returned to the caller since `{place}` is used at {{0}}. "
f"Borrowed argument `{use.var}` cannot be returned "
f"to the caller since `{place}` is used at {{0}}. "
f"Consider writing a value back into `{place}` before "
"returning.",
use.var.defined_at,
Expand Down
6 changes: 3 additions & 3 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,18 @@ def _update_inout_ports(
inout_ports: Iterator[Wire],
func_ty: FunctionType,
) -> None:
"""Helper method that updates the ports for @inout arguments after a call."""
"""Helper method that updates the ports for borrowed arguments after a call."""
for inp, arg in zip(func_ty.inputs, args, strict=True):
if InputFlags.Inout in inp.flags:
# Linearity checker ensures that inout arguments that are not places
# Linearity checker ensures that borrowed arguments that are not places
# can be safely dropped after the call returns
if not isinstance(arg, PlaceNode):
next(inout_ports)
continue
self.dfg[arg.place] = next(inout_ports)
# Places involving subscripts need to generate code for the appropriate
# `__setitem__` call. Nested subscripts are handled automatically since
# `arg.place.parent` occurs as an inout arg of this call, so will also
# `arg.place.parent` occurs as an arg of this call, so will also
# be recursively reassigned.
if subscript := contains_subscript(arg.place):
self.visit(subscript.setitem_call)
Expand Down
Loading
Loading