Skip to content

Commit

Permalink
feat: ToNode interface to treat builders as nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 13, 2024
1 parent 81caec3 commit 7a4fa78
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 47 deletions.
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from ._hugr import Hugr, Node, Wire
from ._hugr import Hugr, Node, Wire, ParentBuilder
from ._dfg import DfBase, _from_base
from ._tys import Type, FunctionType, TypeRow, Sum
import hugr._ops as ops
Expand All @@ -17,7 +17,7 @@ def single_successor_outputs(self, *outputs: Wire) -> None:


@dataclass
class Cfg:
class Cfg(ParentBuilder):
hugr: Hugr
root: Node
_entry_block: Block
Expand Down
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Sequence, Iterable, TYPE_CHECKING, Generic, TypeVar, cast
import typing
from ._hugr import Hugr, Node, Wire, OutPort
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

import hugr._ops as ops
from ._exceptions import NoSiblingAncestor
Expand All @@ -16,7 +16,7 @@


@dataclass()
class DfBase(Generic[DP]):
class DfBase(ParentBuilder, Generic[DP]):
hugr: Hugr
root: Node
input_node: Node
Expand Down
88 changes: 55 additions & 33 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def out_port(self) -> OutPort:
return self


@dataclass(frozen=True, eq=True, order=True)
class Node(Wire):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)
class ToNode(Wire, Protocol):
def to_node(self) -> Node: ...

@overload
def __getitem__(self, index: int) -> OutPort: ...
Expand All @@ -74,6 +72,32 @@ def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ...

