Skip to content

Commit

Permalink
feat: Parse inout annotations in function signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jul 24, 2024
1 parent 2bd7fdf commit dd758f3
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 3 deletions.
24 changes: 21 additions & 3 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,28 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
inputs = []
input_names = []
for inp in func_def.args.args:
if inp.annotation is None:
ty_ast = inp.annotation
if ty_ast is None:
raise GuppyError("Argument type must be annotated", inp)
ty = type_from_ast(inp.annotation, globals, param_var_mapping)
inputs.append((ty, InputFlags.NoFlags))
flags = InputFlags.NoFlags
# Detect `@flag` argument annotations
# TODO: This doesn't work if the type annotation is a string forward ref. We
# should rethink how we handle these...
if isinstance(ty_ast, ast.BinOp) and isinstance(ty_ast.op, ast.MatMult):
ty = type_from_ast(ty_ast.left, globals, param_var_mapping)
match ty_ast.right:
case ast.Name(id="inout"):
if not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`",
ty_ast.right,
)
flags |= InputFlags.Inout
case _:
raise GuppyError("Invalid annotation", ty_ast.right)
else:
ty = type_from_ast(ty_ast, globals, param_var_mapping)
inputs.append((ty, flags))
input_names.append(inp.arg)
ret_type = type_from_ast(func_def.returns, globals, param_var_mapping)

Expand Down
10 changes: 10 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ def py(*_args: Any) -> Any:
raise GuppyError("`py` can only by used in a Guppy context")


class _Inout:
"""Dummy class to support `@inout` annotations."""

def __rmatmul__(self, other: Any) -> Any:
return other


inout = _Inout()


class nat:
"""Class to import in order to use nats."""

Expand Down
Empty file.
6 changes: 6 additions & 0 deletions tests/error/inout_errors/nonlinear.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:11

9: @guppy.declare(module)
10: def foo(x: int @inout) -> qubit: ...
^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
14 changes: 14 additions & 0 deletions tests/error/inout_errors/nonlinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")


@guppy.declare(module)
def foo(x: int @inout) -> qubit: ...


module.compile()
16 changes: 16 additions & 0 deletions tests/integration/test_inout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout
from guppylang.prelude.quantum import qubit

import guppylang.prelude.quantum as quantum


def test_declare(validate):
module = GuppyModule("test")
module.load(quantum)

@guppy.declare(module)
def test(q: qubit @inout) -> qubit: ...

validate(module.compile())

0 comments on commit dd758f3

Please sign in to comment.