From 263d527010f95c6dc2510cef87b1a5f7e4f8b7ca Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Nov 2024 16:35:48 +0000 Subject: [PATCH 1/6] feat: Generic function definitions --- guppylang/checker/cfg_checker.py | 16 ++- guppylang/checker/core.py | 2 + guppylang/checker/expr_checker.py | 17 ++- guppylang/checker/func_checker.py | 14 ++- guppylang/checker/stmt_checker.py | 2 +- guppylang/compiler/expr_compiler.py | 10 +- guppylang/definition/const.py | 2 +- guppylang/definition/extern.py | 2 +- guppylang/definition/function.py | 11 +- guppylang/definition/struct.py | 10 +- guppylang/nodes.py | 11 ++ guppylang/tys/parsing.py | 78 +++++++++----- guppylang/tys/printing.py | 4 +- tests/error/poly_errors/define.err | 6 -- tests/error/poly_errors/nested_generic.err | 7 ++ tests/error/poly_errors/nested_generic.py | 16 +++ .../poly_errors/type_param_not_value.err | 7 ++ .../{define.py => type_param_not_value.py} | 4 +- tests/integration/test_poly_def.py | 101 ++++++++++++++++++ 19 files changed, 260 insertions(+), 60 deletions(-) delete mode 100644 tests/error/poly_errors/define.err create mode 100644 tests/error/poly_errors/nested_generic.err create mode 100644 tests/error/poly_errors/nested_generic.py create mode 100644 tests/error/poly_errors/type_param_not_value.err rename tests/error/poly_errors/{define.py => type_param_not_value.py} (84%) create mode 100644 tests/integration/test_poly_def.py diff --git a/guppylang/checker/cfg_checker.py b/guppylang/checker/cfg_checker.py index 0f7a4e53..5ad9500e 100644 --- a/guppylang/checker/cfg_checker.py +++ b/guppylang/checker/cfg_checker.py @@ -16,6 +16,7 @@ from guppylang.checker.expr_checker import ExprSynthesizer, to_bool from guppylang.checker.stmt_checker import StmtChecker from guppylang.error import GuppyError +from guppylang.tys.param import Parameter from guppylang.tys.ty import InputFlags, Type Row = Sequence[V] @@ -58,7 +59,11 @@ def __init__(self, input_tys: list[Type], output_ty: Type) -> None: def check_cfg( - cfg: CFG, inputs: Row[Variable], return_ty: Type, globals: Globals + cfg: CFG, + inputs: Row[Variable], + return_ty: Type, + generic_params: dict[str, Parameter], + globals: Globals, ) -> CheckedCFG[Place]: """Type checks a control-flow graph. @@ -74,7 +79,7 @@ def check_cfg( # We start by compiling the entry BB checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty) checked_cfg.entry_bb = check_bb( - cfg.entry_bb, checked_cfg, inputs, return_ty, globals + cfg.entry_bb, checked_cfg, inputs, return_ty, generic_params, globals ) compiled = {cfg.entry_bb: checked_cfg.entry_bb} @@ -100,7 +105,9 @@ def check_cfg( check_rows_match(input_row, compiled[bb].sig.input_row, bb) else: # Otherwise, check the BB and enqueue its successors - checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) + checked_bb = check_bb( + bb, checked_cfg, input_row, return_ty, generic_params, globals + ) queue += [ # We enumerate the successor starting from the back, so we start with # the `True` branch. This way, we find errors in a more natural order @@ -134,6 +141,7 @@ def check_bb( checked_cfg: CheckedCFG[Variable], inputs: Row[Variable], return_ty: Type, + generic_params: dict[str, Parameter], globals: Globals, ) -> CheckedBB[Variable]: cfg = bb.containing_cfg @@ -147,7 +155,7 @@ def check_bb( raise GuppyError(f"Variable `{x}` is not defined", use) # Check the basic block - ctx = Context(globals, Locals({v.name: v for v in inputs})) + ctx = Context(globals, Locals({v.name: v for v in inputs}), generic_params) checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) # If we branch, we also have to check the branch predicate diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index ce2d6818..29460826 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -32,6 +32,7 @@ sized_iter_type_def, tuple_type_def, ) +from guppylang.tys.param import Parameter from guppylang.tys.ty import ( BoundTypeVar, ExistentialTypeVar, @@ -366,6 +367,7 @@ class Context(NamedTuple): globals: Globals locals: Locals[str, Variable] + generic_params: dict[str, Parameter] class DummyEvalDict(PyScope): diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index c21ffede..61d296bd 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -27,6 +27,8 @@ from dataclasses import replace from typing import Any, NoReturn, cast +from typing_extensions import assert_never + from guppylang.ast_util import ( AstNode, AstVisitor, @@ -62,6 +64,7 @@ DesugaredGenerator, DesugaredListComp, FieldAccessAndDrop, + GenericParamValue, GlobalName, InoutReturnSentinel, IterEnd, @@ -84,7 +87,7 @@ is_list_type, list_type, ) -from guppylang.tys.param import TypeParam +from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( ExistentialTypeVar, @@ -353,6 +356,16 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]: if x in self.ctx.locals: var = self.ctx.locals[x] return with_loc(node, PlaceNode(place=var)), var.ty + elif x in self.ctx.generic_params: + param = self.ctx.generic_params[x] + match param: + case ConstParam() as param: + ast_node = with_loc(node, GenericParamValue(id=x, param=param)) + return ast_node, param.ty + case TypeParam() as param: + raise GuppyError(f"Expected a value, got type `{param.name}`", node) + case _: + return assert_never(param) elif x in self.ctx.globals: defn = self.ctx.globals[x] return self._check_global(defn, x, node) @@ -1034,7 +1047,7 @@ def synthesize_comprehension( # The rest is checked in a new nested context to ensure that variables don't escape # their scope inner_locals: Locals[str, Variable] = Locals({}, parent_scope=ctx.locals) - inner_ctx = Context(ctx.globals, inner_locals) + inner_ctx = Context(ctx.globals, inner_locals, ctx.generic_params) expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx) gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign) gen.next_assign = stmt_chk.visit_Assign(gen.next_assign) diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index 47d59b2a..c3e7bd5b 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -36,7 +36,10 @@ def check_global_func_def( Variable(x, inp.ty, loc, inp.flags) for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True) ] - return check_cfg(cfg, inputs, ty.output, globals) + generic_params = { + param.name: param.with_idx(i) for i, param in enumerate(ty.params) + } + return check_cfg(cfg, inputs, ty.output, generic_params, globals) def check_nested_func_def( @@ -46,6 +49,11 @@ def check_nested_func_def( func_ty = check_signature(func_def, ctx.globals) assert func_ty.input_names is not None + if func_ty.parametrized: + raise GuppyError( + "Nested generic function definitions are not supported", func_def + ) + # We've already built the CFG for this function while building the CFG of the # enclosing function cfg = func_def.cfg @@ -102,7 +110,7 @@ def check_nested_func_def( # Otherwise, we treat it like a local name inputs.append(Variable(func_def.name, func_def.ty, func_def)) - checked_cfg = check_cfg(cfg, inputs, func_ty.output, globals) + checked_cfg = check_cfg(cfg, inputs, func_ty.output, {}, globals) checked_def = CheckedNestedFunctionDef( def_id, checked_cfg, @@ -153,7 +161,7 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType input_nodes.append(ty_ast) input_names.append(inp.arg) inputs, output = parse_function_io_types( - input_nodes, func_def.returns, func_def, globals, param_var_mapping + input_nodes, func_def.returns, func_def, globals, param_var_mapping, True ) return FunctionType( inputs, diff --git a/guppylang/checker/stmt_checker.py b/guppylang/checker/stmt_checker.py index b4f8db48..7c71369d 100644 --- a/guppylang/checker/stmt_checker.py +++ b/guppylang/checker/stmt_checker.py @@ -132,7 +132,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: raise GuppyError( "Variable declaration is not supported. Assignment is required", node ) - ty = type_from_ast(node.annotation, self.ctx.globals) + ty = type_from_ast(node.annotation, self.ctx.globals, self.ctx.generic_params) node.value, subst = self._check_expr(node.value, ty) assert not ty.unsolved_vars # `ty` must be closed! assert len(subst) == 0 diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index fefe6736..c3f4ad1d 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -20,7 +20,7 @@ from guppylang.checker.core import Variable from guppylang.checker.linearity_checker import contains_subscript from guppylang.compiler.core import CompilerBase, DFContainer -from guppylang.compiler.hugr_extension import PartialOp +from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp from guppylang.definition.custom import CustomFunctionDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef from guppylang.error import GuppyError, InternalGuppyError @@ -28,6 +28,7 @@ DesugaredGenerator, DesugaredListComp, FieldAccessAndDrop, + GenericParamValue, GlobalCall, GlobalName, InoutReturnSentinel, @@ -195,6 +196,13 @@ def visit_GlobalName(self, node: GlobalName) -> Wire: ) return defn.load(self.dfg, self.globals, node) + def visit_GenericParamValue(self, node: GenericParamValue) -> Wire: + # TODO: We need a way to look up the concrete value of a generic type arg in + # Hugr. For example, a new op that captures the value during monomorphisation + return self.builder.add_op( + UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op + ) + def visit_Name(self, node: ast.Name) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/definition/const.py b/guppylang/definition/const.py index d7479425..ea15f6c0 100644 --- a/guppylang/definition/const.py +++ b/guppylang/definition/const.py @@ -28,7 +28,7 @@ def parse(self, globals: Globals) -> "ConstDef": self.id, self.name, self.defined_at, - type_from_ast(self.type_ast, globals, None), + type_from_ast(self.type_ast, globals, {}), self.type_ast, self.value, ) diff --git a/guppylang/definition/extern.py b/guppylang/definition/extern.py index 3950750d..62a6c526 100644 --- a/guppylang/definition/extern.py +++ b/guppylang/definition/extern.py @@ -28,7 +28,7 @@ def parse(self, globals: Globals) -> "ExternDef": self.id, self.name, self.defined_at, - type_from_ast(self.type_ast, globals, None), + type_from_ast(self.type_ast, globals, {}), self.symbol, self.constant, self.type_ast, diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index c6601fc7..6a32a523 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -58,10 +58,6 @@ def parse(self, globals: Globals) -> "ParsedFunctionDef": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func) ty = check_signature(func_ast, globals.with_python_scope(self.python_scope)) - if ty.parametrized: - raise GuppyError( - "Generic function definitions are not supported yet", func_ast - ) return ParsedFunctionDef( self.id, self.name, func_ast, ty, self.python_scope, docstring ) @@ -155,9 +151,10 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledFunctionDe access to the other compiled functions yet. The body is compiled later in `CompiledFunctionDef.compile_inner()`. """ - func_type = self.ty.to_hugr() - func_def = module.define_function(self.name, func_type.input) - func_def.declare_outputs(func_type.output) + func_type = self.ty.to_hugr_poly() + func_def = module.define_function( + self.name, func_type.body.input, func_type.body.output, func_type.params + ) return CompiledFunctionDef( self.id, self.name, diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 521b5549..ddd65c28 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -157,9 +157,9 @@ def check(self, globals: Globals) -> "CheckedStructDef": # otherwise the code below would not terminate. # TODO: This is not ideal (see todo in `check_instantiate`) globals = globals.with_python_scope(self.python_scope) - check_not_recursive(self, globals) - param_var_mapping = {p.name: p for p in self.params} + check_not_recursive(self, globals, param_var_mapping) + fields = [ StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping)) for f in self.fields @@ -295,7 +295,9 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet return params -def check_not_recursive(defn: ParsedStructDef, globals: Globals) -> None: +def check_not_recursive( + defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter] +) -> None: """Throws a user error if the given struct definition is recursive.""" # TODO: The implementation below hijacks the type parsing logic to detect recursive @@ -324,4 +326,4 @@ def check_instantiate( } dummy_globals = replace(globals, defs=globals.defs | dummy_defs) for field in defn.fields: - type_from_ast(field.type_ast, dummy_globals, {}) + type_from_ast(field.type_ast, dummy_globals, param_var_mapping) diff --git a/guppylang/nodes.py b/guppylang/nodes.py index ec916433..e6d87449 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -15,6 +15,7 @@ from guppylang.checker.core import Place, Variable from guppylang.definition.common import DefId from guppylang.definition.struct import StructField + from guppylang.tys.param import ConstParam class PlaceNode(ast.expr): @@ -33,6 +34,16 @@ class GlobalName(ast.Name): ) +class GenericParamValue(ast.Name): + id: str + param: "ConstParam" + + _fields = ( + "id", + "param", + ) + + class LocalCall(ast.expr): func: ast.expr args: list[ast.expr] diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index 56985f47..4caf4c89 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -32,12 +32,15 @@ def arg_from_ast( node: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None = None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> Argument: """Turns an AST expression into an argument.""" # A single (possibly qualified) identifier if defn := _try_parse_defn(node, globals): - return _arg_from_instantiated_defn(defn, [], globals, node, param_var_mapping) + return _arg_from_instantiated_defn( + defn, [], globals, node, param_var_mapping, allow_free_vars + ) # A parametrised type, e.g. `list[??]` if isinstance(node, ast.Subscript) and ( @@ -47,13 +50,16 @@ def arg_from_ast( node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] ) return _arg_from_instantiated_defn( - defn, arg_nodes, globals, node, param_var_mapping + defn, arg_nodes, globals, node, param_var_mapping, allow_free_vars ) # We allow tuple types to be written as `(int, bool)` if isinstance(node, ast.Tuple): ty = TupleType( - [type_from_ast(el, globals, param_var_mapping) for el in node.elts] + [ + type_from_ast(el, globals, param_var_mapping, allow_free_vars) + for el in node.elts + ] ) return TypeArg(ty) @@ -73,7 +79,7 @@ def arg_from_ast( # Py-expressions can also be used to specify static numbers if py_expr := is_py_expression(node): - v = eval_py_expr(py_expr, Context(globals, Locals({}))) + v = eval_py_expr(py_expr, Context(globals, Locals({}), {})) if isinstance(v, int): nat_ty = NumericType(NumericType.Kind.Nat) return ConstArg(ConstValue(nat_ty, v)) @@ -87,7 +93,7 @@ def arg_from_ast( # Finally, we also support delayed annotations in strings if isinstance(node, ast.Constant) and isinstance(node.value, str): node = _parse_delayed_annotation(node.value, node) - return arg_from_ast(node, globals, param_var_mapping) + return arg_from_ast(node, globals, param_var_mapping, allow_free_vars) raise GuppyError("Not a valid type argument", node) @@ -123,19 +129,22 @@ def _arg_from_instantiated_defn( arg_nodes: list[ast.expr], globals: Globals, node: AstNode, - param_var_mapping: dict[str, Parameter] | None = None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> Argument: """Parses a globals definition with type args into an argument.""" match defn: # Special case for the `Callable` type case CallableTypeDef(): return TypeArg( - _parse_callable_type(arg_nodes, node, globals, param_var_mapping) + _parse_callable_type( + arg_nodes, node, globals, param_var_mapping, allow_free_vars + ) ) # Either a defined type (e.g. `int`, `bool`, ...) case TypeDef() as defn: args = [ - arg_from_ast(arg_node, globals, param_var_mapping) + arg_from_ast(arg_node, globals, param_var_mapping, allow_free_vars) for arg_node in arg_nodes ] ty = defn.check_instantiate(args, globals, node) @@ -149,12 +158,13 @@ def _arg_from_instantiated_defn( f"are not supported", node, ) - if param_var_mapping is None: - raise GuppyError( - "Free type variable. Only function types can be generic", node - ) if defn.name not in param_var_mapping: - param_var_mapping[defn.name] = defn.to_param(len(param_var_mapping)) + if allow_free_vars: + param_var_mapping[defn.name] = defn.to_param(len(param_var_mapping)) + else: + raise GuppyError( + "Free type variable. Only function types can be generic", node + ) return param_var_mapping[defn.name].to_bound() case defn: raise GuppyError( @@ -184,7 +194,8 @@ def _parse_callable_type( args: list[ast.expr], loc: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> FunctionType: """Helper function to parse a `Callable[[], ]` type.""" err = ( @@ -197,7 +208,7 @@ def _parse_callable_type( if not isinstance(inputs, ast.List): raise GuppyError(err, loc) inouts, output = parse_function_io_types( - inputs.elts, output, loc, globals, param_var_mapping + inputs.elts, output, loc, globals, param_var_mapping, allow_free_vars ) return FunctionType(inouts, output) @@ -207,7 +218,8 @@ def parse_function_io_types( output_node: ast.expr, loc: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> tuple[list[FuncInput], Type]: """Parses the inputs and output types of a function type. @@ -217,7 +229,9 @@ def parse_function_io_types( """ inputs = [] for inp in input_nodes: - ty, flags = type_with_flags_from_ast(inp, globals, param_var_mapping) + ty, flags = type_with_flags_from_ast( + inp, globals, param_var_mapping, allow_free_vars + ) if InputFlags.Owned in flags and not ty.linear: raise GuppyError( f"Non-linear type `{ty}` cannot be annotated as `@owned`", loc @@ -226,7 +240,7 @@ def parse_function_io_types( flags |= InputFlags.Inout inputs.append(FuncInput(ty, flags)) - output = type_from_ast(output_node, globals, param_var_mapping) + output = type_from_ast(output_node, globals, param_var_mapping, allow_free_vars) return inputs, output @@ -236,10 +250,13 @@ def parse_function_io_types( def type_with_flags_from_ast( node: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None = None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> tuple[Type, InputFlags]: if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult): - ty, flags = type_with_flags_from_ast(node.left, globals, param_var_mapping) + ty, flags = type_with_flags_from_ast( + node.left, globals, param_var_mapping, allow_free_vars + ) match node.right: case ast.Name(id="owned"): if not ty.linear: @@ -256,10 +273,12 @@ def type_with_flags_from_ast( # We also need to handle the case that this could be a delayed string annotation elif isinstance(node, ast.Constant) and isinstance(node.value, str): node = _parse_delayed_annotation(node.value, node) - return type_with_flags_from_ast(node, globals, param_var_mapping) + return type_with_flags_from_ast( + node, globals, param_var_mapping, allow_free_vars + ) else: # Parse an argument and check that it's valid for a `TypeParam` - arg = arg_from_ast(node, globals, param_var_mapping) + arg = arg_from_ast(node, globals, param_var_mapping, allow_free_vars) tyarg = _type_param.check_arg(arg, node) return tyarg.ty, InputFlags.NoFlags @@ -267,10 +286,13 @@ def type_with_flags_from_ast( def type_from_ast( node: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None = None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> Type: """Turns an AST expression into a Guppy type.""" - ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping) + ty, flags = type_with_flags_from_ast( + node, globals, param_var_mapping, allow_free_vars + ) if flags != InputFlags.NoFlags: assert InputFlags.Inout not in flags # Users shouldn't be able to set this raise GuppyError( @@ -280,7 +302,9 @@ def type_from_ast( return ty -def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: +def type_row_from_ast( + node: ast.expr, globals: "Globals", allow_free_vars: bool = False +) -> Sequence[Type]: """Turns an AST expression into a Guppy type row. This is needed to interpret the return type annotation of functions. @@ -288,7 +312,7 @@ def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: # The return type `-> None` is represented in the ast as `ast.Constant(value=None)` if isinstance(node, ast.Constant) and node.value is None: return [] - ty = type_from_ast(node, globals) + ty = type_from_ast(node, globals, {}, allow_free_vars) if isinstance(ty, TupleType): return ty.element_types else: diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 579d5336..1e2cde16 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -62,7 +62,9 @@ def _visit(self, ty: Type, inside_row: bool) -> str: @_visit.register def _visit_BoundVar(self, var: BoundVar, inside_row: bool) -> str: - return self.bound_names[var.idx] + if var.idx < len(self.bound_names): + return self.bound_names[var.idx] + return var.display_name @_visit.register def _visit_ExistentialVar(self, var: ExistentialVar, inside_row: bool) -> str: diff --git a/tests/error/poly_errors/define.err b/tests/error/poly_errors/define.err deleted file mode 100644 index d2500379..00000000 --- a/tests/error/poly_errors/define.err +++ /dev/null @@ -1,6 +0,0 @@ -Guppy compilation failed. Error in file $FILE:11 - -9: @guppy(module) -10: def main(x: T) -> T: - ^^^^^^^^^^^^^^^^^^^^ -GuppyError: Generic function definitions are not supported yet diff --git a/tests/error/poly_errors/nested_generic.err b/tests/error/poly_errors/nested_generic.err new file mode 100644 index 00000000..d45261ed --- /dev/null +++ b/tests/error/poly_errors/nested_generic.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy(module) +11: def foo() -> None: +12: def bar(x: T) -> T: + ^^^^^^^^^^^^^^^^^^^ +GuppyError: Nested generic function definitions are not supported diff --git a/tests/error/poly_errors/nested_generic.py b/tests/error/poly_errors/nested_generic.py new file mode 100644 index 00000000..79ec06ae --- /dev/null +++ b/tests/error/poly_errors/nested_generic.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var("T", module=module) + + +@guppy(module) +def foo() -> None: + def bar(x: T) -> T: + return x + + +module.compile() diff --git a/tests/error/poly_errors/type_param_not_value.err b/tests/error/poly_errors/type_param_not_value.err new file mode 100644 index 00000000..d85ddc99 --- /dev/null +++ b/tests/error/poly_errors/type_param_not_value.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy(module) +11: def foo(x: T) -> None: +12: y = T + ^ +GuppyError: Expected a value, got type `T` diff --git a/tests/error/poly_errors/define.py b/tests/error/poly_errors/type_param_not_value.py similarity index 84% rename from tests/error/poly_errors/define.py rename to tests/error/poly_errors/type_param_not_value.py index 72d1ae57..7faa7a97 100644 --- a/tests/error/poly_errors/define.py +++ b/tests/error/poly_errors/type_param_not_value.py @@ -8,8 +8,8 @@ @guppy(module) -def main(x: T) -> T: - return x +def foo(x: T) -> None: + y = T module.compile() diff --git a/tests/integration/test_poly_def.py b/tests/integration/test_poly_def.py new file mode 100644 index 00000000..9581c282 --- /dev/null +++ b/tests/integration/test_poly_def.py @@ -0,0 +1,101 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +def test_id(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def identity(x: T) -> T: + return x + + validate(module.compile()) + + +def test_nonlinear(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def copy(x: T) -> tuple[T, T]: + return x, x + + validate(module.compile()) + + +def test_apply(validate): + module = GuppyModule("test") + S = guppy.type_var("S", module=module) + T = guppy.type_var("T", module=module) + + @guppy(module) + def apply(f: Callable[[S], T], x: S) -> T: + return f(x) + + validate(module.compile()) + + +def test_annotate(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def identity(x: T) -> T: + y: T = x + return y + + validate(module.compile()) + + +def test_recurse(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def empty() -> T: + return empty() + + validate(module.compile()) + + +def test_call(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def identity(x: T) -> T: + return x + + @guppy(module) + def main() -> float: + return identity(5) + identity(42.0) + + validate(module.compile()) + + +def test_nat(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + n = guppy.nat_var("n", module=module) + + @guppy(module) + def foo(xs: array[T, n]) -> array[T, n]: + return xs + + validate(module.compile()) + + +def test_nat_use(validate): + module = GuppyModule("test") + n = guppy.nat_var("n", module=module) + + @guppy(module) + def foo(xs: array[int, n]) -> int: + return int(n) + + validate(module.compile()) + From d1af305fb4f8f5b99f0e563b4afe3fb713aabb27 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 7 Nov 2024 10:33:58 +0000 Subject: [PATCH 2/6] fix: Fix generic array functions --- guppylang/prelude/_internal/compiler/array.py | 15 ++++++++++---- guppylang/tys/builtin.py | 5 ++--- tests/integration/test_array.py | 20 +++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/guppylang/prelude/_internal/compiler/array.py b/guppylang/prelude/_internal/compiler/array.py index 5c571d4e..7d161690 100644 --- a/guppylang/prelude/_internal/compiler/array.py +++ b/guppylang/prelude/_internal/compiler/array.py @@ -99,7 +99,8 @@ class NewArrayCompiler(ArrayCompiler): def build_classical_array(self, elems: list[Wire]) -> Wire: """Lowers a call to `array.__new__` for classical arrays.""" - return self.builder.add_op(array_new(self.elem_ty, len(elems)), *elems) + # See https://github.com/CQCL/guppylang/issues/629 + return self.build_linear_array(elems) def build_linear_array(self, elems: list[Wire]) -> Wire: """Lowers a call to `array.__new__` for linear arrays.""" @@ -121,9 +122,12 @@ class ArrayGetitemCompiler(ArrayCompiler): def build_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires: """Lowers a call to `array.__getitem__` for classical arrays.""" + # See https://github.com/CQCL/guppylang/issues/629 + elem_opt_ty = ht.Option(self.elem_ty) idx = self.builder.add_op(convert_itousize(), idx) - result = self.builder.add_op(array_get(self.elem_ty, self.length), array, idx) - elem = build_unwrap(self.builder, result, "Array index out of bounds") + result = self.builder.add_op(array_get(elem_opt_ty, self.length), array, idx) + elem_opt = build_unwrap(self.builder, result, "Array index out of bounds") + elem = build_unwrap(self.builder, elem_opt, "array.__getitem__: Internal error") return CallReturnWires(regular_returns=[elem], inout_returns=[array]) def build_linear_getitem(self, array: Wire, idx: Wire) -> CallReturnWires: @@ -163,9 +167,12 @@ def build_classical_setitem( self, array: Wire, idx: Wire, elem: Wire ) -> CallReturnWires: """Lowers a call to `array.__setitem__` for classical arrays.""" + # See https://github.com/CQCL/guppylang/issues/629 + elem_opt_ty = ht.Option(self.elem_ty) idx = self.builder.add_op(convert_itousize(), idx) + elem_opt = self.builder.add_op(ops.Tag(1, elem_opt_ty), elem) result = self.builder.add_op( - array_set(self.elem_ty, self.length), array, idx, elem + array_set(elem_opt_ty, self.length), array, idx, elem_opt ) # Unwrap the result, but we don't have to hold onto the returned old value _, array = build_unwrap_right(self.builder, result, "Array index out of bounds") diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index f201ca61..12e53836 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -132,9 +132,8 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: # Linear elements are turned into an optional to enable unsafe indexing. # See `ArrayGetitemCompiler` for details. - elem_ty = ( - ht.Option(ty_arg.ty.to_hugr()) if ty_arg.ty.linear else ty_arg.ty.to_hugr() - ) + # Same also for classical arrays, see https://github.com/CQCL/guppylang/issues/629 + elem_ty = ht.Option(ty_arg.ty.to_hugr()) array = hugr.std.PRELUDE.get_type("array") return array.instantiate([len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)]) diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index d71f5b2a..9b65f8d1 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -227,6 +227,26 @@ def main(a: A @owned, i: int, j: int, k: int) -> A: validate(module.compile()) + +def test_generic_function(validate): + module = GuppyModule("test") + module.load(qubit) + T = guppy.type_var("T", linear=True, module=module) + n = guppy.nat_var("n", module=module) + + @guppy(module) + def foo(xs: array[T, n] @owned) -> array[T, n]: + return xs + + @guppy(module) + def main() -> tuple[array[int, 3], array[qubit, 2]]: + xs = array(1, 2, 3) + ys = array(qubit(), qubit()) + return foo(xs), foo(ys) + + validate(module.compile()) + + def test_exec_array(validate, run_int_fn): module = GuppyModule("test") From a81e3f1b648ee6d79bb81817f8871184f1e8481a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 7 Nov 2024 11:30:30 +0000 Subject: [PATCH 3/6] Disable array results --- guppylang/prelude/_internal/checker.py | 6 ++++-- tests/integration/test_result.py | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/guppylang/prelude/_internal/checker.py b/guppylang/prelude/_internal/checker.py index 7a2eafe2..21f1d665 100644 --- a/guppylang/prelude/_internal/checker.py +++ b/guppylang/prelude/_internal/checker.py @@ -268,8 +268,10 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: assert isinstance(len_arg, ConstArg) if not self._is_numeric_or_bool_type(ty_arg.ty): raise GuppyError(err, value) - base_ty = ty_arg.ty - array_len = len_arg.const + _base_ty = ty_arg.ty + _array_len = len_arg.const + # See https://github.com/CQCL/guppylang/issues/631 + raise GuppyError("Array results are currently disabled", value) else: raise GuppyError(err, value) node = ResultExpr(value, base_ty, array_len, tag.value) diff --git a/tests/integration/test_result.py b/tests/integration/test_result.py index dea78add..a9f1dba8 100644 --- a/tests/integration/test_result.py +++ b/tests/integration/test_result.py @@ -1,3 +1,5 @@ +import pytest + from guppylang.prelude.builtins import result, nat, array from tests.util import compile_guppy @@ -21,6 +23,7 @@ def main(w: nat, x: int, y: float, z: bool) -> None: validate(main) +@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/631") def test_array(validate): @compile_guppy def main( From 8c4f4d10a1f3575e67715c76128d9e6b1dcef973 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 14 Nov 2024 09:41:50 +0000 Subject: [PATCH 4/6] Update diagnostics --- guppylang/checker/expr_checker.py | 4 +++- guppylang/checker/func_checker.py | 2 +- guppylang/definition/function.py | 2 +- tests/error/poly_errors/nested_generic.err | 15 +++++++++------ tests/error/poly_errors/type_param_not_value.err | 13 +++++++------ tests/integration/test_poly_def.py | 2 +- 6 files changed, 22 insertions(+), 16 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index bf37ba16..60ebcc1f 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -378,7 +378,9 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]: ast_node = with_loc(node, GenericParamValue(id=x, param=param)) return ast_node, param.ty case TypeParam() as param: - raise GuppyError(f"Expected a value, got type `{param.name}`", node) + raise GuppyError( + ExpectedError(node, "a value", got=f"type `{param.name}`") + ) case _: return assert_never(param) elif x in self.ctx.globals: diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index b0f134d1..7463cf7b 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -89,7 +89,7 @@ def check_nested_func_def( if func_ty.parametrized: raise GuppyError( - "Nested generic function definitions are not supported", func_def + UnsupportedError(func_def, "Nested generic function definitions") ) # We've already built the CFG for this function while building the CFG of the diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index 4914c425..94f85f19 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -14,7 +14,7 @@ from guppylang.ast_util import AstNode, annotate_location, with_loc from guppylang.checker.cfg_checker import CheckedCFG from guppylang.checker.core import Context, Globals, Place, PyScope -from guppylang.checker.errors.generic import ExpectedError, UnsupportedError +from guppylang.checker.errors.generic import ExpectedError from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.checker.func_checker import ( check_global_func_def, diff --git a/tests/error/poly_errors/nested_generic.err b/tests/error/poly_errors/nested_generic.err index d45261ed..eff911f0 100644 --- a/tests/error/poly_errors/nested_generic.err +++ b/tests/error/poly_errors/nested_generic.err @@ -1,7 +1,10 @@ -Guppy compilation failed. Error in file $FILE:12 +Error: Unsupported (at $FILE:12:4) + | +10 | @guppy(module) +11 | def foo() -> None: +12 | def bar(x: T) -> T: + | ^^^^^^^^^^^^^^^^^^^ +13 | return x + | ^^^^^^^^^^^^^^^^ Nested generic function definitions are not supported -10: @guppy(module) -11: def foo() -> None: -12: def bar(x: T) -> T: - ^^^^^^^^^^^^^^^^^^^ -GuppyError: Nested generic function definitions are not supported +Guppy compilation failed due to 1 previous error diff --git a/tests/error/poly_errors/type_param_not_value.err b/tests/error/poly_errors/type_param_not_value.err index d85ddc99..d63e4ffb 100644 --- a/tests/error/poly_errors/type_param_not_value.err +++ b/tests/error/poly_errors/type_param_not_value.err @@ -1,7 +1,8 @@ -Guppy compilation failed. Error in file $FILE:12 +Error: Expected a value (at $FILE:12:8) + | +10 | @guppy(module) +11 | def foo(x: T) -> None: +12 | y = T + | ^ Expected a value, got type `T` -10: @guppy(module) -11: def foo(x: T) -> None: -12: y = T - ^ -GuppyError: Expected a value, got type `T` +Guppy compilation failed due to 1 previous error diff --git a/tests/integration/test_poly_def.py b/tests/integration/test_poly_def.py index 9581c282..b054228d 100644 --- a/tests/integration/test_poly_def.py +++ b/tests/integration/test_poly_def.py @@ -2,7 +2,7 @@ from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.prelude.builtins import array +from guppylang.std.builtins import array def test_id(validate): From 5efa146ca2b8f07c82178b2871a4e0330361884c Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 14 Nov 2024 10:08:58 +0000 Subject: [PATCH 5/6] Add nat call and recurse tests --- tests/integration/test_poly_def.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/integration/test_poly_def.py b/tests/integration/test_poly_def.py index b054228d..bef1f9fc 100644 --- a/tests/integration/test_poly_def.py +++ b/tests/integration/test_poly_def.py @@ -99,3 +99,30 @@ def foo(xs: array[int, n]) -> int: validate(module.compile()) + +def test_nat_call(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + n = guppy.nat_var("n", module=module) + + @guppy(module) + def foo() -> array[T, n]: + return foo() + + @guppy(module) + def main() -> tuple[array[int, 10], array[float, 20]]: + return foo(), foo() + + validate(module.compile()) + + +def test_nat_recurse(validate): + module = GuppyModule("test") + n = guppy.nat_var("n", module=module) + + @guppy(module) + def empty() -> array[int, n]: + return empty() + + validate(module.compile()) + From 0de51a3f7e48bbf78ac3869e5303feade0161d7b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 14 Nov 2024 10:24:46 +0000 Subject: [PATCH 6/6] Update diagnostics --- guppylang/std/_internal/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppylang/std/_internal/checker.py b/guppylang/std/_internal/checker.py index 6e8e83a0..9afc259b 100644 --- a/guppylang/std/_internal/checker.py +++ b/guppylang/std/_internal/checker.py @@ -270,7 +270,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: _base_ty = ty_arg.ty _array_len = len_arg.const # See https://github.com/CQCL/guppylang/issues/631 - raise GuppyError("Array results are currently disabled", value) + raise GuppyError(UnsupportedError(value, "Array results")) else: raise GuppyError(err) node = ResultExpr(value, base_ty, array_len, tag.value)