diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 8ac058a23..8936c077c 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -1,7 +1,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Iterable, Sequence -from ._hugr import Hugr, Node, Wire +from ._hugr import Hugr, Node, ToNode, Wire +from ._hugr import ParentBuilder from ._dfg import DfBase, _from_base from ._tys import FunctionType, TypeRow, Sum from ._exceptions import NoSiblingAncestor, NotInSameCfg @@ -37,7 +38,7 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]): @dataclass -class Cfg: +class Cfg(ParentBuilder): hugr: Hugr root: Node _entry_block: Block @@ -86,5 +87,5 @@ def simple_block( ) -> Block: return self.add_block(input_types, [[]] * n_branches, other_outputs) - def branch(self, src: Wire, dst: Node) -> None: + def branch(self, src: Wire, dst: ToNode) -> None: self.hugr.add_link(src.out_port(), dst.inp(0)) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 17499a7b9..40b84e422 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Iterable, TYPE_CHECKING, Generic, TypeVar, cast import typing -from ._hugr import Hugr, Node, Wire, OutPort +from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder import hugr._ops as ops from ._exceptions import NoSiblingAncestor @@ -16,7 +16,7 @@ @dataclass() -class DfBase(Generic[DP]): +class DfBase(ParentBuilder, Generic[DP]): hugr: Hugr root: Node input_node: Node diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 3bee26bd6..2a55fd359 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -59,10 +59,8 @@ def out_port(self) -> OutPort: return self -@dataclass(frozen=True, eq=True, order=True) -class Node(Wire): - idx: int - _num_out_ports: int | None = field(default=None, compare=False) +class ToNode(Wire, Protocol): + def to_node(self) -> Node: ... @overload def __getitem__(self, index: int) -> OutPort: ... @@ -73,6 +71,32 @@ def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ... def __getitem__( self, index: int | slice | tuple[int, ...] + ) -> OutPort | Iterator[OutPort]: + return self.to_node()._index(index) + + def out_port(self) -> "OutPort": + return OutPort(self.to_node(), 0) + + def inp(self, offset: int) -> InPort: + return InPort(self.to_node(), offset) + + def out(self, offset: int) -> OutPort: + return OutPort(self.to_node(), offset) + + def port(self, offset: int, direction: Direction) -> InPort | OutPort: + if direction == Direction.INCOMING: + return self.inp(offset) + else: + return self.out(offset) + + +@dataclass(frozen=True, eq=True, order=True) +class Node(ToNode): + idx: int + _num_out_ports: int | None = field(default=None, compare=False) + + def _index( + self, index: int | slice | tuple[int, ...] ) -> OutPort | Iterator[OutPort]: match index: case int(index): @@ -92,20 +116,8 @@ def __getitem__( case tuple(xs): return (self[i] for i in xs) - def out_port(self) -> "OutPort": - return OutPort(self, 0) - - def inp(self, offset: int) -> InPort: - return InPort(self, offset) - - def out(self, offset: int) -> OutPort: - return OutPort(self, offset) - - def port(self, offset: int, direction: Direction) -> InPort | OutPort: - if direction == Direction.INCOMING: - return self.inp(offset) - else: - return self.out(offset) + def to_node(self) -> Node: + return self @dataclass() @@ -139,6 +151,13 @@ def next_sub_offset(self) -> Self: _SI = _SubPort[InPort] +class ParentBuilder(ToNode): + root: Node + + def to_node(self) -> Node: + return self.root + + @dataclass() class Hugr(Mapping[Node, NodeData]): root: Node @@ -152,7 +171,8 @@ def __init__(self, root_op: Op) -> None: self._nodes = [] self.root = self._add_node(root_op, None, 0) - def __getitem__(self, key: Node) -> NodeData: + def __getitem__(self, key: ToNode) -> NodeData: + key = key.to_node() try: n = self._nodes[key.idx] except IndexError: @@ -167,16 +187,17 @@ def __iter__(self): def __len__(self) -> int: return self.num_nodes() - def children(self, node: Node | None = None) -> list[Node]: + def children(self, node: ToNode | None = None) -> list[Node]: node = node or self.root return self[node].children def _add_node( self, op: Op, - parent: Node | None = None, + parent: ToNode | None = None, num_outs: int | None = None, ) -> Node: + parent = parent.to_node() if parent else None node_data = NodeData(op, parent) if self._free_nodes: @@ -193,13 +214,14 @@ def _add_node( def add_node( self, op: Op, - parent: Node | None = None, + parent: ToNode | None = None, num_outs: int | None = None, ) -> Node: parent = parent or self.root return self._add_node(op, parent, num_outs) - def delete_node(self, node: Node) -> NodeData | None: + def delete_node(self, node: ToNode) -> NodeData | None: + node = node.to_node() parent = self[node].parent if parent: self[parent].children.remove(node) @@ -247,17 +269,17 @@ def delete_link(self, src: OutPort, dst: InPort) -> None: def num_nodes(self) -> int: return len(self._nodes) - len(self._free_nodes) - def num_ports(self, node: Node, direction: Direction) -> int: + def num_ports(self, node: ToNode, direction: Direction) -> int: return ( self.num_in_ports(node) if direction == Direction.INCOMING else self.num_out_ports(node) ) - def num_in_ports(self, node: Node) -> int: + def num_in_ports(self, node: ToNode) -> int: return self[node]._num_inps - def num_out_ports(self, node: Node) -> int: + def num_out_ports(self, node: ToNode) -> int: return self[node]._num_outs def _linked_ports( @@ -282,14 +304,14 @@ def linked_ports(self, port: OutPort | InPort): # TODO: single linked port - def outgoing_order_links(self, node: Node) -> Iterable[Node]: + def outgoing_order_links(self, node: ToNode) -> Iterable[Node]: return (p.node for p in self.linked_ports(node.out(-1))) - def incoming_order_links(self, node: Node) -> Iterable[Node]: + def incoming_order_links(self, node: ToNode) -> Iterable[Node]: return (p.node for p in self.linked_ports(node.inp(-1))) def _node_links( - self, node: Node, links: dict[_SubPort[P], _SubPort[K]] + self, node: ToNode, links: dict[_SubPort[P], _SubPort[K]] ) -> Iterable[tuple[P, list[K]]]: try: direction = next(iter(links.keys())).port.direction @@ -300,23 +322,23 @@ def _node_links( port = cast(P, node.port(offset, direction)) yield port, list(self._linked_ports(port, links)) - def outgoing_links(self, node: Node) -> Iterable[tuple[OutPort, list[InPort]]]: + def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]: return self._node_links(node, self._links.fwd) - def incoming_links(self, node: Node) -> Iterable[tuple[InPort, list[OutPort]]]: + def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]: return self._node_links(node, self._links.bck) def num_incoming(self, node: Node) -> int: # connecetd links return sum(1 for _ in self.incoming_links(node)) - def num_outgoing(self, node: Node) -> int: + def num_outgoing(self, node: ToNode) -> int: # connecetd links return sum(1 for _ in self.outgoing_links(node)) # TODO: num_links and _linked_ports - def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node]: + def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]: mapping: dict[Node, Node] = {} for idx, node_data in enumerate(hugr._nodes): diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 457334526..8d41ef8c0 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -8,7 +8,7 @@ def build_basic_cfg(cfg: Cfg) -> None: entry = cfg.simple_entry(1, [tys.Bool]) entry.set_block_outputs(*entry.inputs()) - cfg.branch(entry.root.out(0), cfg.exit) + cfg.branch(entry[0], cfg.exit) def test_basic_cfg() -> None: @@ -29,11 +29,11 @@ def test_branch() -> None: n = middle_2.add(DivMod(i, i)) middle_2.set_block_outputs(u, n[0]) - cfg.branch(entry.root.out(0), middle_1.root) - cfg.branch(entry.root.out(1), middle_2.root) + cfg.branch(entry[0], middle_1) + cfg.branch(entry[1], middle_2) - cfg.branch(middle_1.root.out(0), cfg.exit) - cfg.branch(middle_2.root.out(0), cfg.exit) + cfg.branch(middle_1[0], cfg.exit) + cfg.branch(middle_2[0], cfg.exit) _validate(cfg.hugr) @@ -44,7 +44,7 @@ def test_nested_cfg() -> None: cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs()) build_basic_cfg(cfg) - dfg.set_outputs(cfg.root) + dfg.set_outputs(cfg) _validate(dfg.hugr) @@ -62,10 +62,10 @@ def test_dom_edge() -> None: middle_2 = cfg.simple_block([INT_T], 1, [INT_T]) middle_2.set_block_outputs(u, *middle_2.inputs()) - cfg.branch(entry.root.out(0), middle_1.root) - cfg.branch(entry.root.out(1), middle_2.root) + cfg.branch(entry[0], middle_1) + cfg.branch(entry[1], middle_2) - cfg.branch(middle_1.root.out(0), cfg.exit) - cfg.branch(middle_2.root.out(0), cfg.exit) + cfg.branch(middle_1[0], cfg.exit) + cfg.branch(middle_2[0], cfg.exit) _validate(cfg.hugr) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 47dd9ce36..167f28d2f 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -224,8 +224,8 @@ def _nested_nop(dfg: Dfg): nested = h.add_nested([tys.Bool], [tys.Bool], a) _nested_nop(nested) - assert len(h.hugr.children(nested.root)) == 3 - h.set_outputs(nested.root) + assert len(h.hugr.children(nested)) == 3 + h.set_outputs(nested) _validate(h.hugr) @@ -238,14 +238,14 @@ def test_build_inter_graph(): nt = nested.add(Not(a)) nested.set_outputs(nt) - h.set_outputs(nested.root, b) + h.set_outputs(nested, b) _validate(h.hugr) assert _SubPort(h.input_node.out(-1)) in h.hugr._links assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1 - assert len(list(h.hugr.incoming_order_links(nested.root))) == 1 + assert len(list(h.hugr.incoming_order_links(nested))) == 1 assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0