Skip to content

Commit

Permalink
feat: flow input types to output in cfg building
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 17, 2024
1 parent dcfebc4 commit e8b3a17
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 52 deletions.
59 changes: 27 additions & 32 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence
from dataclasses import dataclass, replace

import hugr._ops as ops

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, Sum, TypeRow, Type
from ._tys import FunctionType, TypeRow, Type


class Block(_DfBase[ops.DataflowBlock]):
Expand Down Expand Up @@ -46,37 +45,34 @@ class Cfg(ParentBuilder[ops.CFG]):
_entry_block: Block
exit: Node

def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
def __init__(self, input_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=[]))
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, input_types, output_types)
self._init_impl(hugr, hugr.root, input_types)

def _init_impl(
self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow
) -> None:
def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None:
self.hugr = hugr
self.parent_node = root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = Block.new_nested(
ops.DataflowBlock(input_types, []), hugr, root
)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.parent_node)
self.exit = self.hugr.add_node(ops.ExitBlock([]), self.parent_node)

@classmethod
def new_nested(
cls,
input_types: TypeRow,
output_types: TypeRow,
hugr: Hugr,
parent: ToNode | None = None,
) -> Cfg:
new = cls.__new__(cls)
root = hugr.add_node(
ops.CFG(FunctionType(input=input_types, output=output_types)),
ops.CFG(FunctionType(input=input_types, output=[])),
parent or hugr.root,
)
new._init_impl(hugr, root, input_types, output_types)
new._init_impl(hugr, root, input_types)
return new

@property
Expand All @@ -89,30 +85,29 @@ def _entry_op(self) -> ops.DataflowBlock:
def _exit_op(self) -> ops.ExitBlock:
return self.hugr._get_typed_op(self.exit, ops.ExitBlock)

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
def add_entry(self) -> Block:
return self._entry_block

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
def add_block(self, input_types: TypeRow) -> Block:
new_block = Block.new_nested(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs),
ops.DataflowBlock(input_types, [], []),
self.hugr,
self.parent_node,
)
return new_block

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)

def branch(self, src: Wire, dst: ToNode) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
src = src.out_port()
self.hugr.add_link(src, dst.inp(0))

if dst == self.exit:
src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock)
out_types = [*src_block.sum_rows[src.offset], *src_block.other_outputs]
if self._exit_op().cfg_outputs:
if self._exit_op().cfg_outputs != out_types:
raise MismatchedExit(src.node.idx)

Check warning on line 108 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L108

Added line #L108 was not covered by tests
else:
self._exit_op().cfg_outputs = out_types
self.parent_op().signature = replace(
self.parent_op().signature, output=out_types
)
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def add_cfg(
) -> Cfg:
from ._cfg import Cfg

cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.parent_node)
cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

Expand Down
11 changes: 11 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,16 @@ def msg(self):
return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up."


@dataclass
class MismatchedExit(Exception):
src: int

@property
def msg(self):
return (

Check warning on line 30 in hugr-py/src/hugr/_exceptions.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_exceptions.py#L30

Added line #L30 was not covered by tests
f"Exit branch from node {self.src} does not match existing exit block type."
)


class ParentBeforeChild(Exception):
msg: str = "Parent node must be added before child node."
26 changes: 25 additions & 1 deletion hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,30 @@ def _set_in_types(self, types: tys.TypeRow) -> None:
assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}"
(row,) = t.variant_rows
self.types = row
print(row)


UnpackTuple = UnpackTupleDef()


