From d38de16c55d2a74800d15080e3f3b55cdcf57e5b Mon Sep 17 00:00:00 2001 From: Seneca Meeks Date: Wed, 7 Jun 2023 00:17:06 -0700 Subject: [PATCH] Update Density Matrix and State Vector Simulators to work when an operation allocates new qubits as part of its decomposition (#6108) * WIP add factoring and kron methods to sim state for adding and removing ancillas in state vector and density matrix simulators * add test cases * add delegating gate test case * update test * all tests pass * add test case for unitary Y * nit * addresses PR comments by adding empty checks. Applys formatter. Subsequent push will add more test cases per Tanuj's comment * nit formatting changes, add docustring with input/output for remove_qubits * merge this branch and tanujkhattar@ccde689 * merging branches, adding test coverage in next push * format files * add coverage tests * change assert * coverage and type check tests should pass * incorporate tanujkhattar@1db8ac5 * nit * remove block comment * add coverage --------- Co-authored-by: Tanuj Khattar --- cirq/protocols/act_on_protocol.py | 2 +- cirq/sim/density_matrix_simulation_state.py | 16 ++ .../density_matrix_simulation_state_test.py | 12 ++ cirq/sim/simulation_state.py | 47 ++++- cirq/sim/simulation_state_test.py | 176 ++++++++++++++++-- cirq/sim/state_vector_simulation_state.py | 16 ++ 6 files changed, 255 insertions(+), 14 deletions(-) diff --git a/cirq/protocols/act_on_protocol.py b/cirq/protocols/act_on_protocol.py index 07c8f95ee00..8ac4875c01c 100644 --- a/cirq/protocols/act_on_protocol.py +++ b/cirq/protocols/act_on_protocol.py @@ -149,7 +149,7 @@ def act_on( arg_fallback = getattr(sim_state, '_act_on_fallback_', None) if arg_fallback is not None: - qubits = action.qubits if isinstance(action, ops.Operation) else qubits + qubits = action.qubits if is_op else qubits result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose) if result is True: return diff --git a/cirq/sim/density_matrix_simulation_state.py b/cirq/sim/density_matrix_simulation_state.py index 1b7c5fa6e21..8091a3e59cd 100644 --- a/cirq/sim/density_matrix_simulation_state.py +++ b/cirq/sim/density_matrix_simulation_state.py @@ -285,6 +285,22 @@ def __init__( ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) + def add_qubits(self, qubits: Sequence['cirq.Qid']): + ret = super().add_qubits(qubits) + return ( + self.kronecker_product(type(self)(qubits=qubits), inplace=True) + if ret is NotImplemented + else ret + ) + + def remove_qubits(self, qubits: Sequence['cirq.Qid']): + ret = super().remove_qubits(qubits) + if ret is not NotImplemented: + return ret + extracted, remainder = self.factor(qubits) + remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0] + return remainder + def _act_on_fallback_( self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True ) -> bool: diff --git a/cirq/sim/density_matrix_simulation_state_test.py b/cirq/sim/density_matrix_simulation_state_test.py index c2d84b0c4a5..3f37e7c663c 100644 --- a/cirq/sim/density_matrix_simulation_state_test.py +++ b/cirq/sim/density_matrix_simulation_state_test.py @@ -123,3 +123,15 @@ def test_initial_state_bad_shape(): cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64 ) + + +def test_remove_qubits(): + """Test the remove_qubits method.""" + q1 = cirq.LineQubit(0) + q2 = cirq.LineQubit(1) + state = cirq.DensityMatrixSimulationState(qubits=[q1, q2]) + + new_state = state.remove_qubits([q1]) + + assert len(new_state.qubits) == 1 + assert q1 not in new_state.qubits diff --git a/cirq/sim/simulation_state.py b/cirq/sim/simulation_state.py index 88d1aa43221..b87d892b1d7 100644 --- a/cirq/sim/simulation_state.py +++ b/cirq/sim/simulation_state.py @@ -166,6 +166,38 @@ def create_merged_state(self) -> Self: """Creates a final merged state.""" return self + def add_qubits(self: Self, qubits: Sequence['cirq.Qid']): + """Add qubits to a new state space and take the kron product. + + Note that only Density Matrix and State Vector simulators + override this function. + + Args: + qubits: Sequence of qubits to be added. + + Returns: + NotImplemented: If the subclass does not implement this method. + + Raises: + ValueError: If a qubit being added is already tracked. + """ + if any(q in self.qubits for q in qubits): + raise ValueError(f"Qubit to add {qubits} should not already be tracked.") + return NotImplemented + + def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self: + """Remove qubits from the state space. + + Args: + qubits: Sequence of qubits to be added. + + Returns: + A new Simulation State with qubits removed. Or + `self` if there are no qubits to remove.""" + if qubits is None or not qubits: + return self + return NotImplemented + def kronecker_product(self, other: Self, *, inplace=False) -> Self: """Joins two state spaces together.""" args = self if inplace else copy.copy(self) @@ -294,13 +326,24 @@ def strat_act_on_from_apply_decompose( val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid'] ) -> bool: operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val) - assert len(qubits1) == len(qubits) - qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)} if operations is None: return NotImplemented + assert len(qubits1) == len(qubits) + all_qubits = frozenset([q for op in operations for q in op.qubits]) + qubit_map = dict(zip(all_qubits, all_qubits)) + qubit_map.update(dict(zip(qubits1, qubits))) + new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits) + args = args.add_qubits(new_ancilla) + if args is NotImplemented: + return NotImplemented for operation in operations: operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits]) protocols.act_on(operation, args) + args = args.remove_qubits(new_ancilla) + if args is NotImplemented: # coverage: ignore + raise TypeError( # coverage: ignore + f"{type(args)} implements `add_qubits` but not `remove_qubits`." # coverage: ignore + ) # coverage: ignore return True diff --git a/cirq/sim/simulation_state_test.py b/cirq/sim/simulation_state_test.py index 0ba6d675662..42023af9074 100644 --- a/cirq/sim/simulation_state_test.py +++ b/cirq/sim/simulation_state_test.py @@ -19,6 +19,7 @@ import cirq from cirq.sim import simulation_state +from cirq.testing import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla class DummyQuantumState(cirq.QuantumStateRepresentation): @@ -33,8 +34,8 @@ def reindex(self, axes): class DummySimulationState(cirq.SimulationState): - def __init__(self): - super().__init__(state=DummyQuantumState(), qubits=cirq.LineQubit.range(2)) + def __init__(self, qubits=cirq.LineQubit.range(2)): + super().__init__(state=DummyQuantumState(), qubits=qubits) def _act_on_fallback_( self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True @@ -42,6 +43,70 @@ def _act_on_fallback_( return True +class AncillaZ(cirq.Gate): + def __init__(self, exponent=1): + self._exponent = exponent + + def num_qubits(self) -> int: + return 1 + + def _decompose_(self, qubits): + ancilla = cirq.NamedQubit('Ancilla') + yield cirq.CX(qubits[0], ancilla) + yield cirq.Z(ancilla) ** self._exponent + yield cirq.CX(qubits[0], ancilla) + + +class AncillaH(cirq.Gate): + def __init__(self, exponent=1): + self._exponent = exponent + + def num_qubits(self) -> int: + return 1 + + def _decompose_(self, qubits): + ancilla = cirq.NamedQubit('Ancilla') + yield cirq.H(ancilla) ** self._exponent + yield cirq.CX(ancilla, qubits[0]) + yield cirq.H(ancilla) ** self._exponent + + +class AncillaY(cirq.Gate): + def __init__(self, exponent=1): + self._exponent = exponent + + def num_qubits(self) -> int: + return 1 + + def _decompose_(self, qubits): + ancilla = cirq.NamedQubit('Ancilla') + yield cirq.Y(ancilla) ** self._exponent + yield cirq.CX(ancilla, qubits[0]) + yield cirq.Y(ancilla) ** self._exponent + + +class DelegatingAncillaZ(cirq.Gate): + def __init__(self, exponent=1): + self._exponent = exponent + + def num_qubits(self) -> int: + return 1 + + def _decompose_(self, qubits): + a = cirq.NamedQubit('a') + yield cirq.CX(qubits[0], a) + yield AncillaZ(self._exponent).on(a) + yield cirq.CX(qubits[0], a) + + +class Composite(cirq.Gate): + def num_qubits(self) -> int: + return 1 + + def _decompose_(self, qubits): + yield cirq.X(*qubits) + + def test_measurements(): args = DummySimulationState() args.measure([cirq.LineQubit(0)], "test", [False], {}) @@ -49,16 +114,10 @@ def test_measurements(): def test_decompose(): - class Composite(cirq.Gate): - def num_qubits(self) -> int: - return 1 - - def _decompose_(self, qubits): - yield cirq.X(*qubits) - args = DummySimulationState() - assert simulation_state.strat_act_on_from_apply_decompose( - Composite(), args, [cirq.LineQubit(0)] + assert ( + simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)]) + is NotImplemented ) @@ -101,3 +160,98 @@ def test_field_getters(): args = DummySimulationState() assert args.prng is np.random assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))} + + +@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3]) +def test_ancilla_z(exp): + q = cirq.LineQubit(0) + test_circuit = cirq.Circuit(AncillaZ(exp).on(q)) + + control_circuit = cirq.Circuit(cirq.ZPowGate(exponent=exp).on(q)) + + assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) + + +@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3]) +def test_ancilla_y(exp): + q = cirq.LineQubit(0) + test_circuit = cirq.Circuit(AncillaY(exp).on(q)) + + control_circuit = cirq.Circuit(cirq.Y(q)) + control_circuit.append(cirq.Y(q)) + control_circuit.append(cirq.XPowGate(exponent=exp).on(q)) + + assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) + + +@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3]) +def test_borrowable_qubit(exp): + q = cirq.LineQubit(0) + test_circuit = cirq.Circuit() + test_circuit.append(cirq.H(q)) + test_circuit.append(cirq.X(q)) + test_circuit.append(AncillaH(exp).on(q)) + + control_circuit = cirq.Circuit(cirq.H(q)) + + assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) + + +@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3]) +def test_delegating_gate_qubit(exp): + q = cirq.LineQubit(0) + + test_circuit = cirq.Circuit() + test_circuit.append(cirq.H(q)) + test_circuit.append(DelegatingAncillaZ(exp).on(q)) + + control_circuit = cirq.Circuit(cirq.H(q)) + control_circuit.append(cirq.ZPowGate(exponent=exp).on(q)) + + assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) + + +@pytest.mark.parametrize('num_ancilla', [1, 2, 3]) +def test_phase_using_dirty_ancilla(num_ancilla: int): + q = cirq.LineQubit(0) + anc = cirq.NamedQubit.range(num_ancilla, prefix='anc') + + u = cirq.MatrixGate(cirq.testing.random_unitary(2 ** (num_ancilla + 1))) + test_circuit = cirq.Circuit( + u.on(q, *anc), PhaseUsingDirtyAncilla(ancilla_bitsize=num_ancilla).on(q) + ) + control_circuit = cirq.Circuit(u.on(q, *anc), cirq.Z(q)) + assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) + + +@pytest.mark.parametrize('num_ancilla', [1, 2, 3]) +@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10)) +def test_phase_using_clean_ancilla(num_ancilla: int, theta: float): + q = cirq.LineQubit(0) + u = cirq.MatrixGate(cirq.testing.random_unitary(2)) + test_circuit = cirq.Circuit( + u.on(q), PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q) + ) + control_circuit = cirq.Circuit(u.on(q), cirq.ZPowGate(exponent=theta).on(q)) + assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) + + +def test_add_qubits_raise_value_error(num_ancilla=1): + q = cirq.LineQubit(0) + args = cirq.StateVectorSimulationState(qubits=[q]) + + with pytest.raises(ValueError, match='should not already be tracked.'): + args.add_qubits([q]) + + +def test_remove_qubits_not_implemented(num_ancilla=1): + args = DummySimulationState() + + assert args.remove_qubits([cirq.LineQubit(0)]) is NotImplemented + + +def assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) -> None: + for test_simulator in ['cirq.final_state_vector', 'cirq.final_density_matrix']: + test_sim = eval(test_simulator)(test_circuit) + control_sim = eval(test_simulator)(control_circuit) + assert np.allclose(test_sim, control_sim) diff --git a/cirq/sim/state_vector_simulation_state.py b/cirq/sim/state_vector_simulation_state.py index 9a0f547c5a4..f721fb618a8 100644 --- a/cirq/sim/state_vector_simulation_state.py +++ b/cirq/sim/state_vector_simulation_state.py @@ -355,6 +355,22 @@ def __init__( ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) + def add_qubits(self, qubits: Sequence['cirq.Qid']): + ret = super().add_qubits(qubits) + return ( + self.kronecker_product(type(self)(qubits=qubits), inplace=True) + if ret is NotImplemented + else ret + ) + + def remove_qubits(self, qubits: Sequence['cirq.Qid']): + ret = super().remove_qubits(qubits) + if ret is not NotImplemented: + return ret + extracted, remainder = self.factor(qubits, inplace=True) + remainder._state._state_vector *= extracted._state._state_vector.reshape((-1,))[0] + return remainder + def _act_on_fallback_( self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True ) -> bool: