Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Parse inout annotations in function signatures #316

Merged
merged 8 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from guppylang.definition.common import DefId
from guppylang.error import GuppyError
from guppylang.nodes import CheckedNestedFunctionDef, NestedFunctionDef
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType
from guppylang.tys.parsing import parse_function_io_types
from guppylang.tys.ty import FunctionType, NoneType

if TYPE_CHECKING:
from guppylang.tys.param import Parameter
Expand Down Expand Up @@ -143,19 +143,19 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType

# TODO: Prepopulate mapping when using Python 3.12 style generic functions
param_var_mapping: dict[str, Parameter] = {}
inputs = []
input_nodes = []
input_names = []
for inp in func_def.args.args:
if inp.annotation is None:
raise GuppyError("Argument type must be annotated", inp)
ty = type_from_ast(inp.annotation, globals, param_var_mapping)
inputs.append(FuncInput(ty, InputFlags.NoFlags))
input_nodes.append(inp.annotation)
input_names.append(inp.arg)
ret_type = type_from_ast(func_def.returns, globals, param_var_mapping)

inputs, output = parse_function_io_types(
input_nodes, func_def.returns, func_def, globals, param_var_mapping
)
return FunctionType(
inputs,
ret_type,
output,
input_names,
sorted(param_var_mapping.values(), key=lambda v: v.idx),
)
Expand Down
10 changes: 10 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def py(*_args: Any) -> Any:
raise GuppyError("`py` can only by used in a Guppy context")


class _Inout:
"""Dummy class to support `@inout` annotations."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to allow def foo(x: int @inout), since Python (in its infinite wisdom cough splutter boo) disallows def foo(x: @inout int). Obviously, it's fine to do matrix multiplication inside a type annotation but not invoke a decorator....

However, have you thought about the "pythonic" alternative, which I believe would be def foo(x: Annotated[int, "inout"]?@inout is a bastardization of python syntax anyway so what would users think? I admit Annotated is rather long - what about def foo(x: Inout[int]) ? (No way, I hear you say - and I see the downside - maybe that lengthy Annotated isn't sooo bad, then....)

Not seriously recommending you change, but how different / how much simpler might that make things, would it impact the design at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created #359 for discussion.

Not seriously recommending you change, but how different / how much simpler might that make things, would it impact the design at all?

I think this is mostly a syntax concern. The design wouldn't be affected much imo


def __rmatmul__(self, other: Any) -> Any:
return other


inout = _Inout()


class nat:
"""Class to import in order to use nats."""

Expand Down
25 changes: 5 additions & 20 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from itertools import repeat
from typing import TYPE_CHECKING, Literal

from hugr.serialization import tys

from guppylang.ast_util import AstNode
from guppylang.definition.common import DefId
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
NumericType,
OpaqueType,
Expand All @@ -27,7 +24,7 @@


@dataclass(frozen=True)
class _CallableTypeDef(TypeDef):
class CallableTypeDef(TypeDef):
"""Type definition associated with the builtin `Callable` type.
Any impls on functions can be registered with this definition.
Expand All @@ -38,20 +35,8 @@ class _CallableTypeDef(TypeDef):
def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> FunctionType:
# We get the inputs/output as a flattened list: `args = [*inputs, output]`.
if not args:
raise GuppyError(f"Missing parameter for type `{self.name}`", loc)
args = [
# TODO: Better error location
TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty
for i, arg in enumerate(args)
]
*input_tys, output = args
inputs = [
FuncInput(ty, flags)
for ty, flags in zip(input_tys, repeat(InputFlags.NoFlags), strict=False)
]
return FunctionType(list(inputs), output)
# Callable types are constructed using special login in the type parser
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Callable types are constructed using special login in the type parser
# Callable types are constructed using special logic in the type parser

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could consider moving some of that code into CallableTypeDef itself, just so you could have a comment "Callable types are constructed using instantiate_func_type()" or something like that, but this is fine

raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`")


@dataclass(frozen=True)
Expand Down Expand Up @@ -157,7 +142,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> tys.Type:
return tys.Type(ty)


callable_type_def = _CallableTypeDef(DefId.fresh(), None)
callable_type_def = CallableTypeDef(DefId.fresh(), None)
tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
none_type_def = _NoneTypeDef(DefId.fresh(), None)
bool_type_def = OpaqueTypeDef(
Expand Down
156 changes: 125 additions & 31 deletions guppylang/tys/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.builtin import CallableTypeDef
from guppylang.tys.const import ConstValue
from guppylang.tys.param import Parameter, TypeParam
from guppylang.tys.ty import NoneType, NumericType, TupleType, Type
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
NumericType,
TupleType,
Type,
)


def arg_from_ast(
Expand All @@ -28,6 +37,11 @@ def arg_from_ast(
if x not in globals:
raise GuppyError("Unknown identifier", node)
match globals[x]:
# Special case for the `Callable` type
case CallableTypeDef():
return TypeArg(
_parse_callable_type([], node, globals, param_var_mapping)
)
# Either a defined type (e.g. `int`, `bool`, ...)
case TypeDef() as defn:
return TypeArg(defn.check_instantiate([], globals, node))
Expand All @@ -50,21 +64,16 @@ def arg_from_ast(
x = node.value.id
if x in globals:
defn = globals[x]
if isinstance(defn, TypeDef):
arg_nodes = (
node.slice.elts
if isinstance(node.slice, ast.Tuple)
else [node.slice]
arg_nodes = (
node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
)
if isinstance(defn, CallableTypeDef):
# Special case for the `Callable[[S1, S2, ...], T]` type to support the
# input list syntax and @inout annotations.
return TypeArg(
_parse_callable_type(arg_nodes, node, globals, param_var_mapping)
)
# Hack: Flatten argument lists to support the `Callable` type. For
# example, we turn `Callable[[int, int], bool]` into
# `Callable[int, int, bool]`.
# TODO: We can get rid of this once we added support for variadic params
arg_nodes = [
n
for arg in arg_nodes
for n in (arg.elts if isinstance(arg, ast.List) else (arg,))
]
if isinstance(defn, TypeDef):
args = [
arg_from_ast(arg_node, globals, param_var_mapping)
for arg_node in arg_nodes
Expand Down Expand Up @@ -102,35 +111,120 @@ def arg_from_ast(

# Finally, we also support delayed annotations in strings
if isinstance(node, ast.Constant) and isinstance(node.value, str):
try:
[stmt] = ast.parse(node.value).body
if not isinstance(stmt, ast.Expr):
raise GuppyError("Invalid Guppy type", node)
set_location_from(stmt, loc=node)
shift_loc(
stmt,
delta_lineno=node.lineno - 1, # -1 since lines start at 1
delta_col_offset=node.col_offset + 1, # +1 to remove the `"`
)
return arg_from_ast(stmt.value, globals, param_var_mapping)
except (SyntaxError, ValueError):
raise GuppyError("Invalid Guppy type", node) from None
node = _parse_delayed_annotation(node.value, node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, breaking that out is a good move :)

return arg_from_ast(node, globals, param_var_mapping)

raise GuppyError("Not a valid type argument", node)


def _parse_delayed_annotation(ast_str: str, node: ast.Constant) -> ast.expr:
"""Parses a delayed type annotation in a string."""
try:
[stmt] = ast.parse(ast_str).body
if not isinstance(stmt, ast.Expr):
raise GuppyError("Invalid Guppy type", node)
set_location_from(stmt, loc=node)
shift_loc(
stmt,
delta_lineno=node.lineno - 1, # -1 since lines start at 1
delta_col_offset=node.col_offset + 1, # +1 to remove the `"`
)
except (SyntaxError, ValueError):
raise GuppyError("Invalid Guppy type", node) from None
else:
return stmt.value


def _parse_callable_type(
args: list[ast.expr],
loc: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None,
) -> FunctionType:
"""Helper function to parse a `Callable[[<arguments>], <return types>]` type."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Helper function to parse a `Callable[[<arguments>], <return types>]` type."""
"""Helper function to parse a `Callable[[<arguments>], <return type>]` type."""

Or are multiple return types allowed??

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it should only be a single return type (that could be a tuple of multiple types) 👍

