Skip to content

Commit

Permalink
Add support for decompositions of parameterized cirq.DiagonalGate (q…
Browse files Browse the repository at this point in the history
…uantumlib#5085)

- Adds support for decomposing parameterized `cirq.DiagonalGate`.
- Global phase is ignored for parameterized version because `cirq.GlobalPhaseGate` doesn't yet support symbols. 
- Part of quantumlib#4858
  • Loading branch information
tanujkhattar authored and rht committed May 1, 2023
1 parent 0cee02b commit c305340
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
17 changes: 9 additions & 8 deletions cirq-core/cirq/ops/diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _value_equality_values_(self) -> Any:
return tuple(self._diag_angles_radians)

def _decompose_for_basis(
self, index: int, bit_flip: int, theta: float, qubits: Sequence['cirq.Qid']
self, index: int, bit_flip: int, theta: value.TParamVal, qubits: Sequence['cirq.Qid']
) -> Iterator[Union['cirq.ZPowGate', 'cirq.CXPowGate']]:
if index == 0:
return []
Expand All @@ -166,7 +166,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
│ │ │ │
2: ───Rz(1)───@───────────@───────────────────────@───────────────────────@───────────
where the angles in Rz gates are corresponding to the fast-walsh-Hadamard transfrom
where the angles in Rz gates are corresponding to the fast-walsh-Hadamard transform
of diagonal_angles in the Gray Code order.
For n qubits decomposition looks similar but with 2^n-1 Rz gates and 2^n-2 CNOT gates.
Expand All @@ -176,19 +176,20 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
ancillas." New Journal of Physics 16.3 (2014): 033040.
https://iopscience.iop.org/article/10.1088/1367-2630/16/3/033040/meta
"""
if protocols.is_parameterized(self):
return NotImplemented

n = self._num_qubits_()
hat_angles = _fast_walsh_hadamard_transform(self._diag_angles_radians) / (2 ** n)

# There is one global phase shift between unitary matrix of the diagonal gate and the
# decomposed gates. On its own it is not physically observable. However, if using this
# diagonal gate for sub-system like controlled gate, it is no longer equivalent. Hence,
# we add global phase.
decomposed_circ: List[Any] = [
global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))
]
# Global phase is ignored for parameterized gates as `cirq.GlobalPhaseGate` expects a
# scalar value.
decomposed_circ: List[Any] = (
[global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))]
if not protocols.is_parameterized(hat_angles[0])
else []
)
for i, bit_flip in _gen_gray_code(n):
decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits))
return decomposed_circ
Expand Down
25 changes: 18 additions & 7 deletions cirq-core/cirq/ops/diagonal_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,24 @@ def test_decomposition_diagonal_exponent(n):
np.testing.assert_allclose(decomposed_f, expected_f)


def test_decomposition_with_parameterization():
diagonal_gate = cirq.DiagonalGate([2, 3, 5, sympy.Symbol('a')])
op = diagonal_gate(*cirq.LineQubit.range(2))

# We do not support the decomposition of parameterized case yet.
# So cirq.decompose should do nothing.
assert cirq.decompose(op) == [op]
@pytest.mark.parametrize('n', [1, 2, 3, 4])
def test_decomposition_with_parameterization(n):
angles = sympy.symbols([f'x_{i}' for i in range(2 ** n)])
exponent = sympy.Symbol('e')
diagonal_gate = cirq.DiagonalGate(angles) ** exponent
parameterized_op = diagonal_gate(*cirq.LineQubit.range(n))
decomposed_circuit = cirq.Circuit(cirq.decompose(parameterized_op))
for exponent_value in [-0.5, 0.5, 1]:
for i in range(len(_candidate_angles) - 2 ** n + 1):
resolver = {exponent: exponent_value}
resolver.update(
{angles[j]: x_j for j, x_j in enumerate(_candidate_angles[i : i + 2 ** n])}
)
resolved_op = cirq.resolve_parameters(parameterized_op, resolver)
resolved_circuit = cirq.resolve_parameters(decomposed_circuit, resolver)
cirq.testing.assert_allclose_up_to_global_phase(
cirq.unitary(resolved_op), cirq.unitary(resolved_circuit), atol=1e-8
)


def test_diagram():
Expand Down

0 comments on commit c305340

Please sign in to comment.