Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): only require input type annotations when building #1199

Merged
merged 17 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def set_single_successor_outputs(self, *outputs: Wire) -> None:

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
cfg_node = self.hugr[self.root].parent
cfg_node = self.hugr[self.parent_node].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
Expand All @@ -40,9 +40,9 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:


@dataclass
class Cfg(ParentBuilder):
class Cfg(ParentBuilder[ops.CFG]):
hugr: Hugr
root: Node
parent_node: Node
_entry_block: Block
exit: Node

Expand All @@ -55,13 +55,13 @@ def _init_impl(
self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow
) -> None:
self.hugr = hugr
self.root = root
self.parent_node = root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = Block.new_nested(
ops.DataflowBlock(input_types, []), hugr, root
)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root)
self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.parent_node)

@classmethod
def new_nested(
Expand All @@ -81,7 +81,7 @@ def new_nested(

@property
def entry(self) -> Node:
return self._entry_block.root
return self._entry_block.parent_node

def _entry_op(self) -> ops.DataflowBlock:
return self.hugr._get_typed_op(self.entry, ops.DataflowBlock)
Expand All @@ -105,7 +105,7 @@ def add_block(
new_block = Block.new_nested(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs),
self.hugr,
self.root,
self.parent_node,
)
return new_block

Expand Down
43 changes: 20 additions & 23 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from typing import (
Iterable,
TYPE_CHECKING,
Generic,
TypeVar,
cast,
)
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

Expand All @@ -25,31 +23,33 @@


@dataclass()
class _DfBase(ParentBuilder, Generic[DP]):
class _DfBase(ParentBuilder[DP]):
hugr: Hugr
root: Node
parent_node: Node
input_node: Node
output_node: Node

def __init__(self, root_op: DP) -> None:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.parent_node = self.hugr.root
self._init_io_nodes(root_op)

def _init_io_nodes(self, root_op: DP):
inner_sig = root_op.inner_signature()

self.input_node = self.hugr.add_node(
ops.Input(inner_sig.input), self.root, len(inner_sig.input)
ops.Input(inner_sig.input), self.parent_node, len(inner_sig.input)
)
self.output_node = self.hugr.add_node(
ops.Output(inner_sig.output), self.parent_node
)
self.output_node = self.hugr.add_node(ops.Output(inner_sig.output), self.root)

@classmethod
def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self:
new = cls.__new__(cls)

new.hugr = hugr
new.root = hugr.add_node(root_op, parent or hugr.root)
new.parent_node = hugr.add_node(root_op, parent or hugr.root)
new._init_io_nodes(root_op)
return new

Expand All @@ -59,14 +59,11 @@ def _input_op(self) -> ops.Input:
def _output_op(self) -> ops.Output:
return self.hugr._get_typed_op(self.output_node, ops.Output)

def root_op(self) -> DP:
return cast(DP, self.hugr[self.root].op)

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
new_n = self.hugr.add_node(op, self.root)
new_n = self.hugr.add_node(op, self.parent_node)
self._wire_up(new_n, args)

return replace(new_n, _num_out_ports=op.num_out)
Expand All @@ -75,9 +72,9 @@ def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], args)
return mapping[dfg.root]
mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node)
self._wire_up(mapping[dfg.parent_node], args)
return mapping[dfg.parent_node]

def add_nested(
self,
Expand All @@ -88,8 +85,8 @@ def add_nested(
input_types = [self._get_dataflow_type(w) for w in args]

root_op = ops.DFG(FunctionType(input=list(input_types), output=[]))
dfg = Dfg.new_nested(root_op, self.hugr, self.root)
self._wire_up(dfg.root, args)
dfg = Dfg.new_nested(root_op, self.hugr, self.parent_node)
self._wire_up(dfg.parent_node, args)
return dfg

def add_cfg(
Expand All @@ -100,18 +97,18 @@ def add_cfg(
) -> Cfg:
from ._cfg import Cfg

ss2165 marked this conversation as resolved.
Show resolved Hide resolved
cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.root)
self._wire_up(cfg.root, args)
cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.root)
self._wire_up(mapping[cfg.root], args)
return mapping[cfg.root]
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node)
self._wire_up(mapping[cfg.parent_node], args)
return mapping[cfg.parent_node]

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
self.root_op()._set_out_types(self._output_op().types)
self.parent_op()._set_out_types(self._output_op().types)

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
Expand Down
18 changes: 12 additions & 6 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,25 @@ def next_sub_offset(self) -> Self:
_SI = _SubPort[InPort]


class ParentBuilder(ToNode, Protocol):
hugr: Hugr
root: Node
class ParentBuilder(ToNode, Protocol[OpVar]):
hugr: Hugr[OpVar]
parent_node: Node

def to_node(self) -> Node:
return self.root
return self.parent_node

def parent_op(self) -> OpVar:
return cast(OpVar, self.hugr[self.parent_node].op)


@dataclass()
class Hugr(Mapping[Node, NodeData]):
class Hugr(Mapping[Node, NodeData], Generic[OpVar]):
root: Node
_nodes: list[NodeData | None]
_links: BiMap[_SO, _SI]
_free_nodes: list[Node]

def __init__(self, root_op: Op) -> None:
def __init__(self, root_op: OpVar) -> None:
self._free_nodes = []
self._links = BiMap()
self._nodes = []
Expand Down Expand Up @@ -269,6 +272,9 @@ def delete_link(self, src: OutPort, dst: InPort) -> None:
return
# TODO make sure sub-offset is handled correctly

def root_op(self) -> OpVar:
return cast(OpVar, self[self.root].op)
mark-koch marked this conversation as resolved.
Show resolved Hide resolved

def num_nodes(self) -> int:
return len(self._nodes) - len(self._free_nodes)

Expand Down
2 changes: 1 addition & 1 deletion hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,4 @@ def test_ancestral_sibling():

nt = nested.add(Not(a))

assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.root
assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.parent_node