From 126468bb456a7028bc25da3709610e1d78104c90 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 16 Apr 2024 08:35:39 +0100 Subject: [PATCH 01/41] feat: Allow calling a tensor of functions --- guppylang/cfg/builder.py | 6 + guppylang/checker/core.py | 3 + guppylang/checker/expr_checker.py | 191 ++++++++++++++++++++++++---- guppylang/compiler/expr_compiler.py | 15 +++ guppylang/nodes.py | 26 ++++ guppylang/prelude/_internal.py | 15 ++- guppylang/tys/printing.py | 7 + guppylang/tys/ty.py | 60 ++++++++- tests/integration/test_tensor.py | 81 ++++++++++++ 9 files changed, 374 insertions(+), 30 deletions(-) create mode 100644 tests/integration/test_tensor.py diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index 60fc13c1..b661b3c3 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -18,6 +18,7 @@ from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, + FunctionTensor, IterEnd, IterHasNext, IterNext, @@ -345,6 +346,11 @@ def visit_Call(self, node: ast.Call) -> ast.AST: case args: arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) return with_loc(node, PyExpr(value=arg)) + elif isinstance(node.func, ast.Tuple): + new_elts = [self.visit(elt) for elt in node.func.elts] + node.func = FunctionTensor(new_elts) + return node + return self.generic_visit(node) def generic_visit(self, node: ast.AST) -> ast.AST: diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 8f653ff8..50b00a90 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -22,6 +22,7 @@ from guppylang.tys.ty import ( BoundTypeVar, ExistentialTypeVar, + FunctionTensorType, FunctionType, NoneType, OpaqueType, @@ -84,6 +85,8 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None pass case BoundTypeVar() | ExistentialTypeVar() | SumType(): return None + case FunctionTensorType(): + type_defn = callable_type_def case FunctionType(): type_defn = callable_type_def case OpaqueType() as ty: diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 574d7e8e..a57555d9 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -53,6 +53,8 @@ from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, + FunctionTensor, + GlobalCall, GlobalName, IterEnd, IterHasNext, @@ -61,6 +63,7 @@ LocalName, MakeIter, PyExpr, + TensorCall, TypeApply, ) from guppylang.tys.arg import TypeArg @@ -77,6 +80,7 @@ from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( ExistentialTypeVar, + FunctionTensorType, FunctionType, NoneType, OpaqueType, @@ -237,8 +241,46 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: check_inst(func_ty, inst, node) node.func = instantiate_poly(node.func, func_ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty - elif f := self.ctx.globals.get_instance_func(func_ty, "__call__"): - return f.check_call(node.args, ty, node, self.ctx) + elif isinstance(func_ty, FunctionTensorType): + assert isinstance(node.func, FunctionTensor) + remaining_args: list[ast.expr] = node.args + call_nodes: list[GlobalCall | LocalCall] = [] + big_subst: Subst = {} + for f, f_ty in zip(node.func.elts, func_ty.element_types): + # Use the concrete output type of the function, we'll try to + # unify all of the results with `ty` at the end + processed_args, subst, inst, remaining_args = check_call_with_leftovers( + f_ty, remaining_args, f_ty.output, f, self.ctx + ) + f_processed = instantiate_poly(f, f_ty, inst) + + check_inst(f_ty, inst, node) + # Expect that each function is a `CallableDef` + # TODO: What if it's a tensor? + if isinstance(f_processed, LocalName): + call_nodes.append(LocalCall(func=f_processed, args=processed_args)) + elif isinstance(f_processed, GlobalName): + assert isinstance(self.ctx.globals[f_processed.def_id], CallableDef) + call_nodes.append( + GlobalCall( + def_id=f_processed.def_id, + args=processed_args, + type_args=inst, + ) + ) + else: + raise GuppyError(f"Tensor isn't defined for {ast.dump(f)}") + + big_subst |= subst + assert all(isinstance(ty, FunctionType) for ty in func_ty.element_types) + + # If the substitution isn't empty, ... + subst = unify(ty, TupleType(func_ty.outputs()), big_subst) or big_subst + + return with_loc(node, TensorCall(call_nodes=call_nodes)), subst + + elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): + return callee.check_call(node.args, ty, node, self.ctx) else: raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func) @@ -332,7 +374,7 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]: ) raise InternalGuppyError( f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have " - f"been caught by program analysis!" + "been caught by program analysis!" ) def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]: @@ -470,11 +512,35 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: if isinstance(defn, CallableDef): return defn.synthesize_call(node.args, node, self.ctx) - # Otherwise, it must be a function as a higher-order value + # Otherwise, it must be a function as a higher-order value, or a tensor if isinstance(ty, FunctionType): args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx) node.func = instantiate_poly(node.func, ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + elif isinstance(ty, FunctionTensorType): + assert isinstance(node.func, FunctionTensor) + # Note: The FunctionTensorType is made up of FunctionTypes, none of which + # should have overlapping type arguments. + func_ty = FunctionType(ty.inputs(), TupleType(element_types=ty.outputs())) + new_elt_tys = [] + remaining_args = node.args + return_tys: list[Type] = [] + processed_args: list[ast.expr] = [] + for i, elt in enumerate(node.func.elts): + node.func.elts[i], new_elt_ty = self.visit(elt) + new_elt_tys.append(new_elt_ty) + args, return_ty, inst, remaining_args = synthesize_call_with_leftovers( + func_ty, remaining_args, node, self.ctx + ) + processed_args.extend(args) + if isinstance(return_ty, TupleType): + return_tys.extend(return_ty.element_types) + else: + return_tys.append(return_ty) + + call_node = TensorCall(func=node.func, args=processed_args) + return with_loc(node, call_node), TupleType(return_tys) + elif f := self.ctx.globals.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) else: @@ -559,6 +625,21 @@ def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, Type]: f"`{ast.unparse(node)}`" ) + def visit_FunctionTensor(self, node: ast.expr) -> tuple[ast.expr, Type]: + assert isinstance(node, FunctionTensor) + new_elts = [] + elt_tys = [] + for elt in node.elts: + new_elt, new_elt_ty = self.visit(elt) + if isinstance(new_elt_ty, FunctionType): + elt_tys.append(new_elt_ty) + else: + raise GuppyError("Expected function type for function tensor", node) + new_elts.append(new_elt) + node.elts = new_elts + ty = FunctionTensorType(elt_tys) + return with_type(ty, node), ty + def check_type_against( act: Type, exp: Type, node: AstNode, kind: str = "expression" @@ -612,18 +693,23 @@ def check_type_against( return subst, [] -def check_num_args(exp: int, act: int, node: AstNode) -> None: - """Checks that the correct number of arguments have been passed to a function.""" +def check_num_args_sufficient(exp: int, act: int, node: AstNode) -> None: + """Checks that enough arguments have been passed to a function.""" if act < exp: raise GuppyTypeError( f"Not enough arguments passed (expected {exp}, got {act})", node ) - if exp < act: + + +def check_leftovers_nil( + args_checked: int, leftovers: list[ast.expr], node: AstNode +) -> None: + if len(leftovers) > 0: if isinstance(node, ast.Call): - raise GuppyTypeError("Unexpected argument", node.args[exp]) - raise GuppyTypeError( - f"Too many arguments passed (expected {exp}, got {act})", node - ) + raise GuppyTypeError("Unexpected argument", leftovers[0]) + total_args = args_checked + len(leftovers) + msg = f"Too many arguments passed (expected {args_checked}, got {total_args})" + raise GuppyTypeError(msg, node) def type_check_args( @@ -635,18 +721,39 @@ def type_check_args( ) -> tuple[list[ast.expr], Subst]: """Checks the arguments of a function call and infers free type variables. + We expect that parameters have been replaced with free unification variables. + Checks that all unification variables can be inferred. + """ + exprs, subst, leftovers = type_check_args_with_leftovers( + inputs, func_ty, subst, ctx, node + ) + check_leftovers_nil(len(exprs), leftovers, node) + return exprs, subst + + +def type_check_args_with_leftovers( + inputs: list[ast.expr], + func_ty: FunctionType, + subst: Subst, + ctx: Context, + node: AstNode, +) -> tuple[list[ast.expr], Subst, list[ast.expr]]: + """Checks the arguments of a function call and infers free type variables. + We expect that parameters have been replaced with free unification variables. Checks that all unification variables can be inferred. """ assert not func_ty.parametrized - check_num_args(len(func_ty.inputs), len(inputs), node) + check_num_args_sufficient(len(func_ty.inputs), len(inputs), node) new_args: list[ast.expr] = [] - for inp, ty in zip(inputs, func_ty.inputs): + for ty, inp in zip(func_ty.inputs, inputs): a, s = ExprChecker(ctx).check(inp, ty.substitute(subst), "argument") new_args.append(a) subst |= s + leftovers = inputs[len(func_ty.inputs) :] + # If the argument check succeeded, this means that we must have found instantiations # for all unification variables occurring in the input types assert all(set.issubset(inp.unsolved_vars, subst.keys()) for inp in func_ty.inputs) @@ -659,24 +766,37 @@ def type_check_args( node, ) - return new_args, subst + return new_args, subst, leftovers def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[list[ast.expr], Type, Inst]: + exprs, tys, inst, leftovers = synthesize_call_with_leftovers( + func_ty, args, node, ctx + ) + check_leftovers_nil(len(exprs), leftovers, node) + return exprs, tys, inst + + +def synthesize_call_with_leftovers( + func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context +) -> tuple[list[ast.expr], Type, Inst, list[ast.expr]]: """Synthesizes the return type of a function call. Returns an annotated argument list, the synthesized return type, and an instantiation for the quantifiers in the function type. """ assert not func_ty.unsolved_vars - check_num_args(len(func_ty.inputs), len(args), node) + check_num_args_sufficient(len(func_ty.inputs), len(args), node) # Replace quantified variables with free unification variables and try to infer an # instantiation by checking the arguments unquantified, free_vars = func_ty.unquantified() - args, subst = type_check_args(args, unquantified, {}, ctx, node) + + args, subst, leftovers = type_check_args_with_leftovers( + args, unquantified, {}, ctx, node + ) # Success implies that the substitution is closed assert all(not t.unsolved_vars for t in subst.values()) @@ -685,7 +805,7 @@ def synthesize_call( # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) - return args, unquantified.output.substitute(subst), inst + return args, unquantified.output.substitute(subst), inst, leftovers def check_call( @@ -696,13 +816,28 @@ def check_call( ctx: Context, kind: str = "expression", ) -> tuple[list[ast.expr], Subst, Inst]: + exprs, subst, inst, leftovers = check_call_with_leftovers( + func_ty, inputs, ty, node, ctx, kind + ) + check_leftovers_nil(len(exprs), leftovers, node) + return exprs, subst, inst + + +def check_call_with_leftovers( + func_ty: FunctionType, + inputs: list[ast.expr], + ty: Type, + node: AstNode, + ctx: Context, + kind: str = "expression", +) -> tuple[list[ast.expr], Subst, Inst, list[ast.expr]]: """Checks the return type of a function call against a given type. Returns an annotated argument list, a substitution for the free variables in the expected type, and an instantiation for the quantifiers in the function type. """ assert not func_ty.unsolved_vars - check_num_args(len(func_ty.inputs), len(inputs), node) + check_num_args_sufficient(len(func_ty.inputs), len(inputs), node) # When checking, we can use the information from the expected return type to infer # some type arguments. However, this pushes errors inwards. For example, given a @@ -724,18 +859,20 @@ def check_call( # in practice. Can we do better than that? # First, try to synthesize - res: tuple[Type, Inst] | None = None + res: tuple[Type, Inst, list[ast.expr]] | None = None try: - inputs, synth, inst = synthesize_call(func_ty, inputs, node, ctx) - res = synth, inst + inputs, synth, inst, leftovers = synthesize_call_with_leftovers( + func_ty, inputs, node, ctx + ) + res = synth, inst, leftovers except GuppyTypeInferenceError: pass if res is not None: - synth, inst = res + synth, inst, leftovers = res subst = unify(ty, synth, {}) if subst is None: raise GuppyTypeError(f"Expected {kind} of type `{ty}`, got `{synth}`", node) - return inputs, subst, inst + return inputs, subst, inst, leftovers # If synthesis fails, we try again, this time also using information from the # expected return type @@ -747,7 +884,9 @@ def check_call( ) # Try to infer more by checking against the arguments - inputs, subst = type_check_args(inputs, unquantified, subst, ctx, node) + inputs, subst, leftovers = type_check_args_with_leftovers( + inputs, unquantified, subst, ctx, node + ) # Also make sure we found an instantiation for all free vars in the type we're # checking against @@ -766,7 +905,7 @@ def check_call( # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) - return inputs, subst, inst + return inputs, subst, inst, leftovers def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None: @@ -797,7 +936,7 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr: if len(inst) > 0: node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst)) return with_type(ty.instantiate(inst), node) - return node + return with_type(ty, node) def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type]: diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 91f13826..7eee3588 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -22,6 +22,7 @@ GlobalName, LocalCall, LocalName, + TensorCall, TypeApply, ) from guppylang.tys.builtin import bool_type, get_element_type, is_list_type @@ -164,6 +165,10 @@ def visit_List(self, node: ast.List) -> OutPortV: ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts] ).add_out_port(get_type(node)) + def _unpack_tuple(self, wire: OutPortV) -> list[OutPortV]: + unpack_node = self.graph.add_unpack_tuple(wire, self.dfg.node) + return list(unpack_node.out_ports) + def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: """Groups function return values into a tuple""" if len(returns) != 1: @@ -189,6 +194,16 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: ) return self._pack_returns(rets) + def visit_TensorCall(self, node: TensorCall) -> OutPortV: + outputs = [] + for call in node.call_nodes: + output = self.visit(call) + if isinstance(output.ty, TupleType): + outputs.extend(self._unpack_tuple(output)) + else: + outputs.append(output) + return self._pack_returns(outputs) + def visit_Call(self, node: ast.Call) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/nodes.py b/guppylang/nodes.py index af260062..0d43208a 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from guppylang.error import GuppyError from guppylang.tys.subst import Inst from guppylang.tys.ty import FunctionType @@ -52,6 +53,15 @@ class GlobalCall(ast.expr): ) +class TensorCall(ast.expr): + """A call to a tuple of functions. Stores a call node for each function in the + tuple""" + + call_nodes: list[ast.expr] + + _fields = ("call_nodes",) + + class TypeApply(ast.expr): value: ast.expr inst: Inst @@ -187,3 +197,19 @@ def __init__( self.cfg = cfg self.ty = ty self.captured = captured + + +class FunctionTensor(ast.expr): + """A tensor product of one or more functions""" + + elts: list[ast.expr] + + _fields = ("elts",) + + def node_for_input(self, n: int, func_tys: list[FunctionType]) -> ast.expr: + for expr, func_ty in zip(self.elts, func_tys): + if n < len(func_ty.inputs): + return expr + else: + n -= len(func_ty.inputs) + raise GuppyError("Invalid call to node_for_input") diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index db0e0aee..bed4b16c 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -5,7 +5,11 @@ from guppylang.ast_util import AstNode, get_type, with_loc, with_type from guppylang.checker.core import Context -from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args +from guppylang.checker.expr_checker import ( + ExprSynthesizer, + check_leftovers_nil, + check_num_args_sufficient, +) from guppylang.definition.custom import ( CustomCallChecker, CustomCallCompiler, @@ -25,6 +29,11 @@ INT_WIDTH = 6 # 2^6 = 64 bit +def check_num_args(exp: int, args: list[ast.expr], node: AstNode) -> None: + check_num_args_sufficient(exp, len(args), node) + check_leftovers_nil(exp, args[exp:], node) + + hugr_int_type = tys.Opaque( extension="arithmetic.int.types", id="int", @@ -193,7 +202,7 @@ def __init__(self, dunder_name: str, num_args: int = 1): self.num_args = num_args def _get_func(self, args: list[ast.expr]) -> tuple[list[ast.expr], CallableDef]: - check_num_args(self.num_args, len(args), self.node) + check_num_args(self.num_args, args, self.node) fst, *rest = args fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) func = self.ctx.globals.get_instance_func(ty, self.dunder_name) @@ -218,7 +227,7 @@ class CallableChecker(CustomCallChecker): """Call checker for the builtin `callable` function""" def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - check_num_args(1, len(args), self.node) + check_num_args(1, args, self.node) [arg] = args arg, ty = ExprSynthesizer(self.ctx).synthesize(arg) is_callable = ( diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 252efb54..e5e16d3d 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -4,6 +4,7 @@ from guppylang.tys.arg import ConstArg, TypeArg from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.ty import ( + FunctionTensorType, FunctionType, NoneType, OpaqueType, @@ -81,6 +82,12 @@ def _visit_FunctionType(self, ty: FunctionType, inside_row: bool) -> str: return _wrap(f"forall {quantified}. {inputs} -> {output}", inside_row) return _wrap(f"{inputs} -> {output}", inside_row) + @_visit.register + def _visit_FunctionTensorType( + self, ty: FunctionTensorType, inside_row: bool + ) -> str: + return f"tensor[{ty.inputs()}\n -> \n{ty.outputs()}]" + @_visit.register def _visit_OpaqueType(self, ty: OpaqueType, inside_row: bool) -> str: if ty.args: diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 6fa1e625..4131db1b 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -390,6 +390,62 @@ def transform(self, transformer: Transformer) -> "Type": ) +@dataclass(frozen=True, init=False) +class FunctionTensorType(ParametrizedTypeBase): + params: list[Parameter] + element_types: list[FunctionType] + + def __init__(self, element_types: Sequence["FunctionType"]) -> None: + # We need a custom __init__ to set the args + params: list[Parameter] = [] + args: list[Argument] = [] + for func_ty in element_types: + params.extend(func_ty.params) + args += func_ty.args + params += func_ty.params + object.__setattr__(self, "args", args) + object.__setattr__(self, "params", params) + object.__setattr__(self, "element_types", element_types) + + def inputs(self) -> "TypeRow": + input_row: list[Type] = [] + for fn_ty in self.element_types: + assert isinstance(fn_ty, FunctionType) + input_row.extend(fn_ty.inputs) + return input_row + + def outputs(self) -> "TypeRow": + outputs: list[Type] = [] + for fn_ty in self.element_types: + if isinstance(fn_ty.output, TupleType): + outputs.extend(fn_ty.output.element_types) + else: + outputs.append(fn_ty.output) + return outputs + + @property + def intrinsically_linear(self) -> bool: + return False + + def visit(self, visitor: Visitor) -> None: + visitor.visit(self) + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + if all( + isinstance(e, TupleType) and len(e.element_types) == 0 + for e in self.element_types + ): + return tys.UnitSum(size=len(self.element_types)) + return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or SumType( + [ty.transform(transformer) for ty in self.element_types] + ) + + @dataclass(frozen=True) class OpaqueType(ParametrizedTypeBase): """Type that is directly backed by a Hugr opaque type. @@ -429,7 +485,9 @@ def transform(self, transformer: Transformer) -> "Type": # Note that this might become obsolete in case the `@sealed` decorator is added: # * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types # * https://github.com/johnthagen/sealed-typing-pep -ParametrizedType: TypeAlias = FunctionType | TupleType | SumType | OpaqueType +ParametrizedType: TypeAlias = ( + FunctionType | TupleType | SumType | OpaqueType | FunctionTensorType +) Type: TypeAlias = BoundTypeVar | ExistentialTypeVar | NoneType | ParametrizedType TypeRow: TypeAlias = Sequence[Type] diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py new file mode 100644 index 00000000..ff8c0fd5 --- /dev/null +++ b/tests/integration/test_tensor.py @@ -0,0 +1,81 @@ +# ruff: noqa: F821 +# ^ Stop ruff complaining about not knowing what "tensor" is + +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +def test_singleton(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int) -> bool: + return x == 42 + + @guppy(module) + def call_singleton(x: int) -> tuple[bool]: + return (foo,)(x) + + validate(module.compile()) + + +def test_call(validate): + module = GuppyModule("module") + + @guppy(module) + def foo() -> int: + return 42 + + @guppy(module) + def bar() -> bool: + return True + + @guppy(module) + def baz() -> tuple[int, bool]: + return (foo, bar)() + + validate(module.compile()) + + +def test_call_inplace(validate): + module = GuppyModule("module") + + @guppy(module) + def local(f: Callable[[int], bool], g: Callable[[bool], int]) -> tuple[bool, int]: + return (f, g)(42, True) + + validate(module.compile()) + + +def test_call_back(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int) -> int: + return x + + @guppy(module) + def bar(x: int) -> int: + return x * x + + @guppy(module) + def baz(x: int, y: int) -> tuple[int, int]: + return (foo, bar)(x, y) + + validate(module.compile()) + + +def test_normal(validate): + module = GuppyModule("module") + + @guppy(module) + def glo(x: int) -> int: + return x + + @guppy(module) + def foo(x: int) -> int: + return glo(x) + + validate(module.compile()) From 0943e68efaf9cf912adf7a7de21f879c7bfd1492 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 16 Apr 2024 09:03:48 +0100 Subject: [PATCH 02/41] cleanup: Remove redundant method --- guppylang/nodes.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 0d43208a..bab77a5e 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -4,7 +4,6 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from guppylang.error import GuppyError from guppylang.tys.subst import Inst from guppylang.tys.ty import FunctionType @@ -205,11 +204,3 @@ class FunctionTensor(ast.expr): elts: list[ast.expr] _fields = ("elts",) - - def node_for_input(self, n: int, func_tys: list[FunctionType]) -> ast.expr: - for expr, func_ty in zip(self.elts, func_tys): - if n < len(func_ty.inputs): - return expr - else: - n -= len(func_ty.inputs) - raise GuppyError("Invalid call to node_for_input") From ead0212396336a061088945e1942adbe03d9c523 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 17 Apr 2024 12:23:00 +0100 Subject: [PATCH 03/41] refactor: Remove `FunctionTensor` --- guppylang/cfg/builder.py | 3 +- guppylang/checker/expr_checker.py | 80 ++++++++++++++++++++----------- guppylang/nodes.py | 8 ---- tests/hugr/test_dummy_nodes.py | 4 +- 4 files changed, 56 insertions(+), 39 deletions(-) diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index b661b3c3..ea74f244 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -18,7 +18,6 @@ from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, - FunctionTensor, IterEnd, IterHasNext, IterNext, @@ -348,7 +347,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: return with_loc(node, PyExpr(value=arg)) elif isinstance(node.func, ast.Tuple): new_elts = [self.visit(elt) for elt in node.func.elts] - node.func = FunctionTensor(new_elts) + node.func = ast.Tuple(new_elts) return node return self.generic_visit(node) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index a57555d9..e1d35e91 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -53,7 +53,6 @@ from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, - FunctionTensor, GlobalCall, GlobalName, IterEnd, @@ -192,13 +191,31 @@ def _synthesize( return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: - if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): - return self._fail(ty, node) - subst: Subst = {} - for i, el in enumerate(node.elts): - node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) - subst |= s - return node, subst + # Data tuples should be checkable + if isinstance(ty, TupleType) and len(ty.element_types) == len(node.elts): + subst: Subst = {} + for i, el in enumerate(node.elts): + node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) + subst |= s + return node, subst + else: + nodes: list[ast.expr] = [] + elem_tys: list[FunctionType] = [] + for elt in node.elts: + ann_node, fun_ty = self._synthesize(elt, allow_free_vars=True) + assert isinstance(fun_ty, FunctionType) + nodes.append(ann_node) + elem_tys.append(fun_ty) + + if all(isinstance(ty, FunctionType) for ty in elem_tys): + tensor_ty = FunctionTensorType(elem_tys) + func_ty = FunctionType( + tensor_ty.inputs(), TupleType(tensor_ty.outputs()) + ) + subst = unify(func_ty, ty, {}) or {} + return with_type(tensor_ty, node), subst + else: + return self._fail(ty, node) def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]: if not is_list_type(ty) and not is_linst_type(ty): @@ -242,7 +259,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: node.func = instantiate_poly(node.func, func_ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif isinstance(func_ty, FunctionTensorType): - assert isinstance(node.func, FunctionTensor) + assert isinstance(node.func, ast.Tuple) remaining_args: list[ast.expr] = node.args call_nodes: list[GlobalCall | LocalCall] = [] big_subst: Subst = {} @@ -378,9 +395,31 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]: ) def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]: - elems = [self.synthesize(elem) for elem in node.elts] - node.elts = [n for n, _ in elems] - return node, TupleType([ty for _, ty in elems]) + elems: list[tuple[ast.expr, Type]] = [ + self.synthesize(elem) for elem in node.elts + ] + + if all(isinstance(ty, FunctionType) for _, ty in elems): + input_row: list[Type] = [] + output_row: list[Type] = [] + func_tys: list[FunctionType] = [] + for i, (func_node, func_ty) in enumerate(elems): + assert isinstance(func_node.type, FunctionType) # type: ignore[attr-defined] + assert isinstance(func_ty, FunctionType) + node.elts[i] = func_node + func_tys.append(func_ty) + input_row.extend(func_ty.inputs) + if isinstance(func_ty.output, TupleType): + output_row.extend(func_ty.output.element_types) + else: + output_row.append(func_ty.output) + + tensor_ty = FunctionTensorType(func_tys) + tuple_node = ast.Tuple([node for node, _ in elems]) + return with_type(tensor_ty, tuple_node), tensor_ty + else: + node.elts = [n for n, _ in elems] + return node, TupleType([ty for _, ty in elems]) def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]: if len(node.elts) == 0: @@ -518,7 +557,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: node.func = instantiate_poly(node.func, ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif isinstance(ty, FunctionTensorType): - assert isinstance(node.func, FunctionTensor) + assert isinstance(node.func, ast.Tuple) # Note: The FunctionTensorType is made up of FunctionTypes, none of which # should have overlapping type arguments. func_ty = FunctionType(ty.inputs(), TupleType(element_types=ty.outputs())) @@ -625,21 +664,6 @@ def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, Type]: f"`{ast.unparse(node)}`" ) - def visit_FunctionTensor(self, node: ast.expr) -> tuple[ast.expr, Type]: - assert isinstance(node, FunctionTensor) - new_elts = [] - elt_tys = [] - for elt in node.elts: - new_elt, new_elt_ty = self.visit(elt) - if isinstance(new_elt_ty, FunctionType): - elt_tys.append(new_elt_ty) - else: - raise GuppyError("Expected function type for function tensor", node) - new_elts.append(new_elt) - node.elts = new_elts - ty = FunctionTensorType(elt_tys) - return with_type(ty, node), ty - def check_type_against( act: Type, exp: Type, node: AstNode, kind: str = "expression" diff --git a/guppylang/nodes.py b/guppylang/nodes.py index bab77a5e..c1cbdcd0 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -196,11 +196,3 @@ def __init__( self.cfg = cfg self.ty = ty self.captured = captured - - -class FunctionTensor(ast.expr): - """A tensor product of one or more functions""" - - elts: list[ast.expr] - - _fields = ("elts",) diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index f102d090..ab415da6 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -22,7 +22,9 @@ def test_single_dummy(): def test_unique_names(): g = Hugr() defn = g.add_def( - FunctionType([bool_type()], TupleType([bool_type(), bool_type()])), g.root, "test" + FunctionType([bool_type()], TupleType([bool_type(), bool_type()])), + g.root, + "test", ) dfg = g.add_dfg(defn) inp = g.add_input([bool_type()], dfg).out_port(0) From 67b6b6783a5cb548f433941e3dcca29792c440ff Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 1 May 2024 13:48:23 +0100 Subject: [PATCH 04/41] new: More robust function tensor types --- guppylang/checker/core.py | 3 - guppylang/checker/expr_checker.py | 142 +++++++++++++++++----------- guppylang/compiler/expr_compiler.py | 55 +++++++++-- guppylang/hugr/hugr.py | 22 +++-- guppylang/tys/printing.py | 7 -- guppylang/tys/ty.py | 101 ++++++++------------ tests/integration/test_tensor.py | 93 +++++++++++++++++- 7 files changed, 280 insertions(+), 143 deletions(-) diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 50b00a90..8f653ff8 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -22,7 +22,6 @@ from guppylang.tys.ty import ( BoundTypeVar, ExistentialTypeVar, - FunctionTensorType, FunctionType, NoneType, OpaqueType, @@ -85,8 +84,6 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None pass case BoundTypeVar() | ExistentialTypeVar() | SumType(): return None - case FunctionTensorType(): - type_defn = callable_type_def case FunctionType(): type_defn = callable_type_def case OpaqueType() as ty: diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index e1d35e91..99262bd1 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -79,13 +79,14 @@ from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( ExistentialTypeVar, - FunctionTensorType, FunctionType, NoneType, OpaqueType, TupleType, Type, TypeBase, + function_tensor_signature, + parse_function_tensor, row_to_type, unify, ) @@ -192,30 +193,24 @@ def _synthesize( def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: # Data tuples should be checkable - if isinstance(ty, TupleType) and len(ty.element_types) == len(node.elts): + if not (isinstance(ty, TupleType) and len(ty.element_types) == len(node.elts)): + return self._fail(ty, node) + + if not parse_function_tensor(ty): subst: Subst = {} for i, el in enumerate(node.elts): node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) subst |= s return node, subst else: - nodes: list[ast.expr] = [] elem_tys: list[FunctionType] = [] - for elt in node.elts: - ann_node, fun_ty = self._synthesize(elt, allow_free_vars=True) + for i, (elt, elt_ty) in enumerate(zip(node.elts, ty.element_types)): + node.elts[i], fun_ty = self._synthesize(elt, allow_free_vars=True) assert isinstance(fun_ty, FunctionType) - nodes.append(ann_node) elem_tys.append(fun_ty) + subst = unify(fun_ty, elt_ty, {}) or {} - if all(isinstance(ty, FunctionType) for ty in elem_tys): - tensor_ty = FunctionTensorType(elem_tys) - func_ty = FunctionType( - tensor_ty.inputs(), TupleType(tensor_ty.outputs()) - ) - subst = unify(func_ty, ty, {}) or {} - return with_type(tensor_ty, node), subst - else: - return self._fail(ty, node) + return node, subst def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]: if not is_list_type(ty) and not is_linst_type(ty): @@ -258,43 +253,67 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: check_inst(func_ty, inst, node) node.func = instantiate_poly(node.func, func_ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty - elif isinstance(func_ty, FunctionTensorType): - assert isinstance(node.func, ast.Tuple) + + if isinstance(func_ty, TupleType) and parse_function_tensor(func_ty): + function_elements = parse_function_tensor(func_ty) + assert isinstance(function_elements, list) + tensor_ty = function_tensor_signature(function_elements) + remaining_args: list[ast.expr] = node.args call_nodes: list[GlobalCall | LocalCall] = [] big_subst: Subst = {} - for f, f_ty in zip(node.func.elts, func_ty.element_types): - # Use the concrete output type of the function, we'll try to - # unify all of the results with `ty` at the end - processed_args, subst, inst, remaining_args = check_call_with_leftovers( - f_ty, remaining_args, f_ty.output, f, self.ctx - ) - f_processed = instantiate_poly(f, f_ty, inst) - - check_inst(f_ty, inst, node) - # Expect that each function is a `CallableDef` - # TODO: What if it's a tensor? - if isinstance(f_processed, LocalName): - call_nodes.append(LocalCall(func=f_processed, args=processed_args)) - elif isinstance(f_processed, GlobalName): - assert isinstance(self.ctx.globals[f_processed.def_id], CallableDef) - call_nodes.append( - GlobalCall( - def_id=f_processed.def_id, - args=processed_args, - type_args=inst, + if isinstance(node.func, ast.Tuple): + for f, f_ty in zip(node.func.elts, func_ty.element_types): + assert isinstance(f_ty, FunctionType) + # Use the concrete output type of the function, we'll try to + # unify all of the results with `ty` at the end + processed_args, subst, inst, remaining_args = ( + check_call_with_leftovers( + f_ty, remaining_args, f_ty.output, f, self.ctx ) ) - else: - raise GuppyError(f"Tensor isn't defined for {ast.dump(f)}") + f_processed = instantiate_poly(f, f_ty, inst) + + check_inst(f_ty, inst, node) + # Expect that each function is a `CallableDef` + # TODO: What if it's a tensor? + if isinstance(f_processed, GlobalName): + assert isinstance( + self.ctx.globals[f_processed.def_id], CallableDef + ) + call_nodes.append( + GlobalCall( + def_id=f_processed.def_id, + args=processed_args, + type_args=inst, + ) + ) + else: + call_nodes.append( + LocalCall(func=f_processed, args=processed_args) + ) + + big_subst |= subst + + # If the substitution isn't empty, ... + subst = unify(ty, tensor_ty.output, big_subst) or big_subst - big_subst |= subst - assert all(isinstance(ty, FunctionType) for ty in func_ty.element_types) + return with_loc(node, TensorCall(call_nodes=call_nodes)), subst + + else: + # The func isn't a tuple, it could be a call or whatever + # we should do a check_call or something + processed_args, big_subst, inst = check_call( + tensor_ty, node.args, tensor_ty.output, node.func, self.ctx + ) + # f_processed = instantiate_poly(node.func, tensor_ty, inst) - # If the substitution isn't empty, ... - subst = unify(ty, TupleType(func_ty.outputs()), big_subst) or big_subst + # If the substitution isn't empty, ... + subst = unify(ty, tensor_ty.output, big_subst) or big_subst - return with_loc(node, TensorCall(call_nodes=call_nodes)), subst + return with_loc( + node, LocalCall(func=node.func, args=processed_args) + ), subst elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): return callee.check_call(node.args, ty, node, self.ctx) @@ -414,9 +433,9 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]: else: output_row.append(func_ty.output) - tensor_ty = FunctionTensorType(func_tys) + tensor_ty = TupleType(func_tys) tuple_node = ast.Tuple([node for node, _ in elems]) - return with_type(tensor_ty, tuple_node), tensor_ty + return tuple_node, tensor_ty else: node.elts = [n for n, _ in elems] return node, TupleType([ty for _, ty in elems]) @@ -556,18 +575,21 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx) node.func = instantiate_poly(node.func, ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty - elif isinstance(ty, FunctionTensorType): - assert isinstance(node.func, ast.Tuple) - # Note: The FunctionTensorType is made up of FunctionTypes, none of which - # should have overlapping type arguments. - func_ty = FunctionType(ty.inputs(), TupleType(element_types=ty.outputs())) - new_elt_tys = [] + elif ( + isinstance(ty, TupleType) + and parse_function_tensor(ty) + and isinstance(node.func, ast.Tuple) + ): + # Note: None of the function types in a tuple of functions will have + # overlapping type arguments. + function_elems = parse_function_tensor(ty) + assert isinstance(function_elems, list) + func_ty = function_tensor_signature(function_elems) remaining_args = node.args return_tys: list[Type] = [] processed_args: list[ast.expr] = [] - for i, elt in enumerate(node.func.elts): - node.func.elts[i], new_elt_ty = self.visit(elt) - new_elt_tys.append(new_elt_ty) + call_nodes: list[ast.expr] = [] + for func in node.func.elts: args, return_ty, inst, remaining_args = synthesize_call_with_leftovers( func_ty, remaining_args, node, self.ctx ) @@ -577,7 +599,15 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: else: return_tys.append(return_ty) - call_node = TensorCall(func=node.func, args=processed_args) + if isinstance(func, GlobalName): + assert isinstance(self.ctx.globals[func.def_id], CallableDef) + call_nodes.append( + GlobalCall(def_id=func.def_id, args=args, type_args=inst) + ) + else: + call_nodes.append(LocalCall(func=func, args=args)) + + call_node = TensorCall(call_nodes) return with_loc(node, call_node), TupleType(return_tys) elif f := self.ctx.globals.get_instance_func(ty, "__call__"): diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 7eee3588..46c26163 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -33,6 +33,8 @@ NoneType, TupleType, Type, + function_tensor_signature, + parse_function_tensor, type_to_row, ) @@ -177,21 +179,62 @@ def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: def visit_LocalCall(self, node: LocalCall) -> OutPortV: func = self.visit(node.func) - assert isinstance(func.ty, FunctionType) + + assert isinstance(func.ty, FunctionType) or ( + isinstance(func.ty, TupleType) and parse_function_tensor(func.ty) + ) + if isinstance(func.ty, FunctionType): + output = func.ty.output + elif isinstance(func.ty, TupleType): + funcs = parse_function_tensor(func.ty) + assert isinstance(funcs, list) + output = function_tensor_signature(funcs).output args = [self.visit(arg) for arg in node.args] - call = self.graph.add_indirect_call(func, args) - rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.output)))] + + rets: list[OutPortV] = [] + if isinstance(func.ty, FunctionType): + call = self.graph.add_indirect_call(func, args) + rets = [call.out_port(i) for i in range(len(type_to_row(output)))] + elif isinstance(func.ty, TupleType) and parse_function_tensor(func.ty): + func_ports = self._unpack_tuple(func) + # Now we have to manage this hand-off + remaining_args = args + for func in func_ports: + outs, remaining_args = self._compile_tensor_with_leftovers(func, args) + rets.extend(outs) + else: + raise InternalGuppyError("Local call of something without a callable type") + return self._pack_returns(rets) + def _compile_tensor_with_leftovers( + self, func: OutPortV, args: list[OutPortV] + ) -> tuple[ + list[OutPortV], # Compiled outputs + list[OutPortV], + ]: # Leftover args + assert isinstance(func.ty, FunctionType) + input_len = len(func.ty.inputs) + call = self.graph.add_indirect_call(func, args[0:input_len]) + + return [ + call.out_port(i) for i in range(len(type_to_row(func.ty.output))) + ], args[input_len:] + def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: func = self.globals[node.def_id] assert isinstance(func, CompiledCallableDef) args = [self.visit(arg) for arg in node.args] - rets = func.compile_call( - args, list(node.type_args), self.dfg, self.graph, self.globals, node - ) + + if isinstance(func.ty, FunctionType): + rets = func.compile_call( + args, list(node.type_args), self.dfg, self.graph, self.globals, node + ) + else: + raise InternalGuppyError("Local call of something without a callable type") + return self._pack_returns(rets) def visit_TensorCall(self, node: TensorCall) -> OutPortV: diff --git a/guppylang/hugr/hugr.py b/guppylang/hugr/hugr.py index dcd23c7c..f6c54f67 100644 --- a/guppylang/hugr/hugr.py +++ b/guppylang/hugr/hugr.py @@ -16,6 +16,7 @@ SumType, TupleType, Type, + is_function, row_to_type, type_to_row, ) @@ -491,11 +492,13 @@ def add_call( self, def_port: OutPortV, args: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds a `Call` node to the graph.""" - assert isinstance(def_port.ty, FunctionType) + func_ty = is_function(def_port.ty) + assert isinstance(func_ty, FunctionType) + return self.add_node( ops.Call(), None, - list(type_to_row(def_port.ty.output)), + list(type_to_row(func_ty.output)), parent, [*args, def_port], ) @@ -504,11 +507,13 @@ def add_indirect_call( self, fun_port: OutPortV, args: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds an `IndirectCall` node to the graph.""" - assert isinstance(fun_port.ty, FunctionType) + func_ty = is_function(fun_port.ty) + assert isinstance(func_ty, FunctionType) + return self.add_node( ops.CallIndirect(), None, - list(type_to_row(fun_port.ty.output)), + list(type_to_row(func_ty.output)), parent, [fun_port, *args], ) @@ -535,11 +540,12 @@ def add_type_apply( self, func_port: OutPortV, args: Inst, parent: Node | None = None ) -> VNode: """Adds a `TypeApply` node to the graph.""" - assert isinstance(func_port.ty, FunctionType) - assert len(func_port.ty.params) == len(args) - result_ty = func_port.ty.instantiate(args) + func_ty = is_function(func_port.ty) + assert isinstance(func_ty, FunctionType) + assert len(func_ty.params) == len(args) + result_ty = func_ty.instantiate(args) ta = ops.TypeApplication( - input=func_port.ty.to_hugr(), + input=func_ty.to_hugr(), args=[arg.to_hugr() for arg in args], output=result_ty.to_hugr(), ) diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index e5e16d3d..252efb54 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -4,7 +4,6 @@ from guppylang.tys.arg import ConstArg, TypeArg from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.ty import ( - FunctionTensorType, FunctionType, NoneType, OpaqueType, @@ -82,12 +81,6 @@ def _visit_FunctionType(self, ty: FunctionType, inside_row: bool) -> str: return _wrap(f"forall {quantified}. {inputs} -> {output}", inside_row) return _wrap(f"{inputs} -> {output}", inside_row) - @_visit.register - def _visit_FunctionTensorType( - self, ty: FunctionTensorType, inside_row: bool - ) -> str: - return f"tensor[{ty.inputs()}\n -> \n{ty.outputs()}]" - @_visit.register def _visit_OpaqueType(self, ty: OpaqueType, inside_row: bool) -> str: if ty.args: diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 4131db1b..1e3c793f 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -390,62 +390,6 @@ def transform(self, transformer: Transformer) -> "Type": ) -@dataclass(frozen=True, init=False) -class FunctionTensorType(ParametrizedTypeBase): - params: list[Parameter] - element_types: list[FunctionType] - - def __init__(self, element_types: Sequence["FunctionType"]) -> None: - # We need a custom __init__ to set the args - params: list[Parameter] = [] - args: list[Argument] = [] - for func_ty in element_types: - params.extend(func_ty.params) - args += func_ty.args - params += func_ty.params - object.__setattr__(self, "args", args) - object.__setattr__(self, "params", params) - object.__setattr__(self, "element_types", element_types) - - def inputs(self) -> "TypeRow": - input_row: list[Type] = [] - for fn_ty in self.element_types: - assert isinstance(fn_ty, FunctionType) - input_row.extend(fn_ty.inputs) - return input_row - - def outputs(self) -> "TypeRow": - outputs: list[Type] = [] - for fn_ty in self.element_types: - if isinstance(fn_ty.output, TupleType): - outputs.extend(fn_ty.output.element_types) - else: - outputs.append(fn_ty.output) - return outputs - - @property - def intrinsically_linear(self) -> bool: - return False - - def visit(self, visitor: Visitor) -> None: - visitor.visit(self) - - def to_hugr(self) -> tys.Type: - """Computes the Hugr representation of the type.""" - if all( - isinstance(e, TupleType) and len(e.element_types) == 0 - for e in self.element_types - ): - return tys.UnitSum(size=len(self.element_types)) - return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) - - def transform(self, transformer: Transformer) -> "Type": - """Accepts a transformer on this type.""" - return transformer.transform(self) or SumType( - [ty.transform(transformer) for ty in self.element_types] - ) - - @dataclass(frozen=True) class OpaqueType(ParametrizedTypeBase): """Type that is directly backed by a Hugr opaque type. @@ -485,9 +429,7 @@ def transform(self, transformer: Transformer) -> "Type": # Note that this might become obsolete in case the `@sealed` decorator is added: # * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types # * https://github.com/johnthagen/sealed-typing-pep -ParametrizedType: TypeAlias = ( - FunctionType | TupleType | SumType | OpaqueType | FunctionTensorType -) +ParametrizedType: TypeAlias = FunctionType | TupleType | SumType | OpaqueType Type: TypeAlias = BoundTypeVar | ExistentialTypeVar | NoneType | ParametrizedType TypeRow: TypeAlias = Sequence[Type] @@ -572,3 +514,44 @@ def _unify_args( case _: return None return subst + + +### Helpers for working with tuples of functions + + +def parse_function_tensor(ty: TupleType) -> list[FunctionType] | None: + """Parses a nested tuple of function types into a flat list of functions.""" + result = [] + for el in ty.element_types: + if isinstance(el, FunctionType): + result.append(el) + elif isinstance(el, TupleType): + funcs = parse_function_tensor(el) + if funcs: + result.extend(funcs) + else: + return None + return result + + +def function_tensor_signature(tys: list[FunctionType]) -> FunctionType: + """Compute the combined function signature of a list of functions""" + inputs: list[Type] = [] + outputs: list[Type] = [] + for fun_ty in tys: + inputs.extend(fun_ty.inputs) + if isinstance(fun_ty.output, TupleType): + outputs.extend(fun_ty.output.element_types) + else: + outputs.append(fun_ty.output) + return FunctionType(inputs, TupleType(outputs)) + + +def is_function(ty: Type) -> FunctionType | None: + if isinstance(ty, FunctionType): + return ty + elif isinstance(ty, TupleType): + funcs = parse_function_tensor(ty) + if isinstance(funcs, list): + return function_tensor_signature(funcs) + return None diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py index ff8c0fd5..fe3971d6 100644 --- a/tests/integration/test_tensor.py +++ b/tests/integration/test_tensor.py @@ -7,16 +7,50 @@ from guppylang.module import GuppyModule -def test_singleton(validate): +def test_bug(validate): module = GuppyModule("module") @guppy(module) - def foo(x: int) -> bool: + def bar(f: Callable[[int], bool]) -> Callable[[int], bool]: + return f + + @guppy(module) + def is_42(x: int) -> bool: return x == 42 @guppy(module) - def call_singleton(x: int) -> tuple[bool]: - return (foo,)(x) + def baz(x: int) -> tuple[bool]: + return (bar,)(is_42)(x) + + validate(module.compile()) + + +def test_check_callable(validate): + module = GuppyModule("module") + + @guppy(module) + def bar(f: Callable[[int], bool]) -> Callable[[int], bool]: + return f + + @guppy(module) + def foo(f: Callable[[int], bool]) -> tuple[Callable[[int], bool]]: + return (f,) + + @guppy(module) + def is_42(x: int) -> bool: + return x == 42 + + @guppy(module) + def baz(x: int) -> tuple[bool]: + return foo(is_42)(x) + + @guppy(module) + def baz1() -> tuple[Callable[[int], bool]]: + return (foo,)(is_42) + + @guppy(module) + def baz2(x: int) -> tuple[bool]: + return (foo,)(is_42)(x) validate(module.compile()) @@ -32,10 +66,28 @@ def foo() -> int: def bar() -> bool: return True + @guppy(module) + def baz_ho() -> tuple[Callable[[], int], Callable[[], bool]]: + return (foo, bar) + + @guppy(module) + def baz_ho_id() -> tuple[Callable[[], int], Callable[[], bool]]: + return baz_ho() + + @guppy(module) + def baz_ho_call() -> tuple[int, bool]: + return baz_ho_id()() + @guppy(module) def baz() -> tuple[int, bool]: return (foo, bar)() + @guppy(module) + def local_ho( + f: Callable[[int], bool], g: Callable[[bool], int] + ) -> tuple[bool, int]: + return (f, g)(2, True) + validate(module.compile()) @@ -79,3 +131,36 @@ def foo(x: int) -> int: return glo(x) validate(module.compile()) + + +def test_higher_order(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int) -> bool: + return x > 42 + + @guppy(module) + def bar(x: float) -> int: + if x < 5.0: + return 0 + else: + return 1 + + @guppy(module) + def baz() -> tuple[Callable[[int], bool], Callable[[float], int]]: + return foo, bar + + # For the future: + # + # @guppy(module) + # def apply(f: Callable[[int, float], + # tuple[bool, int]], + # args: tuple[int, float]) -> tuple[bool, int]: + # return f(*args) + # + # @guppy(module) + # def apply_call(args: tuple[int, float]) -> tuple[bool, int]: + # return apply(baz, args) + + validate(module.compile()) From 8a49a3a1db87aa20c07a77c20e0586bd6f87b3c6 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Fri, 3 May 2024 16:01:43 +0100 Subject: [PATCH 05/41] [fix] Substitution logic for checking tuples of functions --- guppylang/checker/expr_checker.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index f42b63d1..5facd194 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -196,21 +196,33 @@ def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: if not (isinstance(ty, TupleType) and len(ty.element_types) == len(node.elts)): return self._fail(ty, node) - if not parse_function_tensor(ty): + # Tuples can either be inert python tuples or tuples of functions which + # can be called in guppy. The former thing is checkable, but in the + # latter case we should be able to synthesise function types for the + # elements. Check here whether the given type is a tuple of function + # types to work out which case we're in. + function_types = parse_function_tensor(ty) + if not function_types: subst: Subst = {} for i, el in enumerate(node.elts): node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) subst |= s return node, subst else: + assert isinstance(function_types, list) elem_tys: list[FunctionType] = [] - for i, (elt, elt_ty) in enumerate(zip(node.elts, ty.element_types)): + # The substitution for the whole tuple of function types + big_subst: Subst = {} + for i, (elt, elt_ty) in enumerate(zip(node.elts, function_types)): node.elts[i], fun_ty = self._synthesize(elt, allow_free_vars=True) assert isinstance(fun_ty, FunctionType) elem_tys.append(fun_ty) + # Start with an empty substitution because the function types + # should have independent variables subst = unify(fun_ty, elt_ty, {}) or {} + big_subst |= {} - return node, subst + return node, big_subst def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]: if not is_list_type(ty) and not is_linst_type(ty): From 99e22d52d5747cb066fafee34f0152cb2475d1c2 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 7 May 2024 09:35:47 +0100 Subject: [PATCH 06/41] docs: Update comments --- guppylang/cfg/builder.py | 1 + guppylang/checker/expr_checker.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index ea74f244..02d0246d 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -345,6 +345,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: case args: arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) return with_loc(node, PyExpr(value=arg)) + # Unlike python, we can call a tuple of callable things elif isinstance(node.func, ast.Tuple): new_elts = [self.visit(elt) for elt in node.func.elts] node.func = ast.Tuple(new_elts) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 5facd194..023e0573 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -192,7 +192,6 @@ def _synthesize( return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: - # Data tuples should be checkable if not (isinstance(ty, TupleType) and len(ty.element_types) == len(node.elts)): return self._fail(ty, node) @@ -259,7 +258,8 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: if isinstance(defn, CallableDef): return defn.check_call(node.args, ty, node, self.ctx) - # Otherwise, it must be a function as a higher-order value + # Otherwise, it must be a function as a higher-order value - something + # whose type is either a FunctionType or a Tuple of FunctionTypes if isinstance(func_ty, FunctionType): args, return_ty, inst = check_call(func_ty, node.args, ty, node, self.ctx) check_inst(func_ty, inst, node) @@ -288,7 +288,6 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: check_inst(f_ty, inst, node) # Expect that each function is a `CallableDef` - # TODO: What if it's a tensor? if isinstance(f_processed, GlobalName): assert isinstance( self.ctx.globals[f_processed.def_id], CallableDef @@ -313,14 +312,19 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: return with_loc(node, TensorCall(call_nodes=call_nodes)), subst else: - # The func isn't a tuple, it could be a call or whatever - # we should do a check_call or something + # The func isn't a tuple, it could be a call or a variable. + # Here, the return type we expect has the outputs of all the + # function types merged together, i.e. + # f : Callable([A], tuple[B, C]) + # g : Callable([D], E) + # (f, g)(a, d) : tuple[B, C, E] processed_args, big_subst, inst = check_call( tensor_ty, node.args, tensor_ty.output, node.func, self.ctx ) + + # TODO: instantiate a tuple of functions # f_processed = instantiate_poly(node.func, tensor_ty, inst) - # If the substitution isn't empty, ... subst = unify(ty, tensor_ty.output, big_subst) or big_subst return with_loc( From 23b2769b5d9c87364e7eca816e37013489fc0295 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 7 May 2024 09:48:26 +0100 Subject: [PATCH 07/41] refactor: Don't do anything special synthesising tuples All of the logic for tuples of functions is handled when these things are *called*, so nothing needs to be done here. --- guppylang/checker/expr_checker.py | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 023e0573..6797bbce 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -430,31 +430,10 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]: ) def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]: - elems: list[tuple[ast.expr, Type]] = [ - self.synthesize(elem) for elem in node.elts - ] - - if all(isinstance(ty, FunctionType) for _, ty in elems): - input_row: list[Type] = [] - output_row: list[Type] = [] - func_tys: list[FunctionType] = [] - for i, (func_node, func_ty) in enumerate(elems): - assert isinstance(func_node.type, FunctionType) # type: ignore[attr-defined] - assert isinstance(func_ty, FunctionType) - node.elts[i] = func_node - func_tys.append(func_ty) - input_row.extend(func_ty.inputs) - if isinstance(func_ty.output, TupleType): - output_row.extend(func_ty.output.element_types) - else: - output_row.append(func_ty.output) + elems = [self.synthesize(elem) for elem in node.elts] - tensor_ty = TupleType(func_tys) - tuple_node = ast.Tuple([node for node, _ in elems]) - return tuple_node, tensor_ty - else: - node.elts = [n for n, _ in elems] - return node, TupleType([ty for _, ty in elems]) + node.elts = [n for n, _ in elems] + return node, TupleType([ty for _, ty in elems]) def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]: if len(node.elts) == 0: From 12153ca935473717ab24702c401ad8f55ae1f58f Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 7 May 2024 12:01:42 +0100 Subject: [PATCH 08/41] fix: Revert `is_function` change for building hugr --- guppylang/hugr/hugr.py | 20 ++++++++------------ guppylang/tys/ty.py | 10 ---------- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/guppylang/hugr/hugr.py b/guppylang/hugr/hugr.py index f6c54f67..4e0f798b 100644 --- a/guppylang/hugr/hugr.py +++ b/guppylang/hugr/hugr.py @@ -16,7 +16,6 @@ SumType, TupleType, Type, - is_function, row_to_type, type_to_row, ) @@ -492,13 +491,12 @@ def add_call( self, def_port: OutPortV, args: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds a `Call` node to the graph.""" - func_ty = is_function(def_port.ty) - assert isinstance(func_ty, FunctionType) + assert isinstance(def_port.ty, FunctionType) return self.add_node( ops.Call(), None, - list(type_to_row(func_ty.output)), + list(type_to_row(def_port.ty.output)), parent, [*args, def_port], ) @@ -507,13 +505,12 @@ def add_indirect_call( self, fun_port: OutPortV, args: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds an `IndirectCall` node to the graph.""" - func_ty = is_function(fun_port.ty) - assert isinstance(func_ty, FunctionType) + assert isinstance(fun_port.ty, FunctionType) return self.add_node( ops.CallIndirect(), None, - list(type_to_row(func_ty.output)), + list(type_to_row(fun_port.ty.output)), parent, [fun_port, *args], ) @@ -540,12 +537,11 @@ def add_type_apply( self, func_port: OutPortV, args: Inst, parent: Node | None = None ) -> VNode: """Adds a `TypeApply` node to the graph.""" - func_ty = is_function(func_port.ty) - assert isinstance(func_ty, FunctionType) - assert len(func_ty.params) == len(args) - result_ty = func_ty.instantiate(args) + assert isinstance(func_port.ty, FunctionType) + assert len(func_port.ty.params) == len(args) + result_ty = func_port.ty.instantiate(args) ta = ops.TypeApplication( - input=func_ty.to_hugr(), + input=func_port.ty.to_hugr(), args=[arg.to_hugr() for arg in args], output=result_ty.to_hugr(), ) diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 1f3fc566..723d90de 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -551,13 +551,3 @@ def function_tensor_signature(tys: list[FunctionType]) -> FunctionType: else: outputs.append(fun_ty.output) return FunctionType(inputs, TupleType(outputs)) - - -def is_function(ty: Type) -> FunctionType | None: - if isinstance(ty, FunctionType): - return ty - elif isinstance(ty, TupleType): - funcs = parse_function_tensor(ty) - if isinstance(funcs, list): - return function_tensor_signature(funcs) - return None From a21a481fc70670b8fe42464b801f088f191691e4 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 7 May 2024 12:04:34 +0100 Subject: [PATCH 09/41] [refactor] Use type<->row helpers --- guppylang/tys/ty.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 723d90de..3a91e520 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -546,8 +546,5 @@ def function_tensor_signature(tys: list[FunctionType]) -> FunctionType: outputs: list[Type] = [] for fun_ty in tys: inputs.extend(fun_ty.inputs) - if isinstance(fun_ty.output, TupleType): - outputs.extend(fun_ty.output.element_types) - else: - outputs.append(fun_ty.output) - return FunctionType(inputs, TupleType(outputs)) + outputs.extend(type_to_row(fun_ty.output)) + return FunctionType(inputs, row_to_type(outputs)) From e71321c73bc868e96193f2e2dde2ba7c44728382 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 7 May 2024 12:07:30 +0100 Subject: [PATCH 10/41] cleanup: Redundant code handled by generic_visit --- guppylang/cfg/builder.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index 02d0246d..60fc13c1 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -345,12 +345,6 @@ def visit_Call(self, node: ast.Call) -> ast.AST: case args: arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) return with_loc(node, PyExpr(value=arg)) - # Unlike python, we can call a tuple of callable things - elif isinstance(node.func, ast.Tuple): - new_elts = [self.visit(elt) for elt in node.func.elts] - node.func = ast.Tuple(new_elts) - return node - return self.generic_visit(node) def generic_visit(self, node: ast.AST) -> ast.AST: From d08e212dfb640a950b4ffe812ff7163918ebb3e4 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 8 May 2024 15:36:55 +0100 Subject: [PATCH 11/41] refactor: Remove redundant parse tensor calls --- guppylang/checker/expr_checker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 6797bbce..00e7bec0 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -266,8 +266,9 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: node.func = instantiate_poly(node.func, func_ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty - if isinstance(func_ty, TupleType) and parse_function_tensor(func_ty): - function_elements = parse_function_tensor(func_ty) + if isinstance(func_ty, TupleType) and ( + function_elements := parse_function_tensor(func_ty) + ): assert isinstance(function_elements, list) tensor_ty = function_tensor_signature(function_elements) @@ -572,12 +573,11 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif ( isinstance(ty, TupleType) - and parse_function_tensor(ty) + and (function_elems := parse_function_tensor(ty)) and isinstance(node.func, ast.Tuple) ): # Note: None of the function types in a tuple of functions will have # overlapping type arguments. - function_elems = parse_function_tensor(ty) assert isinstance(function_elems, list) func_ty = function_tensor_signature(function_elems) remaining_args = node.args From b5d502637ae46967882615e2f880fe322041134b Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 8 May 2024 15:46:37 +0100 Subject: [PATCH 12/41] fix: Return the right substitution --- guppylang/checker/expr_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 00e7bec0..86a8b9a2 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -310,7 +310,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: # If the substitution isn't empty, ... subst = unify(ty, tensor_ty.output, big_subst) or big_subst - return with_loc(node, TensorCall(call_nodes=call_nodes)), subst + return with_loc(node, TensorCall(call_nodes=call_nodes)), big_subst else: # The func isn't a tuple, it could be a call or a variable. From d2267cb08e019cf33319544d0fefb99fdb6befcc Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 8 May 2024 15:46:56 +0100 Subject: [PATCH 13/41] fix: Throw user errors when unification fails --- guppylang/checker/expr_checker.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 86a8b9a2..d9c6be42 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -214,12 +214,17 @@ def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: big_subst: Subst = {} for i, (elt, elt_ty) in enumerate(zip(node.elts, function_types)): node.elts[i], fun_ty = self._synthesize(elt, allow_free_vars=True) - assert isinstance(fun_ty, FunctionType) - elem_tys.append(fun_ty) - # Start with an empty substitution because the function types - # should have independent variables - subst = unify(fun_ty, elt_ty, {}) or {} - big_subst |= {} + if not isinstance(fun_ty, FunctionType): + return self._fail(elt_ty, fun_ty, node.elts[i]) + else: + elem_tys.append(fun_ty) + # Start with an empty substitution because the function types + # should have independent variables + subst = unify(fun_ty, elt_ty, {}) or {} + if subst is None: + return self._fail(elt_ty, fun_ty, node.elts[i]) + else: + big_subst |= subst return node, big_subst @@ -308,7 +313,11 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: big_subst |= subst # If the substitution isn't empty, ... - subst = unify(ty, tensor_ty.output, big_subst) or big_subst + subst = unify(ty, tensor_ty.output, big_subst) + if subst is None: + return self._fail(ty, tensor_ty.output, call_nodes[-1]) + else: + big_subst |= subst return with_loc(node, TensorCall(call_nodes=call_nodes)), big_subst @@ -326,7 +335,9 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: # TODO: instantiate a tuple of functions # f_processed = instantiate_poly(node.func, tensor_ty, inst) - subst = unify(ty, tensor_ty.output, big_subst) or big_subst + subst = unify(ty, tensor_ty.output, big_subst) + if subst is None: + return self._fail(ty, tensor_ty.output, node) return with_loc( node, LocalCall(func=node.func, args=processed_args) From 661f7e173c6d56776bd2cd6b3b1a69a2fa0c84a9 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 09:26:30 +0100 Subject: [PATCH 14/41] fix: More subst fixes --- guppylang/checker/expr_checker.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index d9c6be42..3ec937aa 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -312,14 +312,13 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: big_subst |= subst - # If the substitution isn't empty, ... - subst = unify(ty, tensor_ty.output, big_subst) - if subst is None: - return self._fail(ty, tensor_ty.output, call_nodes[-1]) - else: - big_subst |= subst - - return with_loc(node, TensorCall(call_nodes=call_nodes)), big_subst + # If the substitution isn't empty, ... + if result_subst := unify(ty, tensor_ty.output, big_subst): + return with_loc( + node, TensorCall(call_nodes=call_nodes) + ), result_subst + else: + return self._fail(ty, tensor_ty.output, call_nodes[-1]) else: # The func isn't a tuple, it could be a call or a variable. @@ -335,14 +334,13 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: # TODO: instantiate a tuple of functions # f_processed = instantiate_poly(node.func, tensor_ty, inst) - subst = unify(ty, tensor_ty.output, big_subst) - if subst is None: + if result_subst := unify(ty, tensor_ty.output, big_subst): + return with_loc( + node, LocalCall(func=node.func, args=processed_args) + ), result_subst + else: return self._fail(ty, tensor_ty.output, node) - return with_loc( - node, LocalCall(func=node.func, args=processed_args) - ), subst - elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): return callee.check_call(node.args, ty, node, self.ctx) else: From 990d95f0f8cea91f40b60143419d47d43b8bbd5c Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 09:41:28 +0100 Subject: [PATCH 15/41] fix: Use check_call instead of making a GlobalCall --- guppylang/checker/expr_checker.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 3ec937aa..0aa6c686 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -278,7 +278,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: tensor_ty = function_tensor_signature(function_elements) remaining_args: list[ast.expr] = node.args - call_nodes: list[GlobalCall | LocalCall] = [] + call_nodes: list[ast.expr] = [] big_subst: Subst = {} if isinstance(node.func, ast.Tuple): for f, f_ty in zip(node.func.elts, func_ty.element_types): @@ -295,16 +295,12 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: check_inst(f_ty, inst, node) # Expect that each function is a `CallableDef` if isinstance(f_processed, GlobalName): - assert isinstance( - self.ctx.globals[f_processed.def_id], CallableDef - ) - call_nodes.append( - GlobalCall( - def_id=f_processed.def_id, - args=processed_args, - type_args=inst, - ) + defn = self.ctx.globals[f_processed.def_id] + assert isinstance(defn, CallableDef) + call_node, subst = defn.check_call( + processed_args, f_ty.output, f_processed, self.ctx ) + call_nodes.append(call_node) else: call_nodes.append( LocalCall(func=f_processed, args=processed_args) From bf73897da64c0dffb184e9e72765aee3d96f91bf Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 10:15:58 +0100 Subject: [PATCH 16/41] fix: Subst fix fixes {} is falsey... --- guppylang/checker/expr_checker.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 0aa6c686..91a1246c 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -309,12 +309,13 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: big_subst |= subst # If the substitution isn't empty, ... - if result_subst := unify(ty, tensor_ty.output, big_subst): + result_subst = unify(ty, tensor_ty.output, big_subst) + if result_subst is None: + return self._fail(ty, tensor_ty.output, call_nodes[-1]) + else: return with_loc( node, TensorCall(call_nodes=call_nodes) ), result_subst - else: - return self._fail(ty, tensor_ty.output, call_nodes[-1]) else: # The func isn't a tuple, it could be a call or a variable. @@ -330,12 +331,13 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: # TODO: instantiate a tuple of functions # f_processed = instantiate_poly(node.func, tensor_ty, inst) - if result_subst := unify(ty, tensor_ty.output, big_subst): + result_subst = unify(ty, tensor_ty.output, big_subst) + if result_subst is None: + return self._fail(ty, tensor_ty.output, node) + else: return with_loc( node, LocalCall(func=node.func, args=processed_args) ), result_subst - else: - return self._fail(ty, tensor_ty.output, node) elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): return callee.check_call(node.args, ty, node, self.ctx) From d7fb0ed4d4452029b9c6b5aaa866551aca8c2a73 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 10:21:41 +0100 Subject: [PATCH 17/41] Simplify function tuple semantics: (f) != f --- guppylang/checker/expr_checker.py | 37 +++++-------------------------- tests/integration/test_tensor.py | 8 +++---- 2 files changed, 9 insertions(+), 36 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 91a1246c..dd33fb40 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -195,38 +195,11 @@ def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: if not (isinstance(ty, TupleType) and len(ty.element_types) == len(node.elts)): return self._fail(ty, node) - # Tuples can either be inert python tuples or tuples of functions which - # can be called in guppy. The former thing is checkable, but in the - # latter case we should be able to synthesise function types for the - # elements. Check here whether the given type is a tuple of function - # types to work out which case we're in. - function_types = parse_function_tensor(ty) - if not function_types: - subst: Subst = {} - for i, el in enumerate(node.elts): - node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) - subst |= s - return node, subst - else: - assert isinstance(function_types, list) - elem_tys: list[FunctionType] = [] - # The substitution for the whole tuple of function types - big_subst: Subst = {} - for i, (elt, elt_ty) in enumerate(zip(node.elts, function_types)): - node.elts[i], fun_ty = self._synthesize(elt, allow_free_vars=True) - if not isinstance(fun_ty, FunctionType): - return self._fail(elt_ty, fun_ty, node.elts[i]) - else: - elem_tys.append(fun_ty) - # Start with an empty substitution because the function types - # should have independent variables - subst = unify(fun_ty, elt_ty, {}) or {} - if subst is None: - return self._fail(elt_ty, fun_ty, node.elts[i]) - else: - big_subst |= subst - - return node, big_subst + subst: Subst = {} + for i, el in enumerate(node.elts): + node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) + subst |= s + return node, subst def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]: if not is_list_type(ty) and not is_linst_type(ty): diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py index fe3971d6..fcac9c73 100644 --- a/tests/integration/test_tensor.py +++ b/tests/integration/test_tensor.py @@ -20,7 +20,7 @@ def is_42(x: int) -> bool: @guppy(module) def baz(x: int) -> tuple[bool]: - return (bar,)(is_42)(x) + return bar(is_42)(x) validate(module.compile()) @@ -41,15 +41,15 @@ def is_42(x: int) -> bool: return x == 42 @guppy(module) - def baz(x: int) -> tuple[bool]: + def baz(x: int) -> bool: return foo(is_42)(x) @guppy(module) def baz1() -> tuple[Callable[[int], bool]]: - return (foo,)(is_42) + return foo(is_42) @guppy(module) - def baz2(x: int) -> tuple[bool]: + def baz2(x: int) -> bool: return (foo,)(is_42)(x) validate(module.compile()) From 8c6ac051fe7354b83e34ee2d72235deee0075c46 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 10:35:56 +0100 Subject: [PATCH 18/41] cleanup: Remove redundant test --- tests/integration/test_tensor.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py index fcac9c73..98a291fc 100644 --- a/tests/integration/test_tensor.py +++ b/tests/integration/test_tensor.py @@ -7,24 +7,6 @@ from guppylang.module import GuppyModule -def test_bug(validate): - module = GuppyModule("module") - - @guppy(module) - def bar(f: Callable[[int], bool]) -> Callable[[int], bool]: - return f - - @guppy(module) - def is_42(x: int) -> bool: - return x == 42 - - @guppy(module) - def baz(x: int) -> tuple[bool]: - return bar(is_42)(x) - - validate(module.compile()) - - def test_check_callable(validate): module = GuppyModule("module") From 2e38476d077fbffdf13171a11ac21683444ddc69 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 10:46:11 +0100 Subject: [PATCH 19/41] cleanup: Remove redundant test --- tests/integration/test_tensor.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py index 98a291fc..fb435381 100644 --- a/tests/integration/test_tensor.py +++ b/tests/integration/test_tensor.py @@ -64,12 +64,6 @@ def baz_ho_call() -> tuple[int, bool]: def baz() -> tuple[int, bool]: return (foo, bar)() - @guppy(module) - def local_ho( - f: Callable[[int], bool], g: Callable[[bool], int] - ) -> tuple[bool, int]: - return (f, g)(2, True) - validate(module.compile()) From 497daf9e62ad4135ee485746c6a4c8645fa70f08 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 11:15:31 +0100 Subject: [PATCH 20/41] fix: Bug in compiling calls of tuples --- guppylang/compiler/expr_compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 46c26163..e2c86ddf 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -201,7 +201,9 @@ def visit_LocalCall(self, node: LocalCall) -> OutPortV: # Now we have to manage this hand-off remaining_args = args for func in func_ports: - outs, remaining_args = self._compile_tensor_with_leftovers(func, args) + outs, remaining_args = self._compile_tensor_with_leftovers( + func, remaining_args + ) rets.extend(outs) else: raise InternalGuppyError("Local call of something without a callable type") From e70a7ac8d5d83cd31304b704a8871f85025f94b9 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 11:17:43 +0100 Subject: [PATCH 21/41] cleanup: Remove special case for function tuples --- guppylang/checker/expr_checker.py | 67 +++++-------------------------- 1 file changed, 11 insertions(+), 56 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index dd33fb40..20fe2cbb 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -250,67 +250,22 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: assert isinstance(function_elements, list) tensor_ty = function_tensor_signature(function_elements) - remaining_args: list[ast.expr] = node.args - call_nodes: list[ast.expr] = [] big_subst: Subst = {} - if isinstance(node.func, ast.Tuple): - for f, f_ty in zip(node.func.elts, func_ty.element_types): - assert isinstance(f_ty, FunctionType) - # Use the concrete output type of the function, we'll try to - # unify all of the results with `ty` at the end - processed_args, subst, inst, remaining_args = ( - check_call_with_leftovers( - f_ty, remaining_args, f_ty.output, f, self.ctx - ) - ) - f_processed = instantiate_poly(f, f_ty, inst) - - check_inst(f_ty, inst, node) - # Expect that each function is a `CallableDef` - if isinstance(f_processed, GlobalName): - defn = self.ctx.globals[f_processed.def_id] - assert isinstance(defn, CallableDef) - call_node, subst = defn.check_call( - processed_args, f_ty.output, f_processed, self.ctx - ) - call_nodes.append(call_node) - else: - call_nodes.append( - LocalCall(func=f_processed, args=processed_args) - ) - big_subst |= subst + processed_args, big_subst, inst = check_call( + tensor_ty, node.args, tensor_ty.output, node.func, self.ctx + ) - # If the substitution isn't empty, ... - result_subst = unify(ty, tensor_ty.output, big_subst) - if result_subst is None: - return self._fail(ty, tensor_ty.output, call_nodes[-1]) - else: - return with_loc( - node, TensorCall(call_nodes=call_nodes) - ), result_subst + # TODO: instantiate a tuple of functions + # f_processed = instantiate_poly(node.func, tensor_ty, inst) + result_subst = unify(ty, tensor_ty.output, big_subst) + if result_subst is None: + return self._fail(ty, tensor_ty.output, node) else: - # The func isn't a tuple, it could be a call or a variable. - # Here, the return type we expect has the outputs of all the - # function types merged together, i.e. - # f : Callable([A], tuple[B, C]) - # g : Callable([D], E) - # (f, g)(a, d) : tuple[B, C, E] - processed_args, big_subst, inst = check_call( - tensor_ty, node.args, tensor_ty.output, node.func, self.ctx - ) - - # TODO: instantiate a tuple of functions - # f_processed = instantiate_poly(node.func, tensor_ty, inst) - - result_subst = unify(ty, tensor_ty.output, big_subst) - if result_subst is None: - return self._fail(ty, tensor_ty.output, node) - else: - return with_loc( - node, LocalCall(func=node.func, args=processed_args) - ), result_subst + return with_loc( + node, LocalCall(func=node.func, args=processed_args) + ), result_subst elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): return callee.check_call(node.args, ty, node, self.ctx) From 88c4bf3506236d372609d875ee567317fb211723 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 11:31:17 +0100 Subject: [PATCH 22/41] test: Add tests of calling tuple variables --- tests/integration/test_tensor.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py index fb435381..dbfce8ac 100644 --- a/tests/integration/test_tensor.py +++ b/tests/integration/test_tensor.py @@ -64,6 +64,12 @@ def baz_ho_call() -> tuple[int, bool]: def baz() -> tuple[int, bool]: return (foo, bar)() + @guppy(module) + def call_var() -> tuple[int, bool, int]: + f = foo + g = (foo, bar, f) + return g() + validate(module.compile()) @@ -92,6 +98,11 @@ def bar(x: int) -> int: def baz(x: int, y: int) -> tuple[int, int]: return (foo, bar)(x, y) + @guppy(module) + def call_var(x: int) -> tuple[int, int, int]: + f = foo, baz + return f(x, x, x) + validate(module.compile()) From 39a9d7245a080343619d073f887c0e0476315968 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 11:45:10 +0100 Subject: [PATCH 23/41] new: Handle nested function tuples when compiling --- guppylang/compiler/expr_compiler.py | 24 ++++++++++++++++++------ tests/integration/test_tensor.py | 19 +++++++++++++++++++ 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index e2c86ddf..38a47c8c 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -216,13 +216,25 @@ def _compile_tensor_with_leftovers( list[OutPortV], # Compiled outputs list[OutPortV], ]: # Leftover args - assert isinstance(func.ty, FunctionType) - input_len = len(func.ty.inputs) - call = self.graph.add_indirect_call(func, args[0:input_len]) + if isinstance(func.ty, TupleType): + remaining_args = args + all_outs = [] + for elem in self._unpack_tuple(func): + outs, remaining_args = self._compile_tensor_with_leftovers( + elem, remaining_args + ) + all_outs.extend(outs) + return all_outs, remaining_args + + elif isinstance(func.ty, FunctionType): + input_len = len(func.ty.inputs) + call = self.graph.add_indirect_call(func, args[0:input_len]) - return [ - call.out_port(i) for i in range(len(type_to_row(func.ty.output))) - ], args[input_len:] + return [ + call.out_port(i) for i in range(len(type_to_row(func.ty.output))) + ], args[input_len:] + else: + raise InternalGuppyError("Tensor element wasn't function or tuple") def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: func = self.globals[node.def_id] diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py index dbfce8ac..b512667f 100644 --- a/tests/integration/test_tensor.py +++ b/tests/integration/test_tensor.py @@ -151,3 +151,22 @@ def baz() -> tuple[Callable[[int], bool], Callable[[float], int]]: # return apply(baz, args) validate(module.compile()) + + +def test_nesting(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int, y: int) -> int: + return x + y + + @guppy(module) + def bar(x: int) -> int: + return -x + + @guppy(module) + def call(x: int) -> tuple[int, int, int]: + f = bar, (bar, foo) + return f(x, x, x, x) + + validate(module.compile()) From 69b8c596653d56495d6b73983935bea8af39cbea Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 12:13:01 +0100 Subject: [PATCH 24/41] Complain when we have to instantiate in tuple call --- guppylang/checker/expr_checker.py | 48 +++++++++---------------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 20fe2cbb..7c39a028 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -53,7 +53,6 @@ from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, - GlobalCall, GlobalName, IterEnd, IterHasNext, @@ -62,7 +61,6 @@ LocalName, MakeIter, PyExpr, - TensorCall, TypeApply, ) from guppylang.tys.arg import TypeArg @@ -256,8 +254,10 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: tensor_ty, node.args, tensor_ty.output, node.func, self.ctx ) - # TODO: instantiate a tuple of functions - # f_processed = instantiate_poly(node.func, tensor_ty, inst) + if len(inst) > 0: + raise GuppyTypeError( + "Polymorphic functions in tuples are not supported" + ) result_subst = unify(ty, tensor_ty.output, big_subst) if result_subst is None: @@ -506,39 +506,19 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx) node.func = instantiate_poly(node.func, ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty - elif ( - isinstance(ty, TupleType) - and (function_elems := parse_function_tensor(ty)) - and isinstance(node.func, ast.Tuple) + elif isinstance(ty, TupleType) and ( + function_elems := parse_function_tensor(ty) ): - # Note: None of the function types in a tuple of functions will have - # overlapping type arguments. - assert isinstance(function_elems, list) - func_ty = function_tensor_signature(function_elems) - remaining_args = node.args - return_tys: list[Type] = [] - processed_args: list[ast.expr] = [] - call_nodes: list[ast.expr] = [] - for func in node.func.elts: - args, return_ty, inst, remaining_args = synthesize_call_with_leftovers( - func_ty, remaining_args, node, self.ctx + tensor_ty = function_tensor_signature(function_elems) + args, return_ty, inst = synthesize_call( + tensor_ty, node.args, node, self.ctx + ) + if len(inst) > 0: + raise GuppyTypeError( + "Polymorphic functions in tuples are not supported" ) - processed_args.extend(args) - if isinstance(return_ty, TupleType): - return_tys.extend(return_ty.element_types) - else: - return_tys.append(return_ty) - - if isinstance(func, GlobalName): - assert isinstance(self.ctx.globals[func.def_id], CallableDef) - call_nodes.append( - GlobalCall(def_id=func.def_id, args=args, type_args=inst) - ) - else: - call_nodes.append(LocalCall(func=func, args=args)) - call_node = TensorCall(call_nodes) - return with_loc(node, call_node), TupleType(return_tys) + return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif f := self.ctx.globals.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) From 70e3db3c94893ffe503c829b7284e0dc5aaf9771 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 12:15:28 +0100 Subject: [PATCH 25/41] cleanup: Remove TensorCall node --- guppylang/compiler/expr_compiler.py | 11 ----------- guppylang/nodes.py | 9 --------- 2 files changed, 20 deletions(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 38a47c8c..c77415db 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -22,7 +22,6 @@ GlobalName, LocalCall, LocalName, - TensorCall, TypeApply, ) from guppylang.tys.builtin import bool_type, get_element_type, is_list_type @@ -251,16 +250,6 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: return self._pack_returns(rets) - def visit_TensorCall(self, node: TensorCall) -> OutPortV: - outputs = [] - for call in node.call_nodes: - output = self.visit(call) - if isinstance(output.ty, TupleType): - outputs.extend(self._unpack_tuple(output)) - else: - outputs.append(output) - return self._pack_returns(outputs) - def visit_Call(self, node: ast.Call) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/nodes.py b/guppylang/nodes.py index c1cbdcd0..af260062 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -52,15 +52,6 @@ class GlobalCall(ast.expr): ) -class TensorCall(ast.expr): - """A call to a tuple of functions. Stores a call node for each function in the - tuple""" - - call_nodes: list[ast.expr] - - _fields = ("call_nodes",) - - class TypeApply(ast.expr): value: ast.expr inst: Inst From 7a12f75ec9f885c2d13f80dcd9dde11558f44775 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 17:15:50 +0100 Subject: [PATCH 26/41] cleanup: Undo changes adding _with_leftovers fns --- guppylang/checker/expr_checker.py | 100 +++++++----------------------- guppylang/prelude/_internal.py | 15 +---- 2 files changed, 24 insertions(+), 91 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 7c39a028..b92e6613 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -657,23 +657,18 @@ def check_type_against( return subst, [] -def check_num_args_sufficient(exp: int, act: int, node: AstNode) -> None: - """Checks that enough arguments have been passed to a function.""" +def check_num_args(exp: int, act: int, node: AstNode) -> None: + """Checks that the correct number of arguments have been passed to a function.""" if act < exp: raise GuppyTypeError( f"Not enough arguments passed (expected {exp}, got {act})", node ) - - -def check_leftovers_nil( - args_checked: int, leftovers: list[ast.expr], node: AstNode -) -> None: - if len(leftovers) > 0: + if exp < act: if isinstance(node, ast.Call): - raise GuppyTypeError("Unexpected argument", leftovers[0]) - total_args = args_checked + len(leftovers) - msg = f"Too many arguments passed (expected {args_checked}, got {total_args})" - raise GuppyTypeError(msg, node) + raise GuppyTypeError("Unexpected argument", node.args[exp]) + raise GuppyTypeError( + f"Too many arguments passed (expected {exp}, got {act})", node + ) def type_check_args( @@ -685,39 +680,18 @@ def type_check_args( ) -> tuple[list[ast.expr], Subst]: """Checks the arguments of a function call and infers free type variables. - We expect that parameters have been replaced with free unification variables. - Checks that all unification variables can be inferred. - """ - exprs, subst, leftovers = type_check_args_with_leftovers( - inputs, func_ty, subst, ctx, node - ) - check_leftovers_nil(len(exprs), leftovers, node) - return exprs, subst - - -def type_check_args_with_leftovers( - inputs: list[ast.expr], - func_ty: FunctionType, - subst: Subst, - ctx: Context, - node: AstNode, -) -> tuple[list[ast.expr], Subst, list[ast.expr]]: - """Checks the arguments of a function call and infers free type variables. - We expect that parameters have been replaced with free unification variables. Checks that all unification variables can be inferred. """ assert not func_ty.parametrized - check_num_args_sufficient(len(func_ty.inputs), len(inputs), node) + check_num_args(len(func_ty.inputs), len(inputs), node) new_args: list[ast.expr] = [] - for ty, inp in zip(func_ty.inputs, inputs): + for inp, ty in zip(inputs, func_ty.inputs): a, s = ExprChecker(ctx).check(inp, ty.substitute(subst), "argument") new_args.append(a) subst |= s - leftovers = inputs[len(func_ty.inputs) :] - # If the argument check succeeded, this means that we must have found instantiations # for all unification variables occurring in the input types assert all(set.issubset(inp.unsolved_vars, subst.keys()) for inp in func_ty.inputs) @@ -730,37 +704,24 @@ def type_check_args_with_leftovers( node, ) - return new_args, subst, leftovers + return new_args, subst def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[list[ast.expr], Type, Inst]: - exprs, tys, inst, leftovers = synthesize_call_with_leftovers( - func_ty, args, node, ctx - ) - check_leftovers_nil(len(exprs), leftovers, node) - return exprs, tys, inst - - -def synthesize_call_with_leftovers( - func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context -) -> tuple[list[ast.expr], Type, Inst, list[ast.expr]]: """Synthesizes the return type of a function call. Returns an annotated argument list, the synthesized return type, and an instantiation for the quantifiers in the function type. """ assert not func_ty.unsolved_vars - check_num_args_sufficient(len(func_ty.inputs), len(args), node) + check_num_args(len(func_ty.inputs), len(args), node) # Replace quantified variables with free unification variables and try to infer an # instantiation by checking the arguments unquantified, free_vars = func_ty.unquantified() - - args, subst, leftovers = type_check_args_with_leftovers( - args, unquantified, {}, ctx, node - ) + args, subst = type_check_args(args, unquantified, {}, ctx, node) # Success implies that the substitution is closed assert all(not t.unsolved_vars for t in subst.values()) @@ -769,7 +730,7 @@ def synthesize_call_with_leftovers( # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) - return args, unquantified.output.substitute(subst), inst, leftovers + return args, unquantified.output.substitute(subst), inst def check_call( @@ -780,28 +741,13 @@ def check_call( ctx: Context, kind: str = "expression", ) -> tuple[list[ast.expr], Subst, Inst]: - exprs, subst, inst, leftovers = check_call_with_leftovers( - func_ty, inputs, ty, node, ctx, kind - ) - check_leftovers_nil(len(exprs), leftovers, node) - return exprs, subst, inst - - -def check_call_with_leftovers( - func_ty: FunctionType, - inputs: list[ast.expr], - ty: Type, - node: AstNode, - ctx: Context, - kind: str = "expression", -) -> tuple[list[ast.expr], Subst, Inst, list[ast.expr]]: """Checks the return type of a function call against a given type. Returns an annotated argument list, a substitution for the free variables in the expected type, and an instantiation for the quantifiers in the function type. """ assert not func_ty.unsolved_vars - check_num_args_sufficient(len(func_ty.inputs), len(inputs), node) + check_num_args(len(func_ty.inputs), len(inputs), node) # When checking, we can use the information from the expected return type to infer # some type arguments. However, this pushes errors inwards. For example, given a @@ -823,20 +769,18 @@ def check_call_with_leftovers( # in practice. Can we do better than that? # First, try to synthesize - res: tuple[Type, Inst, list[ast.expr]] | None = None + res: tuple[Type, Inst] | None = None try: - inputs, synth, inst, leftovers = synthesize_call_with_leftovers( - func_ty, inputs, node, ctx - ) - res = synth, inst, leftovers + inputs, synth, inst = synthesize_call(func_ty, inputs, node, ctx) + res = synth, inst except GuppyTypeInferenceError: pass if res is not None: - synth, inst, leftovers = res + synth, inst = res subst = unify(ty, synth, {}) if subst is None: raise GuppyTypeError(f"Expected {kind} of type `{ty}`, got `{synth}`", node) - return inputs, subst, inst, leftovers + return inputs, subst, inst # If synthesis fails, we try again, this time also using information from the # expected return type @@ -848,9 +792,7 @@ def check_call_with_leftovers( ) # Try to infer more by checking against the arguments - inputs, subst, leftovers = type_check_args_with_leftovers( - inputs, unquantified, subst, ctx, node - ) + inputs, subst = type_check_args(inputs, unquantified, subst, ctx, node) # Also make sure we found an instantiation for all free vars in the type we're # checking against @@ -869,7 +811,7 @@ def check_call_with_leftovers( # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) - return inputs, subst, inst, leftovers + return inputs, subst, inst def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None: diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index bed4b16c..db0e0aee 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -5,11 +5,7 @@ from guppylang.ast_util import AstNode, get_type, with_loc, with_type from guppylang.checker.core import Context -from guppylang.checker.expr_checker import ( - ExprSynthesizer, - check_leftovers_nil, - check_num_args_sufficient, -) +from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args from guppylang.definition.custom import ( CustomCallChecker, CustomCallCompiler, @@ -29,11 +25,6 @@ INT_WIDTH = 6 # 2^6 = 64 bit -def check_num_args(exp: int, args: list[ast.expr], node: AstNode) -> None: - check_num_args_sufficient(exp, len(args), node) - check_leftovers_nil(exp, args[exp:], node) - - hugr_int_type = tys.Opaque( extension="arithmetic.int.types", id="int", @@ -202,7 +193,7 @@ def __init__(self, dunder_name: str, num_args: int = 1): self.num_args = num_args def _get_func(self, args: list[ast.expr]) -> tuple[list[ast.expr], CallableDef]: - check_num_args(self.num_args, args, self.node) + check_num_args(self.num_args, len(args), self.node) fst, *rest = args fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) func = self.ctx.globals.get_instance_func(ty, self.dunder_name) @@ -227,7 +218,7 @@ class CallableChecker(CustomCallChecker): """Call checker for the builtin `callable` function""" def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - check_num_args(1, args, self.node) + check_num_args(1, len(args), self.node) [arg] = args arg, ty = ExprSynthesizer(self.ctx).synthesize(arg) is_callable = ( From 87f404cb05cbc4c672f15a3658da109ae3dd1def Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 17:43:54 +0100 Subject: [PATCH 27/41] test: Add tests for function tensor errors --- tests/error/tensor_errors/poly_tensor.err | 7 +++++++ tests/error/tensor_errors/poly_tensor.py | 16 ++++++++++++++++ tests/error/tensor_errors/too_few_args.err | 7 +++++++ tests/error/tensor_errors/too_few_args.py | 14 ++++++++++++++ tests/error/tensor_errors/too_many_args.err | 7 +++++++ tests/error/tensor_errors/too_many_args.py | 14 ++++++++++++++ tests/error/tensor_errors/type_mismatch.err | 7 +++++++ tests/error/tensor_errors/type_mismatch.py | 14 ++++++++++++++ tests/error/test_tensor_errors.py | 15 +++++++++++++++ 9 files changed, 101 insertions(+) create mode 100644 tests/error/tensor_errors/poly_tensor.err create mode 100644 tests/error/tensor_errors/poly_tensor.py create mode 100644 tests/error/tensor_errors/too_few_args.err create mode 100644 tests/error/tensor_errors/too_few_args.py create mode 100644 tests/error/tensor_errors/too_many_args.err create mode 100644 tests/error/tensor_errors/too_many_args.py create mode 100644 tests/error/tensor_errors/type_mismatch.err create mode 100644 tests/error/tensor_errors/type_mismatch.py create mode 100644 tests/error/test_tensor_errors.py diff --git a/tests/error/tensor_errors/poly_tensor.err b/tests/error/tensor_errors/poly_tensor.err new file mode 100644 index 00000000..eb466d9a --- /dev/null +++ b/tests/error/tensor_errors/poly_tensor.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:14 + +12: @guppy(module) +13: def main() -> int: +14: return (foo, foo)(42, 42) + ^^ +GuppyTypeError: Expected argument of type `T`, got `int` diff --git a/tests/error/tensor_errors/poly_tensor.py b/tests/error/tensor_errors/poly_tensor.py new file mode 100644 index 00000000..0931b137 --- /dev/null +++ b/tests/error/tensor_errors/poly_tensor.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + +@guppy.declare(module) +def foo(x: T) -> T: + ... + +@guppy(module) +def main() -> int: + return (foo, foo)(42, 42) + +module.compile() diff --git a/tests/error/tensor_errors/too_few_args.err b/tests/error/tensor_errors/too_few_args.err new file mode 100644 index 00000000..ed7eb98d --- /dev/null +++ b/tests/error/tensor_errors/too_few_args.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy(module) +11: def main() -> int: +12: return (foo, foo)(42) + ^^^^^^^^^^ +GuppyTypeError: Not enough arguments passed (expected 2, got 1) diff --git a/tests/error/tensor_errors/too_few_args.py b/tests/error/tensor_errors/too_few_args.py new file mode 100644 index 00000000..6bd2d17a --- /dev/null +++ b/tests/error/tensor_errors/too_few_args.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(x: int) -> int: + ... + +@guppy(module) +def main() -> int: + return (foo, foo)(42) + +module.compile() diff --git a/tests/error/tensor_errors/too_many_args.err b/tests/error/tensor_errors/too_many_args.err new file mode 100644 index 00000000..d30f2666 --- /dev/null +++ b/tests/error/tensor_errors/too_many_args.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy(module) +11: def main() -> int: +12: return (foo, foo)(1, 2, 3) + ^^^^^^^^^^ +GuppyTypeError: Too many arguments passed (expected 2, got 3) diff --git a/tests/error/tensor_errors/too_many_args.py b/tests/error/tensor_errors/too_many_args.py new file mode 100644 index 00000000..7f70d380 --- /dev/null +++ b/tests/error/tensor_errors/too_many_args.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(x: int) -> int: + ... + +@guppy(module) +def main() -> int: + return (foo, foo)(1, 2, 3) + +module.compile() diff --git a/tests/error/tensor_errors/type_mismatch.err b/tests/error/tensor_errors/type_mismatch.err new file mode 100644 index 00000000..cfc31e7e --- /dev/null +++ b/tests/error/tensor_errors/type_mismatch.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy(module) +11: def main() -> int: +12: return (foo, foo)(42, False) + ^^^^^ +GuppyTypeError: Expected argument of type `int`, got `bool` diff --git a/tests/error/tensor_errors/type_mismatch.py b/tests/error/tensor_errors/type_mismatch.py new file mode 100644 index 00000000..9e4d40a2 --- /dev/null +++ b/tests/error/tensor_errors/type_mismatch.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(x: int) -> int: + ... + +@guppy(module) +def main() -> int: + return (foo, foo)(42, False) + +module.compile() diff --git a/tests/error/test_tensor_errors.py b/tests/error/test_tensor_errors.py new file mode 100644 index 00000000..38c2959b --- /dev/null +++ b/tests/error/test_tensor_errors.py @@ -0,0 +1,15 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "tensor_errors" +files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_type_errors(file, capsys): + run_error_test(file, capsys) From 16fd3b52eb8fe68cb46a2543b7a612118be84e2f Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 17:46:39 +0100 Subject: [PATCH 28/41] Update guppylang/compiler/expr_compiler.py Co-authored-by: Mark Koch <48097969+mark-koch@users.noreply.github.com> --- guppylang/compiler/expr_compiler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index c77415db..cc79e722 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -229,9 +229,7 @@ def _compile_tensor_with_leftovers( input_len = len(func.ty.inputs) call = self.graph.add_indirect_call(func, args[0:input_len]) - return [ - call.out_port(i) for i in range(len(type_to_row(func.ty.output))) - ], args[input_len:] + return list(call.out_ports), args[input_len:] else: raise InternalGuppyError("Tensor element wasn't function or tuple") From 22f2cc615e413822681af12993c37a7619e91ad6 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 9 May 2024 17:48:39 +0100 Subject: [PATCH 29/41] big_subst -> subst --- guppylang/checker/expr_checker.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index b92e6613..4c9c12d2 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -248,9 +248,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: assert isinstance(function_elements, list) tensor_ty = function_tensor_signature(function_elements) - big_subst: Subst = {} - - processed_args, big_subst, inst = check_call( + processed_args, subst, inst = check_call( tensor_ty, node.args, tensor_ty.output, node.func, self.ctx ) @@ -259,7 +257,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: "Polymorphic functions in tuples are not supported" ) - result_subst = unify(ty, tensor_ty.output, big_subst) + result_subst = unify(ty, tensor_ty.output, subst) if result_subst is None: return self._fail(ty, tensor_ty.output, node) else: From b3fe31756f8dbd4c6cc718a70195524dd3e70173 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Fri, 10 May 2024 09:11:45 +0100 Subject: [PATCH 30/41] refactor: check for parametrized fns earlier --- guppylang/checker/expr_checker.py | 20 ++++++++++++-------- guppylang/tys/ty.py | 1 + 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 4c9c12d2..60919150 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -246,16 +246,18 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: function_elements := parse_function_tensor(func_ty) ): assert isinstance(function_elements, list) + if any(f.parametrized for f in function_elements): + raise GuppyTypeError( + "Polymorphic functions in tuples are not supported" + ) + tensor_ty = function_tensor_signature(function_elements) processed_args, subst, inst = check_call( tensor_ty, node.args, tensor_ty.output, node.func, self.ctx ) - if len(inst) > 0: - raise GuppyTypeError( - "Polymorphic functions in tuples are not supported" - ) + assert len(inst) == 0 result_subst = unify(ty, tensor_ty.output, subst) if result_subst is None: @@ -507,14 +509,16 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: elif isinstance(ty, TupleType) and ( function_elems := parse_function_tensor(ty) ): + if any(f.parametrized for f in function_elems): + raise GuppyTypeError( + "Polymorphic functions in tuples are not supported" + ) + tensor_ty = function_tensor_signature(function_elems) args, return_ty, inst = synthesize_call( tensor_ty, node.args, node, self.ctx ) - if len(inst) > 0: - raise GuppyTypeError( - "Polymorphic functions in tuples are not supported" - ) + assert len(inst) == 0 return with_loc(node, LocalCall(func=node.func, args=args)), return_ty diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 3a91e520..4e114663 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -545,6 +545,7 @@ def function_tensor_signature(tys: list[FunctionType]) -> FunctionType: inputs: list[Type] = [] outputs: list[Type] = [] for fun_ty in tys: + assert not fun_ty.parametrized inputs.extend(fun_ty.inputs) outputs.extend(type_to_row(fun_ty.output)) return FunctionType(inputs, row_to_type(outputs)) From 34c1953fbfd9bbd13738ed717d622a74d96ea327 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Fri, 10 May 2024 09:13:10 +0100 Subject: [PATCH 31/41] Revert "cleanup: Remove TensorCall node" This reverts commit 70e3db3c94893ffe503c829b7284e0dc5aaf9771. --- guppylang/compiler/expr_compiler.py | 11 +++++++++++ guppylang/nodes.py | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index cc79e722..1d1496f6 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -22,6 +22,7 @@ GlobalName, LocalCall, LocalName, + TensorCall, TypeApply, ) from guppylang.tys.builtin import bool_type, get_element_type, is_list_type @@ -248,6 +249,16 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: return self._pack_returns(rets) + def visit_TensorCall(self, node: TensorCall) -> OutPortV: + outputs = [] + for call in node.call_nodes: + output = self.visit(call) + if isinstance(output.ty, TupleType): + outputs.extend(self._unpack_tuple(output)) + else: + outputs.append(output) + return self._pack_returns(outputs) + def visit_Call(self, node: ast.Call) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/nodes.py b/guppylang/nodes.py index af260062..c1cbdcd0 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -52,6 +52,15 @@ class GlobalCall(ast.expr): ) +class TensorCall(ast.expr): + """A call to a tuple of functions. Stores a call node for each function in the + tuple""" + + call_nodes: list[ast.expr] + + _fields = ("call_nodes",) + + class TypeApply(ast.expr): value: ast.expr inst: Inst From b0de9a89d20658bec78c9451fe0cc735f477ef97 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Fri, 10 May 2024 09:31:05 +0100 Subject: [PATCH 32/41] refactor: Bring back a TensorCall for compiling --- guppylang/checker/expr_checker.py | 9 ++-- guppylang/compiler/expr_compiler.py | 67 ++++++++++------------------- guppylang/nodes.py | 12 ++++-- 3 files changed, 36 insertions(+), 52 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 60919150..45728836 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -61,6 +61,7 @@ LocalName, MakeIter, PyExpr, + TensorCall, TypeApply, ) from guppylang.tys.arg import TypeArg @@ -190,9 +191,8 @@ def _synthesize( return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: - if not (isinstance(ty, TupleType) and len(ty.element_types) == len(node.elts)): + if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): return self._fail(ty, node) - subst: Subst = {} for i, el in enumerate(node.elts): node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) @@ -264,7 +264,8 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: return self._fail(ty, tensor_ty.output, node) else: return with_loc( - node, LocalCall(func=node.func, args=processed_args) + node, + TensorCall(func=node.func, args=processed_args), ), result_subst elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): @@ -520,7 +521,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: ) assert len(inst) == 0 - return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + return with_loc(node, TensorCall(func=node.func, args=args)), return_ty elif f := self.ctx.globals.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 1d1496f6..db4cc94e 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -33,8 +33,6 @@ NoneType, TupleType, Type, - function_tensor_signature, - parse_function_tensor, type_to_row, ) @@ -179,36 +177,32 @@ def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: def visit_LocalCall(self, node: LocalCall) -> OutPortV: func = self.visit(node.func) + assert isinstance(func.ty, FunctionType) - assert isinstance(func.ty, FunctionType) or ( - isinstance(func.ty, TupleType) and parse_function_tensor(func.ty) - ) - if isinstance(func.ty, FunctionType): - output = func.ty.output - elif isinstance(func.ty, TupleType): - funcs = parse_function_tensor(func.ty) - assert isinstance(funcs, list) - output = function_tensor_signature(funcs).output + args = [self.visit(arg) for arg in node.args] + call = self.graph.add_indirect_call(func, args) + rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.output)))] + return self._pack_returns(rets) + def visit_TensorCall(self, node: TensorCall) -> OutPortV: + func = self.visit(node.func) args = [self.visit(arg) for arg in node.args] + assert isinstance(func.ty, TupleType) + rets: list[OutPortV] = [] - if isinstance(func.ty, FunctionType): - call = self.graph.add_indirect_call(func, args) - rets = [call.out_port(i) for i in range(len(type_to_row(output)))] - elif isinstance(func.ty, TupleType) and parse_function_tensor(func.ty): - func_ports = self._unpack_tuple(func) - # Now we have to manage this hand-off - remaining_args = args - for func in func_ports: - outs, remaining_args = self._compile_tensor_with_leftovers( - func, remaining_args - ) - rets.extend(outs) - else: - raise InternalGuppyError("Local call of something without a callable type") + remaining_args = args + for elem in self._unpack_tuple(func): + outs, remaining_args = self._compile_tensor_with_leftovers( + elem, remaining_args + ) + rets.extend(outs) + assert remaining_args == [] - return self._pack_returns(rets) + if len(rets) == 1: + return rets[0] + else: + return self._pack_returns(rets) def _compile_tensor_with_leftovers( self, func: OutPortV, args: list[OutPortV] @@ -239,26 +233,11 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: assert isinstance(func, CompiledCallableDef) args = [self.visit(arg) for arg in node.args] - - if isinstance(func.ty, FunctionType): - rets = func.compile_call( - args, list(node.type_args), self.dfg, self.graph, self.globals, node - ) - else: - raise InternalGuppyError("Local call of something without a callable type") - + rets = func.compile_call( + args, list(node.type_args), self.dfg, self.graph, self.globals, node + ) return self._pack_returns(rets) - def visit_TensorCall(self, node: TensorCall) -> OutPortV: - outputs = [] - for call in node.call_nodes: - output = self.visit(call) - if isinstance(output.ty, TupleType): - outputs.extend(self._unpack_tuple(output)) - else: - outputs.append(output) - return self._pack_returns(outputs) - def visit_Call(self, node: ast.Call) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/nodes.py b/guppylang/nodes.py index c1cbdcd0..94030ff6 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -53,12 +53,16 @@ class GlobalCall(ast.expr): class TensorCall(ast.expr): - """A call to a tuple of functions. Stores a call node for each function in the - tuple""" + """A call to a tuple of functions. Behaves like a local call, but more + unpacking of tuples is required at compilation""" - call_nodes: list[ast.expr] + func: ast.expr + args: list[ast.expr] - _fields = ("call_nodes",) + _fields = ( + "func", + "args", + ) class TypeApply(ast.expr): From 219fd96d4092e1c8cb2570eed15e511333394a28 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 14:17:55 +0100 Subject: [PATCH 33/41] fix: Add out_tys to TensorCall; pack correct types --- guppylang/checker/expr_checker.py | 4 +++- guppylang/compiler/expr_compiler.py | 2 +- guppylang/nodes.py | 4 +++- tests/integration/test_call.py | 6 +++--- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index ec271e84..c568b1da 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -265,7 +265,9 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: else: return with_loc( node, - TensorCall(func=node.func, args=processed_args), + TensorCall( + func=node.func, args=processed_args, out_tys=tensor_ty.output + ), ), result_subst elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index f8e6c372..a5204505 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -206,7 +206,7 @@ def visit_TensorCall(self, node: TensorCall) -> OutPortV: if len(rets) == 1: return rets[0] else: - return self._pack_returns(rets, func.ty) + return self._pack_returns(rets, node.out_tys) def _compile_tensor_with_leftovers( self, func: OutPortV, args: list[OutPortV] diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 94030ff6..286d7bf0 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any from guppylang.tys.subst import Inst -from guppylang.tys.ty import FunctionType +from guppylang.tys.ty import FunctionType, Type if TYPE_CHECKING: from guppylang.cfg.cfg import CFG @@ -58,10 +58,12 @@ class TensorCall(ast.expr): func: ast.expr args: list[ast.expr] + out_tys: Type _fields = ( "func", "args", + "out_tys", ) diff --git a/tests/integration/test_call.py b/tests/integration/test_call.py index 70067adb..82fe8c5d 100644 --- a/tests/integration/test_call.py +++ b/tests/integration/test_call.py @@ -58,11 +58,11 @@ def test_unary_tuple(validate): @guppy(module) def foo(x: int) -> tuple[int]: - return x, + return (x,) @guppy(module) def bar(x: int) -> int: - y, = foo(x) + (y,) = foo(x) return y - validate(module.compile()) \ No newline at end of file + validate(module.compile()) From b89275dc0d2bf00481c722302de3b384f51801a4 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 14:19:14 +0100 Subject: [PATCH 34/41] tests: Add a test --- tests/integration/test_tensor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py index b512667f..d0f41467 100644 --- a/tests/integration/test_tensor.py +++ b/tests/integration/test_tensor.py @@ -83,6 +83,20 @@ def local(f: Callable[[int], bool], g: Callable[[bool], int]) -> tuple[bool, int validate(module.compile()) +def test_singleton(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int, y: int) -> tuple[int, int]: + return y, x + + @guppy(module) + def baz(x: int) -> tuple[int, int]: + return (foo,)(x, x) + + validate(module.compile()) + + def test_call_back(validate): module = GuppyModule("module") From e68fe5c17288c4ddcb4845ef252af4f2c2a76204 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 14:19:59 +0100 Subject: [PATCH 35/41] fix: Add missing loc to poly tensor errors And update golden test --- guppylang/checker/expr_checker.py | 4 ++-- tests/error/tensor_errors/poly_tensor.err | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index c568b1da..02202bb1 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -248,7 +248,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: assert isinstance(function_elements, list) if any(f.parametrized for f in function_elements): raise GuppyTypeError( - "Polymorphic functions in tuples are not supported" + "Polymorphic functions in tuples are not supported", node.func ) tensor_ty = function_tensor_signature(function_elements) @@ -514,7 +514,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: ): if any(f.parametrized for f in function_elems): raise GuppyTypeError( - "Polymorphic functions in tuples are not supported" + "Polymorphic functions in tuples are not supported", node.func ) tensor_ty = function_tensor_signature(function_elems) diff --git a/tests/error/tensor_errors/poly_tensor.err b/tests/error/tensor_errors/poly_tensor.err index eb466d9a..5d4e15ff 100644 --- a/tests/error/tensor_errors/poly_tensor.err +++ b/tests/error/tensor_errors/poly_tensor.err @@ -1,7 +1,7 @@ -Guppy compilation failed. Error in file $FILE:14 +Guppy compilation failed. Error in file /Users/croy/work/guppy/tests/error/tensor_errors/poly_tensor.py:14 12: @guppy(module) 13: def main() -> int: 14: return (foo, foo)(42, 42) - ^^ -GuppyTypeError: Expected argument of type `T`, got `int` + ^^^^^^^^^^ +GuppyTypeError: Polymorphic functions in tuples are not supported From 45b62b40af2f4c10d0a6440d771b59dc3362430d Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 14:22:25 +0100 Subject: [PATCH 36/41] fix: Remove file path in golden test --- tests/error/tensor_errors/poly_tensor.err | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/error/tensor_errors/poly_tensor.err b/tests/error/tensor_errors/poly_tensor.err index 5d4e15ff..c6b871e2 100644 --- a/tests/error/tensor_errors/poly_tensor.err +++ b/tests/error/tensor_errors/poly_tensor.err @@ -1,4 +1,4 @@ -Guppy compilation failed. Error in file /Users/croy/work/guppy/tests/error/tensor_errors/poly_tensor.py:14 +Guppy compilation failed. Error in file $FILE:14 12: @guppy(module) 13: def main() -> int: From 611b84da4fad274041dfe715dfe20c5b2f461c40 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 16:27:16 +0100 Subject: [PATCH 37/41] Update guppylang/checker/expr_checker.py Co-authored-by: Mark Koch <48097969+mark-koch@users.noreply.github.com> --- guppylang/checker/expr_checker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 02202bb1..a28dcd5f 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -245,7 +245,6 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: if isinstance(func_ty, TupleType) and ( function_elements := parse_function_tensor(func_ty) ): - assert isinstance(function_elements, list) if any(f.parametrized for f in function_elements): raise GuppyTypeError( "Polymorphic functions in tuples are not supported", node.func From f618f66365e4524643a6be8923a642690e4338d2 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 16:33:35 +0100 Subject: [PATCH 38/41] refactor: Simplify type synth logic for tuple calls --- guppylang/checker/expr_checker.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index a28dcd5f..b0e7e398 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -253,21 +253,15 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: tensor_ty = function_tensor_signature(function_elements) processed_args, subst, inst = check_call( - tensor_ty, node.args, tensor_ty.output, node.func, self.ctx + tensor_ty, node.args, ty, node.func, self.ctx ) - assert len(inst) == 0 - - result_subst = unify(ty, tensor_ty.output, subst) - if result_subst is None: - return self._fail(ty, tensor_ty.output, node) - else: - return with_loc( - node, - TensorCall( - func=node.func, args=processed_args, out_tys=tensor_ty.output - ), - ), result_subst + return with_loc( + node, + TensorCall( + func=node.func, args=processed_args, out_tys=tensor_ty.output + ), + ), subst elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): return callee.check_call(node.args, ty, node, self.ctx) From 36d98dfb9ccc41628cc01d19ae41f87d4a122009 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 16:36:27 +0100 Subject: [PATCH 39/41] feat: Change check_call location for better errors --- guppylang/checker/expr_checker.py | 2 +- tests/error/tensor_errors/too_few_args.err | 2 +- tests/error/tensor_errors/too_many_args.err | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index b0e7e398..efc14131 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -253,7 +253,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: tensor_ty = function_tensor_signature(function_elements) processed_args, subst, inst = check_call( - tensor_ty, node.args, ty, node.func, self.ctx + tensor_ty, node.args, ty, node, self.ctx ) assert len(inst) == 0 return with_loc( diff --git a/tests/error/tensor_errors/too_few_args.err b/tests/error/tensor_errors/too_few_args.err index ed7eb98d..3f63bb04 100644 --- a/tests/error/tensor_errors/too_few_args.err +++ b/tests/error/tensor_errors/too_few_args.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:12 10: @guppy(module) 11: def main() -> int: 12: return (foo, foo)(42) - ^^^^^^^^^^ + ^^^^^^^^^^^^^^ GuppyTypeError: Not enough arguments passed (expected 2, got 1) diff --git a/tests/error/tensor_errors/too_many_args.err b/tests/error/tensor_errors/too_many_args.err index d30f2666..d6e23bfe 100644 --- a/tests/error/tensor_errors/too_many_args.err +++ b/tests/error/tensor_errors/too_many_args.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:12 10: @guppy(module) 11: def main() -> int: 12: return (foo, foo)(1, 2, 3) - ^^^^^^^^^^ -GuppyTypeError: Too many arguments passed (expected 2, got 3) + ^ +GuppyTypeError: Unexpected argument From 7c73c10fac068b69f26d53f8718539060d7c1de2 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 16:43:37 +0100 Subject: [PATCH 40/41] fix: Add missing out_tys param in synth TensorCall --- guppylang/checker/expr_checker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index efc14131..f1c6e0ee 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -516,7 +516,9 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: ) assert len(inst) == 0 - return with_loc(node, TensorCall(func=node.func, args=args)), return_ty + return with_loc( + node, TensorCall(func=node.func, args=args, out_tys=tensor_ty.output) + ), return_ty elif f := self.ctx.globals.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) From 0373dc6d8ff956bd33174e99b0573786d9c312ad Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 14 May 2024 16:44:46 +0100 Subject: [PATCH 41/41] cleanup: Remove duplicate checks --- guppylang/compiler/expr_compiler.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index a5204505..e219d3da 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -203,10 +203,7 @@ def visit_TensorCall(self, node: TensorCall) -> OutPortV: rets.extend(outs) assert remaining_args == [] - if len(rets) == 1: - return rets[0] - else: - return self._pack_returns(rets, node.out_tys) + return self._pack_returns(rets, node.out_tys) def _compile_tensor_with_leftovers( self, func: OutPortV, args: list[OutPortV]