Skip to content

Commit

Permalink
Add default decomposition for cirq.TwoQubitDiagonalGate (#5084)
Browse files Browse the repository at this point in the history
* Add default decomposition for cirq.TwoQubitDiagonalGate

* Update pyquil tests
  • Loading branch information
tanujkhattar authored Mar 16, 2022
1 parent 28f90b0 commit 2fb5651
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
21 changes: 11 additions & 10 deletions cirq-core/cirq/circuits/quil_output_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,15 @@ def test_equivalent_unitaries():
CPHASE(pi/2) 0 1
"""

QUIL_DIAGONAL_DEFGATE_PROGRAM = """
DEFGATE USERGATE1:
1.0, 0.0, 0.0, 0.0
0.0, 1.0, 0.0, 0.0
0.0, 0.0, 1.0, 0.0
0.0, 0.0, 0.0, 1.0
USERGATE1 0 1
QUIL_DIAGONAL_DECOMPOSE_PROGRAM = """
RZ(0) 0
RZ(0) 1
CPHASE(0) 0 1
X 0
X 1
CPHASE(0) 0 1
X 0
X 1
"""


Expand All @@ -463,13 +464,13 @@ def test_two_qubit_diagonal_gate_quil_output():
# Qubit ordering differs between pyQuil and Cirq.
cirq_unitary = cirq.Circuit(cirq.SWAP(q0, q1), operations, cirq.SWAP(q0, q1)).unitary()
assert np.allclose(pyquil_unitary, cirq_unitary)
# Also test non-CPHASE case.
# Also test non-CPHASE case, which decomposes into X/RZ/CPhase
operations = [
cirq.TwoQubitDiagonalGate([0, 0, 0, 0])(q0, q1),
]
output = cirq.QuilOutput(operations, (q0, q1))
program = pyquil.Program(str(output))
assert f"\n{program.out()}" == QUIL_DIAGONAL_DEFGATE_PROGRAM
assert f"\n{program.out()}" == QUIL_DIAGONAL_DECOMPOSE_PROGRAM


def test_parseable_defgate_output():
Expand Down
12 changes: 11 additions & 1 deletion cirq-core/cirq/ops/two_qubit_diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from cirq import protocols, value
from cirq._compat import proper_repr
from cirq.ops import raw_types
from cirq.ops import raw_types, common_gates

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -72,6 +72,16 @@ def _unitary_(self) -> Optional[np.ndarray]:
return None
return np.diag([np.exp(1j * angle) for angle in self._diag_angles_radians])

def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
x0, x1, x2, x3 = self._diag_angles_radians
q0, q1 = qubits
yield common_gates.ZPowGate(exponent=x2 / np.pi).on(q0)
yield common_gates.ZPowGate(exponent=x1 / np.pi).on(q1)
yield common_gates.CZPowGate(exponent=(x3 - (x1 + x2)) / np.pi).on(q0, q1)
yield common_gates.XPowGate().on_each(q0, q1)
yield common_gates.CZPowGate(exponent=x0 / np.pi).on(q0, q1)
yield common_gates.XPowGate().on_each(q0, q1)

def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray:
if self._is_parameterized_():
return NotImplemented
Expand Down
16 changes: 16 additions & 0 deletions cirq-core/cirq/ops/two_qubit_diagonal_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ def test_consistent_protocols(gate):
cirq.testing.assert_implements_consistent_protocols(gate)


def test_parameterized_decompose():
angles = sympy.symbols('x0, x1, x2, x3')
parameterized_op = cirq.TwoQubitDiagonalGate(angles).on(*cirq.LineQubit.range(2))
decomposed_circuit = cirq.Circuit(cirq.decompose(parameterized_op))
for resolver in (
cirq.Linspace('x0', -2, 2, 6)
* cirq.Linspace('x1', -2, 2, 6)
* cirq.Linspace('x2', -2, 2, 6)
* cirq.Linspace('x3', -2, 2, 6)
):
np.testing.assert_allclose(
cirq.unitary(cirq.resolve_parameters(parameterized_op, resolver)),
cirq.unitary(cirq.resolve_parameters(decomposed_circuit, resolver)),
)


def test_unitary():
diagonal_angles = [2, 3, 5, 7]
assert cirq.has_unitary(cirq.TwoQubitDiagonalGate(diagonal_angles))
Expand Down

0 comments on commit 2fb5651

Please sign in to comment.