-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Parse inout annotations in function signatures #316
Changes from all commits
2bd7fdf
b5a590a
c455a54
68ece53
11f2e25
bf070ab
77a4dc8
ff545a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,20 +1,17 @@ | ||||||
from collections.abc import Sequence | ||||||
from dataclasses import dataclass, field | ||||||
from itertools import repeat | ||||||
from typing import TYPE_CHECKING, Literal | ||||||
|
||||||
from hugr.serialization import tys | ||||||
|
||||||
from guppylang.ast_util import AstNode | ||||||
from guppylang.definition.common import DefId | ||||||
from guppylang.definition.ty import OpaqueTypeDef, TypeDef | ||||||
from guppylang.error import GuppyError | ||||||
from guppylang.error import GuppyError, InternalGuppyError | ||||||
from guppylang.tys.arg import Argument, ConstArg, TypeArg | ||||||
from guppylang.tys.param import ConstParam, TypeParam | ||||||
from guppylang.tys.ty import ( | ||||||
FuncInput, | ||||||
FunctionType, | ||||||
InputFlags, | ||||||
NoneType, | ||||||
NumericType, | ||||||
OpaqueType, | ||||||
|
@@ -27,7 +24,7 @@ | |||||
|
||||||
|
||||||
@dataclass(frozen=True) | ||||||
class _CallableTypeDef(TypeDef): | ||||||
class CallableTypeDef(TypeDef): | ||||||
"""Type definition associated with the builtin `Callable` type. | ||||||
|
||||||
Any impls on functions can be registered with this definition. | ||||||
|
@@ -38,20 +35,8 @@ class _CallableTypeDef(TypeDef): | |||||
def check_instantiate( | ||||||
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None | ||||||
) -> FunctionType: | ||||||
# We get the inputs/output as a flattened list: `args = [*inputs, output]`. | ||||||
if not args: | ||||||
raise GuppyError(f"Missing parameter for type `{self.name}`", loc) | ||||||
args = [ | ||||||
# TODO: Better error location | ||||||
TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty | ||||||
for i, arg in enumerate(args) | ||||||
] | ||||||
*input_tys, output = args | ||||||
inputs = [ | ||||||
FuncInput(ty, flags) | ||||||
for ty, flags in zip(input_tys, repeat(InputFlags.NoFlags), strict=False) | ||||||
] | ||||||
return FunctionType(list(inputs), output) | ||||||
# Callable types are constructed using special login in the type parser | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could consider moving some of that code into CallableTypeDef itself, just so you could have a comment "Callable types are constructed using instantiate_func_type()" or something like that, but this is fine |
||||||
raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`") | ||||||
|
||||||
|
||||||
@dataclass(frozen=True) | ||||||
|
@@ -157,7 +142,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> tys.Type: | |||||
return tys.Type(ty) | ||||||
|
||||||
|
||||||
callable_type_def = _CallableTypeDef(DefId.fresh(), None) | ||||||
callable_type_def = CallableTypeDef(DefId.fresh(), None) | ||||||
tuple_type_def = _TupleTypeDef(DefId.fresh(), None) | ||||||
none_type_def = _NoneTypeDef(DefId.fresh(), None) | ||||||
bool_type_def = OpaqueTypeDef( | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,9 +11,18 @@ | |||||
from guppylang.definition.ty import TypeDef | ||||||
from guppylang.error import GuppyError | ||||||
from guppylang.tys.arg import Argument, ConstArg, TypeArg | ||||||
from guppylang.tys.builtin import CallableTypeDef | ||||||
from guppylang.tys.const import ConstValue | ||||||
from guppylang.tys.param import Parameter, TypeParam | ||||||
from guppylang.tys.ty import NoneType, NumericType, TupleType, Type | ||||||
from guppylang.tys.ty import ( | ||||||
FuncInput, | ||||||
FunctionType, | ||||||
InputFlags, | ||||||
NoneType, | ||||||
NumericType, | ||||||
TupleType, | ||||||
Type, | ||||||
) | ||||||
|
||||||
|
||||||
def arg_from_ast( | ||||||
|
@@ -28,6 +37,11 @@ def arg_from_ast( | |||||
if x not in globals: | ||||||
raise GuppyError("Unknown identifier", node) | ||||||
match globals[x]: | ||||||
# Special case for the `Callable` type | ||||||
case CallableTypeDef(): | ||||||
return TypeArg( | ||||||
_parse_callable_type([], node, globals, param_var_mapping) | ||||||
) | ||||||
# Either a defined type (e.g. `int`, `bool`, ...) | ||||||
case TypeDef() as defn: | ||||||
return TypeArg(defn.check_instantiate([], globals, node)) | ||||||
|
@@ -50,21 +64,16 @@ def arg_from_ast( | |||||
x = node.value.id | ||||||
if x in globals: | ||||||
defn = globals[x] | ||||||
if isinstance(defn, TypeDef): | ||||||
arg_nodes = ( | ||||||
node.slice.elts | ||||||
if isinstance(node.slice, ast.Tuple) | ||||||
else [node.slice] | ||||||
arg_nodes = ( | ||||||
node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] | ||||||
) | ||||||
if isinstance(defn, CallableTypeDef): | ||||||
# Special case for the `Callable[[S1, S2, ...], T]` type to support the | ||||||
# input list syntax and @inout annotations. | ||||||
return TypeArg( | ||||||
_parse_callable_type(arg_nodes, node, globals, param_var_mapping) | ||||||
) | ||||||
# Hack: Flatten argument lists to support the `Callable` type. For | ||||||
# example, we turn `Callable[[int, int], bool]` into | ||||||
# `Callable[int, int, bool]`. | ||||||
# TODO: We can get rid of this once we added support for variadic params | ||||||
arg_nodes = [ | ||||||
n | ||||||
for arg in arg_nodes | ||||||
for n in (arg.elts if isinstance(arg, ast.List) else (arg,)) | ||||||
] | ||||||
if isinstance(defn, TypeDef): | ||||||
args = [ | ||||||
arg_from_ast(arg_node, globals, param_var_mapping) | ||||||
for arg_node in arg_nodes | ||||||
|
@@ -102,35 +111,120 @@ def arg_from_ast( | |||||
|
||||||
# Finally, we also support delayed annotations in strings | ||||||
if isinstance(node, ast.Constant) and isinstance(node.value, str): | ||||||
try: | ||||||
[stmt] = ast.parse(node.value).body | ||||||
if not isinstance(stmt, ast.Expr): | ||||||
raise GuppyError("Invalid Guppy type", node) | ||||||
set_location_from(stmt, loc=node) | ||||||
shift_loc( | ||||||
stmt, | ||||||
delta_lineno=node.lineno - 1, # -1 since lines start at 1 | ||||||
delta_col_offset=node.col_offset + 1, # +1 to remove the `"` | ||||||
) | ||||||
return arg_from_ast(stmt.value, globals, param_var_mapping) | ||||||
except (SyntaxError, ValueError): | ||||||
raise GuppyError("Invalid Guppy type", node) from None | ||||||
node = _parse_delayed_annotation(node.value, node) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, breaking that out is a good move :) |
||||||
return arg_from_ast(node, globals, param_var_mapping) | ||||||
|
||||||
raise GuppyError("Not a valid type argument", node) | ||||||
|
||||||
|
||||||
def _parse_delayed_annotation(ast_str: str, node: ast.Constant) -> ast.expr: | ||||||
"""Parses a delayed type annotation in a string.""" | ||||||
try: | ||||||
[stmt] = ast.parse(ast_str).body | ||||||
if not isinstance(stmt, ast.Expr): | ||||||
raise GuppyError("Invalid Guppy type", node) | ||||||
set_location_from(stmt, loc=node) | ||||||
shift_loc( | ||||||
stmt, | ||||||
delta_lineno=node.lineno - 1, # -1 since lines start at 1 | ||||||
delta_col_offset=node.col_offset + 1, # +1 to remove the `"` | ||||||
) | ||||||
except (SyntaxError, ValueError): | ||||||
raise GuppyError("Invalid Guppy type", node) from None | ||||||
else: | ||||||
return stmt.value | ||||||
|
||||||
|
||||||
def _parse_callable_type( | ||||||
args: list[ast.expr], | ||||||
loc: AstNode, | ||||||
globals: Globals, | ||||||
param_var_mapping: dict[str, Parameter] | None, | ||||||
) -> FunctionType: | ||||||
"""Helper function to parse a `Callable[[<arguments>], <return types>]` type.""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Or are multiple return types allowed?? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it should only be a single return type (that could be a tuple of multiple types) 👍 |
||||||
err = ( | ||||||
"Function types should be specified via " | ||||||
"`Callable[[<arguments>], <return types>]`" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto
Suggested change
I wonder about also noting, this is standard python hyping! |
||||||
) | ||||||
if len(args) != 2: | ||||||
raise GuppyError(err, loc) | ||||||
[inputs, output] = args | ||||||
if not isinstance(inputs, ast.List): | ||||||
raise GuppyError(err, loc) | ||||||
inouts, output = parse_function_io_types( | ||||||
inputs.elts, output, loc, globals, param_var_mapping | ||||||
) | ||||||
return FunctionType(inouts, output) | ||||||
|
||||||
|
||||||
def parse_function_io_types( | ||||||
input_nodes: list[ast.expr], | ||||||
output_node: ast.expr, | ||||||
loc: AstNode, | ||||||
globals: Globals, | ||||||
param_var_mapping: dict[str, Parameter] | None, | ||||||
) -> tuple[list[FuncInput], Type]: | ||||||
"""Parses the inputs and output types of a function type. | ||||||
|
||||||
This function takes care of parsing `@inout` annotations and any related checks. | ||||||
|
||||||
Returns the parsed input and output types. | ||||||
""" | ||||||
inputs = [] | ||||||
for inp in input_nodes: | ||||||
ty, flags = type_with_flags_from_ast(inp, globals, param_var_mapping) | ||||||
if InputFlags.Inout in flags and not ty.linear: | ||||||
raise GuppyError( | ||||||
f"Non-linear type `{ty}` cannot be annotated as `@inout`", loc | ||||||
) | ||||||
inputs.append(FuncInput(ty, flags)) | ||||||
output = type_from_ast(output_node, globals, param_var_mapping) | ||||||
return inputs, output | ||||||
|
||||||
|
||||||
_type_param = TypeParam(0, "T", True) | ||||||
|
||||||
|
||||||
def type_with_flags_from_ast( | ||||||
node: AstNode, | ||||||
globals: Globals, | ||||||
param_var_mapping: dict[str, Parameter] | None = None, | ||||||
) -> tuple[Type, InputFlags]: | ||||||
"""Turns an AST expression into a Guppy type with some optional @flags.""" | ||||||
# Check for `type @flag` annotations | ||||||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult): | ||||||
ty, flags = type_with_flags_from_ast(node.left, globals, param_var_mapping) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note this means There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, I think that should be fine though |
||||||
match node.right: | ||||||
case ast.Name(id="inout"): | ||||||
if not ty.linear: | ||||||
raise GuppyError( | ||||||
f"Non-linear type `{ty}` cannot be annotated as `@inout`", | ||||||
node.right, | ||||||
) | ||||||
flags |= InputFlags.Inout | ||||||
case _: | ||||||
raise GuppyError("Invalid annotation", node.right) | ||||||
return ty, flags | ||||||
# We also need to handle the case that this could be a delayed string annotation | ||||||
elif isinstance(node, ast.Constant) and isinstance(node.value, str): | ||||||
node = _parse_delayed_annotation(node.value, node) | ||||||
return type_with_flags_from_ast(node, globals, param_var_mapping) | ||||||
else: | ||||||
# Parse an argument and check that it's valid for a `TypeParam` | ||||||
arg = arg_from_ast(node, globals, param_var_mapping) | ||||||
return _type_param.check_arg(arg, node).ty, InputFlags.NoFlags | ||||||
|
||||||
|
||||||
def type_from_ast( | ||||||
node: AstNode, | ||||||
globals: Globals, | ||||||
param_var_mapping: dict[str, Parameter] | None = None, | ||||||
) -> Type: | ||||||
"""Turns an AST expression into a Guppy type.""" | ||||||
# Parse an argument and check that it's valid for a `TypeParam` | ||||||
arg = arg_from_ast(node, globals, param_var_mapping) | ||||||
return _type_param.check_arg(arg, node).ty | ||||||
ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping) | ||||||
if flags != InputFlags.NoFlags: | ||||||
raise GuppyError("`@` type annotations are not allowed in this position", node) | ||||||
return ty | ||||||
|
||||||
|
||||||
def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Guppy compilation failed. Error in file $FILE:11 | ||
|
||
9: @guppy.declare(module) | ||
10: def foo(x: int @inout) -> qubit: ... | ||
^^^^^ | ||
GuppyError: Non-linear type `int` cannot be annotated as `@inout` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
from guppylang.prelude.builtins import inout | ||
from guppylang.prelude.quantum import qubit | ||
|
||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy.declare(module) | ||
def foo(x: int @inout) -> qubit: ... | ||
|
||
|
||
module.compile() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Guppy compilation failed. Error in file $FILE:12 | ||
|
||
10: @guppy.declare(module) | ||
11: def foo(f: Callable[[int @inout], None]) -> None: ... | ||
^^^^^ | ||
GuppyError: Non-linear type `int` cannot be annotated as `@inout` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from typing import Callable | ||
|
||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
from guppylang.prelude.builtins import inout | ||
|
||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy.declare(module) | ||
def foo(f: Callable[[int @inout], None]) -> None: ... | ||
|
||
|
||
module.compile() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Guppy compilation failed. Error in file $FILE:10 | ||
|
||
8: @guppy.declare(module) | ||
9: def foo(f: Callable) -> None: ... | ||
^^^^^^^^ | ||
GuppyError: Function types should be specified via `Callable[[<arguments>], <return types>]` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import Callable | ||
|
||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
|
||
module = GuppyModule("test") | ||
|
||
@guppy.declare(module) | ||
def foo(f: Callable) -> None: ... | ||
|
||
|
||
module.compile() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Guppy compilation failed. Error in file $FILE:10 | ||
|
||
8: @guppy.declare(module) | ||
9: def foo(f: "Callable[int, float, bool]") -> None: ... | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
GuppyError: Function types should be specified via `Callable[[<arguments>], <return types>]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to allow
def foo(x: int @inout)
, since Python (in its infinite wisdom cough splutter boo) disallowsdef foo(x: @inout int)
. Obviously, it's fine to do matrix multiplication inside a type annotation but not invoke a decorator....However, have you thought about the "pythonic" alternative, which I believe would be
def foo(x: Annotated[int, "inout"]
?@inout
is a bastardization of python syntax anyway so what would users think? I admitAnnotated
is rather long - what aboutdef foo(x: Inout[int])
? (No way, I hear you say - and I see the downside - maybe that lengthy Annotated isn't sooo bad, then....)Not seriously recommending you change, but how different / how much simpler might that make things, would it impact the design at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created #359 for discussion.
I think this is mostly a syntax concern. The design wouldn't be affected much imo