diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index 7083b08c3..ba8e3b850 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -1,6 +1,8 @@ +"""Builder for HUGR datflow graphs.""" + from __future__ import annotations -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from typing import ( TYPE_CHECKING, Iterable, @@ -28,10 +30,20 @@ @dataclass() class _DfBase(ParentBuilder[DP]): - hugr: Hugr + """Base class for dataflow graph builders. + + Args: + parent_op: The parent operation of the dataflow graph. + """ + + #: The Hugr instance that the builder is using. + hugr: Hugr = field(repr=False) + #: The parent node of the dataflow graph. parent_node: Node - input_node: Node - output_node: Node + #: The input node of the dataflow graph. + input_node: Node = field(repr=False) + #: The output node of the dataflow graph. + output_node: Node = field(repr=False) def __init__(self, parent_op: DP) -> None: self.hugr = Hugr(parent_op) @@ -50,6 +62,20 @@ def _init_io_nodes(self, parent_op: DP): def new_nested( cls, parent_op: DP, hugr: Hugr, parent: ToNode | None = None ) -> Self: + """Start building a dataflow graph nested inside a larger hugr. + + Args: + parent_op: The parent operation of the new dataflow graph. + hugr: The host hugr instance to build the dataflow graph in. + parent: Parent of new dataflow graph's root node, defaults to the + host hugr root. + + Example: + >>> hugr = Hugr() + >>> dfg = Dfg.new_nested(ops.DFG([]), hugr) + >>> dfg.parent_node + Node(1) + """ new = cls.__new__(cls) new.hugr = hugr @@ -64,15 +90,49 @@ def _output_op(self) -> ops.Output: return self.hugr._get_typed_op(self.output_node, ops.Output) def inputs(self) -> list[OutPort]: + """List all incoming wires, output ports of the input node. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> dfg.inputs() + [OutPort(Node(1), 0)] + """ return [self.input_node.out(i) for i in range(len(self._input_op().types))] def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node: + """Add a dataflow operation to the graph, wiring in input ports. + + Args: + op: The operation to add. + args: The input wires to the operation. + + Returns: + The node holding the new operation. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> dfg.add_op(ops.Noop(), dfg.inputs()[0]) + Node(3) + """ new_n = self.hugr.add_node(op, self.parent_node) self._wire_up(new_n, args) return replace(new_n, _num_out_ports=op.num_out) def add(self, com: ops.Command) -> Node: + """Add a command (holding a dataflow operation and the incoming wires) + to the graph. + + Args: + com: The command to add. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> (i,) = dfg.inputs() + >>> dfg.add(ops.Noop()(i)) + Node(3) + + """ return self.add_op(com.op, *com.incoming) def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node: @@ -81,12 +141,42 @@ def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node: return mapping[builder.parent_node] def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: + """Insert a nested dataflow graph into the current graph, wiring in the + inputs. + + Args: + dfg: The dataflow graph to insert. + args: The input wires to the graph. + + Returns: + The root node of the inserted graph. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> dfg2 = Dfg(tys.Bool) + >>> dfg.insert_nested(dfg2, dfg.inputs()[0]) + Node(3) + """ return self._insert_nested_impl(dfg, *args) def add_nested( self, *args: Wire, ) -> Dfg: + """Start building a nested dataflow graph. + + Args: + args: The input wires to the nested DFG. + + Returns: + Builder for new nested dataflow graph. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> dfg2 = dfg.add_nested(dfg.inputs()[0]) + >>> dfg2.parent_node + Node(3) + """ from .dfg import Dfg parent_op = ops.DFG(self._wire_types(args)) @@ -101,6 +191,20 @@ def add_cfg( self, *args: Wire, ) -> Cfg: + """Start building a new CFG nested inside the current dataflow graph. + + Args: + args: The input wires to the new CFG. + + Returns: + Builder for new nested CFG. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> cfg = dfg.add_cfg(dfg.inputs()[0]) + >>> cfg.parent_op + CFG(inputs=[Bool]) + """ from .cfg import Cfg cfg = Cfg.new_nested(self._wire_types(args), self.hugr, self.parent_node) @@ -108,29 +212,121 @@ def add_cfg( return cfg def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: + """Insert a CFG into the current dataflow graph, wiring in the inputs. + + Args: + cfg: The CFG to insert. + args: The input wires to the CFG. + + Returns: + The root node of the inserted CFG. + + Example: + >>> from hugr.cfg import Cfg + >>> dfg = Dfg(tys.Bool) + >>> cfg = Cfg(tys.Bool) + >>> dfg.insert_cfg(cfg, dfg.inputs()[0]) + Node(3) + """ return self._insert_nested_impl(cfg, *args) - def add_conditional(self, cond: Wire, *args: Wire) -> Conditional: + def add_conditional(self, cond_wire: Wire, *args: Wire) -> Conditional: + """Start building a new conditional nested inside the current dataflow + graph. + + Args: + cond_wire: The wire holding the wire (of Sum type) to branch the + conditional on. + args: Remaining input wires to the conditional. + + Returns: + Builder for new nested conditional. + + Example: + >>> dfg = Dfg(tys.Bool, tys.Unit) + >>> (cond, unit) = dfg.inputs() + >>> cond = dfg.add_conditional(cond, unit) + >>> cond.parent_node + Node(3) + """ from .cond_loop import Conditional - args = (cond, *args) + args = (cond_wire, *args) (sum_, other_inputs) = tys.get_first_sum(self._wire_types(args)) - cond = Conditional.new_nested(sum_, other_inputs, self.hugr, self.parent_node) - self._wire_up(cond.parent_node, args) - return cond - - def insert_conditional(self, cond: Conditional, *args: Wire) -> Node: - return self._insert_nested_impl(cond, *args) + cond_wire = Conditional.new_nested( + sum_, other_inputs, self.hugr, self.parent_node + ) + self._wire_up(cond_wire.parent_node, args) + return cond_wire - def add_if(self, cond: Wire, *args: Wire) -> If: + def insert_conditional( + self, cond: Conditional, cond_wire: Wire, *args: Wire + ) -> Node: + """Insert a conditional into the current dataflow graph, wiring in the + inputs. + + Args: + cond: The conditional to insert. + cond_wire: The wire holding the wire (of Sum type) to branch the Conditional on. + args: Remaining input wires to the conditional. + + Returns: + The root node of the inserted conditional. + + Example: + >>> from hugr.cond_loop import Conditional + >>> cond = Conditional(tys.Bool, []) + >>> dfg = Dfg(tys.Bool) + >>> cond_n = dfg.insert_conditional(cond, dfg.inputs()[0]) + >>> dfg.hugr[cond_n].op + Conditional(sum_ty=Bool, other_inputs=[]) + """ + return self._insert_nested_impl(cond, *(cond_wire, *args)) + + def add_if(self, cond_wire: Wire, *args: Wire) -> If: + """Start building a new if block nested inside the current dataflow + graph. + + Args: + cond_wire: The wire holding the Bool wire to branch the If on. + args: Remaining input wires to the If (and subsequent Else). + + Returns: + Builder for new nested If. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> (cond,) = dfg.inputs() + >>> if_ = dfg.add_if(cond, cond) + >>> if_.parent_op + Case(inputs=[Bool]) + """ from .cond_loop import If - conditional = self.add_conditional(cond, *args) + conditional = self.add_conditional(cond_wire, *args) return If(conditional.add_case(1)) def add_tail_loop( self, just_inputs: Sequence[Wire], rest: Sequence[Wire] ) -> TailLoop: + """Start building a new tail loop nested inside the current dataflow + graph. + + Args: + just_inputs: input wires for types that are only inputs to the loop body. + rest: input wires for types that are inputs and outputs of the loop + body. + + Returns: + Builder for new nested TailLoop. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> (cond,) = dfg.inputs() + >>> tl = dfg.add_tail_loop([cond], [cond]) + >>> tl.parent_op + TailLoop(just_inputs=[Bool], rest=[Bool]) + """ from .cond_loop import TailLoop just_input_types = self._wire_types(just_inputs) @@ -143,20 +339,97 @@ def add_tail_loop( def insert_tail_loop( self, tl: TailLoop, just_inputs: Sequence[Wire], rest: Sequence[Wire] ) -> Node: + """Insert a tail loop into the current dataflow graph, wiring in the + inputs. + + Args: + tl: The tail loop to insert. + just_inputs: input wires for types that are only inputs to the loop body. + rest: input wires for types that are inputs and outputs of the loop + body. + + Returns: + The root node of the inserted tail loop. + + Example: + >>> from hugr.cond_loop import TailLoop + >>> tl = TailLoop([tys.Bool], [tys.Bool]) + >>> dfg = Dfg(tys.Bool) + >>> (b,) = dfg.inputs() + >>> tl_n = dfg.insert_tail_loop(tl, [b], [b]) + >>> dfg.hugr[tl_n].op + TailLoop(just_inputs=[Bool], rest=[Bool]) + """ return self._insert_nested_impl(tl, *(*just_inputs, *rest)) def set_outputs(self, *args: Wire) -> None: + """Set the outputs of the dataflow graph. + Connects wires to the output node. + + Args: + args: Wires to connect to the output node. + + Example: + + >>> dfg = Dfg(tys.Bool) + >>> dfg.set_outputs(dfg.inputs()[0]) # connect input to output + """ self._wire_up(self.output_node, args) self.parent_op._set_out_types(self._output_op().types) def add_state_order(self, src: Node, dst: Node) -> None: + """Add a state order link between two nodes. + + Args: + src: The source node. + dst: The destination node. + + Examples: + >>> df = dfg.Dfg() + >>> df.add_state_order(df.input_node, df.output_node) + >>> list(df.hugr.outgoing_order_links(df.input_node)) + [Node(2)] + """ # adds edge to the right of all existing edges self.hugr.add_link(src.out(-1), dst.inp(-1)) def add_const(self, val: val.Value) -> Node: + """Add a static constant to the graph. + + Args: + val: The value to add. + + Returns: + The node holding the :class:`Const ` operation. + + Example: + >>> dfg = Dfg() + >>> const_n = dfg.add_const(val.TRUE) + >>> dfg.hugr[const_n].op + Const(TRUE) + """ return self.hugr.add_const(val, self.parent_node) def load(self, const: ToNode | val.Value) -> Node: + """Load a constant into the graph as a dataflow value. + + Args: + const: The constant to load, either a Value that will be added as a + child Const node then loaded, or a node corresponding to an existing + Const. + + Returns: + The node holding the :class:`LoadConst ` + operation. + + Example: + >>> dfg = Dfg() + >>> const_n = dfg.load(val.TRUE) + >>> len(dfg.hugr) # parent, input, output, const, load + 5 + >>> dfg.hugr[const_n].op + LoadConst(Bool) + """ if isinstance(const, val.Value): const = self.add_const(const) const_op = self.hugr._get_typed_op(const, ops.Const) @@ -174,6 +447,21 @@ def call( instantiation: tys.FunctionType | None = None, type_args: Sequence[tys.TypeArg] | None = None, ) -> Node: + """Call a static function in the graph. + See :class:`Call ` for more on how polymorphic functions + are handled. + + Args: + func: The node corresponding to the function definition/declaration to call. + args: The input wires to the function call. + instantiation: The concrete function type to call (needed if polymorphic). + type_args: The type arguments for the function (needed if + polymorphic). + + Returns: + The node holding the :class:`Call ` operation. + """ + signature = self._fn_sig(func) call_op = ops.Call(signature, instantiation, type_args) call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out) @@ -189,6 +477,17 @@ def load_function( instantiation: tys.FunctionType | None = None, type_args: Sequence[tys.TypeArg] | None = None, ) -> Node: + """Load a static function in to the graph as a higher-order value. + + Args: + func: The node corresponding to the function definition/declaration to load. + instantiation: The concrete function type to load (needed if polymorphic). + type_args: The type arguments for the function (needed if + polymorphic). + + Returns: + The node holding the :class:`LoadFunc ` operation. + """ signature = self._fn_sig(func) load_op = ops.LoadFunc(signature, instantiation, type_args) load_n = self.hugr.add_node(load_op, self.parent_node) @@ -231,6 +530,20 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> tys.Type: class Dfg(_DfBase[ops.DFG]): + """Builder for a simple nested Dataflow graph, with root node of type + :class:`DFG `. + + Args: + input_types: The input types of the the dataflow graph. Output types are + calculated by propagating types through the graph. + extension_delta: The extension delta of the graph. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> dfg.parent_op + DFG(inputs=[Bool]) + """ + def __init__( self, *input_types: tys.Type, extension_delta: tys.ExtensionSet | None = None ) -> None: @@ -239,6 +552,7 @@ def __init__( def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: + """Find the ancestor of `tgt` that is a sibling of `src`, if one exists.""" src_parent = h[src].parent while (tgt_parent := h[tgt].parent) is not None: