diff --git a/cirq/ops/gate_operation.py b/cirq/ops/gate_operation.py index 0d83f26fe78..f928a29a419 100644 --- a/cirq/ops/gate_operation.py +++ b/cirq/ops/gate_operation.py @@ -14,7 +14,6 @@ """Basic types defining qubits, gates, and operations.""" -import itertools import re from typing import ( AbstractSet, @@ -29,6 +28,7 @@ TypeVar, TYPE_CHECKING, Union, + List, ) import numpy as np @@ -146,17 +146,11 @@ def _group_interchangeable_qubits( ) -> Tuple[Union['cirq.Qid', Tuple[int, FrozenSet['cirq.Qid']]], ...]: if not isinstance(self.gate, gate_features.InterchangeableQubitsGate): return self.qubits - else: - - def make_key(i_q: Tuple[int, 'cirq.Qid']) -> int: - return cast( - gate_features.InterchangeableQubitsGate, self.gate - ).qubit_index_to_equivalence_group_key(i_q[0]) - - return tuple( - (k, frozenset(g for _, g in kg)) - for k, kg in itertools.groupby(enumerate(self.qubits), make_key) - ) + groups: Dict[int, List['cirq.Qid']] = {} + for i, q in enumerate(self.qubits): + k = self.gate.qubit_index_to_equivalence_group_key(i) + groups.setdefault(k, []).append(q) + return tuple(sorted((k, frozenset(v)) for k, v in groups.items())) def _value_equality_values_(self): return self.gate, self._group_interchangeable_qubits() diff --git a/cirq/ops/gate_operation_test.py b/cirq/ops/gate_operation_test.py index aeaceadbc29..3f2dd87073d 100644 --- a/cirq/ops/gate_operation_test.py +++ b/cirq/ops/gate_operation_test.py @@ -444,3 +444,24 @@ def _is_parameterized_(self): assert not cirq.is_parameterized(No1().on(q)) assert not cirq.is_parameterized(No2().on(q)) assert cirq.is_parameterized(Yes().on(q)) + + +def test_group_interchangeable_qubits_creates_tuples_with_unique_keys(): + class MyGate(cirq.Gate, cirq.InterchangeableQubitsGate): + def __init__(self, num_qubits) -> None: + self._num_qubits = num_qubits + + def num_qubits(self) -> int: + return self._num_qubits + + def qubit_index_to_equivalence_group_key(self, index: int) -> int: + if index % 2 == 0: + return index + return 0 + + qubits = cirq.LineQubit.range(4) + gate = MyGate(len(qubits)) + + assert gate(qubits[0], qubits[1], qubits[2], qubits[3]) == gate( + qubits[3], qubits[1], qubits[2], qubits[0] + )