From dd758f374ce7384c2e1e34465e5ad8e4cf5bd75a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 24 Jul 2024 12:18:40 +0100 Subject: [PATCH] feat: Parse inout annotations in function signatures --- guppylang/checker/func_checker.py | 24 +++++++++++++++++++++--- guppylang/prelude/builtins.py | 10 ++++++++++ tests/error/inout_errors/__init__.py | 0 tests/error/inout_errors/nonlinear.err | 6 ++++++ tests/error/inout_errors/nonlinear.py | 14 ++++++++++++++ tests/integration/test_inout.py | 16 ++++++++++++++++ 6 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 tests/error/inout_errors/__init__.py create mode 100644 tests/error/inout_errors/nonlinear.err create mode 100644 tests/error/inout_errors/nonlinear.py create mode 100644 tests/integration/test_inout.py diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index 91c240b3..af7e68af 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -146,10 +146,28 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType inputs = [] input_names = [] for inp in func_def.args.args: - if inp.annotation is None: + ty_ast = inp.annotation + if ty_ast is None: raise GuppyError("Argument type must be annotated", inp) - ty = type_from_ast(inp.annotation, globals, param_var_mapping) - inputs.append((ty, InputFlags.NoFlags)) + flags = InputFlags.NoFlags + # Detect `@flag` argument annotations + # TODO: This doesn't work if the type annotation is a string forward ref. We + # should rethink how we handle these... + if isinstance(ty_ast, ast.BinOp) and isinstance(ty_ast.op, ast.MatMult): + ty = type_from_ast(ty_ast.left, globals, param_var_mapping) + match ty_ast.right: + case ast.Name(id="inout"): + if not ty.linear: + raise GuppyError( + f"Non-linear type `{ty}` cannot be annotated as `@inout`", + ty_ast.right, + ) + flags |= InputFlags.Inout + case _: + raise GuppyError("Invalid annotation", ty_ast.right) + else: + ty = type_from_ast(ty_ast, globals, param_var_mapping) + inputs.append((ty, flags)) input_names.append(inp.arg) ret_type = type_from_ast(func_def.returns, globals, param_var_mapping) diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 6e260406..b5c7c7cd 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -56,6 +56,16 @@ def py(*_args: Any) -> Any: raise GuppyError("`py` can only by used in a Guppy context") +class _Inout: + """Dummy class to support `@inout` annotations.""" + + def __rmatmul__(self, other: Any) -> Any: + return other + + +inout = _Inout() + + class nat: """Class to import in order to use nats.""" diff --git a/tests/error/inout_errors/__init__.py b/tests/error/inout_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/inout_errors/nonlinear.err b/tests/error/inout_errors/nonlinear.err new file mode 100644 index 00000000..0e6c0e26 --- /dev/null +++ b/tests/error/inout_errors/nonlinear.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:11 + +9: @guppy.declare(module) +10: def foo(x: int @inout) -> qubit: ... + ^^^^^ +GuppyError: Non-linear type `int` cannot be annotated as `@inout` diff --git a/tests/error/inout_errors/nonlinear.py b/tests/error/inout_errors/nonlinear.py new file mode 100644 index 00000000..7739e2a0 --- /dev/null +++ b/tests/error/inout_errors/nonlinear.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import inout +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") + + +@guppy.declare(module) +def foo(x: int @inout) -> qubit: ... + + +module.compile() diff --git a/tests/integration/test_inout.py b/tests/integration/test_inout.py new file mode 100644 index 00000000..dfe7a512 --- /dev/null +++ b/tests/integration/test_inout.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import inout +from guppylang.prelude.quantum import qubit + +import guppylang.prelude.quantum as quantum + + +def test_declare(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def test(q: qubit @inout) -> qubit: ... + + validate(module.compile())