Skip to content

Commit

Permalink
feat(hugr-py): CFG builder (#1192)
Browse files Browse the repository at this point in the history
Closes #1188
  • Loading branch information
ss2165 authored Jun 17, 2024
1 parent 102d661 commit c5ea47f
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 39 deletions.
90 changes: 90 additions & 0 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Sequence
from ._hugr import Hugr, Node, Wire
from ._dfg import DfBase, _from_base
from ._tys import FunctionType, TypeRow, Sum
from ._exceptions import NoSiblingAncestor, NotInSameCfg
import hugr._ops as ops


class Block(DfBase[ops.DataflowBlock]):
def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def set_single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
cfg_node = self.hugr[self.root].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
self._wire_up_port(node, i, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(i))


@dataclass
class Cfg:
hugr: Hugr
root: Node
_entry_block: Block
exit: Node

def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = _from_base(
Block, self.hugr.add_dfg(ops.DataflowBlock(input_types, []))
)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root)

@property
def entry(self) -> Node:
return self._entry_block.root

def _entry_op(self) -> ops.DataflowBlock:
dop = self.hugr[self.entry].op
assert isinstance(dop, ops.DataflowBlock)
return dop

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
return self._entry_block

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
new_block = self.hugr.add_dfg(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs)
)
return _from_base(Block, new_block)

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)

def branch(self, src: Wire, dst: Node) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
112 changes: 81 additions & 31 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence, Iterable
from typing import Iterable, TYPE_CHECKING, Generic, TypeVar, cast
import typing
from ._hugr import Hugr, Node, Wire, OutPort

from ._ops import Op, Command, Input, Output, DFG
import hugr._ops as ops
from ._exceptions import NoSiblingAncestor
from hugr._tys import FunctionType, Type
from hugr._tys import FunctionType, TypeRow

if TYPE_CHECKING:
from ._cfg import Cfg


DP = TypeVar("DP", bound=ops.DfParentOp)


@dataclass()
class Dfg:
class DfBase(Generic[DP]):
hugr: Hugr
root: Node
input_node: Node
output_node: Node

def __init__(
self, input_types: Sequence[Type], output_types: Sequence[Type]
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DFG(FunctionType(input=input_types, output=output_types))
def __init__(self, root_op: DP) -> None:
input_types = root_op.input_types()
output_types = root_op.output_types()
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
Input(input_types), self.root, len(input_types)
ops.Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)
self.output_node = self.hugr.add_node(ops.Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> Input:
def _input_op(self) -> ops.Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, Input)
assert isinstance(dop, ops.Input)
return dop

def _output_op(self) -> ops.Output:
dop = self.hugr[self.output_node].op
assert isinstance(dop, ops.Output)
return dop

def root_op(self) -> DP:
return cast(DP, self.hugr[self.root].op)

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
Expand All @@ -55,13 +63,30 @@ def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:

def add_nested(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Dfg:
dfg = self.hugr.add_dfg(input_types, output_types)
dfg = self.hugr.add_dfg(
ops.DFG(FunctionType(input=input_types, output=output_types))
)
self._wire_up(dfg.root, args)
return dfg
return _from_base(Dfg, dfg)

def add_cfg(
self,
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Cfg:
cfg = self.hugr.add_cfg(input_types, output_types)
self._wire_up(cfg.root, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.root)
self._wire_up(mapping[cfg.root], args)
return mapping[cfg.root]

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
Expand All @@ -72,13 +97,38 @@ def add_state_order(self, src: Node, dst: Node) -> None:

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(i))
self._wire_up_port(node, i, p)

def _wire_up_port(self, node: Node, offset: int, p: Wire):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(offset))


C = TypeVar("C", bound=DfBase)


def _from_base(cls: typing.Type[C], base: DfBase[DP]) -> C:
new = cls.__new__(cls)
new.hugr = base.hugr
new.root = base.root
new.input_node = base.input_node
new.output_node = base.output_node
return new


class Dfg(DfBase[ops.DFG]):
def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
super().__init__(root_op)

@classmethod
def endo(cls, types: TypeRow) -> Dfg:
return cls(types, types)


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
Expand Down
10 changes: 10 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,15 @@ def msg(self):
return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up."


@dataclass
class NotInSameCfg(Exception):
src: int
tgt: int

@property
def msg(self):
return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up."


class ParentBeforeChild(Exception):
msg: str = "Parent node must be added before child node."
27 changes: 21 additions & 6 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Iterable,
Iterator,
Protocol,
Sequence,
TypeVar,
cast,
overload,
Expand All @@ -19,15 +18,16 @@
from typing_extensions import Self

from hugr._ops import Op
from hugr._tys import Type
from hugr._tys import TypeRow
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap

from ._exceptions import ParentBeforeChild

if TYPE_CHECKING:
from ._dfg import Dfg
from ._dfg import DfBase, DP
from ._cfg import Cfg


class Direction(Enum):
Expand Down Expand Up @@ -337,17 +337,32 @@ def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node
)
return mapping

def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Dfg:
from ._dfg import Dfg
def add_dfg(self, root_op: DP) -> DfBase[DP]:
from ._dfg import DfBase

dfg = Dfg(input_types, output_types)
dfg = DfBase(root_op)
mapping = self.insert_hugr(dfg.hugr, self.root)
dfg.hugr = self
dfg.input_node = mapping[dfg.input_node]
dfg.output_node = mapping[dfg.output_node]
dfg.root = mapping[dfg.root]
return dfg

def add_cfg(self, input_types: TypeRow, output_types: TypeRow) -> Cfg:
from ._cfg import Cfg

cfg = Cfg(input_types, output_types)
mapping = self.insert_hugr(cfg.hugr, self.root)
cfg.hugr = self
cfg._entry_block.root = mapping[cfg.entry]
cfg._entry_block.input_node = mapping[cfg._entry_block.input_node]
cfg._entry_block.output_node = mapping[cfg._entry_block.output_node]
cfg._entry_block.hugr = self
cfg.exit = mapping[cfg.exit]
cfg.root = mapping[cfg.root]
# TODO this is horrible
return cfg

def to_serial(self) -> SerialHugr:
node_it = (node for node in self._nodes if node is not None)

Expand Down
Loading

0 comments on commit c5ea47f

Please sign in to comment.