Skip to content

Commit

Permalink
feat(hugr-py): IndexDfg builder for appending operations by index (#…
Browse files Browse the repository at this point in the history
…1256)

Closes #1242

More test ideas welcome!
  • Loading branch information
ss2165 authored Jul 5, 2024
1 parent 361e01c commit df9b4cc
Show file tree
Hide file tree
Showing 5 changed files with 355 additions and 9 deletions.
10 changes: 9 additions & 1 deletion hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,15 @@ def add(self, com: ops.Command) -> Node:
Node(3)
"""
return self.add_op(com.op, *com.incoming)

def raise_no_ints():
error_message = "Command used with Dfg must hold Wire, not integer indices."
raise ValueError(error_message)

wires = (
(w if not isinstance(w, int) else raise_no_ints()) for w in com.incoming
)
return self.add_op(com.op, *wires)

def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node)
Expand Down
11 changes: 7 additions & 4 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def _check_complete(op, v: V | None) -> V:
return v


ComWire = Wire | int


@dataclass(frozen=True)
class Command:
"""A :class:`DataflowOp` and its incoming :class:`Wire <hugr.nodeport.Wire>`
Expand All @@ -146,7 +149,7 @@ class Command:
"""

op: DataflowOp
incoming: list[Wire]
incoming: list[ComWire]


@dataclass()
Expand Down Expand Up @@ -244,7 +247,7 @@ def to_serial(self, parent: Node) -> sops.MakeTuple:
tys=ser_it(self.types),
)

def __call__(self, *elements: Wire) -> Command:
def __call__(self, *elements: ComWire) -> Command:
return super().__call__(*elements)

def outer_signature(self) -> tys.FunctionType:
Expand Down Expand Up @@ -282,7 +285,7 @@ def to_serial(self, parent: Node) -> sops.UnpackTuple:
tys=ser_it(self.types),
)

def __call__(self, tuple_: Wire) -> Command:
def __call__(self, tuple_: ComWire) -> Command:
return super().__call__(tuple_)

def outer_signature(self) -> tys.FunctionType:
Expand Down Expand Up @@ -925,7 +928,7 @@ def to_serial(self, parent: Node) -> sops.CallIndirect:
signature=self.signature.to_serial(),
)

def __call__(self, function: Wire, *args: Wire) -> Command: # type: ignore[override]
def __call__(self, function: ComWire, *args: ComWire) -> Command: # type: ignore[override]
return super().__call__(function, *args)

def outer_signature(self) -> tys.FunctionType:
Expand Down
220 changes: 220 additions & 0 deletions hugr-py/src/hugr/tracked_dfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Dfg builder that allows tracking a set of wires and appending operations by index."""

from collections.abc import Iterable

from hugr import tys
from hugr.dfg import Dfg
from hugr.node_port import Node, Wire
from hugr.ops import Command, ComWire


class TrackedDfg(Dfg):
"""Dfg builder to append operations to wires by index.
Args:
*input_types: Input types of the Dfg.
track_inputs: Whether to track the input wires.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True)
>>> dfg.tracked
[OutPort(Node(1), 0), OutPort(Node(1), 1)]
"""

#: Tracked wires. None if index is no longer tracked.
tracked: list[Wire | None]

def __init__(self, *input_types: tys.Type, track_inputs: bool = False) -> None:
super().__init__(*input_types)
self.tracked = list(self.inputs()) if track_inputs else []

def track_wire(self, wire: Wire) -> int:
"""Add a wire from this DFG to the tracked wires, and return its index.
Args:
wire: Wire to track.
Returns:
Index of the tracked wire.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit)
>>> dfg.track_wire(dfg.inputs()[0])
0
"""
self.tracked.append(wire)
return len(self.tracked) - 1

def untrack_wire(self, index: int) -> Wire:
"""Untrack a wire by index and return it.
Args:
index: Index of the wire to untrack.
Returns:
Wire that was untracked.
Raises:
IndexError: If the index is not a tracked wire.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit)
>>> w = dfg.inputs()[0]
>>> idx = dfg.track_wire(w)
>>> dfg.untrack_wire(idx) == w
True
"""
w = self.tracked_wire(index)
self.tracked[index] = None
return w

def track_wires(self, wires: Iterable[Wire]) -> list[int]:
"""Set a list of wires to be tracked and return their indices.
Args:
wires: Wires to track.
Returns:
List of indices of the tracked wires.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit)
>>> dfg.track_wires(dfg.inputs())
[0, 1]
"""
return [self.track_wire(w) for w in wires]

def track_inputs(self) -> list[int]:
"""Track all input wires and return their indices.
Returns:
List of indices of the tracked input wires.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit)
>>> dfg.track_inputs()
[0, 1]
"""
return self.track_wires(self.inputs())

def tracked_wire(self, index: int) -> Wire:
"""Get the tracked wire at the given index.
Args:
index: Index of the tracked wire.
Raises:
IndexError: If the index is not a tracked wire.
Returns:
Tracked wire
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True)
>>> dfg.tracked_wire(0) == dfg.inputs()[0]
True
"""
try:
tracked = self.tracked[index]
except IndexError:
tracked = None
if tracked is None:
msg = f"Index {index} not a tracked wire."
raise IndexError(msg)
return tracked

