From ee4d7023335eb9bd9c2d9d8666b21df9936ce57d Mon Sep 17 00:00:00 2001 From: Greg Kahanamoku-Meyer Date: Wed, 22 May 2024 15:43:15 -1000 Subject: [PATCH] enable simulation of controlled gates in classical simulator (#6589) --- cirq-core/cirq/sim/classical_simulator.py | 20 +++++++++- .../cirq/sim/classical_simulator_test.py | 38 +++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/classical_simulator.py b/cirq-core/cirq/sim/classical_simulator.py index a5287637bfc..02879e518a1 100644 --- a/cirq-core/cirq/sim/classical_simulator.py +++ b/cirq-core/cirq/sim/classical_simulator.py @@ -117,12 +117,25 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos Raises: ValueError: If initial_state shape for type np.ndarray is not equal to 1. - If gate is not one of X, CNOT, SWAP, CCNOT, or a measurement. + If gate is not one of X, SWAP, a controlled version of X or SWAP, + or a measurement. """ if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1: raise ValueError('initial_state shape for type np.ndarray is not equal to 1') gate = action.gate if isinstance(action, ops.Operation) else action mapped_qubits = [self.qubit_map[i] for i in qubits] + + if isinstance(gate, ops.ControlledGate): + control_qubits = mapped_qubits[: gate.num_controls()] + mapped_qubits = mapped_qubits[gate.num_controls() :] + + controls_state = tuple(self._state.basis[c] for c in control_qubits) + if controls_state not in gate.control_values.expand(): + # gate has no effect; controls were off + return True + + gate = gate.sub_gate + if _is_identity(gate): pass elif gate == ops.X: @@ -138,7 +151,10 @@ def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompos c1, c2, q = mapped_qubits self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2] else: - raise ValueError(f'{gate} is not one of X, CNOT, SWAP, CCNOT, or a measurement') + raise ValueError( + f'{gate} is not one of X, SWAP; a controlled version ' + 'of X or SWAP; or a measurement' + ) return True diff --git a/cirq-core/cirq/sim/classical_simulator_test.py b/cirq-core/cirq/sim/classical_simulator_test.py index 3cf8c170bd8..96c8a4afdc0 100644 --- a/cirq-core/cirq/sim/classical_simulator_test.py +++ b/cirq-core/cirq/sim/classical_simulator_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from itertools import product import numpy as np import pytest import cirq @@ -78,6 +79,43 @@ def test_CCNOT(): np.testing.assert_equal(results, expected_results) +@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=4)]) +def test_CCCX(initial_state): + CCCX = cirq.CCNOT.controlled() + qubits = cirq.LineQubit.range(4) + + circuit = cirq.Circuit() + circuit.append(CCCX(*qubits)) + circuit.append(cirq.measure(qubits, key='key')) + + final_state = initial_state.copy() + final_state[-1] ^= all(final_state[:-1]) + + sim = cirq.ClassicalStateSimulator() + results = sim.simulate(circuit, initial_state=initial_state).measurements['key'] + np.testing.assert_equal(results, final_state) + + +@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=3)]) +def test_CSWAP(initial_state): + CSWAP = cirq.SWAP.controlled() + qubits = cirq.LineQubit.range(3) + circuit = cirq.Circuit() + + circuit = cirq.Circuit() + circuit.append(CSWAP(*qubits)) + circuit.append(cirq.measure(qubits, key='key')) + + a, b, c = initial_state + if a: + b, c = c, b + final_state = [a, b, c] + + sim = cirq.ClassicalStateSimulator() + results = sim.simulate(circuit, initial_state=initial_state).measurements['key'] + np.testing.assert_equal(results, final_state) + + def test_measurement_gate(): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit()