Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): ToNode interface to treat builders as nodes #1193

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Sequence
from ._hugr import Hugr, Node, Wire
from ._hugr import Hugr, Node, ToNode, Wire
from ._hugr import ParentBuilder
from ._dfg import DfBase, _from_base
from ._tys import FunctionType, TypeRow, Sum
from ._exceptions import NoSiblingAncestor, NotInSameCfg
Expand Down Expand Up @@ -37,7 +38,7 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]):


@dataclass
class Cfg:
class Cfg(ParentBuilder):
hugr: Hugr
root: Node
_entry_block: Block
Expand Down Expand Up @@ -86,5 +87,5 @@ def simple_block(
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)

def branch(self, src: Wire, dst: Node) -> None:
def branch(self, src: Wire, dst: ToNode) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Iterable, TYPE_CHECKING, Generic, TypeVar, cast
import typing
from ._hugr import Hugr, Node, Wire, OutPort
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

import hugr._ops as ops
from ._exceptions import NoSiblingAncestor
Expand All @@ -16,7 +16,7 @@


@dataclass()
class DfBase(Generic[DP]):
class DfBase(ParentBuilder, Generic[DP]):
hugr: Hugr
root: Node
input_node: Node
Expand Down
88 changes: 55 additions & 33 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ def out_port(self) -> OutPort:
return self


@dataclass(frozen=True, eq=True, order=True)
class Node(Wire):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)
class ToNode(Wire, Protocol):
def to_node(self) -> Node: ...

@overload
def __getitem__(self, index: int) -> OutPort: ...
Expand All @@ -73,6 +71,32 @@ def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ...

def __getitem__(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
return self.to_node()._index(index)

def out_port(self) -> "OutPort":
return OutPort(self.to_node(), 0)

def inp(self, offset: int) -> InPort:
return InPort(self.to_node(), offset)

def out(self, offset: int) -> OutPort:
return OutPort(self.to_node(), offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
if direction == Direction.INCOMING:
return self.inp(offset)
else:
return self.out(offset)


@dataclass(frozen=True, eq=True, order=True)
class Node(ToNode):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)

def _index(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
match index:
case int(index):
Expand All @@ -92,20 +116,8 @@ def __getitem__(
case tuple(xs):
return (self[i] for i in xs)

def out_port(self) -> "OutPort":
return OutPort(self, 0)

def inp(self, offset: int) -> InPort:
return InPort(self, offset)

def out(self, offset: int) -> OutPort:
return OutPort(self, offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
if direction == Direction.INCOMING:
return self.inp(offset)
else:
return self.out(offset)
def to_node(self) -> Node:
return self


@dataclass()
Expand Down Expand Up @@ -139,6 +151,13 @@ def next_sub_offset(self) -> Self:
_SI = _SubPort[InPort]


class ParentBuilder(ToNode):
root: Node

def to_node(self) -> Node:
return self.root


@dataclass()
class Hugr(Mapping[Node, NodeData]):
root: Node
Expand All @@ -152,7 +171,8 @@ def __init__(self, root_op: Op) -> None:
self._nodes = []
self.root = self._add_node(root_op, None, 0)

def __getitem__(self, key: Node) -> NodeData:
def __getitem__(self, key: ToNode) -> NodeData:
key = key.to_node()
try:
n = self._nodes[key.idx]
except IndexError:
Expand All @@ -167,16 +187,17 @@ def __iter__(self):
def __len__(self) -> int:
return self.num_nodes()

def children(self, node: Node | None = None) -> list[Node]:
def children(self, node: ToNode | None = None) -> list[Node]:
node = node or self.root
return self[node].children

def _add_node(
self,
op: Op,
parent: Node | None = None,
parent: ToNode | None = None,
num_outs: int | None = None,
) -> Node:
parent = parent.to_node() if parent else None
node_data = NodeData(op, parent)

if self._free_nodes:
Expand All @@ -193,13 +214,14 @@ def _add_node(
def add_node(
self,
op: Op,
parent: Node | None = None,
parent: ToNode | None = None,
num_outs: int | None = None,
) -> Node:
parent = parent or self.root
return self._add_node(op, parent, num_outs)

def delete_node(self, node: Node) -> NodeData | None:
def delete_node(self, node: ToNode) -> NodeData | None:
node = node.to_node()
parent = self[node].parent
if parent:
self[parent].children.remove(node)
Expand Down Expand Up @@ -247,17 +269,17 @@ def delete_link(self, src: OutPort, dst: InPort) -> None:
def num_nodes(self) -> int:
return len(self._nodes) - len(self._free_nodes)

def num_ports(self, node: Node, direction: Direction) -> int:
def num_ports(self, node: ToNode, direction: Direction) -> int:
return (
self.num_in_ports(node)
if direction == Direction.INCOMING
else self.num_out_ports(node)
)

def num_in_ports(self, node: Node) -> int:
def num_in_ports(self, node: ToNode) -> int:
return self[node]._num_inps

def num_out_ports(self, node: Node) -> int:
def num_out_ports(self, node: ToNode) -> int:
return self[node]._num_outs

def _linked_ports(
Expand All @@ -282,14 +304,14 @@ def linked_ports(self, port: OutPort | InPort):

# TODO: single linked port

def outgoing_order_links(self, node: Node) -> Iterable[Node]:
def outgoing_order_links(self, node: ToNode) -> Iterable[Node]:
return (p.node for p in self.linked_ports(node.out(-1)))

def incoming_order_links(self, node: Node) -> Iterable[Node]:
def incoming_order_links(self, node: ToNode) -> Iterable[Node]:
return (p.node for p in self.linked_ports(node.inp(-1)))

def _node_links(
self, node: Node, links: dict[_SubPort[P], _SubPort[K]]
self, node: ToNode, links: dict[_SubPort[P], _SubPort[K]]
) -> Iterable[tuple[P, list[K]]]:
try:
direction = next(iter(links.keys())).port.direction
Expand All @@ -300,23 +322,23 @@ def _node_links(
port = cast(P, node.port(offset, direction))
yield port, list(self._linked_ports(port, links))

def outgoing_links(self, node: Node) -> Iterable[tuple[OutPort, list[InPort]]]:
def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]:
return self._node_links(node, self._links.fwd)

def incoming_links(self, node: Node) -> Iterable[tuple[InPort, list[OutPort]]]:
def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]:
return self._node_links(node, self._links.bck)

def num_incoming(self, node: Node) -> int:
# connecetd links
return sum(1 for _ in self.incoming_links(node))

def num_outgoing(self, node: Node) -> int:
def num_outgoing(self, node: ToNode) -> int:
# connecetd links
return sum(1 for _ in self.outgoing_links(node))

# TODO: num_links and _linked_ports

def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node]:
def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]:
mapping: dict[Node, Node] = {}

for idx, node_data in enumerate(hugr._nodes):
Expand Down
20 changes: 10 additions & 10 deletions hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def build_basic_cfg(cfg: Cfg) -> None:
entry = cfg.simple_entry(1, [tys.Bool])

entry.set_block_outputs(*entry.inputs())
cfg.branch(entry.root.out(0), cfg.exit)
cfg.branch(entry[0], cfg.exit)


def test_basic_cfg() -> None:
Expand All @@ -29,11 +29,11 @@ def test_branch() -> None:
n = middle_2.add(DivMod(i, i))
middle_2.set_block_outputs(u, n[0])

cfg.branch(entry.root.out(0), middle_1.root)
cfg.branch(entry.root.out(1), middle_2.root)
cfg.branch(entry[0], middle_1)
cfg.branch(entry[1], middle_2)

cfg.branch(middle_1.root.out(0), cfg.exit)
cfg.branch(middle_2.root.out(0), cfg.exit)
cfg.branch(middle_1[0], cfg.exit)
cfg.branch(middle_2[0], cfg.exit)

_validate(cfg.hugr)

Expand All @@ -44,7 +44,7 @@ def test_nested_cfg() -> None:
cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs())

build_basic_cfg(cfg)
dfg.set_outputs(cfg.root)
dfg.set_outputs(cfg)

_validate(dfg.hugr)

Expand All @@ -62,10 +62,10 @@ def test_dom_edge() -> None:
middle_2 = cfg.simple_block([INT_T], 1, [INT_T])
middle_2.set_block_outputs(u, *middle_2.inputs())

cfg.branch(entry.root.out(0), middle_1.root)
cfg.branch(entry.root.out(1), middle_2.root)
cfg.branch(entry[0], middle_1)
cfg.branch(entry[1], middle_2)

cfg.branch(middle_1.root.out(0), cfg.exit)
cfg.branch(middle_2.root.out(0), cfg.exit)
cfg.branch(middle_1[0], cfg.exit)
cfg.branch(middle_2[0], cfg.exit)

_validate(cfg.hugr)
8 changes: 4 additions & 4 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def _nested_nop(dfg: Dfg):
nested = h.add_nested([tys.Bool], [tys.Bool], a)

_nested_nop(nested)
assert len(h.hugr.children(nested.root)) == 3
h.set_outputs(nested.root)
assert len(h.hugr.children(nested)) == 3
h.set_outputs(nested)

_validate(h.hugr)

Expand All @@ -238,14 +238,14 @@ def test_build_inter_graph():
nt = nested.add(Not(a))
nested.set_outputs(nt)

h.set_outputs(nested.root, b)
h.set_outputs(nested, b)

_validate(h.hugr)

assert _SubPort(h.input_node.out(-1)) in h.hugr._links
assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order
assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1
assert len(list(h.hugr.incoming_order_links(nested.root))) == 1
assert len(list(h.hugr.incoming_order_links(nested))) == 1
assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0


Expand Down