err = (
"Function types should be specified via "
"`Callable[[<arguments>], <return types>]`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Suggested change
"`Callable[[<arguments>], <return types>]`"
"`Callable[[<arguments>], <return type>]`"

I wonder about also noting, this is standard python hyping!

)
if len(args) != 2:
raise GuppyError(err, loc)
[inputs, output] = args
if not isinstance(inputs, ast.List):
raise GuppyError(err, loc)
inouts, output = parse_function_io_types(
inputs.elts, output, loc, globals, param_var_mapping
)
return FunctionType(inouts, output)


def parse_function_io_types(
input_nodes: list[ast.expr],
output_node: ast.expr,
loc: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None,
) -> tuple[list[FuncInput], Type]:
"""Parses the inputs and output types of a function type.
This function takes care of parsing `@inout` annotations and any related checks.
Returns the parsed input and output types.
"""
inputs = []
for inp in input_nodes:
ty, flags = type_with_flags_from_ast(inp, globals, param_var_mapping)
if InputFlags.Inout in flags and not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`", loc
)
inputs.append(FuncInput(ty, flags))
output = type_from_ast(output_node, globals, param_var_mapping)
return inputs, output


_type_param = TypeParam(0, "T", True)


def type_with_flags_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> tuple[Type, InputFlags]:
"""Turns an AST expression into a Guppy type with some optional @flags."""
# Check for `type @flag` annotations
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
ty, flags = type_with_flags_from_ast(node.left, globals, param_var_mapping)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this means x @inout @inout is also allowed and has the same effect as only once

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I think that should be fine though

match node.right:
case ast.Name(id="inout"):
if not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`",
node.right,
)
flags |= InputFlags.Inout
case _:
raise GuppyError("Invalid annotation", node.right)
return ty, flags
# We also need to handle the case that this could be a delayed string annotation
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
node = _parse_delayed_annotation(node.value, node)
return type_with_flags_from_ast(node, globals, param_var_mapping)
else:
# Parse an argument and check that it's valid for a `TypeParam`
arg = arg_from_ast(node, globals, param_var_mapping)
return _type_param.check_arg(arg, node).ty, InputFlags.NoFlags


def type_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> Type:
"""Turns an AST expression into a Guppy type."""
# Parse an argument and check that it's valid for a `TypeParam`
arg = arg_from_ast(node, globals, param_var_mapping)
return _type_param.check_arg(arg, node).ty
ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping)
if flags != InputFlags.NoFlags:
raise GuppyError("`@` type annotations are not allowed in this position", node)
return ty


def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]:
Expand Down
Empty file.
6 changes: 6 additions & 0 deletions tests/error/inout_errors/nonlinear.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:11

9: @guppy.declare(module)
10: def foo(x: int @inout) -> qubit: ...
^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
14 changes: 14 additions & 0 deletions tests/error/inout_errors/nonlinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")


@guppy.declare(module)
def foo(x: int @inout) -> qubit: ...


module.compile()
6 changes: 6 additions & 0 deletions tests/error/inout_errors/nonlinear_callable.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:12

10: @guppy.declare(module)
11: def foo(f: Callable[[int @inout], None]) -> None: ...
^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
15 changes: 15 additions & 0 deletions tests/error/inout_errors/nonlinear_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Callable

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout


module = GuppyModule("test")


@guppy.declare(module)
def foo(f: Callable[[int @inout], None]) -> None: ...


module.compile()
6 changes: 6 additions & 0 deletions tests/error/misc_errors/callable_no_args.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:10

8: @guppy.declare(module)
9: def foo(f: Callable) -> None: ...
^^^^^^^^
GuppyError: Function types should be specified via `Callable[[<arguments>], <return types>]`
13 changes: 13 additions & 0 deletions tests/error/misc_errors/callable_no_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Callable

from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")

@guppy.declare(module)
def foo(f: Callable) -> None: ...


module.compile()
6 changes: 6 additions & 0 deletions tests/error/misc_errors/callable_not_list1.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:10

8: @guppy.declare(module)
9: def foo(f: "Callable[int, float, bool]") -> None: ...
^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Function types should be specified via `Callable[[<arguments>], <return types>]`
Loading
Loading