diff --git a/unitary/alpha/qudit_gates.py b/unitary/alpha/qudit_gates.py index 628d9a45..177a12c9 100644 --- a/unitary/alpha/qudit_gates.py +++ b/unitary/alpha/qudit_gates.py @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # + + +from typing import ( + List, + Dict, + Tuple, +) + import numpy as np import cirq @@ -51,6 +59,53 @@ def _circuit_diagram_info_(self, args): return f"X({self.source_state}_{self.destination_state})" +class QuditRzGate(cirq.ops.eigen_gate.EigenGate): + """Phase shifts a single state basis of the qudit. + + A generalization of the phase shift gate to qudits. + https://en.wikipedia.org/wiki/Quantum_logic_gate#Phase_shift_gates + + Implements Z_d as defined in eqn (5) of https://arxiv.org/abs/2008.00959 + + For a qudit of dimensionality d, shifts the phase of |d-1> by radians. + + Args: + dimension: dimension of the qudits, for instance, + a dimension of 3 would be a qutrit. + radians: The phase shift applied to basis d-1, measured in radians. + """ + + _eigencomponents: Dict[int, List[Tuple[float, np.ndarray]]] = {} + + def __init__(self, dimension: int, radians: float = np.pi): + super().__init__(exponent=radians / np.pi, global_shift=0) + self.dimension = dimension + + def _qid_shape_(self): + return (self.dimension,) + + def _eigen_components(self) -> List[Tuple[float, np.ndarray]]: + if self.dimension not in QuditRzGate._eigencomponents: + components = [] + for i in range(self.dimension): + half_turns = 0 + m = np.zeros((self.dimension, self.dimension)) + m[i][i] = 1 + if i == self.dimension - 1: + half_turns = 1 + components.append((half_turns, m)) + QuditRzGate._eigencomponents[self.dimension] = components + return QuditRzGate._eigencomponents[self.dimension] + + def _circuit_diagram_info_(self, args): + return cirq.CircuitDiagramInfo( + wire_symbols=("Z_d"), exponent=self._format_exponent_as_angle(args) + ) + + def _with_exponent(self, exponent: float) -> "QuditRzGate": + return QuditRzGate(rads=exponent * np.pi) + + class QuditPlusGate(cirq.Gate): """Cycles all the states using a permutation gate. diff --git a/unitary/alpha/qudit_gates_test.py b/unitary/alpha/qudit_gates_test.py index a228da33..1f273c1a 100644 --- a/unitary/alpha/qudit_gates_test.py +++ b/unitary/alpha/qudit_gates_test.py @@ -224,6 +224,20 @@ def test_iswap(q0: int, q1: int): assert np.all(results.measurements["m1"] == q0) +@pytest.mark.parametrize("dimension, phase_rads", [(2, np.pi), (3, 1), (4, np.pi * 2)]) +def test_rz_unitary(dimension: float, phase_rads: float): + rz = qudit_gates.QuditRzGate(dimension=dimension, radians=phase_rads) + expected_unitary = np.identity(n=dimension, dtype=np.complex64) + + # 1j = e ^ ( j * ( pi / 2 )), so we multiply phase_rads by 2 / pi. + expected_unitary[dimension - 1][dimension - 1] = 1j ** (phase_rads * 2 / np.pi) + + assert np.isclose(phase_rads / np.pi, rz._exponent) + rz_unitary = cirq.unitary(rz) + assert np.allclose(cirq.unitary(rz), expected_unitary) + assert np.allclose(np.eye(len(rz_unitary)), rz_unitary.dot(rz_unitary.T.conj())) + + @pytest.mark.parametrize( "q0, q1", [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)] )