Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cirq.decompose protocol to perform a DFS instead of a BFS on the decomposed OP-TREE #6116

Merged
merged 3 commits into from
Jun 1, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 57 additions & 110 deletions cirq-core/cirq/protocols/decompose_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
overload,
Expand Down Expand Up @@ -128,6 +129,54 @@ def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> DecomposeResult:
pass


def _try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult:
if decomposer is None or not isinstance(val, ops.Operation):
return None
return decomposer(val)


def _decompose_dfs(item: Any, **kwargs) -> Iterator['cirq.Operation']:
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
from cirq.circuits import CircuitOperation, FrozenCircuit

preserve_structure = kwargs.get('preserve_structure', False)
keep = kwargs.get('keep', None)
if isinstance(item, ops.Operation):
item_untagged = item.untagged
if isinstance(item_untagged, CircuitOperation) and preserve_structure:
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
new_fc = FrozenCircuit(_decompose_dfs(item_untagged.circuit, **kwargs))
yield item_untagged.replace(circuit=new_fc).with_tags(*item.tags)
return
if keep is not None and keep(item):
yield item
return

decomposed = _try_op_decomposer(item, kwargs.get('intercepting_decomposer', None))

if decomposed is NotImplemented or decomposed is None:
decomposed = decompose_once(item, default=None)

if decomposed is NotImplemented or decomposed is None:
decomposed = _try_op_decomposer(item, kwargs.get('fallback_decomposer', None))

if decomposed is NotImplemented or decomposed is None:
if not isinstance(item, ops.Operation) and isinstance(item, Iterable):
decomposed = item

if decomposed is NotImplemented or decomposed is None:
on_stuck_raise = kwargs.get('on_stuck_raise', None)
if keep is not None and on_stuck_raise is not None:
if isinstance(on_stuck_raise, Exception):
raise on_stuck_raise
elif callable(on_stuck_raise):
error = on_stuck_raise(item)
if error is not None:
raise error
yield item
else:
for val in ops.flatten_to_ops(decomposed):
yield from _decompose_dfs(val, **kwargs)