@dataclass()
class Tag(DataflowOp):
tag: int
variants: list[tys.TypeRow]
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Tag:
return sops.Tag(
parent=parent.idx,
tag=self.tag,
variants=[ser_it(r) for r in self.variants],
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(
input=self.variants[self.tag], output=[tys.Sum(self.variants)]
)


class DfParentOp(Op, Protocol):
def inner_signature(self) -> tys.FunctionType: ...

Expand Down Expand Up @@ -259,6 +277,12 @@ def inner_signature(self) -> tys.FunctionType:
def port_kind(self, port: _Port) -> tys.Kind:
return tys.CFKind()

Check warning on line 278 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L278

Added line #L278 was not covered by tests

def _set_out_types(self, types: tys.TypeRow) -> None:
(sum_, *other) = types
assert isinstance(sum_, tys.Sum), f"Expected Sum, got {sum_}"
self.sum_rows = sum_.variant_rows
self.other_outputs = other


@dataclass
class ExitBlock(Op):
Expand Down
20 changes: 12 additions & 8 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,6 @@ def to_serial(self) -> stys.Array:
return stys.Array(ty=self.ty.to_serial_root(), len=self.size)


@dataclass(frozen=True)
class UnitSum(Type):
size: int

def to_serial(self) -> stys.UnitSum:
return stys.UnitSum(size=self.size)


@dataclass()
class Sum(Type):
variant_rows: list[TypeRow]
Expand All @@ -181,6 +173,18 @@ def as_tuple(self) -> Tuple:
return Tuple(*self.variant_rows[0])


@dataclass()
class UnitSum(Sum):
size: int

def __init__(self, size: int):
self.size = size
super().__init__(variant_rows=[[]] * size)

def to_serial(self) -> stys.UnitSum: # type: ignore[override]
return stys.UnitSum(size=self.size)


@dataclass()
class Tuple(Sum):
def __init__(self, *tys: Type):
Expand Down
42 changes: 32 additions & 10 deletions hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
from hugr._cfg import Cfg
import hugr._tys as tys
from hugr._dfg import Dfg
import hugr._ops as ops
from .test_hugr_build import _validate, INT_T, DivMod


def build_basic_cfg(cfg: Cfg) -> None:
entry = cfg.simple_entry(1, [tys.Bool])
entry = cfg.add_entry()

entry.set_block_outputs(*entry.inputs())
cfg.branch(entry[0], cfg.exit)


def test_basic_cfg() -> None:
cfg = Cfg([tys.Unit, tys.Bool], [tys.Bool])
cfg = Cfg([tys.Unit, tys.Bool])
build_basic_cfg(cfg)
_validate(cfg.hugr)


def test_branch() -> None:
cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T])
entry = cfg.simple_entry(2, [tys.Unit, INT_T])
cfg = Cfg([tys.Bool, tys.Unit, INT_T])
entry = cfg.add_entry()
entry.set_block_outputs(*entry.inputs())

middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T])
middle_1 = cfg.add_block([tys.Unit, INT_T])
middle_1.set_block_outputs(*middle_1.inputs())
middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T])
middle_2 = cfg.add_block([tys.Unit, INT_T])
u, i = middle_2.inputs()
n = middle_2.add(DivMod(i, i))
middle_2.set_block_outputs(u, n[0])
Expand All @@ -50,16 +51,16 @@ def test_nested_cfg() -> None:


def test_dom_edge() -> None:
cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T])
entry = cfg.simple_entry(2, [INT_T])
cfg = Cfg([tys.Bool, tys.Unit, INT_T])
entry = cfg.add_entry()
b, u, i = entry.inputs()
entry.set_block_outputs(b, i)

# entry dominates both middles so Unit type can be used as inter-graph
# value between basic blocks
middle_1 = cfg.simple_block([INT_T], 1, [INT_T])
middle_1 = cfg.add_block([INT_T])
middle_1.set_block_outputs(u, *middle_1.inputs())
middle_2 = cfg.simple_block([INT_T], 1, [INT_T])
middle_2 = cfg.add_block([INT_T])
middle_2.set_block_outputs(u, *middle_2.inputs())

cfg.branch(entry[0], middle_1)
Expand All @@ -69,3 +70,24 @@ def test_dom_edge() -> None:
cfg.branch(middle_2[0], cfg.exit)

_validate(cfg.hugr)


def test_asymm_types() -> None:
# test different types going to entry block's susccessors
cfg = Cfg([tys.Bool, tys.Unit, INT_T])
entry = cfg.add_entry()
b, u, i = entry.inputs()

tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(i))
entry.set_block_outputs(tagged_int)

middle = cfg.add_block([INT_T])
# discard the int and return the bool from entry
middle.set_block_outputs(u, b)

# middle expects an int and exit expects a bool
cfg.branch(entry[0], middle)
cfg.branch(entry[1], cfg.exit)
cfg.branch(middle[0], cfg.exit)

_validate(cfg.hugr)

0 comments on commit e8b3a17

Please sign in to comment.