def __getitem__(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
return self.to_node()._index(index)

def out_port(self) -> "OutPort":
return OutPort(self.to_node(), 0)

def inp(self, offset: int) -> InPort:
return InPort(self.to_node(), offset)

def out(self, offset: int) -> OutPort:
return OutPort(self.to_node(), offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
if direction == Direction.INCOMING:
return self.inp(offset)
else:
return self.out(offset)


@dataclass(frozen=True, eq=True, order=True)
class Node(ToNode):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)

def _index(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
match index:
case int(index):
Expand All @@ -93,20 +117,8 @@ def __getitem__(
case tuple(xs):
return (self[i] for i in xs)

def out_port(self) -> "OutPort":
return OutPort(self, 0)

def inp(self, offset: int) -> InPort:
return InPort(self, offset)

def out(self, offset: int) -> OutPort:
return OutPort(self, offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
if direction == Direction.INCOMING:
return self.inp(offset)
else:
return self.out(offset)
def to_node(self) -> Node:
return self


@dataclass()
Expand Down Expand Up @@ -140,6 +152,13 @@ def next_sub_offset(self) -> Self:
_SI = _SubPort[InPort]


class ParentBuilder(ToNode):
root: Node

def to_node(self) -> Node:
return self.root


@dataclass()
class Hugr(Mapping[Node, NodeData]):
root: Node
Expand All @@ -153,7 +172,8 @@ def __init__(self, root_op: Op) -> None:
self._nodes = []
self.root = self._add_node(root_op, None, 0)

def __getitem__(self, key: Node) -> NodeData:
def __getitem__(self, key: ToNode) -> NodeData:
key = key.to_node()
try:
n = self._nodes[key.idx]
except IndexError:
Expand All @@ -168,16 +188,17 @@ def __iter__(self):
def __len__(self) -> int:
return self.num_nodes()

def children(self, node: Node | None = None) -> list[Node]:
def children(self, node: ToNode | None = None) -> list[Node]:
node = node or self.root
return self[node].children

def _add_node(
self,
op: Op,
parent: Node | None = None,
parent: ToNode | None = None,
num_outs: int | None = None,
) -> Node:
parent = parent.to_node() if parent else None
node_data = NodeData(op, parent)

if self._free_nodes:
Expand All @@ -194,13 +215,14 @@ def _add_node(
def add_node(
self,
op: Op,
parent: Node | None = None,
parent: ToNode | None = None,
num_outs: int | None = None,
) -> Node:
parent = parent or self.root
return self._add_node(op, parent, num_outs)

def delete_node(self, node: Node) -> NodeData | None:
def delete_node(self, node: ToNode) -> NodeData | None:
node = node.to_node()
parent = self[node].parent
if parent:
self[parent].children.remove(node)
Expand Down Expand Up @@ -248,17 +270,17 @@ def delete_link(self, src: OutPort, dst: InPort) -> None:
def num_nodes(self) -> int:
return len(self._nodes) - len(self._free_nodes)

def num_ports(self, node: Node, direction: Direction) -> int:
def num_ports(self, node: ToNode, direction: Direction) -> int:
return (
self.num_in_ports(node)
if direction == Direction.INCOMING
else self.num_out_ports(node)
)

def num_in_ports(self, node: Node) -> int:
def num_in_ports(self, node: ToNode) -> int:
return self[node]._num_inps

def num_out_ports(self, node: Node) -> int:
def num_out_ports(self, node: ToNode) -> int:
return self[node]._num_outs

def _linked_ports(
Expand All @@ -283,14 +305,14 @@ def linked_ports(self, port: OutPort | InPort):

# TODO: single linked port

def outgoing_order_links(self, node: Node) -> Iterable[Node]:
def outgoing_order_links(self, node: ToNode) -> Iterable[Node]:
return (p.node for p in self.linked_ports(node.out(-1)))

def incoming_order_links(self, node: Node) -> Iterable[Node]:
def incoming_order_links(self, node: ToNode) -> Iterable[Node]:
return (p.node for p in self.linked_ports(node.inp(-1)))

def _node_links(
self, node: Node, links: dict[_SubPort[P], _SubPort[K]]
self, node: ToNode, links: dict[_SubPort[P], _SubPort[K]]
) -> Iterable[tuple[P, list[K]]]:
try:
direction = next(iter(links.keys())).port.direction
Expand All @@ -301,23 +323,23 @@ def _node_links(
port = cast(P, node.port(offset, direction))
yield port, list(self._linked_ports(port, links))

def outgoing_links(self, node: Node) -> Iterable[tuple[OutPort, list[InPort]]]:
def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]:
return self._node_links(node, self._links.fwd)

def incoming_links(self, node: Node) -> Iterable[tuple[InPort, list[OutPort]]]:
def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]:
return self._node_links(node, self._links.bck)

def num_incoming(self, node: Node) -> int:
# connecetd links
return sum(1 for _ in self.incoming_links(node))

def num_outgoing(self, node: Node) -> int:
def num_outgoing(self, node: ToNode) -> int:
# connecetd links
return sum(1 for _ in self.outgoing_links(node))

# TODO: num_links and _linked_ports

def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node]:
def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]:
mapping: dict[Node, Node] = {}

for idx, node_data in enumerate(hugr._nodes):
Expand Down
12 changes: 6 additions & 6 deletions hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def build_basic_cfg(cfg: Cfg) -> None:
entry = cfg.simple_entry(1, [tys.Bool])

entry.block_outputs(*entry.inputs())
cfg.branch(entry.root.out(0), cfg.exit)
cfg.branch(entry[0], cfg.exit)


def test_basic_cfg() -> None:
Expand All @@ -29,11 +29,11 @@ def test_branch() -> None:
n = middle_2.add(DivMod(i, i))
middle_2.block_outputs(u, n[0])

cfg.branch(entry.root.out(0), middle_1.root)
cfg.branch(entry.root.out(1), middle_2.root)
cfg.branch(entry[0], middle_1.root)
cfg.branch(entry[1], middle_2.root)

cfg.branch(middle_1.root.out(0), cfg.exit)
cfg.branch(middle_2.root.out(0), cfg.exit)
cfg.branch(middle_1[0], cfg.exit)
cfg.branch(middle_2[0], cfg.exit)

_validate(cfg.hugr)

Expand All @@ -44,6 +44,6 @@ def test_nested_cfg() -> None:
cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs())

build_basic_cfg(cfg)
dfg.set_outputs(cfg.root)
dfg.set_outputs(cfg)

_validate(dfg.hugr, True)
8 changes: 4 additions & 4 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def _nested_nop(dfg: Dfg):
nested = h.add_nested([tys.Bool], [tys.Bool], a)

_nested_nop(nested)
assert len(h.hugr.children(nested.root)) == 3
h.set_outputs(nested.root)
assert len(h.hugr.children(nested)) == 3
h.set_outputs(nested)

_validate(h.hugr)

Expand All @@ -238,14 +238,14 @@ def test_build_inter_graph():
nt = nested.add(Not(a))
nested.set_outputs(nt)

h.set_outputs(nested.root, b)
h.set_outputs(nested, b)

_validate(h.hugr)

assert _SubPort(h.input_node.out(-1)) in h.hugr._links
assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order
assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1
assert len(list(h.hugr.incoming_order_links(nested.root))) == 1
assert len(list(h.hugr.incoming_order_links(nested))) == 1
assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0


Expand Down

0 comments on commit 7a4fa78

Please sign in to comment.