From 7519b9096a02cf75672313bd0bc90c613e5230ee Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:17:14 +0000 Subject: [PATCH] feat: Generic function definitions (#618) Allow function definitions that are generic over parameters of kind `type` or `nat`. * Store currently available type parameters in the type checking context * Refactor type parsing logic to explicitly opt in to add free variables to the parameter mapping * Add a new `GenericParamValue` AST node to represent usages of generic nat params `n` inside function bodies. We need a way to encode this in Hugr. For now, we just emit a dummy node. See https://github.com/CQCL/hugr/issues/1629 Note: Nested generic functions are not supported yet, so we don't have to worry about scoping of type params. Closes #522 --- guppylang/checker/cfg_checker.py | 17 ++- guppylang/checker/core.py | 2 + guppylang/checker/expr_checker.py | 19 ++- guppylang/checker/func_checker.py | 14 +- guppylang/checker/stmt_checker.py | 2 +- guppylang/compiler/expr_compiler.py | 10 +- guppylang/definition/const.py | 2 +- guppylang/definition/extern.py | 2 +- guppylang/definition/function.py | 11 +- guppylang/definition/struct.py | 10 +- guppylang/nodes.py | 11 ++ guppylang/tys/parsing.py | 74 ++++++---- guppylang/tys/printing.py | 4 +- tests/error/poly_errors/define.err | 10 -- tests/error/poly_errors/nested_generic.err | 10 ++ tests/error/poly_errors/nested_generic.py | 16 +++ .../poly_errors/type_param_not_value.err | 8 ++ .../{define.py => type_param_not_value.py} | 4 +- tests/integration/test_poly_def.py | 128 ++++++++++++++++++ 19 files changed, 293 insertions(+), 61 deletions(-) create mode 100644 tests/error/poly_errors/nested_generic.err create mode 100644 tests/error/poly_errors/nested_generic.py create mode 100644 tests/error/poly_errors/type_param_not_value.err rename tests/error/poly_errors/{define.py => type_param_not_value.py} (84%) create mode 100644 tests/integration/test_poly_def.py diff --git a/guppylang/checker/cfg_checker.py b/guppylang/checker/cfg_checker.py index 909d1d5b..74f1b5ce 100644 --- a/guppylang/checker/cfg_checker.py +++ b/guppylang/checker/cfg_checker.py @@ -18,6 +18,7 @@ from guppylang.checker.stmt_checker import StmtChecker from guppylang.diagnostic import Error, Note from guppylang.error import GuppyError +from guppylang.tys.param import Parameter from guppylang.tys.ty import InputFlags, Type Row = Sequence[V] @@ -60,7 +61,12 @@ def __init__(self, input_tys: list[Type], output_ty: Type) -> None: def check_cfg( - cfg: CFG, inputs: Row[Variable], return_ty: Type, func_name: str, globals: Globals + cfg: CFG, + inputs: Row[Variable], + return_ty: Type, + generic_params: dict[str, Parameter], + func_name: str, + globals: Globals, ) -> CheckedCFG[Place]: """Type checks a control-flow graph. @@ -76,7 +82,7 @@ def check_cfg( # We start by compiling the entry BB checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty) checked_cfg.entry_bb = check_bb( - cfg.entry_bb, checked_cfg, inputs, return_ty, globals + cfg.entry_bb, checked_cfg, inputs, return_ty, generic_params, globals ) compiled = {cfg.entry_bb: checked_cfg.entry_bb} @@ -102,7 +108,9 @@ def check_cfg( check_rows_match(input_row, compiled[bb].sig.input_row, bb) else: # Otherwise, check the BB and enqueue its successors - checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) + checked_bb = check_bb( + bb, checked_cfg, input_row, return_ty, generic_params, globals + ) queue += [ # We enumerate the successor starting from the back, so we start with # the `True` branch. This way, we find errors in a more natural order @@ -174,6 +182,7 @@ def check_bb( checked_cfg: CheckedCFG[Variable], inputs: Row[Variable], return_ty: Type, + generic_params: dict[str, Parameter], globals: Globals, ) -> CheckedBB[Variable]: cfg = bb.containing_cfg @@ -187,7 +196,7 @@ def check_bb( raise GuppyError(VarNotDefinedError(use, x)) # Check the basic block - ctx = Context(globals, Locals({v.name: v for v in inputs})) + ctx = Context(globals, Locals({v.name: v for v in inputs}), generic_params) checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) # If we branch, we also have to check the branch predicate diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 1c4b880f..c2ec79a0 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -32,6 +32,7 @@ sized_iter_type_def, tuple_type_def, ) +from guppylang.tys.param import Parameter from guppylang.tys.ty import ( BoundTypeVar, ExistentialTypeVar, @@ -381,6 +382,7 @@ class Context(NamedTuple): globals: Globals locals: Locals[str, Variable] + generic_params: dict[str, Parameter] class DummyEvalDict(PyScope): diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 574cbae7..60ebcc1f 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -27,6 +27,8 @@ from dataclasses import replace from typing import Any, NoReturn, cast +from typing_extensions import assert_never + from guppylang.ast_util import ( AstNode, AstVisitor, @@ -85,6 +87,7 @@ DesugaredGenerator, DesugaredListComp, FieldAccessAndDrop, + GenericParamValue, GlobalName, InoutReturnSentinel, IterEnd, @@ -108,7 +111,7 @@ is_list_type, list_type, ) -from guppylang.tys.param import TypeParam +from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( ExistentialTypeVar, @@ -368,6 +371,18 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]: if x in self.ctx.locals: var = self.ctx.locals[x] return with_loc(node, PlaceNode(place=var)), var.ty + elif x in self.ctx.generic_params: + param = self.ctx.generic_params[x] + match param: + case ConstParam() as param: + ast_node = with_loc(node, GenericParamValue(id=x, param=param)) + return ast_node, param.ty + case TypeParam() as param: + raise GuppyError( + ExpectedError(node, "a value", got=f"type `{param.name}`") + ) + case _: + return assert_never(param) elif x in self.ctx.globals: defn = self.ctx.globals[x] return self._check_global(defn, x, node) @@ -1031,7 +1046,7 @@ def synthesize_comprehension( # The rest is checked in a new nested context to ensure that variables don't escape # their scope inner_locals: Locals[str, Variable] = Locals({}, parent_scope=ctx.locals) - inner_ctx = Context(ctx.globals, inner_locals) + inner_ctx = Context(ctx.globals, inner_locals, ctx.generic_params) expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx) gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign) gen.next_assign = stmt_chk.visit_Assign(gen.next_assign) diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index f194b8f1..7463cf7b 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -74,7 +74,10 @@ def check_global_func_def( Variable(x, inp.ty, loc, inp.flags) for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True) ] - return check_cfg(cfg, inputs, ty.output, func_def.name, globals) + generic_params = { + param.name: param.with_idx(i) for i, param in enumerate(ty.params) + } + return check_cfg(cfg, inputs, ty.output, generic_params, func_def.name, globals) def check_nested_func_def( @@ -84,6 +87,11 @@ def check_nested_func_def( func_ty = check_signature(func_def, ctx.globals) assert func_ty.input_names is not None + if func_ty.parametrized: + raise GuppyError( + UnsupportedError(func_def, "Nested generic function definitions") + ) + # We've already built the CFG for this function while building the CFG of the # enclosing function cfg = func_def.cfg @@ -137,7 +145,7 @@ def check_nested_func_def( # Otherwise, we treat it like a local name inputs.append(Variable(func_def.name, func_def.ty, func_def)) - checked_cfg = check_cfg(cfg, inputs, func_ty.output, func_def.name, globals) + checked_cfg = check_cfg(cfg, inputs, func_ty.output, {}, func_def.name, globals) checked_def = CheckedNestedFunctionDef( def_id, checked_cfg, @@ -188,7 +196,7 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType input_nodes.append(ty_ast) input_names.append(inp.arg) inputs, output = parse_function_io_types( - input_nodes, func_def.returns, func_def, globals, param_var_mapping + input_nodes, func_def.returns, func_def, globals, param_var_mapping, True ) return FunctionType( inputs, diff --git a/guppylang/checker/stmt_checker.py b/guppylang/checker/stmt_checker.py index 9928a01f..177ba517 100644 --- a/guppylang/checker/stmt_checker.py +++ b/guppylang/checker/stmt_checker.py @@ -144,7 +144,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign: def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: if node.value is None: raise GuppyError(UnsupportedError(node, "Variable declarations")) - ty = type_from_ast(node.annotation, self.ctx.globals) + ty = type_from_ast(node.annotation, self.ctx.globals, self.ctx.generic_params) node.value, subst = self._check_expr(node.value, ty) assert not ty.unsolved_vars # `ty` must be closed! assert len(subst) == 0 diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 5eaed826..acbfa733 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -21,7 +21,7 @@ from guppylang.checker.errors.generic import UnsupportedError from guppylang.checker.linearity_checker import contains_subscript from guppylang.compiler.core import CompilerBase, DFContainer -from guppylang.compiler.hugr_extension import PartialOp +from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp from guppylang.definition.custom import CustomFunctionDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef from guppylang.error import GuppyError, InternalGuppyError @@ -29,6 +29,7 @@ DesugaredGenerator, DesugaredListComp, FieldAccessAndDrop, + GenericParamValue, GlobalCall, GlobalName, InoutReturnSentinel, @@ -196,6 +197,13 @@ def visit_GlobalName(self, node: GlobalName) -> Wire: raise GuppyError(err) return defn.load(self.dfg, self.globals, node) + def visit_GenericParamValue(self, node: GenericParamValue) -> Wire: + # TODO: We need a way to look up the concrete value of a generic type arg in + # Hugr. For example, a new op that captures the value during monomorphisation + return self.builder.add_op( + UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op + ) + def visit_Name(self, node: ast.Name) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/definition/const.py b/guppylang/definition/const.py index 1f691317..4b8a0a09 100644 --- a/guppylang/definition/const.py +++ b/guppylang/definition/const.py @@ -29,7 +29,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ConstDef": self.id, self.name, self.defined_at, - type_from_ast(self.type_ast, globals, None), + type_from_ast(self.type_ast, globals, {}), self.type_ast, self.value, ) diff --git a/guppylang/definition/extern.py b/guppylang/definition/extern.py index a664c91a..79c17ed6 100644 --- a/guppylang/definition/extern.py +++ b/guppylang/definition/extern.py @@ -29,7 +29,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ExternDef": self.id, self.name, self.defined_at, - type_from_ast(self.type_ast, globals, None), + type_from_ast(self.type_ast, globals, {}), self.symbol, self.constant, self.type_ast, diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index 7b2e7cc7..94f85f19 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -14,7 +14,7 @@ from guppylang.ast_util import AstNode, annotate_location, with_loc from guppylang.checker.cfg_checker import CheckedCFG from guppylang.checker.core import Context, Globals, Place, PyScope -from guppylang.checker.errors.generic import ExpectedError, UnsupportedError +from guppylang.checker.errors.generic import ExpectedError from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.checker.func_checker import ( check_global_func_def, @@ -65,8 +65,6 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func, sources) ty = check_signature(func_ast, globals.with_python_scope(self.python_scope)) - if ty.parametrized: - raise GuppyError(UnsupportedError(func_ast, "Generic function definitions")) return ParsedFunctionDef( self.id, self.name, func_ast, ty, self.python_scope, docstring ) @@ -160,9 +158,10 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledFunctionDe access to the other compiled functions yet. The body is compiled later in `CompiledFunctionDef.compile_inner()`. """ - func_type = self.ty.to_hugr() - func_def = module.define_function(self.name, func_type.input) - func_def.declare_outputs(func_type.output) + func_type = self.ty.to_hugr_poly() + func_def = module.define_function( + self.name, func_type.body.input, func_type.body.output, func_type.params + ) return CompiledFunctionDef( self.id, self.name, diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 39507bad..ccbc3a51 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -185,9 +185,9 @@ def check(self, globals: Globals) -> "CheckedStructDef": # otherwise the code below would not terminate. # TODO: This is not ideal (see todo in `check_instantiate`) globals = globals.with_python_scope(self.python_scope) - check_not_recursive(self, globals) - param_var_mapping = {p.name: p for p in self.params} + check_not_recursive(self, globals, param_var_mapping) + fields = [ StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping)) for f in self.fields @@ -330,7 +330,9 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet return params -def check_not_recursive(defn: ParsedStructDef, globals: Globals) -> None: +def check_not_recursive( + defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter] +) -> None: """Throws a user error if the given struct definition is recursive.""" # TODO: The implementation below hijacks the type parsing logic to detect recursive @@ -359,4 +361,4 @@ def check_instantiate( } dummy_globals = replace(globals, defs=globals.defs | dummy_defs) for field in defn.fields: - type_from_ast(field.type_ast, dummy_globals, {}) + type_from_ast(field.type_ast, dummy_globals, param_var_mapping) diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 8b18ea47..7f286e0a 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -15,6 +15,7 @@ from guppylang.checker.core import Place, Variable from guppylang.definition.common import DefId from guppylang.definition.struct import StructField + from guppylang.tys.param import ConstParam class PlaceNode(ast.expr): @@ -33,6 +34,16 @@ class GlobalName(ast.Name): ) +class GenericParamValue(ast.Name): + id: str + param: "ConstParam" + + _fields = ( + "id", + "param", + ) + + class LocalCall(ast.expr): func: ast.expr args: list[ast.expr] diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index c17e76d3..a64189ae 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -45,12 +45,15 @@ def arg_from_ast( node: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None = None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> Argument: """Turns an AST expression into an argument.""" # A single (possibly qualified) identifier if defn := _try_parse_defn(node, globals): - return _arg_from_instantiated_defn(defn, [], globals, node, param_var_mapping) + return _arg_from_instantiated_defn( + defn, [], globals, node, param_var_mapping, allow_free_vars + ) # A parametrised type, e.g. `list[??]` if isinstance(node, ast.Subscript) and ( @@ -60,13 +63,16 @@ def arg_from_ast( node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] ) return _arg_from_instantiated_defn( - defn, arg_nodes, globals, node, param_var_mapping + defn, arg_nodes, globals, node, param_var_mapping, allow_free_vars ) # We allow tuple types to be written as `(int, bool)` if isinstance(node, ast.Tuple): ty = TupleType( - [type_from_ast(el, globals, param_var_mapping) for el in node.elts] + [ + type_from_ast(el, globals, param_var_mapping, allow_free_vars) + for el in node.elts + ] ) return TypeArg(ty) @@ -86,7 +92,7 @@ def arg_from_ast( # Py-expressions can also be used to specify static numbers if py_expr := is_py_expression(node): - v = eval_py_expr(py_expr, Context(globals, Locals({}))) + v = eval_py_expr(py_expr, Context(globals, Locals({}), {})) if isinstance(v, int): nat_ty = NumericType(NumericType.Kind.Nat) return ConstArg(ConstValue(nat_ty, v)) @@ -96,7 +102,7 @@ def arg_from_ast( # Finally, we also support delayed annotations in strings if isinstance(node, ast.Constant) and isinstance(node.value, str): node = _parse_delayed_annotation(node.value, node) - return arg_from_ast(node, globals, param_var_mapping) + return arg_from_ast(node, globals, param_var_mapping, allow_free_vars) raise GuppyError(InvalidTypeArgError(node)) @@ -133,19 +139,22 @@ def _arg_from_instantiated_defn( arg_nodes: list[ast.expr], globals: Globals, node: AstNode, - param_var_mapping: dict[str, Parameter] | None = None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> Argument: """Parses a globals definition with type args into an argument.""" match defn: # Special case for the `Callable` type case CallableTypeDef(): return TypeArg( - _parse_callable_type(arg_nodes, node, globals, param_var_mapping) + _parse_callable_type( + arg_nodes, node, globals, param_var_mapping, allow_free_vars + ) ) # Either a defined type (e.g. `int`, `bool`, ...) case TypeDef() as defn: args = [ - arg_from_ast(arg_node, globals, param_var_mapping) + arg_from_ast(arg_node, globals, param_var_mapping, allow_free_vars) for arg_node in arg_nodes ] ty = defn.check_instantiate(args, globals, node) @@ -155,10 +164,11 @@ def _arg_from_instantiated_defn( # We don't allow parametrised variables like `T[int]` if arg_nodes: raise GuppyError(HigherKindedTypeVarError(node, defn)) - if param_var_mapping is None: - raise GuppyError(FreeTypeVarError(node, defn)) if defn.name not in param_var_mapping: - param_var_mapping[defn.name] = defn.to_param(len(param_var_mapping)) + if allow_free_vars: + param_var_mapping[defn.name] = defn.to_param(len(param_var_mapping)) + else: + raise GuppyError(FreeTypeVarError(node, defn)) return param_var_mapping[defn.name].to_bound() case defn: err = ExpectedError(node, "a type", got=f"{defn.description} `{defn.name}`") @@ -187,7 +197,8 @@ def _parse_callable_type( args: list[ast.expr], loc: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> FunctionType: """Helper function to parse a `Callable[[], ]` type.""" err = InvalidCallableTypeError(loc) @@ -197,7 +208,7 @@ def _parse_callable_type( if not isinstance(inputs, ast.List): raise GuppyError(err) inouts, output = parse_function_io_types( - inputs.elts, output, loc, globals, param_var_mapping + inputs.elts, output, loc, globals, param_var_mapping, allow_free_vars ) return FunctionType(inouts, output) @@ -207,7 +218,8 @@ def parse_function_io_types( output_node: ast.expr, loc: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> tuple[list[FuncInput], Type]: """Parses the inputs and output types of a function type. @@ -217,14 +229,16 @@ def parse_function_io_types( """ inputs = [] for inp in input_nodes: - ty, flags = type_with_flags_from_ast(inp, globals, param_var_mapping) + ty, flags = type_with_flags_from_ast( + inp, globals, param_var_mapping, allow_free_vars + ) if InputFlags.Owned in flags and not ty.linear: raise GuppyError(NonLinearOwnedError(loc, ty)) if ty.linear and InputFlags.Owned not in flags: flags |= InputFlags.Inout inputs.append(FuncInput(ty, flags)) - output = type_from_ast(output_node, globals, param_var_mapping) + output = type_from_ast(output_node, globals, param_var_mapping, allow_free_vars) return inputs, output @@ -234,10 +248,13 @@ def parse_function_io_types( def type_with_flags_from_ast( node: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None = None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> tuple[Type, InputFlags]: if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult): - ty, flags = type_with_flags_from_ast(node.left, globals, param_var_mapping) + ty, flags = type_with_flags_from_ast( + node.left, globals, param_var_mapping, allow_free_vars + ) match node.right: case ast.Name(id="owned"): if not ty.linear: @@ -249,10 +266,12 @@ def type_with_flags_from_ast( # We also need to handle the case that this could be a delayed string annotation elif isinstance(node, ast.Constant) and isinstance(node.value, str): node = _parse_delayed_annotation(node.value, node) - return type_with_flags_from_ast(node, globals, param_var_mapping) + return type_with_flags_from_ast( + node, globals, param_var_mapping, allow_free_vars + ) else: # Parse an argument and check that it's valid for a `TypeParam` - arg = arg_from_ast(node, globals, param_var_mapping) + arg = arg_from_ast(node, globals, param_var_mapping, allow_free_vars) tyarg = _type_param.check_arg(arg, node) return tyarg.ty, InputFlags.NoFlags @@ -260,17 +279,22 @@ def type_with_flags_from_ast( def type_from_ast( node: AstNode, globals: Globals, - param_var_mapping: dict[str, Parameter] | None = None, + param_var_mapping: dict[str, Parameter], + allow_free_vars: bool = False, ) -> Type: """Turns an AST expression into a Guppy type.""" - ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping) + ty, flags = type_with_flags_from_ast( + node, globals, param_var_mapping, allow_free_vars + ) if flags != InputFlags.NoFlags: assert InputFlags.Inout not in flags # Users shouldn't be able to set this raise GuppyError(FlagNotAllowedError(node)) return ty -def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: +def type_row_from_ast( + node: ast.expr, globals: "Globals", allow_free_vars: bool = False +) -> Sequence[Type]: """Turns an AST expression into a Guppy type row. This is needed to interpret the return type annotation of functions. @@ -278,7 +302,7 @@ def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: # The return type `-> None` is represented in the ast as `ast.Constant(value=None)` if isinstance(node, ast.Constant) and node.value is None: return [] - ty = type_from_ast(node, globals) + ty = type_from_ast(node, globals, {}, allow_free_vars) if isinstance(ty, TupleType): return ty.element_types else: diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 579d5336..1e2cde16 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -62,7 +62,9 @@ def _visit(self, ty: Type, inside_row: bool) -> str: @_visit.register def _visit_BoundVar(self, var: BoundVar, inside_row: bool) -> str: - return self.bound_names[var.idx] + if var.idx < len(self.bound_names): + return self.bound_names[var.idx] + return var.display_name @_visit.register def _visit_ExistentialVar(self, var: ExistentialVar, inside_row: bool) -> str: diff --git a/tests/error/poly_errors/define.err b/tests/error/poly_errors/define.err index 4fef9e69..e69de29b 100644 --- a/tests/error/poly_errors/define.err +++ b/tests/error/poly_errors/define.err @@ -1,10 +0,0 @@ -Error: Unsupported (at $FILE:11:0) - | - 9 | -10 | @guppy(module) -11 | def main(x: T) -> T: - | ^^^^^^^^^^^^^^^^^^^^ -12 | return x - | ^^^^^^^^^^^^ Generic function definitions are not supported - -Guppy compilation failed due to 1 previous error diff --git a/tests/error/poly_errors/nested_generic.err b/tests/error/poly_errors/nested_generic.err new file mode 100644 index 00000000..eff911f0 --- /dev/null +++ b/tests/error/poly_errors/nested_generic.err @@ -0,0 +1,10 @@ +Error: Unsupported (at $FILE:12:4) + | +10 | @guppy(module) +11 | def foo() -> None: +12 | def bar(x: T) -> T: + | ^^^^^^^^^^^^^^^^^^^ +13 | return x + | ^^^^^^^^^^^^^^^^ Nested generic function definitions are not supported + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/poly_errors/nested_generic.py b/tests/error/poly_errors/nested_generic.py new file mode 100644 index 00000000..79ec06ae --- /dev/null +++ b/tests/error/poly_errors/nested_generic.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var("T", module=module) + + +@guppy(module) +def foo() -> None: + def bar(x: T) -> T: + return x + + +module.compile() diff --git a/tests/error/poly_errors/type_param_not_value.err b/tests/error/poly_errors/type_param_not_value.err new file mode 100644 index 00000000..d63e4ffb --- /dev/null +++ b/tests/error/poly_errors/type_param_not_value.err @@ -0,0 +1,8 @@ +Error: Expected a value (at $FILE:12:8) + | +10 | @guppy(module) +11 | def foo(x: T) -> None: +12 | y = T + | ^ Expected a value, got type `T` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/poly_errors/define.py b/tests/error/poly_errors/type_param_not_value.py similarity index 84% rename from tests/error/poly_errors/define.py rename to tests/error/poly_errors/type_param_not_value.py index 72d1ae57..7faa7a97 100644 --- a/tests/error/poly_errors/define.py +++ b/tests/error/poly_errors/type_param_not_value.py @@ -8,8 +8,8 @@ @guppy(module) -def main(x: T) -> T: - return x +def foo(x: T) -> None: + y = T module.compile() diff --git a/tests/integration/test_poly_def.py b/tests/integration/test_poly_def.py new file mode 100644 index 00000000..bef1f9fc --- /dev/null +++ b/tests/integration/test_poly_def.py @@ -0,0 +1,128 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import array + + +def test_id(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def identity(x: T) -> T: + return x + + validate(module.compile()) + + +def test_nonlinear(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def copy(x: T) -> tuple[T, T]: + return x, x + + validate(module.compile()) + + +def test_apply(validate): + module = GuppyModule("test") + S = guppy.type_var("S", module=module) + T = guppy.type_var("T", module=module) + + @guppy(module) + def apply(f: Callable[[S], T], x: S) -> T: + return f(x) + + validate(module.compile()) + + +def test_annotate(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def identity(x: T) -> T: + y: T = x + return y + + validate(module.compile()) + + +def test_recurse(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def empty() -> T: + return empty() + + validate(module.compile()) + + +def test_call(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + + @guppy(module) + def identity(x: T) -> T: + return x + + @guppy(module) + def main() -> float: + return identity(5) + identity(42.0) + + validate(module.compile()) + + +def test_nat(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + n = guppy.nat_var("n", module=module) + + @guppy(module) + def foo(xs: array[T, n]) -> array[T, n]: + return xs + + validate(module.compile()) + + +def test_nat_use(validate): + module = GuppyModule("test") + n = guppy.nat_var("n", module=module) + + @guppy(module) + def foo(xs: array[int, n]) -> int: + return int(n) + + validate(module.compile()) + + +def test_nat_call(validate): + module = GuppyModule("test") + T = guppy.type_var("T", module=module) + n = guppy.nat_var("n", module=module) + + @guppy(module) + def foo() -> array[T, n]: + return foo() + + @guppy(module) + def main() -> tuple[array[int, 10], array[float, 20]]: + return foo(), foo() + + validate(module.compile()) + + +def test_nat_recurse(validate): + module = GuppyModule("test") + n = guppy.nat_var("n", module=module) + + @guppy(module) + def empty() -> array[int, n]: + return empty() + + validate(module.compile()) +