diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 9af0567c4..ee4ec1704 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -15,7 +15,8 @@ from hugr._tys import Type, TypeRow, get_first_sum, FunctionType, TypeArg, FunctionKind from ._exceptions import NoSiblingAncestor -from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire +from ._hugr import Hugr, ParentBuilder +from ._node_port import Node, OutPort, Wire, ToNode if TYPE_CHECKING: from ._cfg import Cfg @@ -164,6 +165,28 @@ def load(self, const: ToNode | val.Value) -> Node: return load + def call( + self, + func: ToNode, + *args: Wire, + instantiation: FunctionType | None = None, + type_args: list[TypeArg] | None = None, + ) -> Node: + f_op = self.hugr[func] + f_kind = f_op.op.port_kind(func.out(0)) + match f_kind: + case FunctionKind(sig): + signature = sig + case _: + raise ValueError(f"Expected function type, got {f_kind}") + call_op = ops.Call(signature, instantiation, type_args) + call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out) + self.hugr.add_link(func.out(0), call_n.inp(call_op.function_port_offset())) + + self._wire_up(call_n, args) + + return call_n + def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.PartialOp): diff --git a/hugr-py/src/hugr/_function.py b/hugr-py/src/hugr/_function.py new file mode 100644 index 000000000..d8ffa578b --- /dev/null +++ b/hugr-py/src/hugr/_function.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import hugr._ops as ops +import hugr._val as val + +from ._dfg import _DfBase +from hugr._node_port import Node +from ._hugr import Hugr +from ._tys import TypeRow, TypeParam, PolyFuncType + + +@dataclass +class Function(_DfBase[ops.FuncDefn]): + def __init__( + self, + name: str, + input_types: TypeRow, + type_params: list[TypeParam] | None = None, + ) -> None: + root_op = ops.FuncDefn(name, input_types, type_params or []) + super().__init__(root_op) + + +@dataclass +class Module: + hugr: Hugr + + def __init__(self) -> None: + self.hugr = Hugr(ops.Module()) + + def define_function( + self, + name: str, + input_types: TypeRow, + type_params: list[TypeParam] | None = None, + ) -> Function: + parent_op = ops.FuncDefn(name, input_types, type_params or []) + return Function.new_nested(parent_op, self.hugr) + + def define_main(self, input_types: TypeRow) -> Function: + return self.define_function("main", input_types) + + def declare_function(self, name: str, signature: PolyFuncType) -> Node: + return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.root) + + def add_const(self, value: val.Value) -> Node: + return self.hugr.add_node(ops.Const(value), self.hugr.root) diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index e379ad600..7676fef00 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -13,8 +13,8 @@ ) -from hugr._ops import Op, DataflowOp, Const -from hugr._tys import Type, Kind +from hugr._ops import Op, DataflowOp, Const, Call +from hugr._tys import Type, Kind, ValueKind from hugr._val import Value from hugr._node_port import Direction, InPort, OutPort, ToNode, Node, _SubPort from hugr.serialization.ops import OpType as SerialOp @@ -257,6 +257,10 @@ def port_type(self, port: InPort | OutPort) -> Type | None: op = self[port.node].op if isinstance(op, DataflowOp): return op.port_type(port) + if isinstance(op, Call) and isinstance(port, OutPort): + kind = self.port_kind(port) + assert isinstance(kind, ValueKind) + return kind.ty return None def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]: diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index a590775f6..5bc63c79a 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -14,6 +14,15 @@ from hugr._hugr import Hugr +@dataclass +class InvalidPort(Exception): + port: InPort | OutPort + + @property + def msg(self) -> str: + return f"Invalid port {self.port}" + + @runtime_checkable class Op(Protocol): @property @@ -25,6 +34,14 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... def port_kind(self, port: InPort | OutPort) -> tys.Kind: ... +def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type: + from hugr._hugr import Direction + + if port.direction == Direction.INCOMING: + return sig.input[port.offset] + return sig.output[port.offset] + + @runtime_checkable class DataflowOp(Op, Protocol): def outer_signature(self) -> tys.FunctionType: ... @@ -35,12 +52,7 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.ValueKind(self.port_type(port)) def port_type(self, port: InPort | OutPort) -> tys.Type: - from hugr._hugr import Direction - - sig = self.outer_signature() - if port.direction == Direction.INCOMING: - return sig.input[port.offset] - return sig.output[port.offset] + return _sig_port_type(self.outer_signature(), port) def __call__(self, *args) -> Command: return Command(self, list(args)) @@ -359,7 +371,11 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const: ) def port_kind(self, port: InPort | OutPort) -> tys.Kind: - return tys.ConstKind(self.val.type_()) + match port: + case OutPort(_, 0): + return tys.ConstKind(self.val.type_()) + case _: + raise InvalidPort(port) @dataclass @@ -378,6 +394,15 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant: def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=[], output=[self.type_()]) + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case InPort(_, 0): + return tys.ConstKind(self.type_()) + case OutPort(_, 0): + return tys.ValueKind(self.type_()) + case _: + raise InvalidPort(port) + @dataclass() class Conditional(DataflowOp): @@ -417,15 +442,12 @@ def nth_inputs(self, n: int) -> tys.TypeRow: class Case(DfParentOp): inputs: tys.TypeRow _outputs: tys.TypeRow | None = None + num_out: int | None = 0 @property def outputs(self) -> tys.TypeRow: return _check_complete(self._outputs) - @property - def num_out(self) -> int | None: - return 0 - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Case: return sops.Case( parent=parent.idx, signature=self.inner_signature().to_serial() @@ -435,7 +457,7 @@ def inner_signature(self) -> tys.FunctionType: return tys.FunctionType(self.inputs, self.outputs) def port_kind(self, port: InPort | OutPort) -> tys.Kind: - raise NotImplementedError("Case nodes have no external ports.") + raise InvalidPort(port) def _set_out_types(self, types: tys.TypeRow) -> None: self._outputs = types @@ -486,3 +508,132 @@ def _set_out_types(self, types: tys.TypeRow) -> None: def _inputs(self) -> tys.TypeRow: return self.just_inputs + self.rest + + +@dataclass +class FuncDefn(DfParentOp): + name: str + inputs: tys.TypeRow + params: list[tys.TypeParam] = field(default_factory=list) + _outputs: tys.TypeRow | None = None + num_out: int | None = 1 + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) + + @property + def signature(self) -> tys.PolyFuncType: + return tys.PolyFuncType( + self.params, tys.FunctionType(self.inputs, self.outputs) + ) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.FuncDefn: + return sops.FuncDefn( + parent=parent.idx, + name=self.name, + signature=self.signature.to_serial(), + ) + + def inner_signature(self) -> tys.FunctionType: + return self.signature.body + + def _set_out_types(self, types: tys.TypeRow) -> None: + self._outputs = types + + def _inputs(self) -> tys.TypeRow: + return self.inputs + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case OutPort(_, 0): + return tys.FunctionKind(self.signature) + case _: + raise InvalidPort(port) + + +@dataclass +class FuncDecl(Op): + name: str + signature: tys.PolyFuncType + num_out: int | None = 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.FuncDecl: + return sops.FuncDecl( + parent=parent.idx, + name=self.name, + signature=self.signature.to_serial(), + ) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case OutPort(_, 0): + return tys.FunctionKind(self.signature) + case _: + raise InvalidPort(port) + + +@dataclass +class Module(Op): + num_out: int | None = 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Module: + return sops.Module(parent=parent.idx) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + raise InvalidPort(port) + + +class NoConcreteFunc(Exception): + pass + + +@dataclass +class Call(Op): + signature: tys.PolyFuncType + instantiation: tys.FunctionType + type_args: list[tys.TypeArg] + + def __init__( + self, + signature: tys.PolyFuncType, + instantiation: tys.FunctionType | None = None, + type_args: list[tys.TypeArg] | None = None, + ) -> None: + self.signature = signature + if len(signature.params) == 0: + self.instantiation = signature.body + self.type_args = [] + + else: + # TODO substitute type args into signature to get instantiation + if instantiation is None: + raise NoConcreteFunc("Missing instantiation for polymorphic function.") + type_args = type_args or [] + + if len(signature.params) != len(type_args): + raise NoConcreteFunc("Mismatched number of type arguments.") + self.instantiation = instantiation + self.type_args = type_args + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Call: + return sops.Call( + parent=parent.idx, + func_sig=self.signature.to_serial(), + type_args=ser_it(self.type_args), + instantiation=self.instantiation.to_serial(), + ) + + @property + def num_out(self) -> int | None: + return len(self.signature.body.output) + + def function_port_offset(self) -> int: + return len(self.signature.body.input) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case InPort(_, offset) if offset == self.function_port_offset(): + return tys.FunctionKind(self.signature) + case _: + return tys.ValueKind(_sig_port_type(self.instantiation, port)) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 48dd43b08..6e2e3584e 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -36,6 +36,9 @@ def to_serial(self) -> stys.BaseType: ... def to_serial_root(self) -> stys.Type: return stys.Type(root=self.to_serial()) # type: ignore[arg-type] + def type_arg(self) -> TypeTypeArg: + return TypeTypeArg(self) + TypeRow = list[Type] diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index d2fedf027..a81ca95ba 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -58,6 +58,9 @@ class Module(BaseOp): op: Literal["Module"] = "Module" + def deserialize(self) -> _ops.Module: + return _ops.Module() + class FuncDefn(BaseOp): """A function definition. Children nodes are the body of the definition.""" @@ -67,6 +70,12 @@ class FuncDefn(BaseOp): name: str signature: PolyFuncType + def deserialize(self) -> _ops.FuncDefn: + poly_func = self.signature.deserialize() + return _ops.FuncDefn( + self.name, inputs=poly_func.body.input, _outputs=poly_func.body.output + ) + class FuncDecl(BaseOp): """External function declaration, linked at runtime.""" @@ -75,6 +84,9 @@ class FuncDecl(BaseOp): name: str signature: PolyFuncType + def deserialize(self) -> _ops.FuncDecl: + return _ops.FuncDecl(self.name, self.signature.deserialize()) + class CustomConst(ConfiguredBaseModel): c: str @@ -298,6 +310,13 @@ class Call(DataflowOp): } ) + def deserialize(self) -> _ops.Call: + return _ops.Call( + self.func_sig.deserialize(), + self.instantiation.deserialize(), + deser_it(self.type_args), + ) + class CallIndirect(DataflowOp): """Call a function indirectly. diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 7598350f7..cad65436d 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -7,11 +7,12 @@ from hugr._hugr import Hugr from hugr._dfg import Dfg, _ancestral_sibling -from hugr._ops import Custom, Command +from hugr._ops import Custom, Command, NoConcreteFunc import hugr._ops as ops from hugr.serialization import SerialHugr import hugr._tys as tys import hugr._val as val +from hugr._function import Module import pytest import json @@ -315,3 +316,43 @@ def test_vals(val: val.Value): d.set_outputs(d.load(val)) _validate(d.hugr) + + +def test_poly_function() -> None: + mod = Module() + f_id = mod.declare_function( + "id", + tys.PolyFuncType( + [tys.TypeTypeParam(tys.TypeBound.Any)], + tys.FunctionType.endo([tys.Variable(0, tys.TypeBound.Any)]), + ), + ) + + f_main = mod.define_main([tys.Qubit]) + q = f_main.input_node[0] + with pytest.raises(NoConcreteFunc, match="Missing instantiation"): + f_main.call(f_id, q) + call = f_main.call( + f_id, + q, + # for now concrete instantiations have to be provided. + instantiation=tys.FunctionType.endo([tys.Qubit]), + type_args=[tys.Qubit.type_arg()], + ) + f_main.set_outputs(call) + + _validate(mod.hugr, True) + + +def test_mono_function() -> None: + mod = Module() + f_id = mod.define_function("id", [tys.Qubit]) + f_id.set_outputs(f_id.input_node[0]) + + f_main = mod.define_main([tys.Qubit]) + q = f_main.input_node[0] + # monomorphic functions don't need instantiation specified + call = f_main.call(f_id, q) + f_main.set_outputs(call) + + _validate(mod.hugr, True)