Skip to content

Commit

Permalink
feat: Inout arguments (#311)
Browse files Browse the repository at this point in the history
This is the dev branch for the inout argument feature (tracked by #282).

The idea is to allow explicit `@inout` annotations on function arguments
that "give back" the passed value after the function returns:

```python
@guppy
def foo(q: qubit @inout) -> None: ...

@guppy
def bar(q1: qubit @inout, q2: qubit @inout) -> bool: ...

@guppy
def main() -> None:
   q1, q2 = qubit(), qubit()
   foo(q1)          # Desugars to `q1 = foo(q1)`
   x = bar(q1, q2)  # Desugars to `q1, q2, x = bar(q1, q2)`
   y = bar(q1, q1)  # Error: Linearity violation, q1 used twice
```

To enable this, we need to enforce that `@inout` arguments are not moved
in the body of the function (apart from passing them in another `@inout`
position). This means that the argument will always be bound to the same
name and never aliased which allows us to desugar `@inout` functions
like

```python
@guppy
def bar(q1: qubit, q2: qubit) -> bool:
   [body]
   return [expr]
```

into

```python
@guppy
def bar(q1: qubit, q2: qubit) -> tuple[qubit, qubit, bool]:
   [body]
   return q1, q2, [expr]  # Linearity checker needs to ensure that q1, q2 are unused
```

Note that we only allow `@inout` annotations on linear types, since they
would be useless for classical ones (unless we also implement an
ownership system for classical values). Supporting them would make the
checking logic more complicated without providing any meaningful
benefit.

Tracked PRs:
* #315
* #316
* #349
* #344
* #321
* #331
* #350
* #339 
* #340
* #351
  • Loading branch information
mark-koch authored Aug 22, 2024
1 parent c2ab65f commit 060649b
Show file tree
Hide file tree
Showing 69 changed files with 1,457 additions and 171 deletions.
2 changes: 2 additions & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
fle = "fle"
ine = "ine"
inot = "inot"
inout = "inout"
inouts = "inouts"
7 changes: 7 additions & 0 deletions guppylang/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ def compute_variable_stats(self) -> VariableStats[str]:
self._vars = visitor.stats
return visitor.stats

@property
def is_exit(self) -> bool:
"""Whether this is the exit BB."""
# The exit BB is the only one without successors (otherwise we would have gotten
# an unreachable code error during CFG building)
return len(self.successors) == 0


class VariableVisitor(ast.NodeVisitor):
"""Visitor that computes used and assigned variables in a BB."""
Expand Down
8 changes: 7 additions & 1 deletion guppylang/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Result,
)
from guppylang.cfg.bb import BB, BBStatement, VariableStats
from guppylang.nodes import InoutReturnSentinel

T = TypeVar("T", bound=BB)

Expand Down Expand Up @@ -61,9 +62,14 @@ def link(self, src_bb: BB, tgt_bb: BB) -> None:
tgt_bb.predecessors.append(src_bb)

def analyze(
self, def_ass_before: set[str], maybe_ass_before: set[str]
self,
def_ass_before: set[str],
maybe_ass_before: set[str],
inout_vars: list[str],
) -> dict[BB, VariableStats[str]]:
stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
# Mark all @inout variables as implicitly used in the exit BB
stats[self.exit_bb].used |= {x: InoutReturnSentinel(var=x) for x in inout_vars}
self.live_before = LivenessAnalysis(stats).run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
stats, def_ass_before, maybe_ass_before
Expand Down
9 changes: 5 additions & 4 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from guppylang.checker.expr_checker import ExprSynthesizer, to_bool
from guppylang.checker.stmt_checker import StmtChecker
from guppylang.error import GuppyError
from guppylang.tys.ty import Type
from guppylang.tys.ty import InputFlags, Type

Row = Sequence[V]

Expand Down Expand Up @@ -68,7 +68,8 @@ def check_cfg(
"""
# First, we need to run program analysis
ass_before = {v.name for v in inputs}
cfg.analyze(ass_before, ass_before)
inout_vars = [v.name for v in inputs if InputFlags.Inout in v.flags]
cfg.analyze(ass_before, ass_before, inout_vars)

# We start by compiling the entry BB
checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty)
Expand All @@ -89,7 +90,7 @@ def check_cfg(
while len(queue) > 0:
pred, num_output, bb = queue.popleft()
input_row = [
Variable(v.name, v.ty, v.defined_at)
Variable(v.name, v.ty, v.defined_at, v.flags)
for v in pred.sig.output_rows[num_output]
]

Expand Down Expand Up @@ -123,7 +124,7 @@ def check_cfg(
# Finally, run the linearity check
from guppylang.checker.linearity_checker import check_cfg_linearity

linearity_checked_cfg = check_cfg_linearity(checked_cfg)
linearity_checked_cfg = check_cfg_linearity(checked_cfg, globals)

return linearity_checked_cfg

Expand Down
12 changes: 11 additions & 1 deletion guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import itertools
from collections.abc import Iterable, Iterator, Mapping
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import cached_property
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -36,6 +36,7 @@
BoundTypeVar,
ExistentialTypeVar,
FunctionType,
InputFlags,
NoneType,
NumericType,
OpaqueType,
Expand Down Expand Up @@ -72,6 +73,7 @@ class Variable:
name: str
ty: Type
defined_at: AstNode | None
flags: InputFlags = InputFlags.NoFlags

@dataclass(frozen=True)
class Id:
Expand All @@ -93,6 +95,10 @@ def __str__(self) -> str:
"""String representation of this place."""
return self.name

def replace_defined_at(self, node: AstNode | None) -> "Variable":
"""Returns a new `Variable` instance with an updated definition location."""
return replace(self, defined_at=node)


@dataclass(frozen=True)
class FieldAccess:
Expand Down Expand Up @@ -143,6 +149,10 @@ def __str__(self) -> str:
"""String representation of this place."""
return f"{self.parent}.{self.field.name}"

def replace_defined_at(self, node: AstNode | None) -> "FieldAccess":
"""Returns a new `FieldAccess` instance with an updated definition location."""
return replace(self, exact_defined_at=node)


PyScope = dict[str, Any]

Expand Down
36 changes: 23 additions & 13 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
ExistentialTypeVar,
FuncInput,
FunctionType,
InputFlags,
NoneType,
OpaqueType,
StructType,
Expand Down Expand Up @@ -261,9 +263,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:
assert len(inst) == 0
return with_loc(
node,
TensorCall(
func=node.func, args=processed_args, out_tys=tensor_ty.output
),
TensorCall(func=node.func, args=processed_args, tensor_ty=tensor_ty),
), subst

elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"):
Expand Down Expand Up @@ -496,7 +496,10 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]:
def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType(
[ty, ExistentialTypeVar.fresh("Key", False)],
[
FuncInput(ty, InputFlags.NoFlags),
FuncInput(ExistentialTypeVar.fresh("Key", False), InputFlags.NoFlags),
],
ExistentialTypeVar.fresh("Val", False),
)
return self._synthesize_instance_func(
Expand Down Expand Up @@ -534,7 +537,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
assert len(inst) == 0

return with_loc(
node, TensorCall(func=node.func, args=args, out_tys=tensor_ty.output)
node, TensorCall(func=node.func, args=args, tensor_ty=tensor_ty)
), return_ty

elif f := self.ctx.globals.get_instance_func(ty, "__call__"):
Expand All @@ -544,7 +547,9 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:

def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([ty], ExistentialTypeVar.fresh("Iter", False))
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False)
)
expr, ty = self._synthesize_instance_func(
node.value, [], "__iter__", "not iterable", exp_sig
)
Expand All @@ -566,23 +571,26 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:

def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([ty], TupleType([bool_type(), ty]))
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty])
)
return self._synthesize_instance_func(
node.value, [], "__hasnext__", "not an iterator", exp_sig, True
)

def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType(
[ty], TupleType([ExistentialTypeVar.fresh("T", False), ty])
[FuncInput(ty, InputFlags.NoFlags)],
TupleType([ExistentialTypeVar.fresh("T", False), ty]),
)
return self._synthesize_instance_func(
node.value, [], "__next__", "not an iterator", exp_sig, True
)

def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([ty], NoneType())
exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType())
return self._synthesize_instance_func(
node.value, [], "__end__", "not an iterator", exp_sig, True
)
Expand Down Expand Up @@ -704,14 +712,16 @@ def type_check_args(
check_num_args(len(func_ty.inputs), len(inputs), node)

new_args: list[ast.expr] = []
for inp, ty in zip(inputs, func_ty.inputs, strict=True):
a, s = ExprChecker(ctx).check(inp, ty.substitute(subst), "argument")
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
new_args.append(a)
subst |= s

# If the argument check succeeded, this means that we must have found instantiations
# for all unification variables occurring in the input types
assert all(set.issubset(inp.unsolved_vars, subst.keys()) for inp in func_ty.inputs)
assert all(
set.issubset(inp.ty.unsolved_vars, subst.keys()) for inp in func_ty.inputs
)

# We also have to check that we found instantiations for all vars in the return type
if not set.issubset(func_ty.output.unsolved_vars, subst.keys()):
Expand Down Expand Up @@ -991,7 +1001,7 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type
[], globals
)
return FunctionType(
[qubit] * v.n_qubits,
[FuncInput(qubit, InputFlags.NoFlags)] * v.n_qubits,
row_to_type(
[qubit] * v.n_qubits + [bool_type()] * v.n_bits
),
Expand Down
42 changes: 27 additions & 15 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from guppylang.definition.common import DefId
from guppylang.error import GuppyError
from guppylang.nodes import CheckedNestedFunctionDef, NestedFunctionDef
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.ty import FunctionType, NoneType
from guppylang.tys.parsing import parse_function_io_types
from guppylang.tys.ty import FunctionType, InputFlags, NoneType

if TYPE_CHECKING:
from guppylang.tys.param import Parameter
Expand All @@ -33,8 +33,8 @@ def check_global_func_def(

cfg = CFGBuilder().build(func_def.body, returns_none, globals)
inputs = [
Variable(x, ty, loc)
for x, ty, loc in zip(ty.input_names, ty.inputs, args, strict=True)
Variable(x, inp.ty, loc, inp.flags)
for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True)
]
return check_cfg(cfg, inputs, ty.output, globals)

Expand All @@ -54,7 +54,8 @@ def check_nested_func_def(
parent_cfg = bb.containing_cfg
def_ass_before = set(func_ty.input_names) | ctx.locals.keys()
maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb]
cfg.analyze(def_ass_before, maybe_ass_before)
inout_vars = inout_var_names(func_ty)
cfg.analyze(def_ass_before, maybe_ass_before, inout_vars)
captured = {
x: (ctx.locals[x], using_bb.vars.used[x])
for x, using_bb in cfg.live_before[cfg.entry_bb].items()
Expand All @@ -75,8 +76,8 @@ def check_nested_func_def(

# Construct inputs for checking the body CFG
inputs = [v for v, _ in captured.values()] + [
Variable(x, ty, func_def.args.args[i])
for i, (x, ty) in enumerate(
Variable(x, inp.ty, func_def.args.args[i], inp.flags)
for i, (x, inp) in enumerate(
zip(func_ty.input_names, func_ty.inputs, strict=True)
)
]
Expand Down Expand Up @@ -143,19 +144,20 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType

# TODO: Prepopulate mapping when using Python 3.12 style generic functions
param_var_mapping: dict[str, Parameter] = {}
input_tys = []
input_nodes = []
input_names = []
for inp in func_def.args.args:
if inp.annotation is None:
ty_ast = inp.annotation
if ty_ast is None:
raise GuppyError("Argument type must be annotated", inp)
ty = type_from_ast(inp.annotation, globals, param_var_mapping)
input_tys.append(ty)
input_nodes.append(ty_ast)
input_names.append(inp.arg)
ret_type = type_from_ast(func_def.returns, globals, param_var_mapping)

inputs, output = parse_function_io_types(
input_nodes, func_def.returns, func_def, globals, param_var_mapping
)
return FunctionType(
input_tys,
ret_type,
inputs,
output,
input_names,
sorted(param_var_mapping.values(), key=lambda v: v.idx),
)
Expand All @@ -180,3 +182,13 @@ def parse_docstring(func_ast: ast.FunctionDef) -> tuple[ast.FunctionDef, str | N
case _:
pass
return func_ast, docstring


def inout_var_names(func_ty: FunctionType) -> list[str]:
"""Returns the names of all `@inout` arguments of a function type."""
assert func_ty.input_names is not None
return [
x
for inp, x in zip(func_ty.inputs, func_ty.input_names, strict=True)
if InputFlags.Inout in inp.flags
]
Loading

0 comments on commit 060649b

Please sign in to comment.