Skip to content

Commit

Permalink
Move cirq.ion.ion_gates.MSGate to cirq.ops module (quantumlib#5508)
Browse files Browse the repository at this point in the history
* Move cirq.ion.ion_gates.MSGate to cirq.ops module

* Fix failing tests
  • Loading branch information
tanujkhattar authored Jun 13, 2022
1 parent 77449f7 commit c79d884
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 168 deletions.
3 changes: 2 additions & 1 deletion cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@
MeasurementGate,
MutableDensePauliString,
MutablePauliString,
ms,
NamedQubit,
NamedQid,
OP_TREE,
Expand Down Expand Up @@ -667,7 +668,7 @@
with_rescoped_keys,
)

from cirq.ion import ConvertToIonGates, IonDevice, ms, two_qubit_matrix_to_ion_operations
from cirq.ion import ConvertToIonGates, IonDevice, two_qubit_matrix_to_ion_operations
from cirq.neutral_atoms import (
ConvertToNeutralAtomGates,
is_native_neutral_atom_gate,
Expand Down
2 changes: 1 addition & 1 deletion cirq/ion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Trapped ion devices, gates, and compiling utilties."""

from cirq.ion.ion_gates import ms
from cirq.ops import ms

from cirq.ion.ion_decomposition import two_qubit_matrix_to_ion_operations

Expand Down
81 changes: 0 additions & 81 deletions cirq/ion/ion_gates.py

This file was deleted.

80 changes: 0 additions & 80 deletions cirq/ion/ion_gates_test.py

This file was deleted.

2 changes: 1 addition & 1 deletion cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
transform_op_tree,
)

from cirq.ops.parity_gates import XX, XXPowGate, YY, YYPowGate, ZZ, ZZPowGate
from cirq.ops.parity_gates import XX, XXPowGate, YY, YYPowGate, ZZ, ZZPowGate, MSGate, ms

from cirq.ops.pauli_gates import Pauli, X, Y, Z

Expand Down
2 changes: 1 addition & 1 deletion cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def all_subclasses(cls):
cirq.circuits.qasm_output.QasmTwoQubitGate,
cirq.circuits.quil_output.QuilTwoQubitGate,
cirq.circuits.quil_output.QuilOneQubitGate,
cirq.ion.ion_gates.MSGate,
cirq.ops.MSGate,
# Gate features.
cirq.SingleQubitGate,
# Interop gates
Expand Down
63 changes: 60 additions & 3 deletions cirq/ops/parity_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

"""Quantum gates that phase with respect to product-of-pauli observables."""

from typing import List, Optional, Tuple, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING

import numpy as np

from cirq import protocols
from cirq import protocols, value
from cirq._compat import proper_repr
from cirq._doc import document
from cirq.ops import gate_features, eigen_gate, common_gates, pauli_gates
Expand Down Expand Up @@ -55,7 +55,7 @@ class XXPowGate(gate_features.InterchangeableQubitsGate, eigen_gate.EigenGate):
f = e^{\frac{i \pi t}{2}}.
$$
See also: `cirq.ion.ion_gates.MSGate` (the Mølmer–Sørensen gate), which is
See also: `cirq.ops.MSGate` (the Mølmer–Sørensen gate), which is
implemented via this class.
"""

Expand Down Expand Up @@ -352,6 +352,63 @@ def __repr__(self) -> str:
)


class MSGate(XXPowGate):
"""The Mølmer–Sørensen gate, a native two-qubit operation in ion traps.
A rotation around the XX axis in the two-qubit bloch sphere.
The gate implements the following unitary:
exp(-i t XX) = [ cos(t) 0 0 -isin(t)]
[ 0 cos(t) -isin(t) 0 ]
[ 0 -isin(t) cos(t) 0 ]
[-isin(t) 0 0 cos(t) ]
"""

def __init__(self, *, rads: float): # Forces keyword args.
XXPowGate.__init__(self, exponent=rads * 2 / np.pi, global_shift=-0.5)
self.rads = rads

