Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): add builders for Conditional and TailLoop #1210

Merged
merged 12 commits into from
Jun 21, 2024
16 changes: 8 additions & 8 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass, replace
from dataclasses import dataclass

import hugr._ops as ops

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, TypeRow, Type
from ._tys import TypeRow, Type
import hugr._val as val


Expand All @@ -16,7 +16,7 @@ def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def set_single_succ_outputs(self, *outputs: Wire) -> None:
u = self.add_load_const(val.Unit)
u = self.load(val.Unit)
self.set_outputs(u, *outputs)

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
Expand Down Expand Up @@ -47,7 +47,7 @@ class Cfg(ParentBuilder[ops.CFG]):
exit: Node

def __init__(self, input_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=[]))
root_op = ops.CFG(inputs=input_types)
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, input_types)

Expand All @@ -68,7 +68,7 @@ def new_nested(
) -> Cfg:
new = cls.__new__(cls)
root = hugr.add_node(
ops.CFG(FunctionType(input=input_types, output=[])),
ops.CFG(inputs=input_types),
parent or hugr.root,
)
new._init_impl(hugr, root, input_types)
Expand Down Expand Up @@ -97,6 +97,8 @@ def add_block(self, input_types: TypeRow) -> Block:
)
return new_block

# TODO insert_block
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO now or later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later - I'm not convinced it's a common enough use case to do immediately


def add_successor(self, pred: Wire) -> Block:
b = self.add_block(self._nth_outputs(pred))

Expand Down Expand Up @@ -125,6 +127,4 @@ def branch_exit(self, src: Wire) -> None:
raise MismatchedExit(src.node.idx)
else:
self._exit_op._cfg_outputs = out_types
self.parent_op.signature = replace(
self.parent_op.signature, output=out_types
)
self.parent_op._outputs = out_types
110 changes: 110 additions & 0 deletions hugr-py/src/hugr/_cond_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

from dataclasses import dataclass

import hugr._ops as ops

from ._dfg import _DfBase
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import Sum, TypeRow


class Case(_DfBase[ops.Case]):
_parent_cond: Conditional | None = None

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)
if self._parent_cond is not None:
self._parent_cond._update_outputs(self._wire_types(outputs))


class ConditionalError(Exception):
pass


@dataclass
class _IfElse(Case):
def __init__(self, case: Case) -> None:
self.hugr = case.hugr
self.parent_node = case.parent_node
self.input_node = case.input_node
self.output_node = case.output_node
self._parent_cond = case._parent_cond

def _parent_conditional(self) -> Conditional:
if self._parent_cond is None:
raise ConditionalError("If must have a parent conditional.")

Check warning on line 36 in hugr-py/src/hugr/_cond_loop.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cond_loop.py#L36

Added line #L36 was not covered by tests
return self._parent_cond


class If(_IfElse):
def add_else(self) -> Else:
return Else(self._parent_conditional().add_case(0))


class Else(_IfElse):
def finish(self) -> Node:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if it were possible to finish an _IfElse with a trivial Else without having to explicitly add the Else.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about this but unfortunately the form of the "trivial else" depends on linearity of types, so instead of making too many assumptions decided to leave it explicit for now.

return self._parent_conditional().parent_node


@dataclass
class Conditional(ParentBuilder[ops.Conditional]):
cases: dict[int, Node | None]

def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
root_op = ops.Conditional(sum_ty, other_inputs)
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, len(sum_ty.variant_rows))

def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None:
self.hugr = hugr
self.parent_node = root
self.cases = {i: None for i in range(n_cases)}

@classmethod
def new_nested(
cls,
sum_ty: Sum,
other_inputs: TypeRow,
hugr: Hugr,
parent: ToNode | None = None,
) -> Conditional:
new = cls.__new__(cls)
root = hugr.add_node(
ops.Conditional(sum_ty, other_inputs),
parent or hugr.root,
)
new._init_impl(hugr, root, len(sum_ty.variant_rows))
return new

def _update_outputs(self, outputs: TypeRow) -> None:
if self.parent_op._outputs is None:
self.parent_op._outputs = outputs
else:
if outputs != self.parent_op._outputs:
raise ConditionalError("Mismatched case outputs.")

Check warning on line 85 in hugr-py/src/hugr/_cond_loop.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cond_loop.py#L85

Added line #L85 was not covered by tests

def add_case(self, case_id: int) -> Case:
if case_id not in self.cases:
raise ConditionalError(f"Case {case_id} out of possible range.")
input_types = self.parent_op.nth_inputs(case_id)
new_case = Case.new_nested(
ops.Case(input_types),
self.hugr,
self.parent_node,
)
new_case._parent_cond = self
self.cases[case_id] = new_case.parent_node
return new_case

# TODO insert_case
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO now or later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later - I'm not convinced it's a common enough use case to do immediately



@dataclass
class TailLoop(_DfBase[ops.TailLoop]):
def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None:
root_op = ops.TailLoop(just_inputs, rest)
super().__init__(root_op)

def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None:
self.set_outputs(sum_wire, *rest)
74 changes: 55 additions & 19 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
from typing import (
TYPE_CHECKING,
Iterable,
Sequence,
TypeVar,
)

from typing_extensions import Self

import hugr._ops as ops
import hugr._val as val
from hugr._tys import Type, TypeRow
from hugr._tys import Type, TypeRow, get_first_sum

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire

if TYPE_CHECKING:
from ._cfg import Cfg
from ._cond_loop import Conditional, If, TailLoop

Check warning on line 22 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L22

Added line #L22 was not covered by tests


DP = TypeVar("DP", bound=ops.DfParentOp)
Expand Down Expand Up @@ -72,39 +74,73 @@
def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming)

def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node)
self._wire_up(mapping[builder.parent_node], args)
return mapping[builder.parent_node]

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node)
self._wire_up(mapping[dfg.parent_node], args)
return mapping[dfg.parent_node]
return self._insert_nested_impl(dfg, *args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this delegated to an internal method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the internal method is generic, the surface ones are concrete for cases we know
about. Allows special casing in future if necessary.


def add_nested(
self,
*args: Wire,
) -> Dfg:
from ._dfg import Dfg

input_types = [self._get_dataflow_type(w) for w in args]

parent_op = ops.DFG(list(input_types))
parent_op = ops.DFG(self._wire_types(args))
dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node)
self._wire_up(dfg.parent_node, args)
return dfg

def _wire_types(self, args: Iterable[Wire]) -> TypeRow:
return [self._get_dataflow_type(w) for w in args]

def add_cfg(
self,
input_types: TypeRow,
*args: Wire,
) -> Cfg:
from ._cfg import Cfg

cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node)
cfg = Cfg.new_nested(self._wire_types(args), self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node)
self._wire_up(mapping[cfg.parent_node], args)
return mapping[cfg.parent_node]
return self._insert_nested_impl(cfg, *args)

Check warning on line 110 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L110

Added line #L110 was not covered by tests

def add_conditional(self, cond: Wire, *args: Wire) -> Conditional:
from ._cond_loop import Conditional

args = (cond, *args)
(sum_, other_inputs) = get_first_sum(self._wire_types(args))
cond = Conditional.new_nested(sum_, other_inputs, self.hugr, self.parent_node)
self._wire_up(cond.parent_node, args)
return cond

def insert_conditional(self, cond: Conditional, *args: Wire) -> Node:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test?

return self._insert_nested_impl(cond, *args)

def add_if(self, cond: Wire, *args: Wire) -> If:
from ._cond_loop import If

conditional = self.add_conditional(cond, *args)
return If(conditional.add_case(1))

def add_tail_loop(
self, just_inputs: Sequence[Wire], rest: Sequence[Wire]
) -> TailLoop:
from ._cond_loop import TailLoop

just_input_types = self._wire_types(just_inputs)
rest_types = self._wire_types(rest)
parent_op = ops.TailLoop(just_input_types, rest_types)
tl = TailLoop.new_nested(parent_op, self.hugr, self.parent_node)
self._wire_up(tl.parent_node, (*just_inputs, *rest))
return tl

def insert_tail_loop(self, tl: TailLoop, *args: Wire) -> Node:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test?

return self._insert_nested_impl(tl, *args)

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
Expand All @@ -117,22 +153,22 @@
def add_const(self, val: val.Value) -> Node:
return self.hugr.add_const(val, self.parent_node)

def load_const(self, const_node: ToNode) -> Node:
const_op = self.hugr._get_typed_op(const_node, ops.Const)
def load(self, const: ToNode | val.Value) -> Node:
if isinstance(const, val.Value):
const = self.add_const(const)
const_op = self.hugr._get_typed_op(const, ops.Const)
load_op = ops.LoadConst(const_op.val.type_())

load = self.add(load_op())
self.hugr.add_link(const_node.out_port(), load.inp(0))
self.hugr.add_link(const.out_port(), load.inp(0))

return load

def add_load_const(self, val: val.Value) -> Node:
return self.load_const(self.add_const(val))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
op.set_in_types(tys)
return tys

def _get_dataflow_type(self, wire: Wire) -> Type:
port = wire.out_port()
Expand Down
Loading
Loading