diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index aec48cbc02a..db41ce3b5b3 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -17,10 +17,10 @@ https://arxiv.org/abs/2002.07730 """ +import dataclasses import math -from typing import Any, Dict, List, Iterator, Optional, Sequence, Set, TYPE_CHECKING, Iterable +from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Iterable, Union -import dataclasses import numpy as np import quimb.tensor as qtn @@ -54,7 +54,9 @@ class MPSOptions: class MPSSimulator( simulator.SimulatesSamples, - simulator.SimulatesIntermediateState['MPSSimulatorStepResult', 'MPSTrialResult', 'MPSState'], + simulator.SimulatesIntermediateState[ + 'MPSSimulatorStepResult', 'MPSTrialResult', 'MPSState', 'MPSState' + ], ): """An efficient simulator for MPS circuits.""" @@ -82,58 +84,68 @@ def __init__( self.simulation_options = simulation_options self.grouping = grouping - def _base_iterator( - self, circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList, initial_state: int - ) -> Iterator['MPSSimulatorStepResult']: - """Iterator over MPSSimulatorStepResult from Moments of a Circuit + def _create_act_on_args( + self, + initial_state: Union[int, 'MPSState'], + qubits: Sequence['cirq.Qid'], + ) -> 'MPSState': + """Creates MPSState args for simulating the Circuit. Args: - circuit: The circuit to simulate. - qubit_order: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. initial_state: The initial state for the simulation in the computational basis. Represented as a big endian int. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. - Yields: - MPSStepResult from simulating a Moment of the Circuit. + Returns: + MPSState args for simulating the Circuit. """ - qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(circuit.all_qubits()) + if isinstance(initial_state, MPSState): + return initial_state + + return MPSState( + qubits=qubits, + prng=self.prng, + simulation_options=self.simulation_options, + grouping=self.grouping, + initial_state=initial_state, + ) - qubit_map = {q: i for i, q in enumerate(qubits)} + def _core_iterator( + self, + circuit: circuits.Circuit, + sim_state: 'MPSState', + ): + """Iterator over MPSSimulatorStepResult from Moments of a Circuit + + Args: + circuit: The circuit to simulate. + sim_state: The initial state args for the simulation in the + computational basis. + Yields: + MPSStepResult from simulating a Moment of the Circuit. + """ if len(circuit) == 0: yield MPSSimulatorStepResult( - measurements={}, - state=MPSState( - qubit_map, - self.prng, - self.simulation_options, - self.grouping, - initial_state=initial_state, - ), + measurements=sim_state.log_of_measurement_results, state=sim_state ) return - state = MPSState( - qubit_map, - self.prng, - self.simulation_options, - self.grouping, - initial_state=initial_state, - ) - noisy_moments = self.noise.noisy_moments(circuit, sorted(circuit.all_qubits())) for op_tree in noisy_moments: for op in flatten_to_ops(op_tree): if protocols.is_measurement(op) or protocols.has_mixture(op): - state.axes = tuple(qubit_map[qubit] for qubit in op.qubits) - protocols.act_on(op, state) + sim_state.axes = tuple(sim_state.qubit_map[qubit] for qubit in op.qubits) + protocols.act_on(op, sim_state) else: raise NotImplementedError(f"Unrecognized operation: {op!r}") - yield MPSSimulatorStepResult(measurements=state.log_of_measurement_results, state=state) - state.log_of_measurement_results.clear() + yield MPSSimulatorStepResult( + measurements=sim_state.log_of_measurement_results, state=sim_state + ) + sim_state.log_of_measurement_results.clear() def _create_simulator_trial_result( self, @@ -286,7 +298,7 @@ class MPSState(ActOnArgs): def __init__( self, - qubit_map: Dict['cirq.Qid', int], + qubits: Sequence['cirq.Qid'], prng: np.random.RandomState, simulation_options: MPSOptions = MPSOptions(), grouping: Optional[Dict['cirq.Qid', int]] = None, @@ -297,7 +309,9 @@ def __init__( """Creates and MPSState Args: - qubit_map: A map from Qid to an integer that uniquely identifies it. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. prng: A random number generator, used to simulate measurements. simulation_options: Numerical options for the simulation. grouping: How to group qubits together, if None all are individual. @@ -307,8 +321,8 @@ def __init__( log_of_measurement_results: A mutable object that measurements are being recorded into. """ - super().__init__(prng, axes, log_of_measurement_results) - self.qubit_map = qubit_map + super().__init__(prng, qubits, axes, log_of_measurement_results) + qubit_map = self.qubit_map self.grouping = qubit_map if grouping is None else grouping if self.grouping.keys() != self.qubit_map.keys(): raise ValueError('Grouping must cover exactly the qubits.') @@ -364,10 +378,10 @@ def _value_equality_values_(self) -> Any: def copy(self) -> 'MPSState': state = MPSState( - self.qubit_map, - self.prng, - self.simulation_options, - self.grouping, + qubits=self.qubits, + prng=self.prng, + simulation_options=self.simulation_options, + grouping=self.grouping, ) state.M = [x.copy() for x in self.M] state.estimated_gate_error_list = self.estimated_gate_error_list @@ -584,6 +598,5 @@ def perform_measurement( def _perform_measurement(self) -> List[int]: """Measures the axes specified by the simulator.""" - qubit_map_inv = {v: k for k, v in self.qubit_map.items()} - qubits = [qubit_map_inv[key] for key in self.axes] + qubits = [self.qubits[key] for key in self.axes] return self.perform_measurement(qubits, self.prng) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index ade54e10f98..f9053d12b5a 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -205,6 +205,23 @@ def test_cnot_flipped(): ) +def test_act_on_args(): + q0, q1 = qubit_order = cirq.LineQubit.range(2) + circuit = cirq.Circuit(cirq.CNOT(q1, q0)) + mps_simulator = ccq.mps_simulator.MPSSimulator() + ref_simulator = cirq.Simulator() + for initial_state in range(4): + args = mps_simulator._create_act_on_args(initial_state=initial_state, qubits=(q0, q1)) + actual = mps_simulator.simulate(circuit, qubit_order=qubit_order, initial_state=args) + expected = ref_simulator.simulate( + circuit, qubit_order=qubit_order, initial_state=initial_state + ) + np.testing.assert_allclose( + actual.final_state.to_numpy(), expected.final_state_vector, atol=1e-4 + ) + assert len(actual.measurements) == 0 + + def test_three_qubits(): q0, q1, q2 = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.CCX(q0, q1, q2)) @@ -257,7 +274,7 @@ def test_measurement_str(): def test_trial_result_str(): q0 = cirq.LineQubit(0) final_simulator_state = ccq.mps_simulator.MPSState( - qubit_map={q0: 0}, + qubits=(q0,), prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(), ) @@ -278,7 +295,7 @@ def test_trial_result_str(): def test_empty_step_result(): q0 = cirq.LineQubit(0) - state = ccq.mps_simulator.MPSState(qubit_map={q0: 0}, prng=value.parse_random_state(0)) + state = ccq.mps_simulator.MPSState(qubits=(q0,), prng=value.parse_random_state(0)) step_result = ccq.mps_simulator.MPSSimulatorStepResult(state, measurements={'0': [1]}) assert ( str(step_result) @@ -292,17 +309,17 @@ def test_empty_step_result(): def test_state_equal(): q0, q1 = cirq.LineQubit.range(2) state0 = ccq.mps_simulator.MPSState( - qubit_map={q0: 0}, + qubits=(q0,), prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(cutoff=1e-3, sum_prob_atol=1e-3), ) state1a = ccq.mps_simulator.MPSState( - qubit_map={q1: 0}, + qubits=(q1,), prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(cutoff=1e-3, sum_prob_atol=1e-3), ) state1b = ccq.mps_simulator.MPSState( - qubit_map={q1: 0}, + qubits=(q1,), prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(cutoff=1729.0, sum_prob_atol=1e-3), ) @@ -500,7 +517,7 @@ def test_state_copy(): def test_state_act_on_args_initializer(): s = ccq.mps_simulator.MPSState( - qubit_map={cirq.LineQubit(0): 0}, + qubits=(cirq.LineQubit(0),), prng=np.random.RandomState(0), axes=[2], log_of_measurement_results={'test': 4}, diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index d0533d0527c..33eb2fab7b4 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -13,13 +13,18 @@ # limitations under the License. """Objects and methods for acting efficiently on a state tensor.""" import abc -from typing import Any, Iterable, Dict, List +from typing import Any, Iterable, Dict, List, TypeVar, TYPE_CHECKING, Sequence import numpy as np from cirq import protocols from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits +TSelf = TypeVar('TSelf', bound='ActOnArgs') + +if TYPE_CHECKING: + import cirq + class ActOnArgs: """State and context for an operation acting on a state tensor.""" @@ -27,23 +32,31 @@ class ActOnArgs: def __init__( self, prng: np.random.RandomState, + qubits: Sequence['cirq.Qid'] = None, axes: Iterable[int] = None, log_of_measurement_results: Dict[str, Any] = None, ): """ Args: - axes: The indices of axes corresponding to the qubits that the - operation is supposed to act upon. prng: The pseudo random number generator to use for probabilistic effects. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + axes: The indices of axes corresponding to the qubits that the + operation is supposed to act upon. log_of_measurement_results: A mutable object that measurements are being recorded into. Edit it easily by calling `ActOnStateVectorArgs.record_measurement_result`. """ + if qubits is None: + qubits = () if axes is None: - axes = [] + axes = () if log_of_measurement_results is None: log_of_measurement_results = {} + self.qubits = tuple(qubits) + self.qubit_map = {q: i for i, q in enumerate(self.qubits)} self.axes = tuple(axes) self.prng = prng self.log_of_measurement_results = log_of_measurement_results @@ -68,6 +81,10 @@ def _perform_measurement(self) -> List[int]: """Child classes that perform measurements should implement this with the implementation.""" + @abc.abstractmethod + def copy(self: TSelf) -> TSelf: + """Creates a copy of the object.""" + def strat_act_on_from_apply_decompose( val: Any, diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index cb57f0b3815..a029772bd61 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -13,13 +13,16 @@ # limitations under the License. """Objects and methods for acting efficiently on a density matrix.""" -from typing import Any, Iterable, Dict, List, Tuple +from typing import Any, Iterable, Dict, List, Tuple, TYPE_CHECKING, Sequence import numpy as np from cirq import protocols, sim from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose +if TYPE_CHECKING: + import cirq + class ActOnDensityMatrixArgs(ActOnArgs): """State and context for an operation acting on a density matrix. @@ -36,6 +39,7 @@ def __init__( qid_shape: Tuple[int, ...], prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], + qubits: Sequence['cirq.Qid'] = None, ): """ Args: @@ -46,6 +50,9 @@ def __init__( `target_tensor`. Used by operations that cannot be applied to `target_tensor` inline, in order to avoid unnecessary allocations. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. axes: The indices of axes corresponding to the qubits that the operation is supposed to act upon. qid_shape: The shape of the target tensor. @@ -55,7 +62,7 @@ def __init__( being recorded into. Edit it easily by calling `ActOnStateVectorArgs.record_measurement_result`. """ - super().__init__(prng, axes, log_of_measurement_results) + super().__init__(prng, qubits, axes, log_of_measurement_results) self.target_tensor = target_tensor self.available_buffer = available_buffer self.qid_shape = qid_shape @@ -92,6 +99,17 @@ def _perform_measurement(self) -> List[int]: ) return bits + def copy(self) -> 'cirq.ActOnDensityMatrixArgs': + return ActOnDensityMatrixArgs( + target_tensor=self.target_tensor.copy(), + available_buffer=[b.copy() for b in self.available_buffer], + qubits=self.qubits, + axes=self.axes, + qid_shape=self.qid_shape, + prng=self.prng, + log_of_measurement_results=self.log_of_measurement_results.copy(), + ) + def _strat_apply_channel_to_state( action: Any, diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 5ac08918e7f..9fb5dcaa46f 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -13,7 +13,7 @@ # limitations under the License. """Objects and methods for acting efficiently on a state vector.""" -from typing import Any, Iterable, Tuple, TYPE_CHECKING, Union, Dict, List +from typing import Any, Iterable, Tuple, TYPE_CHECKING, Union, Dict, List, Sequence import numpy as np @@ -43,6 +43,7 @@ def __init__( axes: Iterable[int], prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], + qubits: Sequence['cirq.Qid'] = None, ): """ Args: @@ -54,6 +55,9 @@ def __init__( `target_tensor` inline, in order to avoid unnecessary allocations. Passing `available_buffer` into `swap_target_tensor_for` will swap it for `target_tensor`. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. axes: The indices of axes corresponding to the qubits that the operation is supposed to act upon. prng: The pseudo random number generator to use for probabilistic @@ -62,7 +66,7 @@ def __init__( being recorded into. Edit it easily by calling `ActOnStateVectorArgs.record_measurement_result`. """ - super().__init__(prng, axes, log_of_measurement_results) + super().__init__(prng, qubits, axes, log_of_measurement_results) self.target_tensor = target_tensor self.available_buffer = available_buffer @@ -167,6 +171,16 @@ def _perform_measurement(self) -> List[int]: ) return bits + def copy(self) -> 'cirq.ActOnStateVectorArgs': + return ActOnStateVectorArgs( + target_tensor=self.target_tensor.copy(), + available_buffer=self.available_buffer.copy(), + qubits=self.qubits, + axes=self.axes, + prng=self.prng, + log_of_measurement_results=self.log_of_measurement_results.copy(), + ) + def _strat_act_on_state_vector_from_apply_unitary( unitary_value: Any, diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 54d5414604c..312e4b18430 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -14,7 +14,7 @@ """A protocol for implementing high performance clifford tableau evolutions for Clifford Simulator.""" -from typing import Any, Dict, Iterable, TYPE_CHECKING, List +from typing import Any, Dict, Iterable, TYPE_CHECKING, List, Sequence import numpy as np @@ -43,11 +43,15 @@ def __init__( axes: Iterable[int], prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], + qubits: Sequence['cirq.Qid'] = None, ): """ Args: tableau: The CliffordTableau to act on. Operations are expected to perform inplace edits of this object. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. axes: The indices of axes corresponding to the qubits that the operation is supposed to act upon. prng: The pseudo random number generator to use for probabilistic @@ -56,7 +60,7 @@ def __init__( being recorded into. Edit it easily by calling `ActOnCliffordTableauArgs.record_measurement_result`. """ - super().__init__(prng, axes, log_of_measurement_results) + super().__init__(prng, qubits, axes, log_of_measurement_results) self.tableau = tableau def _act_on_fallback_(self, action: Any, allow_decompose: bool): @@ -77,6 +81,15 @@ def _perform_measurement(self) -> List[int]: """Returns the measurement from the tableau.""" return [self.tableau._measure(q, self.prng) for q in self.axes] + def copy(self) -> 'cirq.ActOnCliffordTableauArgs': + return ActOnCliffordTableauArgs( + tableau=self.tableau.copy(), + qubits=self.qubits, + axes=self.axes, + prng=self.prng, + log_of_measurement_results=self.log_of_measurement_results.copy(), + ) + def _strat_act_on_clifford_tableau_from_single_qubit_decompose( val: Any, args: 'cirq.ActOnCliffordTableauArgs' diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py index 9fd16de6918..fff652f998c 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py @@ -82,3 +82,21 @@ class NoDetailsSingleQubitGate(cirq.SingleQubitGate): with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(NoDetailsSingleQubitGate(), args) + + +def test_copy(): + args = cirq.ActOnCliffordTableauArgs( + tableau=cirq.CliffordTableau(num_qubits=3), + axes=[1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + args1 = args.copy() + assert isinstance(args1, cirq.ActOnCliffordTableauArgs) + assert args is not args1 + assert args.tableau is not args1.tableau + assert args.tableau == args1.tableau + assert args.axes == args1.axes + assert args.prng is args1.prng + assert args.log_of_measurement_results is not args1.log_of_measurement_results + assert args.log_of_measurement_results == args.log_of_measurement_results diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index b3470f36bb0..071a2a6c21b 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, TYPE_CHECKING, List +from typing import Any, Dict, Iterable, TYPE_CHECKING, List, Sequence import numpy as np @@ -40,11 +40,15 @@ def __init__( axes: Iterable[int], prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], + qubits: Sequence['cirq.Qid'] = None, ): """Initializes with the given state and the axes for the operation. Args: state: The StabilizerStateChForm to act on. Operations are expected to perform inplace edits of this object. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. axes: The indices of axes corresponding to the qubits that the operation is supposed to act upon. prng: The pseudo random number generator to use for probabilistic @@ -53,7 +57,7 @@ def __init__( being recorded into. Edit it easily by calling `ActOnStabilizerCHFormArgs.record_measurement_result`. """ - super().__init__(prng, axes, log_of_measurement_results) + super().__init__(prng, qubits, axes, log_of_measurement_results) self.state = state def _act_on_fallback_(self, action: Any, allow_decompose: bool): @@ -72,6 +76,15 @@ def _perform_measurement(self) -> List[int]: """Returns the measurement from the stabilizer state form.""" return [self.state._measure(q, self.prng) for q in self.axes] + def copy(self) -> 'cirq.ActOnStabilizerCHFormArgs': + return ActOnStabilizerCHFormArgs( + state=self.state.copy(), + qubits=self.qubits, + axes=self.axes, + prng=self.prng, + log_of_measurement_results=self.log_of_measurement_results.copy(), + ) + def _strat_act_on_stabilizer_ch_form_from_single_qubit_decompose( val: Any, args: 'cirq.ActOnStabilizerCHFormArgs' diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py index 8f3963cd4ce..803de773eba 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py @@ -106,3 +106,21 @@ def _unitary_(self): ) cirq.act_on(cirq.H, expected_args) np.testing.assert_allclose(args.state.state_vector(), expected_args.state.state_vector()) + + +def test_copy(): + args = cirq.ActOnStabilizerCHFormArgs( + state=cirq.StabilizerStateChForm(num_qubits=3), + axes=[1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + args1 = args.copy() + assert isinstance(args1, cirq.ActOnStabilizerCHFormArgs) + assert args is not args1 + assert args.state is not args1.state + np.testing.assert_equal(args.state.state_vector(), args1.state.state_vector()) + assert args.axes == args1.axes + assert args.prng is args1.prng + assert args.log_of_measurement_results is not args1.log_of_measurement_results + assert args.log_of_measurement_results == args.log_of_measurement_results diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 03d31e33e79..9f440570fca 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -29,23 +29,26 @@ to state vector amplitudes. """ -from typing import Any, Dict, List, Iterator, Sequence +from typing import Any, Dict, List, Sequence, Union import numpy as np import cirq from cirq import circuits, study, ops, protocols, value +from cirq._compat import deprecated from cirq.ops.dense_pauli_string import DensePauliString from cirq.protocols import act_on from cirq.sim import clifford, simulator -from cirq._compat import deprecated from cirq.sim.simulator import check_all_resolved class CliffordSimulator( simulator.SimulatesSamples, simulator.SimulatesIntermediateState[ - 'CliffordSimulatorStepResult', 'CliffordTrialResult', 'CliffordState' + 'CliffordSimulatorStepResult', + 'CliffordTrialResult', + 'CliffordState', + clifford.ActOnStabilizerCHFormArgs, ], ): """An efficient simulator for Clifford circuits.""" @@ -65,55 +68,76 @@ def is_supported_operation(op: 'cirq.Operation') -> bool: # TODO: support more general Pauli measurements return protocols.has_stabilizer_effect(op) - def _base_iterator( - self, circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList, initial_state: int - ) -> Iterator['cirq.CliffordSimulatorStepResult']: - """Iterator over CliffordSimulatorStepResult from Moments of a Circuit + def _create_act_on_args( + self, + initial_state: Union[int, clifford.ActOnStabilizerCHFormArgs], + qubits: Sequence['cirq.Qid'], + ) -> clifford.ActOnStabilizerCHFormArgs: + """Creates the ActOnStabilizerChFormArgs for a circuit. Args: - circuit: The circuit to simulate. - qubit_order: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. initial_state: The initial state for the simulation in the computational basis. Represented as a big endian int. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + + Returns: + ActOnStabilizerChFormArgs for the circuit. + """ + if isinstance(initial_state, clifford.ActOnStabilizerCHFormArgs): + return initial_state + + qubit_map = {q: i for i, q in enumerate(qubits)} + + state = CliffordState(qubit_map, initial_state=initial_state) + return clifford.ActOnStabilizerCHFormArgs( + state=state.ch_form, + axes=[], + prng=self._prng, + log_of_measurement_results={}, + qubits=qubits, + ) + def _core_iterator( + self, + circuit: circuits.Circuit, + sim_state: clifford.ActOnStabilizerCHFormArgs, + ): + """Iterator over CliffordSimulatorStepResult from Moments of a Circuit + + Args: + circuit: The circuit to simulate. + sim_state: The initial state args for the simulation in the + computational basis. Yields: CliffordStepResult from simulating a Moment of the Circuit. """ - qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(circuit.all_qubits()) - qubit_map = {q: i for i, q in enumerate(qubits)} + def create_state(): + return CliffordState(sim_state.qubit_map, sim_state.state.copy()) if len(circuit) == 0: yield CliffordSimulatorStepResult( - measurements={}, state=CliffordState(qubit_map, initial_state=initial_state) + measurements=sim_state.log_of_measurement_results, state=create_state() ) return - state = CliffordState(qubit_map, initial_state=initial_state) - ch_form_args = clifford.ActOnStabilizerCHFormArgs( - state.ch_form, - [], - self._prng, - {}, - ) - for moment in circuit: - ch_form_args.log_of_measurement_results = {} + sim_state.log_of_measurement_results = {} for op in moment: try: - ch_form_args.axes = tuple(state.qubit_map[i] for i in op.qubits) - act_on(op, ch_form_args) + sim_state.axes = tuple(sim_state.qubit_map[i] for i in op.qubits) + act_on(op, sim_state) except TypeError: raise NotImplementedError( f"CliffordSimulator doesn't support {op!r}" ) # type: ignore yield CliffordSimulatorStepResult( - measurements=ch_form_args.log_of_measurement_results, state=state + measurements=sim_state.log_of_measurement_results, state=create_state() ) def _create_simulator_trial_result( @@ -173,7 +197,7 @@ def __str__(self) -> str: class CliffordSimulatorStepResult(simulator.StepResult['CliffordState']): """A `StepResult` that includes `StateVectorMixin` methods.""" - def __init__(self, state, measurements): + def __init__(self, state: 'CliffordState', measurements): """Results of a step of the simulator. Attributes: state: A CliffordState @@ -235,11 +259,15 @@ class CliffordState: Gates and measurements are applied to each representation in O(n^2) time. """ - def __init__(self, qubit_map, initial_state=0): + def __init__(self, qubit_map, initial_state: Union[int, clifford.StabilizerStateChForm] = 0): self.qubit_map = qubit_map self.n = len(qubit_map) - self.ch_form = clifford.StabilizerStateChForm(self.n, initial_state) + self.ch_form = ( + initial_state + if isinstance(initial_state, clifford.StabilizerStateChForm) + else clifford.StabilizerStateChForm(self.n, initial_state) + ) def _json_dict_(self): return { diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index ebeebdb4f84..2b9810370ef 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -97,6 +97,27 @@ def test_simulate_initial_state(): ) +def test_simulate_act_on_args(): + q0, q1 = cirq.LineQubit.range(2) + simulator = cirq.CliffordSimulator() + for b0 in [0, 1]: + for b1 in [0, 1]: + circuit = cirq.Circuit() + if b0: + circuit.append(cirq.X(q0)) + if b1: + circuit.append(cirq.X(q1)) + circuit.append(cirq.measure(q0, q1)) + + args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1)) + result = simulator.simulate(circuit, initial_state=args) + expected_state = np.zeros(shape=(2, 2)) + expected_state[b0][1 - b1] = 1.0 + np.testing.assert_almost_equal( + result.final_state.to_numpy(), np.reshape(expected_state, 4) + ) + + def test_simulate_qubit_order(): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index ceecae558ca..ba311385f94 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -13,7 +13,7 @@ # limitations under the License. """Simulator for density matrices that simulates noisy quantum circuits.""" import collections -from typing import Any, Dict, Iterator, List, TYPE_CHECKING, Tuple, Union +from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Union, Sequence import numpy as np @@ -30,7 +30,10 @@ class DensityMatrixSimulator( simulator.SimulatesSamples, simulator.SimulatesIntermediateState[ - 'DensityMatrixStepResult', 'DensityMatrixTrialResult', 'DensityMatrixSimulatorState' + 'DensityMatrixStepResult', + 'DensityMatrixTrialResult', + 'DensityMatrixSimulatorState', + act_on_density_matrix_args.ActOnDensityMatrixArgs, ], ): """A simulator for density matrices and noisy quantum circuits. @@ -167,42 +170,36 @@ def _run( param_resolver = param_resolver or study.ParamResolver({}) resolved_circuit = protocols.resolve_parameters(circuit, param_resolver) check_all_resolved(resolved_circuit) - qubit_order = sorted(resolved_circuit.all_qubits()) + qubits = tuple(sorted(resolved_circuit.all_qubits())) + act_on_args = self._create_act_on_args(0, qubits) prefix, general_suffix = split_into_matching_protocol_then_general( resolved_circuit, lambda op: not protocols.is_measurement(op) ) step_result = None - for step_result in self._base_iterator( + for step_result in self._core_iterator( circuit=prefix, - qubit_order=qubit_order, - initial_state=0, + sim_state=act_on_args, ): pass assert step_result is not None - intermediate_state = step_result._density_matrix if general_suffix.are_all_measurements_terminal() and not any( general_suffix.findall_operations(lambda op: isinstance(op, circuits.CircuitOperation)) ): - return self._run_sweep_sample( - general_suffix, repetitions, qubit_order, intermediate_state - ) - return self._run_sweep_repeat(general_suffix, repetitions, qubit_order, intermediate_state) + return self._run_sweep_sample(general_suffix, repetitions, act_on_args) + return self._run_sweep_repeat(general_suffix, repetitions, act_on_args) def _run_sweep_sample( self, circuit: circuits.Circuit, repetitions: int, - qubit_order: ops.QubitOrderOrList, - intermediate_state: np.ndarray, + act_on_args: act_on_density_matrix_args.ActOnDensityMatrixArgs, ) -> Dict[str, np.ndarray]: - for step_result in self._base_iterator( + for step_result in self._core_iterator( circuit=circuit, - qubit_order=qubit_order, - initial_state=intermediate_state, + sim_state=act_on_args, all_measurements_are_terminal=True, - is_raw_state=True, ): pass measurement_ops = [ @@ -214,17 +211,14 @@ def _run_sweep_repeat( self, circuit: circuits.Circuit, repetitions: int, - qubit_order: ops.QubitOrderOrList, - intermediate_state: np.ndarray, + act_on_args: act_on_density_matrix_args.ActOnDensityMatrixArgs, ) -> Dict[str, np.ndarray]: measurements = {} # type: Dict[str, List[np.ndarray]] for _ in range(repetitions): - all_step_results = self._base_iterator( + all_step_results = self._core_iterator( circuit, - qubit_order=qubit_order, - initial_state=intermediate_state, - is_raw_state=True, + sim_state=act_on_args.copy(), ) for step_result in all_step_results: for k, v in step_result.measurements.items(): @@ -233,41 +227,71 @@ def _run_sweep_repeat( measurements[k].append(np.array(v, dtype=np.uint8)) return {k: np.array(v) for k, v in measurements.items()} - def _base_iterator( + def _create_act_on_args( self, - circuit: circuits.Circuit, - qubit_order: ops.QubitOrderOrList, - initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'], - all_measurements_are_terminal=False, - is_raw_state=False, - ) -> Iterator['DensityMatrixStepResult']: - qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(circuit.all_qubits()) + initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], + qubits: Sequence['cirq.Qid'], + ) -> 'cirq.ActOnDensityMatrixArgs': + """Creates the ActOnDensityMatrixArgs for a circuit. + + Args: + initial_state: The initial state for the simulation in the + computational basis. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + + Returns: + ActOnDensityMatrixArgs for the circuit. + """ + if isinstance(initial_state, act_on_density_matrix_args.ActOnDensityMatrixArgs): + return initial_state + qid_shape = protocols.qid_shape(qubits) - qubit_map = {q: i for i, q in enumerate(qubits)} - initial_matrix = ( - qis.to_valid_density_matrix( - initial_state, len(qid_shape), qid_shape=qid_shape, dtype=self._dtype - ) - if not is_raw_state - else initial_state + initial_matrix = qis.to_valid_density_matrix( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=self._dtype ) if np.may_share_memory(initial_matrix, initial_state): initial_matrix = initial_matrix.copy() - if len(circuit) == 0: - yield DensityMatrixStepResult(initial_matrix, {}, qubit_map, self._dtype) - return - tensor = initial_matrix.reshape(qid_shape * 2) - sim_state = act_on_density_matrix_args.ActOnDensityMatrixArgs( + return act_on_density_matrix_args.ActOnDensityMatrixArgs( target_tensor=tensor, available_buffer=[np.empty_like(tensor) for _ in range(3)], + qubits=qubits, axes=[], qid_shape=qid_shape, prng=self._prng, log_of_measurement_results={}, ) + def _core_iterator( + self, + circuit: circuits.Circuit, + sim_state: act_on_density_matrix_args.ActOnDensityMatrixArgs, + all_measurements_are_terminal: bool = False, + ): + """Iterator over DensityMatrixStepResult from Moments of a Circuit + + Args: + circuit: The circuit to simulate. + sim_state: The initial state args for the simulation in the + computational basis. + all_measurements_are_terminal: Indicator that all measurements + are terminal, allowing optimization. + + Yields: + DensityMatrixStepResult from simulating a Moment of the Circuit. + """ + if len(circuit) == 0: + yield DensityMatrixStepResult( + density_matrix=sim_state.target_tensor, + measurements=dict(sim_state.log_of_measurement_results), + qubit_map=sim_state.qubit_map, + dtype=self._dtype, + ) + return + noisy_moments = self.noise.noisy_moments(circuit, sorted(circuit.all_qubits())) measured = collections.defaultdict(bool) # type: Dict[Tuple[cirq.Qid, ...], bool] for moment in noisy_moments: @@ -282,13 +306,13 @@ def _base_iterator( continue if self._ignore_measurement_results: op = ops.phase_damp(1).on(*op.qubits) - sim_state.axes = tuple(qubit_map[qubit] for qubit in op.qubits) + sim_state.axes = tuple(sim_state.qubit_map[qubit] for qubit in op.qubits) protocols.act_on(op, sim_state) yield DensityMatrixStepResult( density_matrix=sim_state.target_tensor, measurements=dict(sim_state.log_of_measurement_results), - qubit_map=qubit_map, + qubit_map=sim_state.qubit_map, dtype=self._dtype, ) sim_state.log_of_measurement_results.clear() diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 6e652f8e325..3a78a81dbb9 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -269,7 +269,7 @@ def _channel_(self): def test_run_measure_at_end_no_repetitions(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -287,7 +287,7 @@ def test_run_measure_at_end_no_repetitions(dtype): def test_run_repetitions_measure_at_end(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -303,7 +303,7 @@ def test_run_repetitions_measure_at_end(dtype): def test_run_qudits_repetitions_measure_at_end(dtype): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1, 2]: circuit = cirq.Circuit( @@ -321,7 +321,7 @@ def test_run_qudits_repetitions_measure_at_end(dtype): def test_run_measurement_not_terminal_no_repetitions(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -344,7 +344,7 @@ def test_run_measurement_not_terminal_no_repetitions(dtype): def test_run_repetitions_measurement_not_terminal(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -365,7 +365,7 @@ def test_run_repetitions_measurement_not_terminal(dtype): def test_run_qudits_repetitions_measurement_not_terminal(dtype): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1, 2]: circuit = cirq.Circuit( @@ -560,6 +560,20 @@ def test_simulate_initial_state(dtype): np.testing.assert_equal(result.final_density_matrix, expected_density_matrix) +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_simulate_act_on_args(dtype): + q0, q1 = cirq.LineQubit.range(2) + simulator = cirq.DensityMatrixSimulator(dtype=dtype) + for b0 in [0, 1]: + for b1 in [0, 1]: + circuit = cirq.Circuit((cirq.X ** b0)(q0), (cirq.X ** b1)(q1)) + args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1)) + result = simulator.simulate(circuit, initial_state=args) + expected_density_matrix = np.zeros(shape=(4, 4)) + expected_density_matrix[b0 * 2 + 1 - b1, b0 * 2 + 1 - b1] = 1.0 + np.testing.assert_equal(result.final_density_matrix, expected_density_matrix) + + def test_simulate_tps_initial_state(): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator() @@ -710,7 +724,10 @@ def test_simulate_moment_steps_empty_circuit(dtype): for step in simulator.simulate_moment_steps(circuit): pass assert step._simulator_state() == cirq.DensityMatrixSimulatorState( - density_matrix=np.array([[1]]), qubit_map={} + density_matrix=np.array( + 1, + ), + qubit_map={}, ) @@ -1304,7 +1321,7 @@ def test_nonmeasuring_subcircuits_do_not_cause_sweep_repeat(): cirq.measure(q, key='x'), ) simulator = cirq.DensityMatrixSimulator() - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: simulator.run(circuit, repetitions=10) assert mock_sim.call_count == 2 @@ -1316,7 +1333,7 @@ def test_measuring_subcircuits_cause_sweep_repeat(): cirq.measure(q, key='x'), ) simulator = cirq.DensityMatrixSimulator() - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: simulator.run(circuit, repetitions=10) assert mock_sim.call_count == 11 diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 3b5e3b5d48b..2feea53bc6c 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -51,6 +51,7 @@ from cirq import circuits, ops, protocols, study, value, work from cirq._compat import deprecated +from cirq.sim.act_on_args import ActOnArgs if TYPE_CHECKING: import cirq @@ -59,6 +60,7 @@ TStepResult = TypeVar('TStepResult', bound='StepResult') TSimulationTrialResult = TypeVar('TSimulationTrialResult', bound='SimulationTrialResult') TSimulatorState = TypeVar('TSimulatorState') +TActOnArgs = TypeVar('TActOnArgs', bound=ActOnArgs) class SimulatesSamples(work.Sampler, metaclass=abc.ABCMeta): @@ -374,7 +376,7 @@ def simulate_sweep( class SimulatesIntermediateState( - Generic[TStepResult, TSimulationTrialResult, TSimulatorState], + Generic[TStepResult, TSimulationTrialResult, TSimulatorState, TActOnArgs], SimulatesFinalState[TSimulationTrialResult], metaclass=abc.ABCMeta, ): @@ -410,8 +412,9 @@ def simulate_sweep( qubit_order: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - initial_state: The initial state for the simulation. The form of - this state depends on the simulation implementation. See + initial_state: The initial state for the simulation. This can be + either a raw state or a `TActOnArgs`. The form of the + raw state depends on the simulation implementation. See documentation of the implementing class for details. Returns: @@ -455,8 +458,9 @@ def simulate_moment_steps( qubit_order: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - initial_state: The initial state for the simulation. The form of - this state depends on the simulation implementation. See + initial_state: The initial state for the simulation. This can be + either a raw state or a `TActOnArgs`. The form of the + raw state depends on the simulation implementation. See documentation of the implementing class for details. Returns: @@ -503,7 +507,6 @@ def _simulator_iterator( """ return self.simulate_moment_steps(circuit, param_resolver, qubit_order, initial_state) - @abc.abstractmethod def _base_iterator( self, circuit: circuits.Circuit, @@ -512,10 +515,14 @@ def _base_iterator( ) -> Iterator[TStepResult]: """Iterator over StepResult from Moments of a Circuit. + This is a thin wrapper around `create_act_on_args` and `_core_iterator`. + Overriding this method was the old way of creating a circuit iterator, + and this method is planned to be formally put on the deprecation path. + Going forward, override the aforementioned two methods in custom + simulators. + Args: circuit: The circuit to simulate. - param_resolver: A ParamResolver for determining values of - Symbols. qubit_order: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. @@ -523,6 +530,52 @@ def _base_iterator( this state depends on the simulation implementation. See documentation of the implementing class for details. + Yields: + StepResults from simulating a Moment of the Circuit. + """ + qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(circuit.all_qubits()) + act_on_args = self._create_act_on_args(initial_state, qubits) + return self._core_iterator(circuit, act_on_args) + + @abc.abstractmethod + def _create_act_on_args( + self, + initial_state: Any, + qubits: Sequence['cirq.Qid'], + ) -> TActOnArgs: + """Creates the ActOnArgs state for a simulator. + + Custom simulators should implement this method. + + Args: + initial_state: The initial state for the simulation. The form of + this state depends on the simulation implementation. See + documentation of the implementing class for details. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + + Returns: + The ActOnArgs for this simulator. + """ + raise NotImplementedError() + + @abc.abstractmethod + def _core_iterator( + self, + circuit: circuits.Circuit, + sim_state: TActOnArgs, + ) -> Iterator[TStepResult]: + """Iterator over StepResult from Moments of a Circuit. + + Custom simulators should implement this method. + + Args: + circuit: The circuit to simulate. + sim_state: The initial args for the simulation. The form of + this state depends on the simulation implementation. See + documentation of the implementing class for details. + Yields: StepResults from simulating a Moment of the Circuit. """ diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index c2d6865289d..5364b6d7c80 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -25,12 +25,13 @@ TSimulatorState, SimulatesIntermediateState, SimulationTrialResult, + TActOnArgs, ) class SimulatesIntermediateStateImpl( - Generic[TStepResult, TSimulatorState], - SimulatesIntermediateState[TStepResult, 'SimulationTrialResult', TSimulatorState], + Generic[TStepResult, TSimulatorState, TActOnArgs], + SimulatesIntermediateState[TStepResult, 'SimulationTrialResult', TSimulatorState, TActOnArgs], metaclass=abc.ABCMeta, ): """A SimulatesIntermediateState that uses the default SimulationTrialResult type.""" diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 51a35438eee..6058eacff84 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -18,13 +18,13 @@ from typing import ( Any, Dict, - Iterator, List, Type, TYPE_CHECKING, DefaultDict, Union, cast, + Sequence, ) import numpy as np @@ -173,7 +173,8 @@ def _run( param_resolver = param_resolver or study.ParamResolver({}) resolved_circuit = protocols.resolve_parameters(circuit, param_resolver) check_all_resolved(resolved_circuit) - qubit_order = sorted(resolved_circuit.all_qubits()) + qubits = tuple(sorted(resolved_circuit.all_qubits())) + act_on_args = self._create_act_on_args(0, qubits) # Simulate as many unitary operations as possible before having to # repeat work for each sample. @@ -183,10 +184,9 @@ def _run( else (resolved_circuit[0:0], resolved_circuit) ) step_result = None - for step_result in self._base_iterator( + for step_result in self._core_iterator( circuit=unitary_prefix, - qubit_order=qubit_order, - initial_state=0, + sim_state=act_on_args, ): pass assert step_result is not None @@ -201,69 +201,98 @@ def _run( seed=self._prng, ) - qid_shape = protocols.qid_shape(qubit_order) - intermediate_state = step_result.state_vector().reshape(qid_shape) return self._brute_force_samples( - initial_state=intermediate_state, + act_on_args=act_on_args, circuit=general_suffix, repetitions=repetitions, - qubit_order=qubit_order, ) def _brute_force_samples( self, - initial_state: np.ndarray, + act_on_args: act_on_state_vector_args.ActOnStateVectorArgs, circuit: circuits.Circuit, - qubit_order: 'cirq.QubitOrderOrList', repetitions: int, ) -> Dict[str, np.ndarray]: """Repeatedly simulate a circuit in order to produce samples.""" measurements: DefaultDict[str, List[np.ndarray]] = collections.defaultdict(list) for _ in range(repetitions): - all_step_results = self._base_iterator( - circuit, initial_state=initial_state, qubit_order=qubit_order - ) + all_step_results = self._core_iterator(circuit, sim_state=act_on_args.copy()) for step_result in all_step_results: for k, v in step_result.measurements.items(): measurements[k].append(np.array(v, dtype=np.uint8)) return {k: np.array(v) for k, v in measurements.items()} - def _base_iterator( + def _create_act_on_args( self, - circuit: circuits.Circuit, - qubit_order: ops.QubitOrderOrList, - initial_state: 'cirq.STATE_VECTOR_LIKE', - ) -> Iterator['SparseSimulatorStep']: - qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(circuit.all_qubits()) + initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], + qubits: Sequence['cirq.Qid'], + ): + """Creates the ActOnStateVectorArgs for a circuit. + + Args: + initial_state: The initial state for the simulation in the + computational basis. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + + Returns: + ActOnStateVectorArgs for the circuit. + """ + if isinstance(initial_state, act_on_state_vector_args.ActOnStateVectorArgs): + return initial_state + num_qubits = len(qubits) qid_shape = protocols.qid_shape(qubits) - qubit_map = {q: i for i, q in enumerate(qubits)} state = qis.to_valid_state_vector( initial_state, num_qubits, qid_shape=qid_shape, dtype=self._dtype ) - if len(circuit) == 0: - yield SparseSimulatorStep(state, {}, qubit_map, self._dtype) - sim_state = act_on_state_vector_args.ActOnStateVectorArgs( + return act_on_state_vector_args.ActOnStateVectorArgs( target_tensor=np.reshape(state, qid_shape), available_buffer=np.empty(qid_shape, dtype=self._dtype), + qubits=qubits, axes=[], prng=self._prng, log_of_measurement_results={}, ) + def _core_iterator( + self, + circuit: circuits.Circuit, + sim_state: act_on_state_vector_args.ActOnStateVectorArgs, + ): + """Iterator over SparseSimulatorStep from Moments of a Circuit + + Args: + circuit: The circuit to simulate. + sim_state: The initial state args for the simulation in the + computational basis. + + Yields: + SparseSimulatorStep from simulating a Moment of the Circuit. + """ + if len(circuit) == 0: + yield SparseSimulatorStep( + state_vector=sim_state.target_tensor, + measurements=dict(sim_state.log_of_measurement_results), + qubit_map=sim_state.qubit_map, + dtype=self._dtype, + ) + return + noisy_moments = self.noise.noisy_moments(circuit, sorted(circuit.all_qubits())) for op_tree in noisy_moments: for op in flatten_to_ops(op_tree): - sim_state.axes = tuple(qubit_map[qubit] for qubit in op.qubits) + sim_state.axes = tuple(sim_state.qubit_map[qubit] for qubit in op.qubits) protocols.act_on(op, sim_state) yield SparseSimulatorStep( state_vector=sim_state.target_tensor, measurements=dict(sim_state.log_of_measurement_results), - qubit_map=qubit_map, + qubit_map=sim_state.qubit_map, dtype=self._dtype, ) sim_state.log_of_measurement_results.clear() diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index 93775f2b2b2..3590281d500 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -89,7 +89,7 @@ def test_run_bit_flips(dtype): def test_run_measure_at_end_no_repetitions(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -114,7 +114,7 @@ def test_run_repetitions_terminal_measurement_stochastic(): def test_run_repetitions_measure_at_end(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -131,7 +131,7 @@ def test_run_repetitions_measure_at_end(dtype): def test_run_invert_mask_measure_not_terminal(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -151,7 +151,7 @@ def test_run_invert_mask_measure_not_terminal(dtype): def test_run_partial_invert_mask_measure_not_terminal(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -171,7 +171,7 @@ def test_run_partial_invert_mask_measure_not_terminal(dtype): def test_run_measurement_not_terminal_no_repetitions(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -194,7 +194,7 @@ def test_run_measurement_not_terminal_no_repetitions(dtype): def test_run_repetitions_measurement_not_terminal(dtype): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype) - with mock.patch.object(simulator, '_base_iterator', wraps=simulator._base_iterator) as mock_sim: + with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit( @@ -449,6 +449,20 @@ def test_simulate_initial_state(dtype): np.testing.assert_equal(result.final_state_vector, np.reshape(expected_state, 4)) +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_simulate_act_on_args(dtype): + q0, q1 = cirq.LineQubit.range(2) + simulator = cirq.Simulator(dtype=dtype) + for b0 in [0, 1]: + for b1 in [0, 1]: + circuit = cirq.Circuit((cirq.X ** b0)(q0), (cirq.X ** b1)(q1)) + args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1)) + result = simulator.simulate(circuit, initial_state=args) + expected_state = np.zeros(shape=(2, 2)) + expected_state[b0][1 - b1] = 1.0 + np.testing.assert_equal(result.final_state_vector, np.reshape(expected_state, 4)) + + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) def test_simulate_qubit_order(dtype): q0, q1 = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index 4e6e013833e..12de1c3a5a3 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -21,6 +21,7 @@ from cirq import ops, study, value from cirq.sim import simulator, state_vector +from cirq.sim.act_on_state_vector_args import ActOnStateVectorArgs if TYPE_CHECKING: import cirq @@ -33,7 +34,10 @@ class SimulatesIntermediateStateVector( Generic[TStateVectorStepResult], simulator.SimulatesAmplitudes, simulator.SimulatesIntermediateState[ - TStateVectorStepResult, 'StateVectorTrialResult', 'StateVectorSimulatorState' + TStateVectorStepResult, + 'StateVectorTrialResult', + 'StateVectorSimulatorState', + ActOnStateVectorArgs, ], metaclass=abc.ABCMeta, ): diff --git a/cirq-google/cirq_google/calibration/engine_simulator.py b/cirq-google/cirq_google/calibration/engine_simulator.py index 15f5243bc98..bb1998e43da 100644 --- a/cirq-google/cirq_google/calibration/engine_simulator.py +++ b/cirq-google/cirq_google/calibration/engine_simulator.py @@ -30,6 +30,7 @@ SimulatesSamples, SimulatesIntermediateStateVector, StateVectorStepResult, + ActOnStateVectorArgs, ) from cirq.study import ParamResolver from cirq.value import RANDOM_STATE_OR_SEED_LIKE, parse_random_state @@ -395,14 +396,20 @@ def _run( converted = _convert_to_circuit_with_drift(self, circuit) return self._simulator._run(converted, param_resolver, repetitions) - def _base_iterator( + def _core_iterator( self, circuit: Circuit, - qubit_order: QubitOrderOrList, - initial_state: Any, + sim_state: Any, ) -> Iterator[StateVectorStepResult]: converted = _convert_to_circuit_with_drift(self, circuit) - return self._simulator._base_iterator(converted, qubit_order, initial_state) + return self._simulator._core_iterator(converted, sim_state) + + def _create_act_on_args( + self, + initial_state: Union[int, ActOnStateVectorArgs], + qubits: Sequence[Qid], + ) -> ActOnStateVectorArgs: + return self._simulator._create_act_on_args(initial_state, qubits) class _PhasedFSimConverter(PointOptimizer):