def _with_exponent(self: 'MSGate', exponent: value.TParamVal) -> 'MSGate':
return type(self)(rads=exponent * np.pi / 2)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Union[str, 'protocols.CircuitDiagramInfo']:
angle_str = self._format_exponent_as_angle(args, order=4)
symbol = f'MS({angle_str})'
return protocols.CircuitDiagramInfo(wire_symbols=(symbol, symbol))

def __str__(self) -> str:
if self._exponent == 1:
return 'MS(π/2)'
return f'MS({self._exponent!r}π/2)'

def __repr__(self) -> str:
if self._exponent == 1:
return 'cirq.ms(np.pi/2)'
return f'cirq.ms({self._exponent!r}*np.pi/2)'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ["rads"])

@classmethod
def _from_json_dict_(cls, rads: float, **kwargs: Any) -> 'MSGate':
return cls(rads=rads)


def ms(rads: float) -> MSGate:
"""A helper to construct the `cirq.MSGate` for the given angle specified in radians.
Args:
rads: The rotation angle in radians.
Returns:
Mølmer–Sørensen gate rotating by the desired amount.
"""
return MSGate(rads=rads)


XX = XXPowGate()
document(
XX,
Expand Down
63 changes: 63 additions & 0 deletions cirq/ops/parity_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,66 @@ def test_trace_distance():
assert cirq.approx_eq(cirq.trace_distance_bound(cirq.XX), 1.0)
assert cirq.approx_eq(cirq.trace_distance_bound(cirq.YY**0), 0)
assert cirq.approx_eq(cirq.trace_distance_bound(cirq.ZZ ** (1 / 3)), np.sin(np.pi / 6))


def test_ms_arguments():
eq_tester = cirq.testing.EqualsTester()
eq_tester.add_equality_group(cirq.ms(np.pi / 2), cirq.ops.MSGate(rads=np.pi / 2))
eq_tester.add_equality_group(cirq.XXPowGate(global_shift=-0.5))


def test_ms_str():
ms = cirq.ms(np.pi / 2)
assert str(ms) == 'MS(π/2)'
assert str(cirq.ms(np.pi)) == 'MS(2.0π/2)'
assert str(ms**0.5) == 'MS(0.5π/2)'
assert str(ms**2) == 'MS(2.0π/2)'
assert str(ms**-1) == 'MS(-1.0π/2)'


def test_ms_matrix():
s = np.sqrt(0.5)
# yapf: disable
np.testing.assert_allclose(cirq.unitary(cirq.ms(np.pi/4)),
np.array([[s, 0, 0, -1j*s],
[0, s, -1j*s, 0],
[0, -1j*s, s, 0],
[-1j*s, 0, 0, s]]),
atol=1e-8)
# yapf: enable
np.testing.assert_allclose(cirq.unitary(cirq.ms(np.pi)), np.diag([-1, -1, -1, -1]), atol=1e-8)


def test_ms_repr():
assert repr(cirq.ms(np.pi / 2)) == 'cirq.ms(np.pi/2)'
assert repr(cirq.ms(np.pi / 4)) == 'cirq.ms(0.5*np.pi/2)'
cirq.testing.assert_equivalent_repr(cirq.ms(np.pi / 4))
ms = cirq.ms(np.pi / 2)
assert repr(ms**2) == 'cirq.ms(2.0*np.pi/2)'
assert repr(ms**-0.5) == 'cirq.ms(-0.5*np.pi/2)'


def test_ms_diagrams():
a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
circuit = cirq.Circuit(cirq.SWAP(a, b), cirq.X(a), cirq.Y(a), cirq.ms(np.pi).on(a, b))
cirq.testing.assert_has_diagram(
circuit,
"""
a: ───×───X───Y───MS(π)───
│ │
b: ───×───────────MS(π)───
""",
)


def test_json_serialization():
def custom_resolver(cirq_type: str):
if cirq_type == "MSGate":
return cirq.ops.MSGate
return None

assert cirq.read_json(
json_text=cirq.to_json(cirq.ms(np.pi / 2)), resolvers=[custom_resolver]
) == cirq.ms(np.pi / 2)
assert custom_resolver('X') is None

0 comments on commit c79d884

Please sign in to comment.