Skip to content

Commit

Permalink
feat(hugr-py): CFG builder
Browse files Browse the repository at this point in the history
Closes #1188
  • Loading branch information
ss2165 committed Jun 12, 2024
1 parent cd5bf2d commit 81caec3
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 1 deletion.
74 changes: 74 additions & 0 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from ._hugr import Hugr, Node, Wire
from ._dfg import DfBase, _from_base
from ._tys import Type, FunctionType, TypeRow, Sum
import hugr._ops as ops


class Block(DfBase[ops.DataflowBlock]):
def block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError


@dataclass
class Cfg:
hugr: Hugr
root: Node
_entry_block: Block
exit: Node

def __init__(
self, input_types: Sequence[Type], output_types: Sequence[Type]
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = _from_base(
Block, self.hugr.add_dfg(ops.DataflowBlock(input_types, []))
)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root)

@property
def entry(self) -> Node:
return self._entry_block.root

def _entry_op(self) -> ops.DataflowBlock:
dop = self.hugr[self.entry].op
assert isinstance(dop, ops.DataflowBlock)
return dop

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
return self._entry_block

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
new_block = self.hugr.add_dfg(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs)
)
return _from_base(Block, new_block)

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)

def branch(self, src: Wire, dst: Node) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
10 changes: 10 additions & 0 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ def add_nested(
self._wire_up(dfg.root, args)
return _from_base(Dfg, dfg)

def add_cfg(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
*args: Wire,
) -> Cfg:
cfg = self.hugr.add_cfg(input_types, output_types)
self._wire_up(cfg.root, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.root)
self._wire_up(mapping[cfg.root], args)
Expand Down
18 changes: 18 additions & 0 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
TypeVar,
cast,
overload,
Sequence,
)

from typing_extensions import Self

from hugr._ops import Op
from hugr._tys import Type
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap
Expand All @@ -26,6 +28,7 @@

if TYPE_CHECKING:
from ._dfg import DfBase, DP
from ._cfg import Cfg


class Direction(Enum):
Expand Down Expand Up @@ -346,6 +349,21 @@ def add_dfg(self, root_op: DP) -> DfBase[DP]:
dfg.root = mapping[dfg.root]
return dfg

def add_cfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Cfg:
from ._cfg import Cfg

cfg = Cfg(input_types, output_types)
mapping = self.insert_hugr(cfg.hugr, self.root)
cfg.hugr = self
cfg._entry_block.root = mapping[cfg.entry]
cfg._entry_block.input_node = mapping[cfg._entry_block.input_node]
cfg._entry_block.output_node = mapping[cfg._entry_block.output_node]
cfg._entry_block.hugr = self
cfg.exit = mapping[cfg.exit]
cfg.root = mapping[cfg.root]
# TODO this is horrible
return cfg

def to_serial(self) -> SerialHugr:
node_it = (node for node in self._nodes if node is not None)

Expand Down
1 change: 1 addition & 0 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,4 @@ def to_serial(self) -> stys.Qubit:

Qubit = QubitDef()
Bool = UnitSum(size=2)
Unit = UnitSum(size=1)
49 changes: 49 additions & 0 deletions hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from hugr._cfg import Cfg
import hugr._tys as tys
from hugr._dfg import Dfg
from .test_hugr_build import _validate, INT_T, DivMod


def build_basic_cfg(cfg: Cfg) -> None:
entry = cfg.simple_entry(1, [tys.Bool])

entry.block_outputs(*entry.inputs())
cfg.branch(entry.root.out(0), cfg.exit)


def test_basic_cfg() -> None:
cfg = Cfg([tys.Unit, tys.Bool], [tys.Bool])
build_basic_cfg(cfg)
_validate(cfg.hugr)


def test_branch() -> None:
cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T])
entry = cfg.simple_entry(2, [tys.Unit, INT_T])
entry.block_outputs(*entry.inputs())

middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T])
middle_1.block_outputs(*middle_1.inputs())
middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T])
u, i = middle_2.inputs()
n = middle_2.add(DivMod(i, i))
middle_2.block_outputs(u, n[0])

cfg.branch(entry.root.out(0), middle_1.root)
cfg.branch(entry.root.out(1), middle_2.root)

cfg.branch(middle_1.root.out(0), cfg.exit)
cfg.branch(middle_2.root.out(0), cfg.exit)

_validate(cfg.hugr)


def test_nested_cfg() -> None:
dfg = Dfg([tys.Unit, tys.Bool], [tys.Bool])

cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs())

build_basic_cfg(cfg)
dfg.set_outputs(cfg.root)

_validate(dfg.hugr, True)
2 changes: 1 addition & 1 deletion hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_build_inter_graph():

h.set_outputs(nested.root, b)

_validate(h.hugr, True)
_validate(h.hugr)

assert _SubPort(h.input_node.out(-1)) in h.hugr._links
assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order
Expand Down

0 comments on commit 81caec3

Please sign in to comment.