-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(hugr-py): add builders for Conditional and TailLoop #1210
Changes from all commits
881d9c3
d0bc866
b1665d9
532a82b
d0f1ec5
e2d6644
9863375
7c56560
939772c
c8105bb
9166f7f
4b5cd90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
|
||
import hugr._ops as ops | ||
|
||
from ._dfg import _DfBase | ||
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire | ||
from ._tys import Sum, TypeRow | ||
|
||
|
||
class Case(_DfBase[ops.Case]): | ||
_parent_cond: Conditional | None = None | ||
|
||
def set_outputs(self, *outputs: Wire) -> None: | ||
super().set_outputs(*outputs) | ||
if self._parent_cond is not None: | ||
self._parent_cond._update_outputs(self._wire_types(outputs)) | ||
|
||
|
||
class ConditionalError(Exception): | ||
pass | ||
|
||
|
||
@dataclass | ||
class _IfElse(Case): | ||
def __init__(self, case: Case) -> None: | ||
self.hugr = case.hugr | ||
self.parent_node = case.parent_node | ||
self.input_node = case.input_node | ||
self.output_node = case.output_node | ||
self._parent_cond = case._parent_cond | ||
|
||
def _parent_conditional(self) -> Conditional: | ||
if self._parent_cond is None: | ||
raise ConditionalError("If must have a parent conditional.") | ||
return self._parent_cond | ||
|
||
|
||
class If(_IfElse): | ||
def add_else(self) -> Else: | ||
return Else(self._parent_conditional().add_case(0)) | ||
|
||
|
||
class Else(_IfElse): | ||
def finish(self) -> Node: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice if it were possible to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I thought about this but unfortunately the form of the "trivial else" depends on linearity of types, so instead of making too many assumptions decided to leave it explicit for now. |
||
return self._parent_conditional().parent_node | ||
|
||
|
||
@dataclass | ||
class Conditional(ParentBuilder[ops.Conditional]): | ||
cases: dict[int, Node | None] | ||
|
||
def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None: | ||
root_op = ops.Conditional(sum_ty, other_inputs) | ||
hugr = Hugr(root_op) | ||
self._init_impl(hugr, hugr.root, len(sum_ty.variant_rows)) | ||
|
||
def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None: | ||
self.hugr = hugr | ||
self.parent_node = root | ||
self.cases = {i: None for i in range(n_cases)} | ||
|
||
@classmethod | ||
def new_nested( | ||
cls, | ||
sum_ty: Sum, | ||
other_inputs: TypeRow, | ||
hugr: Hugr, | ||
parent: ToNode | None = None, | ||
) -> Conditional: | ||
new = cls.__new__(cls) | ||
root = hugr.add_node( | ||
ops.Conditional(sum_ty, other_inputs), | ||
parent or hugr.root, | ||
) | ||
new._init_impl(hugr, root, len(sum_ty.variant_rows)) | ||
return new | ||
|
||
def _update_outputs(self, outputs: TypeRow) -> None: | ||
if self.parent_op._outputs is None: | ||
self.parent_op._outputs = outputs | ||
else: | ||
if outputs != self.parent_op._outputs: | ||
raise ConditionalError("Mismatched case outputs.") | ||
|
||
def add_case(self, case_id: int) -> Case: | ||
if case_id not in self.cases: | ||
raise ConditionalError(f"Case {case_id} out of possible range.") | ||
input_types = self.parent_op.nth_inputs(case_id) | ||
new_case = Case.new_nested( | ||
ops.Case(input_types), | ||
self.hugr, | ||
self.parent_node, | ||
) | ||
new_case._parent_cond = self | ||
self.cases[case_id] = new_case.parent_node | ||
return new_case | ||
|
||
# TODO insert_case | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO now or later? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. later - I'm not convinced it's a common enough use case to do immediately |
||
|
||
|
||
@dataclass | ||
class TailLoop(_DfBase[ops.TailLoop]): | ||
def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None: | ||
root_op = ops.TailLoop(just_inputs, rest) | ||
super().__init__(root_op) | ||
|
||
def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None: | ||
self.set_outputs(sum_wire, *rest) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,20 +4,22 @@ | |
from typing import ( | ||
TYPE_CHECKING, | ||
Iterable, | ||
Sequence, | ||
TypeVar, | ||
) | ||
|
||
from typing_extensions import Self | ||
|
||
import hugr._ops as ops | ||
import hugr._val as val | ||
from hugr._tys import Type, TypeRow | ||
from hugr._tys import Type, TypeRow, get_first_sum | ||
|
||
from ._exceptions import NoSiblingAncestor | ||
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire | ||
|
||
if TYPE_CHECKING: | ||
from ._cfg import Cfg | ||
from ._cond_loop import Conditional, If, TailLoop | ||
|
||
|
||
DP = TypeVar("DP", bound=ops.DfParentOp) | ||
|
@@ -72,39 +74,73 @@ | |
def add(self, com: ops.Command) -> Node: | ||
return self.add_op(com.op, *com.incoming) | ||
|
||
def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node: | ||
mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node) | ||
self._wire_up(mapping[builder.parent_node], args) | ||
return mapping[builder.parent_node] | ||
|
||
def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: | ||
mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node) | ||
self._wire_up(mapping[dfg.parent_node], args) | ||
return mapping[dfg.parent_node] | ||
return self._insert_nested_impl(dfg, *args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this delegated to an internal method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the internal method is generic, the surface ones are concrete for cases we know |
||
|
||
def add_nested( | ||
self, | ||
*args: Wire, | ||
) -> Dfg: | ||
from ._dfg import Dfg | ||
|
||
input_types = [self._get_dataflow_type(w) for w in args] | ||
|
||
parent_op = ops.DFG(list(input_types)) | ||
parent_op = ops.DFG(self._wire_types(args)) | ||
dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node) | ||
self._wire_up(dfg.parent_node, args) | ||
return dfg | ||
|
||
def _wire_types(self, args: Iterable[Wire]) -> TypeRow: | ||
return [self._get_dataflow_type(w) for w in args] | ||
|
||
def add_cfg( | ||
self, | ||
input_types: TypeRow, | ||
*args: Wire, | ||
) -> Cfg: | ||
from ._cfg import Cfg | ||
|
||
cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node) | ||
cfg = Cfg.new_nested(self._wire_types(args), self.hugr, self.parent_node) | ||
self._wire_up(cfg.parent_node, args) | ||
return cfg | ||
|
||
def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: | ||
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node) | ||
self._wire_up(mapping[cfg.parent_node], args) | ||
return mapping[cfg.parent_node] | ||
return self._insert_nested_impl(cfg, *args) | ||
|
||
def add_conditional(self, cond: Wire, *args: Wire) -> Conditional: | ||
from ._cond_loop import Conditional | ||
|
||
args = (cond, *args) | ||
(sum_, other_inputs) = get_first_sum(self._wire_types(args)) | ||
cond = Conditional.new_nested(sum_, other_inputs, self.hugr, self.parent_node) | ||
self._wire_up(cond.parent_node, args) | ||
return cond | ||
|
||
def insert_conditional(self, cond: Conditional, *args: Wire) -> Node: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test? |
||
return self._insert_nested_impl(cond, *args) | ||
|
||
def add_if(self, cond: Wire, *args: Wire) -> If: | ||
from ._cond_loop import If | ||
|
||
conditional = self.add_conditional(cond, *args) | ||
return If(conditional.add_case(1)) | ||
|
||
def add_tail_loop( | ||
self, just_inputs: Sequence[Wire], rest: Sequence[Wire] | ||
) -> TailLoop: | ||
from ._cond_loop import TailLoop | ||
|
||
just_input_types = self._wire_types(just_inputs) | ||
rest_types = self._wire_types(rest) | ||
parent_op = ops.TailLoop(just_input_types, rest_types) | ||
tl = TailLoop.new_nested(parent_op, self.hugr, self.parent_node) | ||
self._wire_up(tl.parent_node, (*just_inputs, *rest)) | ||
return tl | ||
|
||
def insert_tail_loop(self, tl: TailLoop, *args: Wire) -> Node: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test? |
||
return self._insert_nested_impl(tl, *args) | ||
|
||
def set_outputs(self, *args: Wire) -> None: | ||
self._wire_up(self.output_node, args) | ||
|
@@ -117,22 +153,22 @@ | |
def add_const(self, val: val.Value) -> Node: | ||
return self.hugr.add_const(val, self.parent_node) | ||
|
||
def load_const(self, const_node: ToNode) -> Node: | ||
const_op = self.hugr._get_typed_op(const_node, ops.Const) | ||
def load(self, const: ToNode | val.Value) -> Node: | ||
if isinstance(const, val.Value): | ||
const = self.add_const(const) | ||
const_op = self.hugr._get_typed_op(const, ops.Const) | ||
load_op = ops.LoadConst(const_op.val.type_()) | ||
|
||
load = self.add(load_op()) | ||
self.hugr.add_link(const_node.out_port(), load.inp(0)) | ||
self.hugr.add_link(const.out_port(), load.inp(0)) | ||
|
||
return load | ||
|
||
def add_load_const(self, val: val.Value) -> Node: | ||
return self.load_const(self.add_const(val)) | ||
|
||
def _wire_up(self, node: Node, ports: Iterable[Wire]): | ||
def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: | ||
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] | ||
if isinstance(op := self.hugr[node].op, ops.PartialOp): | ||
op.set_in_types(tys) | ||
return tys | ||
|
||
def _get_dataflow_type(self, wire: Wire) -> Type: | ||
port = wire.out_port() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO now or later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
later - I'm not convinced it's a common enough use case to do immediately