diff --git a/hugr-py/src/hugr/build/cond_loop.py b/hugr-py/src/hugr/build/cond_loop.py index 02b046251..7a3a1d732 100644 --- a/hugr-py/src/hugr/build/cond_loop.py +++ b/hugr-py/src/hugr/build/cond_loop.py @@ -5,7 +5,7 @@ from __future__ import annotations from contextlib import AbstractContextManager -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING from typing_extensions import Self @@ -19,6 +19,7 @@ if TYPE_CHECKING: from hugr.hugr.node_port import Node, ToNode, Wire from hugr.tys import TypeRow +import warnings class Case(DfBase[ops.Case]): @@ -104,8 +105,8 @@ class Conditional(ParentBuilder[ops.Conditional], AbstractContextManager): Conditional(sum_ty=Bool, other_inputs=[Qubit]) """ - #: map from case index to node holding the :class:`Case ` - cases: dict[int, Node | None] + #: builders for each case and whether they have been built by the user yet + _case_builders: list[tuple[Case, bool]] = field(default_factory=list) def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None: root_op = ops.Conditional(sum_ty, other_inputs) @@ -115,13 +116,40 @@ def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None: def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None: self.hugr = hugr self.parent_node = root - self.cases = {i: None for i in range(n_cases)} + self._case_builders = [] + + for case_id in range(n_cases): + new_case = Case.new_nested( + ops.Case(self.parent_op.nth_inputs(case_id)), + self.hugr, + self.parent_node, + ) + new_case._parent_cond = self + self._case_builders.append((new_case, False)) + + @property + def cases(self) -> dict[int, Node | None]: + """Map from case index to node holding the :class:`Case `. + + DEPRECATED + """ + # TODO remove in 0.10 + warnings.warn( + "The 'cases' property is deprecated and" + " will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + return { + i: case.parent_node if b else None + for i, (case, b) in enumerate(self._case_builders) + } def __enter__(self) -> Self: return self def __exit__(self, *args) -> None: - if any(c is None for c in self.cases.values()): + if not all(built for _, built in self._case_builders): msg = "All cases must be added before exiting context." raise ConditionalError(msg) return None @@ -185,18 +213,15 @@ def add_case(self, case_id: int) -> Case: >>> with cond.add_case(0) as case:\ case.set_outputs(*case.inputs()) """ - if case_id not in self.cases: + if case_id >= len(self._case_builders): msg = f"Case {case_id} out of possible range." raise ConditionalError(msg) - input_types = self.parent_op.nth_inputs(case_id) - new_case = Case.new_nested( - ops.Case(input_types), - self.hugr, - self.parent_node, - ) - new_case._parent_cond = self - self.cases[case_id] = new_case.parent_node - return new_case + case, built = self._case_builders[case_id] + if built: + msg = f"Case {case_id} already built." + raise ConditionalError(msg) + self._case_builders[case_id] = (case, True) + return case # TODO insert_case diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 73fd55381..ac4f49f7e 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -134,3 +134,13 @@ def test_complex_tail_loop() -> None: h.set_outputs(*tl[:3]) validate(h.hugr) + + +def test_conditional_bug() -> None: + # bug with case ordering https://github.com/CQCL/hugr/issues/1596 + cond = Conditional(tys.Either([tys.USize()], [tys.Unit]), []) + with cond.add_case(1) as case: + case.set_outputs() + with cond.add_case(0) as case: + case.set_outputs() + validate(cond.hugr)