From 8b4ad230a140c060cd76789bb1ed5d85289f224a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 24 Jun 2024 13:11:27 +0100 Subject: [PATCH 1/2] fix: Loading custom polymorphic function defs as values --- guppylang/definition/custom.py | 27 ++++++++++----------------- tests/integration/test_poly.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/guppylang/definition/custom.py b/guppylang/definition/custom.py index 67789fe0..2ea67f42 100644 --- a/guppylang/definition/custom.py +++ b/guppylang/definition/custom.py @@ -12,7 +12,7 @@ from guppylang.definition.common import ParsableDef from guppylang.definition.value import CompiledCallableDef from guppylang.error import GuppyError, InternalGuppyError -from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV +from guppylang.hugr_builder.hugr import Hugr, OutPortV from guppylang.nodes import GlobalCall from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row @@ -145,27 +145,20 @@ def load_with_args( node, ) assert len(self.ty.params) == len(type_args) - - # Find the module node by walking up the hierarchy - module: Node = dfg.node - while not isinstance(module.op, ops.Module): - if module.parent is None: - raise InternalGuppyError( - "Encountered node that is not contained in a module." - ) - module = module.parent + ty = self.ty.instantiate(type_args) # We create a `FunctionDef` that takes some inputs, compiles a call to the # function, and returns the results. - def_node = graph.add_def(self.ty, module, self.name) - _, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node) - returns = self.compile_call( - inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node - ) - graph.add_output(returns, parent=def_node) + def_node = graph.add_def(ty, dfg.node, self.name) + with graph.parent(def_node): + _, inp_ports = graph.add_input_with_ports(list(ty.inputs)) + returns = self.compile_call( + inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node + ) + graph.add_output(returns) # Finally, load the function into the local DFG - return graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) + return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0) class CustomCallChecker(ABC): diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index 98107a32..1d6fc01e 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -3,6 +3,8 @@ import pytest from guppylang.decorator import guppy +from guppylang.definition.custom import CustomCallCompiler +from guppylang.hugr_builder.hugr import OutPortV from guppylang.module import GuppyModule from guppylang.prelude.quantum import qubit @@ -261,6 +263,23 @@ def main() -> None: validate(module.compile()) +def test_custom_higher_order(): + class CustomCompiler(CustomCallCompiler): + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + return args + + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.custom(module, CustomCompiler()) + def foo(x: T) -> T: ... + + @guppy(module) + def main(x: int) -> int: + f: Callable[[int], int] = foo + return f(x) + + @pytest.mark.skip("Not yet supported") def test_higher_order_value(validate): module = GuppyModule("test") From 4c7ab7c982db3253ab3740e954760f7873a1660a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 25 Jun 2024 10:55:22 +0100 Subject: [PATCH 2/2] Add comments and rename ty to func_ty --- guppylang/definition/custom.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/guppylang/definition/custom.py b/guppylang/definition/custom.py index 2ea67f42..b2e86370 100644 --- a/guppylang/definition/custom.py +++ b/guppylang/definition/custom.py @@ -145,19 +145,22 @@ def load_with_args( node, ) assert len(self.ty.params) == len(type_args) - ty = self.ty.instantiate(type_args) # We create a `FunctionDef` that takes some inputs, compiles a call to the - # function, and returns the results. - def_node = graph.add_def(ty, dfg.node, self.name) + # function, and returns the results. If the function signature is polymorphic, + # we explicitly monomorphise here and invoke the call compiler with the + # inferred type args. + fun_ty = self.ty.instantiate(type_args) + def_node = graph.add_def(fun_ty, dfg.node, self.name) with graph.parent(def_node): - _, inp_ports = graph.add_input_with_ports(list(ty.inputs)) + _, inp_ports = graph.add_input_with_ports(list(fun_ty.inputs)) returns = self.compile_call( inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node ) graph.add_output(returns) - # Finally, load the function into the local DFG + # Finally, load the function into the local DFG. We already monomorphised, so we + # can load with empty type args return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0)