diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index f11950b0..f1c6e0ee 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -61,6 +61,7 @@ LocalName, MakeIter, PyExpr, + TensorCall, TypeApply, ) from guppylang.tys.arg import TypeArg @@ -83,6 +84,8 @@ TupleType, Type, TypeBase, + function_tensor_signature, + parse_function_tensor, row_to_type, unify, ) @@ -231,14 +234,37 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: if isinstance(defn, CallableDef): return defn.check_call(node.args, ty, node, self.ctx) - # Otherwise, it must be a function as a higher-order value + # Otherwise, it must be a function as a higher-order value - something + # whose type is either a FunctionType or a Tuple of FunctionTypes if isinstance(func_ty, FunctionType): args, return_ty, inst = check_call(func_ty, node.args, ty, node, self.ctx) check_inst(func_ty, inst, node) node.func = instantiate_poly(node.func, func_ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty - elif f := self.ctx.globals.get_instance_func(func_ty, "__call__"): - return f.check_call(node.args, ty, node, self.ctx) + + if isinstance(func_ty, TupleType) and ( + function_elements := parse_function_tensor(func_ty) + ): + if any(f.parametrized for f in function_elements): + raise GuppyTypeError( + "Polymorphic functions in tuples are not supported", node.func + ) + + tensor_ty = function_tensor_signature(function_elements) + + processed_args, subst, inst = check_call( + tensor_ty, node.args, ty, node, self.ctx + ) + assert len(inst) == 0 + return with_loc( + node, + TensorCall( + func=node.func, args=processed_args, out_tys=tensor_ty.output + ), + ), subst + + elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): + return callee.check_call(node.args, ty, node, self.ctx) else: raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func) @@ -332,11 +358,12 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]: ) raise InternalGuppyError( f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have " - f"been caught by program analysis!" + "been caught by program analysis!" ) def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]: elems = [self.synthesize(elem) for elem in node.elts] + node.elts = [n for n, _ in elems] return node, TupleType([ty for _, ty in elems]) @@ -470,11 +497,29 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: if isinstance(defn, CallableDef): return defn.synthesize_call(node.args, node, self.ctx) - # Otherwise, it must be a function as a higher-order value + # Otherwise, it must be a function as a higher-order value, or a tensor if isinstance(ty, FunctionType): args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx) node.func = instantiate_poly(node.func, ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + elif isinstance(ty, TupleType) and ( + function_elems := parse_function_tensor(ty) + ): + if any(f.parametrized for f in function_elems): + raise GuppyTypeError( + "Polymorphic functions in tuples are not supported", node.func + ) + + tensor_ty = function_tensor_signature(function_elems) + args, return_ty, inst = synthesize_call( + tensor_ty, node.args, node, self.ctx + ) + assert len(inst) == 0 + + return with_loc( + node, TensorCall(func=node.func, args=args, out_tys=tensor_ty.output) + ), return_ty + elif f := self.ctx.globals.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) else: @@ -797,7 +842,7 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr: if len(inst) > 0: node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst)) return with_type(ty.instantiate(inst), node) - return node + return with_type(ty, node) def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type]: diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 10e4cc7f..e219d3da 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -22,6 +22,7 @@ GlobalName, LocalCall, LocalName, + TensorCall, TypeApply, ) from guppylang.tys.builtin import bool_type, get_element_type, is_list_type @@ -164,6 +165,10 @@ def visit_List(self, node: ast.List) -> OutPortV: ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts] ).add_out_port(get_type(node)) + def _unpack_tuple(self, wire: OutPortV) -> list[OutPortV]: + unpack_node = self.graph.add_unpack_tuple(wire, self.dfg.node) + return list(unpack_node.out_ports) + def _pack_returns(self, returns: list[OutPortV], return_ty: Type) -> OutPortV: """Groups function return values into a tuple""" if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve: @@ -183,6 +188,47 @@ def visit_LocalCall(self, node: LocalCall) -> OutPortV: rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.output)))] return self._pack_returns(rets, func.ty.output) + def visit_TensorCall(self, node: TensorCall) -> OutPortV: + func = self.visit(node.func) + args = [self.visit(arg) for arg in node.args] + + assert isinstance(func.ty, TupleType) + + rets: list[OutPortV] = [] + remaining_args = args + for elem in self._unpack_tuple(func): + outs, remaining_args = self._compile_tensor_with_leftovers( + elem, remaining_args + ) + rets.extend(outs) + assert remaining_args == [] + + return self._pack_returns(rets, node.out_tys) + + def _compile_tensor_with_leftovers( + self, func: OutPortV, args: list[OutPortV] + ) -> tuple[ + list[OutPortV], # Compiled outputs + list[OutPortV], + ]: # Leftover args + if isinstance(func.ty, TupleType): + remaining_args = args + all_outs = [] + for elem in self._unpack_tuple(func): + outs, remaining_args = self._compile_tensor_with_leftovers( + elem, remaining_args + ) + all_outs.extend(outs) + return all_outs, remaining_args + + elif isinstance(func.ty, FunctionType): + input_len = len(func.ty.inputs) + call = self.graph.add_indirect_call(func, args[0:input_len]) + + return list(call.out_ports), args[input_len:] + else: + raise InternalGuppyError("Tensor element wasn't function or tuple") + def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: func = self.globals[node.def_id] assert isinstance(func, CompiledCallableDef) diff --git a/guppylang/hugr/hugr.py b/guppylang/hugr/hugr.py index dcd23c7c..4e0f798b 100644 --- a/guppylang/hugr/hugr.py +++ b/guppylang/hugr/hugr.py @@ -492,6 +492,7 @@ def add_call( ) -> VNode: """Adds a `Call` node to the graph.""" assert isinstance(def_port.ty, FunctionType) + return self.add_node( ops.Call(), None, @@ -505,6 +506,7 @@ def add_indirect_call( ) -> VNode: """Adds an `IndirectCall` node to the graph.""" assert isinstance(fun_port.ty, FunctionType) + return self.add_node( ops.CallIndirect(), None, diff --git a/guppylang/nodes.py b/guppylang/nodes.py index af260062..286d7bf0 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any from guppylang.tys.subst import Inst -from guppylang.tys.ty import FunctionType +from guppylang.tys.ty import FunctionType, Type if TYPE_CHECKING: from guppylang.cfg.cfg import CFG @@ -52,6 +52,21 @@ class GlobalCall(ast.expr): ) +class TensorCall(ast.expr): + """A call to a tuple of functions. Behaves like a local call, but more + unpacking of tuples is required at compilation""" + + func: ast.expr + args: list[ast.expr] + out_tys: Type + + _fields = ( + "func", + "args", + "out_tys", + ) + + class TypeApply(ast.expr): value: ast.expr inst: Inst diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 6f6d9916..4e114663 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -520,3 +520,32 @@ def _unify_args( case _: return None return subst + + +### Helpers for working with tuples of functions + + +def parse_function_tensor(ty: TupleType) -> list[FunctionType] | None: + """Parses a nested tuple of function types into a flat list of functions.""" + result = [] + for el in ty.element_types: + if isinstance(el, FunctionType): + result.append(el) + elif isinstance(el, TupleType): + funcs = parse_function_tensor(el) + if funcs: + result.extend(funcs) + else: + return None + return result + + +def function_tensor_signature(tys: list[FunctionType]) -> FunctionType: + """Compute the combined function signature of a list of functions""" + inputs: list[Type] = [] + outputs: list[Type] = [] + for fun_ty in tys: + assert not fun_ty.parametrized + inputs.extend(fun_ty.inputs) + outputs.extend(type_to_row(fun_ty.output)) + return FunctionType(inputs, row_to_type(outputs)) diff --git a/tests/error/tensor_errors/poly_tensor.err b/tests/error/tensor_errors/poly_tensor.err new file mode 100644 index 00000000..c6b871e2 --- /dev/null +++ b/tests/error/tensor_errors/poly_tensor.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:14 + +12: @guppy(module) +13: def main() -> int: +14: return (foo, foo)(42, 42) + ^^^^^^^^^^ +GuppyTypeError: Polymorphic functions in tuples are not supported diff --git a/tests/error/tensor_errors/poly_tensor.py b/tests/error/tensor_errors/poly_tensor.py new file mode 100644 index 00000000..0931b137 --- /dev/null +++ b/tests/error/tensor_errors/poly_tensor.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + +@guppy.declare(module) +def foo(x: T) -> T: + ... + +@guppy(module) +def main() -> int: + return (foo, foo)(42, 42) + +module.compile() diff --git a/tests/error/tensor_errors/too_few_args.err b/tests/error/tensor_errors/too_few_args.err new file mode 100644 index 00000000..3f63bb04 --- /dev/null +++ b/tests/error/tensor_errors/too_few_args.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy(module) +11: def main() -> int: +12: return (foo, foo)(42) + ^^^^^^^^^^^^^^ +GuppyTypeError: Not enough arguments passed (expected 2, got 1) diff --git a/tests/error/tensor_errors/too_few_args.py b/tests/error/tensor_errors/too_few_args.py new file mode 100644 index 00000000..6bd2d17a --- /dev/null +++ b/tests/error/tensor_errors/too_few_args.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(x: int) -> int: + ... + +@guppy(module) +def main() -> int: + return (foo, foo)(42) + +module.compile() diff --git a/tests/error/tensor_errors/too_many_args.err b/tests/error/tensor_errors/too_many_args.err new file mode 100644 index 00000000..d6e23bfe --- /dev/null +++ b/tests/error/tensor_errors/too_many_args.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy(module) +11: def main() -> int: +12: return (foo, foo)(1, 2, 3) + ^ +GuppyTypeError: Unexpected argument diff --git a/tests/error/tensor_errors/too_many_args.py b/tests/error/tensor_errors/too_many_args.py new file mode 100644 index 00000000..7f70d380 --- /dev/null +++ b/tests/error/tensor_errors/too_many_args.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(x: int) -> int: + ... + +@guppy(module) +def main() -> int: + return (foo, foo)(1, 2, 3) + +module.compile() diff --git a/tests/error/tensor_errors/type_mismatch.err b/tests/error/tensor_errors/type_mismatch.err new file mode 100644 index 00000000..cfc31e7e --- /dev/null +++ b/tests/error/tensor_errors/type_mismatch.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy(module) +11: def main() -> int: +12: return (foo, foo)(42, False) + ^^^^^ +GuppyTypeError: Expected argument of type `int`, got `bool` diff --git a/tests/error/tensor_errors/type_mismatch.py b/tests/error/tensor_errors/type_mismatch.py new file mode 100644 index 00000000..9e4d40a2 --- /dev/null +++ b/tests/error/tensor_errors/type_mismatch.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +@guppy.declare(module) +def foo(x: int) -> int: + ... + +@guppy(module) +def main() -> int: + return (foo, foo)(42, False) + +module.compile() diff --git a/tests/error/test_tensor_errors.py b/tests/error/test_tensor_errors.py new file mode 100644 index 00000000..38c2959b --- /dev/null +++ b/tests/error/test_tensor_errors.py @@ -0,0 +1,15 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "tensor_errors" +files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_type_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/integration/test_call.py b/tests/integration/test_call.py index 70067adb..82fe8c5d 100644 --- a/tests/integration/test_call.py +++ b/tests/integration/test_call.py @@ -58,11 +58,11 @@ def test_unary_tuple(validate): @guppy(module) def foo(x: int) -> tuple[int]: - return x, + return (x,) @guppy(module) def bar(x: int) -> int: - y, = foo(x) + (y,) = foo(x) return y - validate(module.compile()) \ No newline at end of file + validate(module.compile()) diff --git a/tests/integration/test_tensor.py b/tests/integration/test_tensor.py new file mode 100644 index 00000000..d0f41467 --- /dev/null +++ b/tests/integration/test_tensor.py @@ -0,0 +1,186 @@ +# ruff: noqa: F821 +# ^ Stop ruff complaining about not knowing what "tensor" is + +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +def test_check_callable(validate): + module = GuppyModule("module") + + @guppy(module) + def bar(f: Callable[[int], bool]) -> Callable[[int], bool]: + return f + + @guppy(module) + def foo(f: Callable[[int], bool]) -> tuple[Callable[[int], bool]]: + return (f,) + + @guppy(module) + def is_42(x: int) -> bool: + return x == 42 + + @guppy(module) + def baz(x: int) -> bool: + return foo(is_42)(x) + + @guppy(module) + def baz1() -> tuple[Callable[[int], bool]]: + return foo(is_42) + + @guppy(module) + def baz2(x: int) -> bool: + return (foo,)(is_42)(x) + + validate(module.compile()) + + +def test_call(validate): + module = GuppyModule("module") + + @guppy(module) + def foo() -> int: + return 42 + + @guppy(module) + def bar() -> bool: + return True + + @guppy(module) + def baz_ho() -> tuple[Callable[[], int], Callable[[], bool]]: + return (foo, bar) + + @guppy(module) + def baz_ho_id() -> tuple[Callable[[], int], Callable[[], bool]]: + return baz_ho() + + @guppy(module) + def baz_ho_call() -> tuple[int, bool]: + return baz_ho_id()() + + @guppy(module) + def baz() -> tuple[int, bool]: + return (foo, bar)() + + @guppy(module) + def call_var() -> tuple[int, bool, int]: + f = foo + g = (foo, bar, f) + return g() + + validate(module.compile()) + + +def test_call_inplace(validate): + module = GuppyModule("module") + + @guppy(module) + def local(f: Callable[[int], bool], g: Callable[[bool], int]) -> tuple[bool, int]: + return (f, g)(42, True) + + validate(module.compile()) + + +def test_singleton(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int, y: int) -> tuple[int, int]: + return y, x + + @guppy(module) + def baz(x: int) -> tuple[int, int]: + return (foo,)(x, x) + + validate(module.compile()) + + +def test_call_back(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int) -> int: + return x + + @guppy(module) + def bar(x: int) -> int: + return x * x + + @guppy(module) + def baz(x: int, y: int) -> tuple[int, int]: + return (foo, bar)(x, y) + + @guppy(module) + def call_var(x: int) -> tuple[int, int, int]: + f = foo, baz + return f(x, x, x) + + validate(module.compile()) + + +def test_normal(validate): + module = GuppyModule("module") + + @guppy(module) + def glo(x: int) -> int: + return x + + @guppy(module) + def foo(x: int) -> int: + return glo(x) + + validate(module.compile()) + + +def test_higher_order(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int) -> bool: + return x > 42 + + @guppy(module) + def bar(x: float) -> int: + if x < 5.0: + return 0 + else: + return 1 + + @guppy(module) + def baz() -> tuple[Callable[[int], bool], Callable[[float], int]]: + return foo, bar + + # For the future: + # + # @guppy(module) + # def apply(f: Callable[[int, float], + # tuple[bool, int]], + # args: tuple[int, float]) -> tuple[bool, int]: + # return f(*args) + # + # @guppy(module) + # def apply_call(args: tuple[int, float]) -> tuple[bool, int]: + # return apply(baz, args) + + validate(module.compile()) + + +def test_nesting(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int, y: int) -> int: + return x + y + + @guppy(module) + def bar(x: int) -> int: + return -x + + @guppy(module) + def call(x: int) -> tuple[int, int, int]: + f = bar, (bar, foo) + return f(x, x, x, x) + + validate(module.compile())