def decompose(
val: Any,
*,
Expand Down Expand Up @@ -200,55 +249,14 @@ def decompose(
"acceptable to keep."
)

if preserve_structure:
return _decompose_preserving_structure(
val,
intercepting_decomposer=intercepting_decomposer,
fallback_decomposer=fallback_decomposer,
keep=keep,
on_stuck_raise=on_stuck_raise,
)

def try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult:
if decomposer is None or not isinstance(val, ops.Operation):
return None
return decomposer(val)

output = []
queue: List[Any] = [val]
while queue:
item = queue.pop(0)
Copy link
Collaborator

@NoureldinYosri NoureldinYosri Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't replacing queue.pop(0) with queue.pop() and queue[:0] = x with queue.extend(reversed(x)) do the trick ?

dfs order is just stack order so replacing the lines (and renaming queue to stack) should do the job.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this would be sufficient for replacing the BFS order to a DFS order. But what we really is not just replacing the output ordering but also make sure that the items in the recursive OP-TREE are iterated upon in true DFS based ordering. So, in a follow up PR (since it requires a bunch of changes to the interfaces), I want the following test to pass:

def test_decompose_recursive_dfs():
    mock_qm = mock.Mock()

    class RecursiveDecompose(cirq.Gate):
        def __init__(self, recurse: bool = True):
            self.recurse = recurse

        def _num_qubits_(self) -> int:
            return 1

        def _decompose_(self, qubits):
            mock_qm.qalloc(self.recurse)
            yield RecursiveDecompose(recurse=False).on(*qubits) if self.recurse else cirq.Z(*qubits)
            mock_qm.qfree(self.recurse)

    gate = RecursiveDecompose()
    _ = cirq.decompose(gate.on(cirq.NamedQubit("q")))
    expected_calls = [
        mock.call.qalloc(True),
        mock.call.qalloc(False),
        mock.call.qfree(False),
        mock.call.qfree(True),
    ]
    mock_qm.assert_has_calls(expected_calls, any_order=False)

If you simply use a stack instead of a queue but continue the with the iterative approach; in the first step of iteration you'll end up yielding all the elements of RecursiveDecompose(True)._decompose_() instead of getting only the first value from the generator and then recursively decomposing it before moving on to the next statement.

If we get a list of all elements from RecursiveDecompose(True)._decompose_() in a single shot and insert them in a stack; the output ordering of decomposed operations would follow the DFS ordering but we would be processing all children of a node at once; instead of processing them one by one.

We could potentially get around this constraint by adding more if/else blocks and checking whether If the operation on top of the stack is an OP-TREE; only yield the first value from the generator and then push the modified OP-TREE back in the stack along with the first child. Then pop the first child and continue to decompose. But I think the recursive implementation here is a lot more intuitive; so unless there are significant performance concerns with using the recursive approach; I'd say let's make the change to a recursive approach now and we can always optimize back to the iterative approach later.

What do you think? I can also prototype the iterative approach directly if you have a strong preference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flattening the OP_TREE and pushing it to the stack in reverse order is what we need

stack.extend(reversed(cirq.flatten_to_ops(cirq.decompose_once(val))))

I'm okay either way. it's just that it seems that modifying the iterative approach is a lot less work

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I highlighted in my previous comment, flattening the OP-TREE will not work in this case. My argument will become more clear in my next follow up PR. I'll tag you once it's out and we can come back to the iterative vs recursive discussion at that point.

if isinstance(item, ops.Operation) and keep is not None and keep(item):
output.append(item)
continue

decomposed = try_op_decomposer(item, intercepting_decomposer)

if decomposed is NotImplemented or decomposed is None:
decomposed = decompose_once(item, default=None)

if decomposed is NotImplemented or decomposed is None:
decomposed = try_op_decomposer(item, fallback_decomposer)

if decomposed is not NotImplemented and decomposed is not None:
queue[:0] = ops.flatten_to_ops(decomposed)
continue

if not isinstance(item, ops.Operation) and isinstance(item, Iterable):
queue[:0] = ops.flatten_to_ops(item)
continue

if keep is not None and on_stuck_raise is not None:
if isinstance(on_stuck_raise, Exception):
raise on_stuck_raise
elif callable(on_stuck_raise):
error = on_stuck_raise(item)
if error is not None:
raise error

output.append(item)

return output
kwargs = {
'intercepting_decomposer': intercepting_decomposer,
'fallback_decomposer': fallback_decomposer,
'keep': keep,
'on_stuck_raise': on_stuck_raise,
'preserve_structure': preserve_structure,
}
return [*_decompose_dfs(val, **kwargs)]


# pylint: disable=function-redefined
Expand Down Expand Up @@ -383,65 +391,4 @@ def _try_decompose_into_operations_and_qubits(
qid_shape_dict[q] = max(qid_shape_dict[q], level)
qubits = sorted(qubit_set)
return result, qubits, tuple(qid_shape_dict[q] for q in qubits)

return None, (), ()


def _decompose_preserving_structure(
val: Any,
*,
intercepting_decomposer: Optional[OpDecomposer] = None,
fallback_decomposer: Optional[OpDecomposer] = None,
keep: Optional[Callable[['cirq.Operation'], bool]] = None,
on_stuck_raise: Union[
None, Exception, Callable[['cirq.Operation'], Optional[Exception]]
] = _value_error_describing_bad_operation,
) -> List['cirq.Operation']:
"""Preserves structure (e.g. subcircuits) while decomposing ops.

This can be used to reduce a circuit to a particular gateset without
increasing its serialization size. See tests for examples.
"""

# This method provides a generated 'keep' to its decompose() calls.
# If the user-provided keep is not set, on_stuck_raise must be unset to
# ensure that failure to decompose does not generate errors.
on_stuck_raise = on_stuck_raise if keep is not None else None

from cirq.circuits import CircuitOperation, FrozenCircuit

visited_fcs = set()

def keep_structure(op: 'cirq.Operation'):
circuit = getattr(op.untagged, 'circuit', None)
if circuit is not None:
return circuit in visited_fcs
if keep is not None and keep(op):
return True

def dps_interceptor(op: 'cirq.Operation'):
if not isinstance(op.untagged, CircuitOperation):
if intercepting_decomposer is None:
return NotImplemented
return intercepting_decomposer(op)

new_fc = FrozenCircuit(
decompose(
op.untagged.circuit,
intercepting_decomposer=dps_interceptor,
fallback_decomposer=fallback_decomposer,
keep=keep_structure,
on_stuck_raise=on_stuck_raise,
)
)
visited_fcs.add(new_fc)
new_co = op.untagged.replace(circuit=new_fc)
return new_co if not op.tags else new_co.with_tags(*op.tags)

return decompose(
val,
intercepting_decomposer=dps_interceptor,
fallback_decomposer=fallback_decomposer,
keep=keep_structure,
on_stuck_raise=on_stuck_raise,
)