Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for decompositions of parameterized cirq.DiagonalGate #5085

Merged
merged 2 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions cirq-core/cirq/ops/diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _value_equality_values_(self) -> Any:
return tuple(self._diag_angles_radians)

def _decompose_for_basis(
self, index: int, bit_flip: int, theta: float, qubits: Sequence['cirq.Qid']
self, index: int, bit_flip: int, theta: value.TParamVal, qubits: Sequence['cirq.Qid']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be 'cirq.TParamVal' if you ever get back to this. xref #4383

) -> Iterator[Union['cirq.ZPowGate', 'cirq.CXPowGate']]:
if index == 0:
return []
Expand All @@ -166,7 +166,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
│ │ │ │
2: ───Rz(1)───@───────────@───────────────────────@───────────────────────@───────────

where the angles in Rz gates are corresponding to the fast-walsh-Hadamard transfrom
where the angles in Rz gates are corresponding to the fast-walsh-Hadamard transform
of diagonal_angles in the Gray Code order.

For n qubits decomposition looks similar but with 2^n-1 Rz gates and 2^n-2 CNOT gates.
Expand All @@ -176,19 +176,20 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
ancillas." New Journal of Physics 16.3 (2014): 033040.
https://iopscience.iop.org/article/10.1088/1367-2630/16/3/033040/meta
"""
if protocols.is_parameterized(self):
return NotImplemented

n = self._num_qubits_()
hat_angles = _fast_walsh_hadamard_transform(self._diag_angles_radians) / (2 ** n)

# There is one global phase shift between unitary matrix of the diagonal gate and the
# decomposed gates. On its own it is not physically observable. However, if using this
# diagonal gate for sub-system like controlled gate, it is no longer equivalent. Hence,
# we add global phase.
decomposed_circ: List[Any] = [
global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))
]
# Global phase is ignored for parameterized gates as `cirq.GlobalPhaseGate` expects a
# scalar value.
decomposed_circ: List[Any] = (
[global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))]
if not protocols.is_parameterized(hat_angles[0])
else []
)
for i, bit_flip in _gen_gray_code(n):
decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits))
return decomposed_circ
Expand Down
25 changes: 18 additions & 7 deletions cirq-core/cirq/ops/diagonal_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,24 @@ def test_decomposition_diagonal_exponent(n):
np.testing.assert_allclose(decomposed_f, expected_f)


def test_decomposition_with_parameterization():
diagonal_gate = cirq.DiagonalGate([2, 3, 5, sympy.Symbol('a')])
op = diagonal_gate(*cirq.LineQubit.range(2))

# We do not support the decomposition of parameterized case yet.
# So cirq.decompose should do nothing.
assert cirq.decompose(op) == [op]
@pytest.mark.parametrize('n', [1, 2, 3, 4])
def test_decomposition_with_parameterization(n):
angles = sympy.symbols([f'x_{i}' for i in range(2 ** n)])
exponent = sympy.Symbol('e')
diagonal_gate = cirq.DiagonalGate(angles) ** exponent
parameterized_op = diagonal_gate(*cirq.LineQubit.range(n))
decomposed_circuit = cirq.Circuit(cirq.decompose(parameterized_op))
for exponent_value in [-0.5, 0.5, 1]:
for i in range(len(_candidate_angles) - 2 ** n + 1):
resolver = {exponent: exponent_value}
resolver.update(
{angles[j]: x_j for j, x_j in enumerate(_candidate_angles[i : i + 2 ** n])}
)
resolved_op = cirq.resolve_parameters(parameterized_op, resolver)
resolved_circuit = cirq.resolve_parameters(decomposed_circuit, resolver)
cirq.testing.assert_allclose_up_to_global_phase(
cirq.unitary(resolved_op), cirq.unitary(resolved_circuit), atol=1e-8
)


def test_diagram():
Expand Down