diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index c79c9e97..afab3a27 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -16,8 +16,8 @@ from guppylang.definition.common import DefId from guppylang.error import GuppyError from guppylang.nodes import CheckedNestedFunctionDef, NestedFunctionDef -from guppylang.tys.parsing import type_from_ast -from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType +from guppylang.tys.parsing import parse_function_io_types +from guppylang.tys.ty import FunctionType, NoneType if TYPE_CHECKING: from guppylang.tys.param import Parameter @@ -143,37 +143,19 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType # TODO: Prepopulate mapping when using Python 3.12 style generic functions param_var_mapping: dict[str, Parameter] = {} - inputs = [] + input_nodes = [] input_names = [] for inp in func_def.args.args: - ty_ast = inp.annotation - if ty_ast is None: + if inp.annotation is None: raise GuppyError("Argument type must be annotated", inp) - flags = InputFlags.NoFlags - # Detect `@flag` argument annotations - # TODO: This doesn't work if the type annotation is a string forward ref. We - # should rethink how we handle these... - if isinstance(ty_ast, ast.BinOp) and isinstance(ty_ast.op, ast.MatMult): - ty = type_from_ast(ty_ast.left, globals, param_var_mapping) - match ty_ast.right: - case ast.Name(id="inout"): - if not ty.linear: - raise GuppyError( - f"Non-linear type `{ty}` cannot be annotated as `@inout`", - ty_ast.right, - ) - flags |= InputFlags.Inout - case _: - raise GuppyError("Invalid annotation", ty_ast.right) - else: - ty = type_from_ast(ty_ast, globals, param_var_mapping) - inputs.append(FuncInput(ty, flags)) + input_nodes.append(inp.annotation) input_names.append(inp.arg) - ret_type = type_from_ast(func_def.returns, globals, param_var_mapping) - + inputs, output = parse_function_io_types( + input_nodes, func_def.returns, func_def, globals, param_var_mapping + ) return FunctionType( inputs, - ret_type, + output, input_names, sorted(param_var_mapping.values(), key=lambda v: v.idx), ) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 6de919dc..1eb6f82d 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -1,6 +1,5 @@ from collections.abc import Sequence from dataclasses import dataclass, field -from itertools import repeat from typing import TYPE_CHECKING, Literal from hugr.serialization import tys @@ -8,13 +7,11 @@ from guppylang.ast_util import AstNode from guppylang.definition.common import DefId from guppylang.definition.ty import OpaqueTypeDef, TypeDef -from guppylang.error import GuppyError +from guppylang.error import GuppyError, InternalGuppyError from guppylang.tys.arg import Argument, ConstArg, TypeArg from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.ty import ( - FuncInput, FunctionType, - InputFlags, NoneType, NumericType, OpaqueType, @@ -27,7 +24,7 @@ @dataclass(frozen=True) -class _CallableTypeDef(TypeDef): +class CallableTypeDef(TypeDef): """Type definition associated with the builtin `Callable` type. Any impls on functions can be registered with this definition. @@ -38,20 +35,8 @@ class _CallableTypeDef(TypeDef): def check_instantiate( self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> FunctionType: - # We get the inputs/output as a flattened list: `args = [*inputs, output]`. - if not args: - raise GuppyError(f"Missing parameter for type `{self.name}`", loc) - args = [ - # TODO: Better error location - TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty - for i, arg in enumerate(args) - ] - *input_tys, output = args - inputs = [ - FuncInput(ty, flags) - for ty, flags in zip(input_tys, repeat(InputFlags.NoFlags), strict=False) - ] - return FunctionType(list(inputs), output) + # Callable types are constructed using special login in the type parser + raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`") @dataclass(frozen=True) @@ -157,7 +142,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> tys.Type: return tys.Type(ty) -callable_type_def = _CallableTypeDef(DefId.fresh(), None) +callable_type_def = CallableTypeDef(DefId.fresh(), None) tuple_type_def = _TupleTypeDef(DefId.fresh(), None) none_type_def = _NoneTypeDef(DefId.fresh(), None) bool_type_def = OpaqueTypeDef( diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index c73222b3..1edd005f 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -11,9 +11,18 @@ from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError from guppylang.tys.arg import Argument, ConstArg, TypeArg +from guppylang.tys.builtin import CallableTypeDef from guppylang.tys.const import ConstValue from guppylang.tys.param import Parameter, TypeParam -from guppylang.tys.ty import NoneType, NumericType, TupleType, Type +from guppylang.tys.ty import ( + FuncInput, + FunctionType, + InputFlags, + NoneType, + NumericType, + TupleType, + Type, +) def arg_from_ast( @@ -28,6 +37,11 @@ def arg_from_ast( if x not in globals: raise GuppyError("Unknown identifier", node) match globals[x]: + # Special case for the `Callable` type + case CallableTypeDef(): + return TypeArg( + _parse_callable_type([], node, globals, param_var_mapping) + ) # Either a defined type (e.g. `int`, `bool`, ...) case TypeDef() as defn: return TypeArg(defn.check_instantiate([], globals, node)) @@ -50,21 +64,16 @@ def arg_from_ast( x = node.value.id if x in globals: defn = globals[x] - if isinstance(defn, TypeDef): - arg_nodes = ( - node.slice.elts - if isinstance(node.slice, ast.Tuple) - else [node.slice] + arg_nodes = ( + node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] + ) + if isinstance(defn, CallableTypeDef): + # Special case for the `Callable[[S1, S2, ...], T]` type to support the + # input list syntax and @inout annotations. + return TypeArg( + _parse_callable_type(arg_nodes, node, globals, param_var_mapping) ) - # Hack: Flatten argument lists to support the `Callable` type. For - # example, we turn `Callable[[int, int], bool]` into - # `Callable[int, int, bool]`. - # TODO: We can get rid of this once we added support for variadic params - arg_nodes = [ - n - for arg in arg_nodes - for n in (arg.elts if isinstance(arg, ast.List) else (arg,)) - ] + if isinstance(defn, TypeDef): args = [ arg_from_ast(arg_node, globals, param_var_mapping) for arg_node in arg_nodes @@ -102,35 +111,120 @@ def arg_from_ast( # Finally, we also support delayed annotations in strings if isinstance(node, ast.Constant) and isinstance(node.value, str): - try: - [stmt] = ast.parse(node.value).body - if not isinstance(stmt, ast.Expr): - raise GuppyError("Invalid Guppy type", node) - set_location_from(stmt, loc=node) - shift_loc( - stmt, - delta_lineno=node.lineno - 1, # -1 since lines start at 1 - delta_col_offset=node.col_offset + 1, # +1 to remove the `"` - ) - return arg_from_ast(stmt.value, globals, param_var_mapping) - except (SyntaxError, ValueError): - raise GuppyError("Invalid Guppy type", node) from None + node = _parse_delayed_annotation(node.value, node) + return arg_from_ast(node, globals, param_var_mapping) raise GuppyError("Not a valid type argument", node) +def _parse_delayed_annotation(ast_str: str, node: ast.Constant) -> ast.expr: + """Parses a delayed type annotation in a string.""" + try: + [stmt] = ast.parse(ast_str).body + if not isinstance(stmt, ast.Expr): + raise GuppyError("Invalid Guppy type", node) + set_location_from(stmt, loc=node) + shift_loc( + stmt, + delta_lineno=node.lineno - 1, # -1 since lines start at 1 + delta_col_offset=node.col_offset + 1, # +1 to remove the `"` + ) + except (SyntaxError, ValueError): + raise GuppyError("Invalid Guppy type", node) from None + else: + return stmt.value + + +def _parse_callable_type( + args: list[ast.expr], + loc: AstNode, + globals: Globals, + param_var_mapping: dict[str, Parameter] | None, +) -> FunctionType: + """Helper function to parse a `Callable[[], ]` type.""" + err = ( + "Function types should be specified via " + "`Callable[[], ]`" + ) + if len(args) != 2: + raise GuppyError(err, loc) + [inputs, output] = args + if not isinstance(inputs, ast.List): + raise GuppyError(err, loc) + inouts, output = parse_function_io_types( + inputs.elts, output, loc, globals, param_var_mapping + ) + return FunctionType(inouts, output) + + +def parse_function_io_types( + input_nodes: list[ast.expr], + output_node: ast.expr, + loc: AstNode, + globals: Globals, + param_var_mapping: dict[str, Parameter] | None, +) -> tuple[list[FuncInput], Type]: + """Parses the inputs and output types of a function type. + + This function takes care of parsing `@inout` annotations and any related checks. + + Returns the parsed input and output types. + """ + inputs = [] + for inp in input_nodes: + ty, flags = type_with_flags_from_ast(inp, globals, param_var_mapping) + if InputFlags.Inout in flags and not ty.linear: + raise GuppyError( + f"Non-linear type `{ty}` cannot be annotated as `@inout`", loc + ) + inputs.append(FuncInput(ty, flags)) + output = type_from_ast(output_node, globals, param_var_mapping) + return inputs, output + + _type_param = TypeParam(0, "T", True) +def type_with_flags_from_ast( + node: AstNode, + globals: Globals, + param_var_mapping: dict[str, Parameter] | None = None, +) -> tuple[Type, InputFlags]: + """Turns an AST expression into a Guppy type with some optional @flags.""" + # Check for `type @flag` annotations + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult): + ty, flags = type_with_flags_from_ast(node.left, globals, param_var_mapping) + match node.right: + case ast.Name(id="inout"): + if not ty.linear: + raise GuppyError( + f"Non-linear type `{ty}` cannot be annotated as `@inout`", + node.right, + ) + flags |= InputFlags.Inout + case _: + raise GuppyError("Invalid annotation", node.right) + return ty, flags + # We also need to handle the case that this could be a delayed string annotation + elif isinstance(node, ast.Constant) and isinstance(node.value, str): + node = _parse_delayed_annotation(node.value, node) + return type_with_flags_from_ast(node, globals, param_var_mapping) + else: + # Parse an argument and check that it's valid for a `TypeParam` + arg = arg_from_ast(node, globals, param_var_mapping) + return _type_param.check_arg(arg, node).ty, InputFlags.NoFlags + + def type_from_ast( node: AstNode, globals: Globals, param_var_mapping: dict[str, Parameter] | None = None, ) -> Type: """Turns an AST expression into a Guppy type.""" - # Parse an argument and check that it's valid for a `TypeParam` - arg = arg_from_ast(node, globals, param_var_mapping) - return _type_param.check_arg(arg, node).ty + ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping) + if flags != InputFlags.NoFlags: + raise GuppyError("`@` type annotations are not allowed in this position", node) + return ty def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: diff --git a/tests/error/inout_errors/nonlinear_callable.err b/tests/error/inout_errors/nonlinear_callable.err index 4df85b50..0bb28859 100644 --- a/tests/error/inout_errors/nonlinear_callable.err +++ b/tests/error/inout_errors/nonlinear_callable.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:12 10: @guppy.declare(module) 11: def foo(f: Callable[[int @inout], None]) -> None: ... - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ^^^^^ GuppyError: Non-linear type `int` cannot be annotated as `@inout` diff --git a/tests/error/misc_errors/callable_no_args.err b/tests/error/misc_errors/callable_no_args.err new file mode 100644 index 00000000..b4391cb4 --- /dev/null +++ b/tests/error/misc_errors/callable_no_args.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:10 + +8: @guppy.declare(module) +9: def foo(f: Callable) -> None: ... + ^^^^^^^^ +GuppyError: Function types should be specified via `Callable[[], ]` diff --git a/tests/error/misc_errors/callable_no_args.py b/tests/error/misc_errors/callable_no_args.py new file mode 100644 index 00000000..29c147fa --- /dev/null +++ b/tests/error/misc_errors/callable_no_args.py @@ -0,0 +1,13 @@ +from typing import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(f: Callable) -> None: ... + + +module.compile() diff --git a/tests/error/misc_errors/callable_not_list1.err b/tests/error/misc_errors/callable_not_list1.err new file mode 100644 index 00000000..fb09294a --- /dev/null +++ b/tests/error/misc_errors/callable_not_list1.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:10 + +8: @guppy.declare(module) +9: def foo(f: "Callable[int, float, bool]") -> None: ... + ^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Function types should be specified via `Callable[[], ]` diff --git a/tests/error/misc_errors/callable_not_list1.py b/tests/error/misc_errors/callable_not_list1.py new file mode 100644 index 00000000..1aebc43a --- /dev/null +++ b/tests/error/misc_errors/callable_not_list1.py @@ -0,0 +1,13 @@ +from typing import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(f: "Callable[int, float, bool]") -> None: ... + + +module.compile() diff --git a/tests/error/misc_errors/callable_not_list2.err b/tests/error/misc_errors/callable_not_list2.err new file mode 100644 index 00000000..44246aad --- /dev/null +++ b/tests/error/misc_errors/callable_not_list2.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:10 + +8: @guppy.declare(module) +9: def foo(f: "Callable[None]") -> None: ... + ^^^^^^^^^^^^^^ +GuppyError: Function types should be specified via `Callable[[], ]` diff --git a/tests/error/misc_errors/callable_not_list2.py b/tests/error/misc_errors/callable_not_list2.py new file mode 100644 index 00000000..6401cf95 --- /dev/null +++ b/tests/error/misc_errors/callable_not_list2.py @@ -0,0 +1,13 @@ +from typing import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(f: "Callable[None]") -> None: ... + + +module.compile() diff --git a/tests/error/misc_errors/nested_arg_flag.err b/tests/error/misc_errors/nested_arg_flag.err index 4f06ad67..4bbb2b53 100644 --- a/tests/error/misc_errors/nested_arg_flag.err +++ b/tests/error/misc_errors/nested_arg_flag.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:12 10: @guppy.declare(module) 11: def foo(x: list[qubit @inout]) -> qubit: ... - ^^^^^^^^^^^^^^^^^^ -GuppyError: `@` type annotations are not allowed in this position + ^^^^^^^^^^^^ +GuppyError: Not a valid type argument diff --git a/tests/error/misc_errors/return_flag_callable.err b/tests/error/misc_errors/return_flag_callable.err index 3c69888d..f36a7a8f 100644 --- a/tests/error/misc_errors/return_flag_callable.err +++ b/tests/error/misc_errors/return_flag_callable.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:14 12: @guppy.declare(module) 13: def foo(f: Callable[[], qubit @inout]) -> None: ... - ^^^^^^^^^^^^^^^^^^^^^^^^^^ + ^^^^^^^^^^^^ GuppyError: `@` type annotations are not allowed in this position