diff --git a/cirq-core/cirq/ops/diagonal_gate.py b/cirq-core/cirq/ops/diagonal_gate.py index 8bde78b55d7..670a683f675 100644 --- a/cirq-core/cirq/ops/diagonal_gate.py +++ b/cirq-core/cirq/ops/diagonal_gate.py @@ -144,7 +144,7 @@ def _value_equality_values_(self) -> Any: return tuple(self._diag_angles_radians) def _decompose_for_basis( - self, index: int, bit_flip: int, theta: float, qubits: Sequence['cirq.Qid'] + self, index: int, bit_flip: int, theta: value.TParamVal, qubits: Sequence['cirq.Qid'] ) -> Iterator[Union['cirq.ZPowGate', 'cirq.CXPowGate']]: if index == 0: return [] @@ -166,7 +166,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': │ │ │ │ 2: ───Rz(1)───@───────────@───────────────────────@───────────────────────@─────────── - where the angles in Rz gates are corresponding to the fast-walsh-Hadamard transfrom + where the angles in Rz gates are corresponding to the fast-walsh-Hadamard transform of diagonal_angles in the Gray Code order. For n qubits decomposition looks similar but with 2^n-1 Rz gates and 2^n-2 CNOT gates. @@ -176,9 +176,6 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': ancillas." New Journal of Physics 16.3 (2014): 033040. https://iopscience.iop.org/article/10.1088/1367-2630/16/3/033040/meta """ - if protocols.is_parameterized(self): - return NotImplemented - n = self._num_qubits_() hat_angles = _fast_walsh_hadamard_transform(self._diag_angles_radians) / (2 ** n) @@ -186,9 +183,13 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': # decomposed gates. On its own it is not physically observable. However, if using this # diagonal gate for sub-system like controlled gate, it is no longer equivalent. Hence, # we add global phase. - decomposed_circ: List[Any] = [ - global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0])) - ] + # Global phase is ignored for parameterized gates as `cirq.GlobalPhaseGate` expects a + # scalar value. + decomposed_circ: List[Any] = ( + [global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))] + if not protocols.is_parameterized(hat_angles[0]) + else [] + ) for i, bit_flip in _gen_gray_code(n): decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits)) return decomposed_circ diff --git a/cirq-core/cirq/ops/diagonal_gate_test.py b/cirq-core/cirq/ops/diagonal_gate_test.py index 1d87a94450c..c4529766628 100644 --- a/cirq-core/cirq/ops/diagonal_gate_test.py +++ b/cirq-core/cirq/ops/diagonal_gate_test.py @@ -77,13 +77,24 @@ def test_decomposition_diagonal_exponent(n): np.testing.assert_allclose(decomposed_f, expected_f) -def test_decomposition_with_parameterization(): - diagonal_gate = cirq.DiagonalGate([2, 3, 5, sympy.Symbol('a')]) - op = diagonal_gate(*cirq.LineQubit.range(2)) - - # We do not support the decomposition of parameterized case yet. - # So cirq.decompose should do nothing. - assert cirq.decompose(op) == [op] +@pytest.mark.parametrize('n', [1, 2, 3, 4]) +def test_decomposition_with_parameterization(n): + angles = sympy.symbols([f'x_{i}' for i in range(2 ** n)]) + exponent = sympy.Symbol('e') + diagonal_gate = cirq.DiagonalGate(angles) ** exponent + parameterized_op = diagonal_gate(*cirq.LineQubit.range(n)) + decomposed_circuit = cirq.Circuit(cirq.decompose(parameterized_op)) + for exponent_value in [-0.5, 0.5, 1]: + for i in range(len(_candidate_angles) - 2 ** n + 1): + resolver = {exponent: exponent_value} + resolver.update( + {angles[j]: x_j for j, x_j in enumerate(_candidate_angles[i : i + 2 ** n])} + ) + resolved_op = cirq.resolve_parameters(parameterized_op, resolver) + resolved_circuit = cirq.resolve_parameters(decomposed_circuit, resolver) + cirq.testing.assert_allclose_up_to_global_phase( + cirq.unitary(resolved_op), cirq.unitary(resolved_circuit), atol=1e-8 + ) def test_diagram():