Skip to content

Commit

Permalink
feat(hugr-py)!: Override add rather than append + extend on all…
Browse files Browse the repository at this point in the history
… `_DfBase` (#1286)

Both `add` and `append` take commands, but `append` allows commands with
indices in. We can therefore replace `append` with an `add` override.
Makes for a more consistent interface.

Allows `extend` to be defined on Dfg too.

BREAKING CHANGE:
    - `TrackedDfg.append` removed, use `add` instead.
    - `extend` now takes `*args` rather than an iterable.
  • Loading branch information
ss2165 authored Jul 10, 2024
1 parent 50d3d98 commit af38154
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 28 deletions.
22 changes: 22 additions & 0 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,28 @@ def raise_no_ints():
)
return self.add_op(com.op, *wires)

def extend(self, *coms: ops.Command) -> list[Node]:
"""Add a series of commands to the DFG.
Shorthand for calling :meth:`add` on each command in `coms`.
Args:
coms: Commands to add.
Returns:
List of the new nodes in the same order as the commands.
Raises:
IndexError: If any input index is not a tracked wire.
Examples:
>>> dfg = Dfg(tys.Bool, tys.Unit)
>>> (b, u) = dfg.inputs()
>>> dfg.extend(ops.Noop()(b), ops.Noop()(u))
[Node(3), Node(4)]
"""
return [self.add(com) for com in coms]

def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node)
self._wire_up(mapping[builder.parent_node], args)
Expand Down
28 changes: 5 additions & 23 deletions hugr-py/src/hugr/tracked_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ def tracked_wire(self, index: int) -> Wire:
raise IndexError(msg)
return tracked

def append(self, com: Command) -> Node:
def add(self, com: Command) -> Node:
"""Add a command to the DFG.
Overrides :meth:`Dfg.add <hugr.dfg.Dfg.add>` to allow Command inputs
to be either :class:`Wire <hugr.node_port.Wire>` or indices to tracked wires.
Any incoming :class:`Wire <hugr.node_port.Wire>` will
be connected directly, while any integer will be treated as a reference
to the tracked wire at that index.
Expand All @@ -146,7 +149,7 @@ def append(self, com: Command) -> Node:
>>> dfg = TrackedDfg(tys.Bool, track_inputs=True)
>>> dfg.tracked
[OutPort(Node(1), 0)]
>>> dfg.append(ops.Noop()(0))
>>> dfg.add(ops.Noop()(0))
Node(3)
>>> dfg.tracked
[OutPort(Node(3), 0)]
Expand All @@ -169,27 +172,6 @@ def _to_wires(self, in_wires: Iterable[ComWire]) -> Iterable[Wire]:
self.tracked_wire(inc) if isinstance(inc, int) else inc for inc in in_wires
)

def extend(self, coms: Iterable[Command]) -> list[Node]:
"""Add a series of commands to the DFG.
Shorthand for calling :meth:`append` on each command in `coms`.
Args:
coms: Commands to append.
Returns:
List of the new nodes in the same order as the commands.
Raises:
IndexError: If any input index is not a tracked wire.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True)
>>> dfg.extend([ops.Noop()(0), ops.Noop()(1)])
[Node(3), Node(4)]
"""
return [self.append(com) for com in coms]

def set_indexed_outputs(self, *in_wires: ComWire) -> None:
"""Set the Dfg outputs, using either :class:`Wire <hugr.node_port.Wire>` or
indices to tracked wires.
Expand Down
10 changes: 5 additions & 5 deletions hugr-py/tests/test_tracked_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def simple_circuit(n_qb: int, float_in: int = 0) -> TrackedDfg:

def test_simple_circuit():
circ = simple_circuit(2)
circ.append(H(0))
[_h, cx_n] = circ.extend([H(0), CX(0, 1)])
circ.add(H(0))
[_h, cx_n] = circ.extend(H(0), CX(0, 1))

circ.set_tracked_outputs()

Expand All @@ -52,12 +52,12 @@ def test_complex_circuit():
circ = simple_circuit(2)
fl = circ.load(FloatVal(0.5))

circ.extend([H(0), Rz(0, fl)])
[_m0, m1] = circ.extend(Measure(i) for i in range(2))
circ.extend(H(0), Rz(0, fl))
[_m0, m1] = circ.extend(*(Measure(i) for i in range(2)))

m_idx = circ.track_wire(m1[1]) # track the bool out
assert m_idx == 2
circ.append(Not(m_idx))
circ.add(Not(m_idx))

circ.set_tracked_outputs()

Expand Down

0 comments on commit af38154

Please sign in to comment.