diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index d26eec4bc..4277edd10 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -8,15 +8,16 @@ from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire from ._tys import FunctionType, TypeRow, Type +import hugr._val as val class Block(_DfBase[ops.DataflowBlock]): def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: self.set_outputs(branching, *other_outputs) - def set_single_successor_outputs(self, *outputs: Wire) -> None: - # TODO requires constants - raise NotImplementedError + def set_single_succ_outputs(self, *outputs: Wire) -> None: + u = self.add_load_const(val.Unit) + self.set_outputs(u, *outputs) def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: src = p.out_port() diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 7e89b91bc..8d4d1447f 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -1,19 +1,20 @@ from __future__ import annotations + from dataclasses import dataclass, replace from typing import ( - Iterable, TYPE_CHECKING, + Iterable, TypeVar, ) -from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder from typing_extensions import Self + import hugr._ops as ops -from hugr._tys import TypeRow +import hugr._val as val +from hugr._tys import Type, TypeRow from ._exceptions import NoSiblingAncestor -from ._hugr import ToNode -from hugr._tys import Type +from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire if TYPE_CHECKING: from ._cfg import Cfg @@ -113,6 +114,21 @@ def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges self.hugr.add_link(src.out(-1), dst.inp(-1)) + def add_const(self, val: val.Value) -> Node: + return self.hugr.add_const(val, self.parent_node) + + def load_const(self, const_node: ToNode) -> Node: + const_op = self.hugr._get_typed_op(const_node, ops.Const) + load_op = ops.LoadConst(const_op.val.type_()) + + load = self.add(load_op()) + self.hugr.add_link(const_node.out_port(), load.inp(0)) + + return load + + def add_load_const(self, val: val.Value) -> Node: + return self.load_const(self.add_const(val)) + def _wire_up(self, node: Node, ports: Iterable[Wire]): tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.PartialOp): diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index a54d35e6a..a13cb3d1d 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -17,8 +17,9 @@ from typing_extensions import Self -from hugr._ops import Op, DataflowOp +from hugr._ops import Op, DataflowOp, Const from hugr._tys import Type, Kind +from hugr._val import Value from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -228,6 +229,9 @@ def add_node( parent = parent or self.root return self._add_node(op, parent, num_outs) + def add_const(self, value: Value, parent: ToNode | None = None) -> Node: + return self.add_node(Const(value), parent) + def delete_node(self, node: ToNode) -> NodeData | None: node = node.to_node() parent = self[node].parent diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 026619890..b74541789 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -6,6 +6,7 @@ import hugr.serialization.ops as sops from hugr.utils import ser_it import hugr._tys as tys +import hugr._val as val from ._exceptions import IncompleteOp if TYPE_CHECKING: @@ -360,3 +361,37 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.CFKind() + + +@dataclass +class Const(Op): + val: val.Value + num_out: int | None = 1 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const: + return sops.Const( + parent=parent.idx, + v=self.val.to_serial_root(), + ) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + return tys.ConstKind(self.val.type_()) + + +@dataclass +class LoadConst(DataflowOp): + typ: tys.Type | None = None + + def type_(self) -> tys.Type: + if self.typ is None: + raise IncompleteOp() + return self.typ + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant: + return sops.LoadConstant( + parent=parent.idx, + datatype=self.type_().to_serial_root(), + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=[], output=[self.type_()]) diff --git a/hugr-py/src/hugr/_val.py b/hugr-py/src/hugr/_val.py new file mode 100644 index 000000000..d2a6277eb --- /dev/null +++ b/hugr-py/src/hugr/_val.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING +import hugr.serialization.ops as sops +import hugr.serialization.tys as stys +import hugr._tys as tys +from hugr.utils import ser_it + +if TYPE_CHECKING: + from hugr._hugr import Hugr + + +@runtime_checkable +class Value(Protocol): + def to_serial(self) -> sops.BaseValue: ... + def to_serial_root(self) -> sops.Value: + return sops.Value(root=self.to_serial()) # type: ignore[arg-type] + + def type_(self) -> tys.Type: ... + + +@dataclass +class Sum(Value): + tag: int + typ: tys.Sum + vals: list[Value] + + def type_(self) -> tys.Sum: + return self.typ + + def to_serial(self) -> sops.SumValue: + return sops.SumValue( + tag=self.tag, + typ=stys.SumType(root=self.type_().to_serial()), + vs=ser_it(self.vals), + ) + + +def bool_value(b: bool) -> Sum: + return Sum( + tag=int(b), + typ=tys.Bool, + vals=[], + ) + + +Unit = Sum(0, tys.Unit, []) +TRUE = bool_value(True) +FALSE = bool_value(False) + + +@dataclass +class Tuple(Value): + vals: list[Value] + + def type_(self) -> tys.Tuple: + return tys.Tuple(*(v.type_() for v in self.vals)) + + def to_serial(self) -> sops.TupleValue: + return sops.TupleValue( + vs=ser_it(self.vals), + ) + + +@dataclass +class Function(Value): + body: Hugr + + def type_(self) -> tys.FunctionType: + return self.body.root_op().inner_signature() + + def to_serial(self) -> sops.FunctionValue: + return sops.FunctionValue( + hugr=self.body.to_serial(), + ) + + +@dataclass +class Extension(Value): + name: str + typ: tys.Type + val: Any + extensions: tys.ExtensionSet = field(default_factory=tys.ExtensionSet) + + def type_(self) -> tys.Type: + return self.typ + + def to_serial(self) -> sops.ExtensionValue: + return sops.ExtensionValue( + typ=self.typ.to_serial_root(), + value=sops.CustomConst(c=self.name, v=self.val), + extensions=self.extensions, + ) + + +class ExtensionValue(Value, Protocol): + def to_value(self) -> Extension: ... + + def type_(self) -> tys.Type: + return self.to_value().type_() + + def to_serial(self) -> sops.ExtensionValue: + return self.to_value().to_serial() diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 049aa1d9f..16eb500be 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect import sys -from abc import ABC +from abc import ABC, abstractmethod from typing import Any, Literal from pydantic import Field, RootModel, ConfigDict @@ -80,36 +80,50 @@ class CustomConst(ConfiguredBaseModel): v: Any -class ExtensionValue(ConfiguredBaseModel): +class BaseValue(ABC, ConfiguredBaseModel): + @abstractmethod + def deserialize(self) -> _val.Value: ... + + +class ExtensionValue(BaseValue): """An extension constant value, that can check it is of a given [CustomType].""" - v: Literal["Extension"] = Field("Extension", title="ValueTag") + v: Literal["Extension"] = Field(default="Extension", title="ValueTag") extensions: ExtensionSet typ: Type value: CustomConst + def deserialize(self) -> _val.Value: + return _val.Extension(self.value.c, self.typ.deserialize(), self.value) -class FunctionValue(ConfiguredBaseModel): + +class FunctionValue(BaseValue): """A higher-order function value.""" - v: Literal["Function"] = Field("Function", title="ValueTag") + v: Literal["Function"] = Field(default="Function", title="ValueTag") hugr: Any # TODO + def deserialize(self) -> _val.Value: + return _val.Function(self.hugr) + -class TupleValue(ConfiguredBaseModel): +class TupleValue(BaseValue): """A constant tuple value.""" - v: Literal["Tuple"] = Field("Tuple", title="ValueTag") + v: Literal["Tuple"] = Field(default="Tuple", title="ValueTag") vs: list["Value"] + def deserialize(self) -> _val.Value: + return _val.Tuple(deser_it((v.root for v in self.vs))) -class SumValue(ConfiguredBaseModel): + +class SumValue(BaseValue): """A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag """ - v: Literal["Sum"] = Field("Sum", title="ValueTag") + v: Literal["Sum"] = Field(default="Sum", title="ValueTag") tag: int typ: SumType vs: list["Value"] @@ -122,6 +136,11 @@ class SumValue(ConfiguredBaseModel): } ) + def deserialize(self) -> _val.Value: + return _val.Sum( + self.tag, self.typ.deserialize(), deser_it((v.root for v in self.vs)) + ) + class Value(RootModel): """A constant Value.""" @@ -282,6 +301,9 @@ class LoadConstant(DataflowOp): op: Literal["LoadConstant"] = "LoadConstant" datatype: Type + def deserialize(self) -> _ops.LoadConst: + return _ops.LoadConst(self.datatype.deserialize()) + class LoadFunction(DataflowOp): """Load a static function in to the local dataflow graph.""" @@ -560,3 +582,4 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): tys_model_rebuild(dict(classes)) from hugr import _ops # noqa: E402 # needed to avoid circular imports +from hugr import _val # noqa: E402 # needed to avoid circular imports diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index c8fc5f511..a5beed4b9 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -1,34 +1,35 @@ from hugr._cfg import Cfg import hugr._tys as tys +import hugr._val as val from hugr._dfg import Dfg import hugr._ops as ops -from .test_hugr_build import _validate, INT_T, DivMod +from .test_hugr_build import _validate, INT_T, DivMod, IntVal def build_basic_cfg(cfg: Cfg) -> None: entry = cfg.add_entry() - entry.set_block_outputs(*entry.inputs()) + entry.set_single_succ_outputs(*entry.inputs()) cfg.branch(entry[0], cfg.exit) def test_basic_cfg() -> None: - cfg = Cfg([tys.Unit, tys.Bool]) + cfg = Cfg([tys.Bool]) build_basic_cfg(cfg) _validate(cfg.hugr) def test_branch() -> None: - cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + cfg = Cfg([tys.Bool, INT_T]) entry = cfg.add_entry() entry.set_block_outputs(*entry.inputs()) middle_1 = cfg.add_successor(entry[0]) - middle_1.set_block_outputs(*middle_1.inputs()) + middle_1.set_single_succ_outputs(*middle_1.inputs()) middle_2 = cfg.add_successor(entry[1]) - u, i = middle_2.inputs() + (i,) = middle_2.inputs() n = middle_2.add(DivMod(i, i)) - middle_2.set_block_outputs(u, n[0]) + middle_2.set_single_succ_outputs(n[0]) cfg.branch_exit(middle_1[0]) cfg.branch_exit(middle_2[0]) @@ -37,9 +38,9 @@ def test_branch() -> None: def test_nested_cfg() -> None: - dfg = Dfg(tys.Unit, tys.Bool) + dfg = Dfg(tys.Bool) - cfg = dfg.add_cfg([tys.Unit, tys.Bool], *dfg.inputs()) + cfg = dfg.add_cfg([tys.Bool], *dfg.inputs()) build_basic_cfg(cfg) dfg.set_outputs(cfg) @@ -68,16 +69,16 @@ def test_dom_edge() -> None: def test_asymm_types() -> None: # test different types going to entry block's susccessors - cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + cfg = Cfg([]) entry = cfg.add_entry() - b, u, i = entry.inputs() - tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(i)) + int_load = entry.add_load_const(IntVal(34)) + tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(int_load)) entry.set_block_outputs(tagged_int) middle = cfg.add_successor(entry[0]) # discard the int and return the bool from entry - middle.set_block_outputs(u, b) + middle.set_single_succ_outputs(middle.add_load_const(val.TRUE)) # middle expects an int and exit expects a bool cfg.branch_exit(entry[1]) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index ff7a2759b..79dc24d0e 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -9,6 +9,7 @@ import hugr._ops as ops from hugr.serialization import SerialHugr import hugr._tys as tys +import hugr._val as val import pytest import json @@ -25,6 +26,14 @@ def int_t(width: int) -> tys.Opaque: INT_T = int_t(5) +@dataclass +class IntVal(val.ExtensionValue): + v: int + + def to_value(self) -> val.Extension: + return val.Extension("int", INT_T, self.v) + + @dataclass class LogicOps(Custom): extension: tys.ExtensionId = "logic"