def append(self, com: Command) -> Node:
"""Add a command to the DFG.
Any incoming :class:`Wire <hugr.node_port.Wire>` will
be connected directly, while any integer will be treated as a reference
to the tracked wire at that index.
Any tracked wires will be updated to the output of the new node at the same port
as the incoming index.
Args:
com: Command to append.
Returns:
The new node.
Raises:
IndexError: If any input index is not a tracked wire.
Examples:
>>> dfg = TrackedDfg(tys.Bool, track_inputs=True)
>>> dfg.tracked
[OutPort(Node(1), 0)]
>>> dfg.append(ops.Noop()(0))
Node(3)
>>> dfg.tracked
[OutPort(Node(3), 0)]
"""
wires = self._to_wires(com.incoming)
n = self.add_op(com.op, *wires)

for port_offset, com_wire in enumerate(com.incoming):
if isinstance(com_wire, int):
tracked_idx = com_wire
else:
continue
# update tracked wires to matching port outputs of new node
self.tracked[tracked_idx] = n.out(port_offset)

return n

def _to_wires(self, in_wires: Iterable[ComWire]) -> Iterable[Wire]:
return (
self.tracked_wire(inc) if isinstance(inc, int) else inc for inc in in_wires
)

def extend(self, coms: Iterable[Command]) -> list[Node]:
"""Add a series of commands to the DFG.
Shorthand for calling :meth:`append` on each command in `coms`.
Args:
coms: Commands to append.
Returns:
List of the new nodes in the same order as the commands.
Raises:
IndexError: If any input index is not a tracked wire.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True)
>>> dfg.extend([ops.Noop()(0), ops.Noop()(1)])
[Node(3), Node(4)]
"""
return [self.append(com) for com in coms]

def set_indexed_outputs(self, *in_wires: ComWire) -> None:
"""Set the Dfg outputs, using either :class:`Wire <hugr.node_port.Wire>` or
indices to tracked wires.
Args:
*in_wires: Wires/indices to set as outputs.
Raises:
IndexError: If any input index is not a tracked wire.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit)
>>> (b, i) = dfg.inputs()
>>> dfg.track_wire(b)
0
>>> dfg.set_indexed_outputs(0, i)
"""
self.set_outputs(*self._to_wires(in_wires))

def set_tracked_outputs(self) -> None:
"""Set the Dfg outputs to the tracked wires.
Examples:
>>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True)
>>> dfg.set_tracked_outputs()
"""
self.set_outputs(*(w for w in self.tracked if w is not None))
55 changes: 51 additions & 4 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from hugr.serialization.serial_hugr import SerialHugr

if TYPE_CHECKING:
from hugr.node_port import Wire
from hugr.ops import ComWire


def int_t(width: int) -> tys.Opaque:
Expand All @@ -36,6 +36,22 @@ def to_value(self) -> val.Extension:
return val.Extension("int", INT_T, self.v)


FLOAT_T = tys.Opaque(
extension="arithmetic.float.types",
id="float64",
args=[],
bound=tys.TypeBound.Copyable,
)


@dataclass
class FloatVal(val.ExtensionValue):
v: float

def to_value(self) -> val.Extension:
return val.Extension("float", FLOAT_T, self.v)


@dataclass
class LogicOps(Custom):
extension: tys.ExtensionId = "logic"
Expand All @@ -51,7 +67,7 @@ class NotDef(LogicOps):
op_name: str = "Not"
signature: tys.FunctionType = _NotSig

def __call__(self, a: Wire) -> Command:
def __call__(self, a: ComWire) -> Command:
return super().__call__(a)


Expand All @@ -72,12 +88,28 @@ class OneQbGate(QuantumOps):
num_out: int = 1
signature: tys.FunctionType = _OneQbSig

def __call__(self, q: Wire) -> Command:
def __call__(self, q: ComWire) -> Command:
return super().__call__(q)


H = OneQbGate("H")


_TwoQbSig = tys.FunctionType.endo([tys.Qubit] * 2)


@dataclass
class TwoQbGate(QuantumOps):
op_name: str
num_out: int = 2
signature: tys.FunctionType = _TwoQbSig

def __call__(self, q0: ComWire, q1: ComWire) -> Command:
return super().__call__(q0, q1)


CX = TwoQbGate("CX")

_MeasSig = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])


Expand All @@ -87,12 +119,27 @@ class MeasureDef(QuantumOps):
num_out: int = 2
signature: tys.FunctionType = _MeasSig

def __call__(self, q: Wire) -> Command:
def __call__(self, q: ComWire) -> Command:
return super().__call__(q)


Measure = MeasureDef()

_RzSig = tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit])


@dataclass
class RzDef(QuantumOps):
op_name: str = "Rz"
num_out: int = 1
signature: tys.FunctionType = _RzSig

def __call__(self, q: ComWire, fl_wire: ComWire) -> Command:
return super().__call__(q, fl_wire)


Rz = RzDef()


@dataclass
class IntOps(Custom):
Expand Down
Loading

0 comments on commit df9b4cc

Please sign in to comment.