From 60e8afd40c13dfa96017cae6e43bcf00a2c3e25f Mon Sep 17 00:00:00 2001 From: Noureldin Date: Mon, 7 Nov 2022 18:03:37 +0000 Subject: [PATCH] created decompose protocol and assert_decompose_is_consistent_with_t_complexity (#99) * created decompose protocol and assert_decompose_is_consistent_with_t_complexity * remove print * removed decompose * update decompose * update decompose * added interception and fallback decomposers to decompose_once_into_operations * remove unused import Co-authored-by: Tanuj Khattar --- cirq_qubitization/__init__.py | 1 + cirq_qubitization/decompose_protocol.py | 92 +++++++++++++++++++ cirq_qubitization/decompose_protocol_test.py | 27 ++++++ cirq_qubitization/t_complexity_protocol.py | 10 +- .../t_complexity_protocol_test.py | 2 + cirq_qubitization/testing.py | 16 +++- cirq_qubitization/testing_test.py | 40 ++++++++ 7 files changed, 182 insertions(+), 6 deletions(-) create mode 100644 cirq_qubitization/decompose_protocol.py create mode 100644 cirq_qubitization/decompose_protocol_test.py diff --git a/cirq_qubitization/__init__.py b/cirq_qubitization/__init__.py index 280bed42b..55055e59b 100644 --- a/cirq_qubitization/__init__.py +++ b/cirq_qubitization/__init__.py @@ -17,6 +17,7 @@ from cirq_qubitization.qrom import QROM from cirq_qubitization.select_swap_qroam import SelectSwapQROM from cirq_qubitization.t_complexity_protocol import TComplexity, t_complexity +from cirq_qubitization.decompose_protocol import decompose_once_into_operations from cirq_qubitization.select_swap_qroam import SelectSwapQROM from cirq_qubitization.programmable_rotation_gate_array import ( ProgrammableRotationGateArray, diff --git a/cirq_qubitization/decompose_protocol.py b/cirq_qubitization/decompose_protocol.py new file mode 100644 index 000000000..8bbb0b098 --- /dev/null +++ b/cirq_qubitization/decompose_protocol.py @@ -0,0 +1,92 @@ +from typing import Any, Callable, Optional, Tuple + +import cirq +from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits + + +DecomposeResult = Optional[Tuple[cirq.Operation, ...]] +OpDecomposer = Callable[[Any], DecomposeResult] + +_FREDKIN_GATESET = cirq.Gateset(cirq.FREDKIN, unroll_circuit_op=False) + + +def _fredkin(qubits: cirq.Qid) -> cirq.OP_TREE: + """Decomposition with 7 T and 10 clifford operations from https://arxiv.org/abs/1308.4134""" + c, t1, t2 = qubits + yield [cirq.CNOT(t2, t1)] + yield [cirq.CNOT(c, t1), cirq.H(t2)] + yield [cirq.T(c), cirq.T(t1) ** -1, cirq.T(t2)] + yield [cirq.CNOT(t2, t1)] + yield [cirq.CNOT(c, t2), cirq.T(t1)] + yield [cirq.CNOT(c, t1), cirq.T(t2) ** -1] + yield [cirq.T(t1) ** -1, cirq.CNOT(c, t2)] + yield [cirq.CNOT(t2, t1)] + yield [cirq.T(t1), cirq.H(t2)] + yield [cirq.CNOT(t2, t1)] + + +def _try_decompose_from_known_decompositions(val: Any) -> DecomposeResult: + """Returns a flattened decomposition of the object into operations, if possible. + + Args: + val: The object to decompose. + + Returns: + A flattened decomposition of `val` if it's a gate or operation with a known decomposition. + """ + known_decompositions = [(_FREDKIN_GATESET, _fredkin)] + if not isinstance(val, (cirq.Gate, cirq.Operation)): + return None + + classical_controls = None + if isinstance(val, cirq.ClassicallyControlledOperation): + classical_controls = val.classical_controls + val = val.without_classical_controls() + + if isinstance(val, cirq.Operation): + qubits = val.qubits + else: + qubits = cirq.LineQid.for_gate(val) + + for gateset, decomposer in known_decompositions: + if val in gateset: + decomposition = cirq.flatten_op_tree(decomposer(qubits)) + if classical_controls is not None: + return tuple(op.with_classical_controls(classical_controls) for op in decomposition) + else: + return tuple(decomposition) + return None + + +def decompose_once_into_operations( + val: Any, + intercepting_decomposer: Optional[OpDecomposer] = _try_decompose_from_known_decompositions, + fallback_decomposer: Optional[OpDecomposer] = None, +) -> DecomposeResult: + """Decomposes a value into operations, if possible. + + Args: + val: The value to decompose into operations. + intercepting_decomposer: An optional method that is called before the + default decomposer (the value's `_decompose_` method). If + `intercepting_decomposer` is specified and returns a result that + isn't `NotImplemented` or `None`, that result is used. Otherwise the + decomposition falls back to the default decomposer. + fallback_decomposer: An optional decomposition that used after the + `intercepting_decomposer` and the default decomposer (the value's + `_decompose_` method) both fail. + Returns: + A tuple of operations if decomposition succeeds. + """ + decomposers = ( + intercepting_decomposer, + lambda x: _try_decompose_into_operations_and_qubits(x)[0], + fallback_decomposer, + ) + for decomposer in decomposers: + if decomposer is None: + continue + res = decomposer(val) + if res is not None: + return res + return None diff --git a/cirq_qubitization/decompose_protocol_test.py b/cirq_qubitization/decompose_protocol_test.py new file mode 100644 index 000000000..22bacfa30 --- /dev/null +++ b/cirq_qubitization/decompose_protocol_test.py @@ -0,0 +1,27 @@ +import cirq +import numpy as np +from cirq_qubitization.decompose_protocol import decompose_once_into_operations, _fredkin + + +def test_fredkin_unitary(): + c, t1, t2 = cirq.LineQid.for_gate(cirq.FREDKIN) + np.testing.assert_allclose( + cirq.Circuit(_fredkin((c, t1, t2))).unitary(), + cirq.unitary(cirq.FREDKIN(c, t1, t2)), + atol=1e-8, + ) + + +def test_decompose_fredkin(): + c, t1, t2 = cirq.LineQid.for_gate(cirq.FREDKIN) + op = cirq.FREDKIN(c, t1, t2) + want = tuple(cirq.flatten_op_tree(_fredkin((c, t1, t2)))) + assert want == decompose_once_into_operations(op) + + op = cirq.FREDKIN(c, t1, t2).with_classical_controls('key') + classical_controls = op.classical_controls + want = tuple( + o.with_classical_controls(classical_controls) + for o in cirq.flatten_op_tree(_fredkin((c, t1, t2))) + ) + assert want == decompose_once_into_operations(op) diff --git a/cirq_qubitization/t_complexity_protocol.py b/cirq_qubitization/t_complexity_protocol.py index 9619ebefb..8653a826c 100644 --- a/cirq_qubitization/t_complexity_protocol.py +++ b/cirq_qubitization/t_complexity_protocol.py @@ -3,7 +3,7 @@ from typing_extensions import Protocol import cirq -from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits +from cirq_qubitization.decompose_protocol import decompose_once_into_operations _T_GATESET = cirq.Gateset(cirq.T, cirq.T**-1, unroll_circuit_op=False) @@ -74,9 +74,9 @@ def _is_iterable(it: Any) -> Optional[TComplexity]: return t -def _has_decomposition(stc: Any) -> Optional[TComplexity]: +def _from_decomposition(stc: Any) -> Optional[TComplexity]: # Decompose the object and recursively compute the complexity. - decomposition, _, _ = _try_decompose_into_operations_and_qubits(stc) + decomposition = decompose_once_into_operations(stc) if decomposition is None: return None return _is_iterable(decomposition) @@ -93,9 +93,9 @@ def t_complexity(stc: Any, fail_quietly: bool = False) -> Optional[TComplexity]: The TComplexity of the given object or None on failure (and fail_quietly=True). Raises: - TypeError if fail_quietly=False and the methods fails to compute TComplexity. + TypeError: if fail_quietly=False and the methods fails to compute TComplexity. """ - strategies = [_has_t_complexity, _is_clifford_or_t, _has_decomposition, _is_iterable] + strategies = [_has_t_complexity, _is_clifford_or_t, _from_decomposition, _is_iterable] for strategy in strategies: ret = strategy(stc) if ret is not None: diff --git a/cirq_qubitization/t_complexity_protocol_test.py b/cirq_qubitization/t_complexity_protocol_test.py index 89b510b9d..359a04a48 100644 --- a/cirq_qubitization/t_complexity_protocol_test.py +++ b/cirq_qubitization/t_complexity_protocol_test.py @@ -64,6 +64,8 @@ def test_gates(): assert t_complexity(And()) == TComplexity(t=4, clifford=9) assert t_complexity(And() ** -1) == TComplexity(clifford=4) + assert t_complexity(cirq.FREDKIN) == TComplexity(t=7, clifford=10) + def test_operations(): q = cirq.NamedQubit('q') diff --git a/cirq_qubitization/testing.py b/cirq_qubitization/testing.py index 481ffb515..7065847b7 100644 --- a/cirq_qubitization/testing.py +++ b/cirq_qubitization/testing.py @@ -1,12 +1,14 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Sequence, Dict, List +from typing import Any, Sequence, Dict, List import cirq import numpy as np import nbformat from nbconvert.preprocessors import ExecutePreprocessor +from cirq_qubitization.t_complexity_protocol import t_complexity +from cirq_qubitization.decompose_protocol import decompose_once_into_operations from cirq_qubitization.gate_with_registers import GateWithRegisters, Registers @@ -88,3 +90,15 @@ def execute_notebook(name: str): nb = nbformat.read(f, as_version=4) ep = ExecutePreprocessor(timeout=600, kernel_name="python3") ep.preprocess(nb) + + +def assert_decompose_is_consistent_with_t_complexity(val: Any): + t_complexity_method = getattr(val, '_t_complexity_', None) + expected = NotImplemented if t_complexity_method is None else t_complexity_method() + if expected is NotImplemented or expected is None: + return + decomposition = decompose_once_into_operations(val) + if decomposition is None: + return + from_decomposition = t_complexity(decomposition, fail_quietly=False) + assert expected == from_decomposition, f'{expected} != {from_decomposition}' diff --git a/cirq_qubitization/testing_test.py b/cirq_qubitization/testing_test.py index 19b642fea..18a83bdd3 100644 --- a/cirq_qubitization/testing_test.py +++ b/cirq_qubitization/testing_test.py @@ -1,5 +1,6 @@ import cirq import pytest +from cirq_qubitization.t_complexity_protocol import TComplexity import cirq_qubitization.testing as cq_testing from cirq_qubitization.and_gate import And @@ -30,3 +31,42 @@ def test_gate_helper(): } assert g.operation.qubits == tuple(g.all_qubits) assert len(g.circuit) == 1 + + +class DoesNotDecompose(cirq.Operation): + def _t_complexity_(self) -> TComplexity: + return TComplexity(t=1, clifford=2, rotations=3) + + @property + def qubits(self): + return [] + + def with_qubits(self, _): + pass + + +class InconsistentDecompostion(cirq.Operation): + def _t_complexity_(self) -> TComplexity: + return TComplexity(rotations=1) + + def _decompose_(self) -> cirq.OP_TREE: + yield cirq.X(self.qubits[0]) + + @property + def qubits(self): + return tuple(cirq.LineQubit(3).range(3)) + + def with_qubits(self, _): + pass + + +@pytest.mark.parametrize( + "val", [cirq.T, DoesNotDecompose(), cq_testing.GateHelper(And()).operation] +) +def test_assert_decompose_is_consistent_with_t_complexity(val): + cq_testing.assert_decompose_is_consistent_with_t_complexity(val) + + +def test_assert_decompose_is_consistent_with_t_complexity_raises(): + with pytest.raises(AssertionError): + cq_testing.assert_decompose_is_consistent_with_t_complexity(InconsistentDecompostion())