From 1d27a0062faf9a042f965a9be9cd12602be37321 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Tue, 21 Dec 2021 08:37:27 -0800 Subject: [PATCH] Fix str/repr explosion in separated states (#4518) * Fix str/repr explosion in separated states * format * Fix str/repr explosion in state vector * format * nondirac_str * TrialResultBase * cleanup * Add get_state_containing_qubit * Ensure eval(repr(obj)) holds * substates not a property * fix test * coverage * test * fix windows Co-authored-by: Cirq Bot --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/_compat.py | 3 + cirq-core/cirq/contrib/quimb/mps_simulator.py | 4 +- .../cirq/protocols/json_test_data/spec.py | 1 + cirq-core/cirq/sim/__init__.py | 3 +- .../cirq/sim/act_on_density_matrix_args.py | 15 +- .../cirq/sim/act_on_state_vector_args.py | 14 +- .../cirq/sim/clifford/clifford_simulator.py | 8 +- .../cirq/sim/density_matrix_simulator.py | 35 +++-- .../cirq/sim/density_matrix_simulator_test.py | 131 ++++++++++++++---- cirq-core/cirq/sim/simulator.py | 2 +- cirq-core/cirq/sim/simulator_base.py | 47 +++++++ cirq-core/cirq/sim/simulator_base_test.py | 2 +- cirq-core/cirq/sim/sparse_simulator.py | 11 +- cirq-core/cirq/sim/sparse_simulator_test.py | 54 ++++++++ cirq-core/cirq/sim/state_vector_simulator.py | 39 ++++-- .../cirq/sim/state_vector_simulator_test.py | 62 ++++++--- 17 files changed, 348 insertions(+), 84 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 110856c2993..f45004b9347 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -436,6 +436,7 @@ SimulatesIntermediateStateVector, SimulatesSamples, SimulationTrialResult, + SimulationTrialResultBase, Simulator, SimulatorBase, SparseSimulatorStep, diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 15acbf0d8ca..6dc089e0f6b 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -87,6 +87,9 @@ def _print(self, expr, **kwargs): f'\n)' ) + if isinstance(value, Dict): + return '{' + ','.join(f"{proper_repr(k)}: {proper_repr(v)}" for k, v in value.items()) + '}' + return repr(value) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 4247bcf6b97..1073672f4dc 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -25,7 +25,7 @@ import quimb.tensor as qtn from cirq import devices, study, ops, protocols, value -from cirq.sim import simulator, simulator_base +from cirq.sim import simulator_base from cirq.sim.act_on_args import ActOnArgs if TYPE_CHECKING: @@ -146,7 +146,7 @@ def _create_simulator_trial_result( ) -class MPSTrialResult(simulator.SimulationTrialResult): +class MPSTrialResult(simulator_base.SimulationTrialResultBase['MPSState', 'MPSState']): """A single trial reult""" def __init__( diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index eb951540e6d..e2ce8086e53 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -63,6 +63,7 @@ 'QuilFormatter', 'QuilOutput', 'SimulationTrialResult', + 'SimulationTrialResultBase', 'SparseSimulatorStep', 'StateVectorMixin', 'TextDiagramDrawer', diff --git a/cirq-core/cirq/sim/__init__.py b/cirq-core/cirq/sim/__init__.py index 02443dbeff1..e09f43eae59 100644 --- a/cirq-core/cirq/sim/__init__.py +++ b/cirq-core/cirq/sim/__init__.py @@ -64,8 +64,9 @@ ) from cirq.sim.simulator_base import ( - StepResultBase, + SimulationTrialResultBase, SimulatorBase, + StepResultBase, ) from cirq.sim.sparse_simulator import ( diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index d1b529d87a9..2e9734954cd 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -18,6 +18,7 @@ import numpy as np from cirq import protocols, sim +from cirq._compat import proper_repr from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose from cirq.linalg import transformations @@ -37,8 +38,8 @@ def __init__( target_tensor: np.ndarray, available_buffer: List[np.ndarray], qid_shape: Tuple[int, ...], - prng: np.random.RandomState, - log_of_measurement_results: Dict[str, Any], + prng: np.random.RandomState = None, + log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, ): """Inits ActOnDensityMatrixArgs. @@ -168,6 +169,16 @@ def sample( seed=seed, ) + def __repr__(self) -> str: + return ( + 'cirq.ActOnDensityMatrixArgs(' + f'target_tensor={proper_repr(self.target_tensor)},' + f' available_buffer={proper_repr(self.available_buffer)},' + f' qid_shape={self.qid_shape!r},' + f' qubits={self.qubits!r},' + f' log_of_measurement_results={proper_repr(self.log_of_measurement_results)})' + ) + def _strat_apply_channel_to_state( action: Any, args: ActOnDensityMatrixArgs, qubits: Sequence['cirq.Qid'] diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 11211f491da..a523047c71d 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -18,6 +18,7 @@ import numpy as np from cirq import linalg, protocols, sim +from cirq._compat import proper_repr from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose from cirq.linalg import transformations @@ -40,8 +41,8 @@ def __init__( self, target_tensor: np.ndarray, available_buffer: np.ndarray, - prng: np.random.RandomState, - log_of_measurement_results: Dict[str, Any], + prng: np.random.RandomState = None, + log_of_measurement_results: Dict[str, Any] = None, qubits: Sequence['cirq.Qid'] = None, ): """Inits ActOnStateVectorArgs. @@ -224,6 +225,15 @@ def sample( seed=seed, ) + def __repr__(self) -> str: + return ( + 'cirq.ActOnStateVectorArgs(' + f'target_tensor={proper_repr(self.target_tensor)},' + f' available_buffer={proper_repr(self.available_buffer)},' + f' qubits={self.qubits!r},' + f' log_of_measurement_results={proper_repr(self.log_of_measurement_results)})' + ) + def _strat_act_on_state_vector_from_apply_unitary( unitary_value: Any, diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index f017d07bcf4..194373625a4 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -36,7 +36,7 @@ import cirq from cirq import study, protocols, value from cirq.protocols import act_on -from cirq.sim import clifford, simulator, simulator_base +from cirq.sim import clifford, simulator_base class CliffordSimulator( @@ -114,7 +114,11 @@ def _create_simulator_trial_result( ) -class CliffordTrialResult(simulator.SimulationTrialResult): +class CliffordTrialResult( + simulator_base.SimulationTrialResultBase[ + 'clifford.CliffordState', 'clifford.ActOnStabilizerCHFormArgs' + ] +): def __init__( self, params: study.ParamResolver, diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index bad0553f0df..2cecac2f76d 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -17,6 +17,7 @@ import numpy as np from cirq import ops, protocols, qis, study, value +from cirq._compat import proper_repr from cirq.sim import ( simulator, act_on_density_matrix_args, @@ -283,7 +284,7 @@ class DensityMatrixStepResult( def __init__( self, sim_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', - simulator: DensityMatrixSimulator, + simulator: DensityMatrixSimulator = None, dtype: 'DTypeLike' = np.complex64, ): """DensityMatrixStepResult. @@ -315,7 +316,8 @@ def set_density_matrix(self, density_matrix_repr: Union[int, np.ndarray]): mixed state it must be correctly sized and positive semidefinite with trace one. """ - self._sim_state = self._simulator._create_act_on_args(density_matrix_repr, self._qubits) + if self._simulator: + self._sim_state = self._simulator._create_act_on_args(density_matrix_repr, self._qubits) def density_matrix(self, copy=True): """Returns the density matrix at this step in the simulation. @@ -361,6 +363,12 @@ def density_matrix(self, copy=True): self._density_matrix = np.reshape(matrix, (size, size)) return self._density_matrix.copy() if copy else self._density_matrix + def __repr__(self) -> str: + return ( + f'cirq.DensityMatrixStepResult(sim_state={self._sim_state!r},' + f' dtype=np.{self._dtype.__name__})' + ) + @value.value_equality(unhashable=True) class DensityMatrixSimulatorState: @@ -381,7 +389,7 @@ def _qid_shape_(self) -> Tuple[int, ...]: return self._qid_shape def _value_equality_values_(self) -> Any: - return (self.density_matrix.tolist(), self.qubit_map) + return self.density_matrix.tolist(), self.qubit_map def __repr__(self) -> str: return ( @@ -392,7 +400,11 @@ def __repr__(self) -> str: @value.value_equality(unhashable=True) -class DensityMatrixTrialResult(simulator.SimulationTrialResult): +class DensityMatrixTrialResult( + simulator_base.SimulationTrialResultBase[ + 'DensityMatrixSimulatorState', act_on_density_matrix_args.ActOnDensityMatrixArgs + ] +): """A `SimulationTrialResult` for `DensityMatrixSimulator` runs. The density matrix that is stored in this result is returned in the @@ -452,17 +464,24 @@ def final_density_matrix(self): def _value_equality_values_(self) -> Any: measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())} - return (self.params, measurements, self._final_simulator_state) + return self.params, measurements, self._final_simulator_state def __str__(self) -> str: samples = super().__str__() - return f'measurements: {samples}\nfinal density matrix:\n{self.final_density_matrix}' + ret = f'measurements: {samples}' + for substate in self._get_substates(): + tensor = substate.target_tensor + size = np.prod([tensor.shape[i] for i in range(tensor.ndim // 2)], dtype=np.int64) + dm = tensor.reshape((size, size)) + label = f'qubits: {substate.qubits}' if substate.qubits else 'phase:' + ret += f'\n\n{label}\nfinal density matrix:\n{dm}' + return ret def __repr__(self) -> str: return ( 'cirq.DensityMatrixTrialResult(' - f'params={self.params!r}, measurements={self.measurements!r}, ' - f'final_simulator_state={self._final_simulator_state!r})' + f'params={self.params!r}, measurements={proper_repr(self.measurements)}, ' + f'final_step_result={self._final_step_result!r})' ) def _repr_pretty_(self, p: Any, cycle: bool): diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 213932eb408..3e03cc19a05 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -1086,26 +1086,35 @@ def test_density_matrix_trial_result_qid_shape(): def test_density_matrix_trial_result_repr(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + args = cirq.ActOnDensityMatrixArgs( + target_tensor=np.ones((2, 2)) * 0.5, + available_buffer=[], + qid_shape=(2,), + prng=np.random.RandomState(0), + log_of_measurement_results={}, + qubits=[q0], ) - assert ( - repr( - cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({'s': 1}), - measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, - ) - ) - == "cirq.DensityMatrixTrialResult(" + final_step_result = cirq.DensityMatrixStepResult(args, cirq.DensityMatrixSimulator()) + trial_result = cirq.DensityMatrixTrialResult( + params=cirq.ParamResolver({'s': 1}), + measurements={'m': np.array([[1]], dtype=np.int32)}, + final_step_result=final_step_result, + ) + expected_repr = ( + "cirq.DensityMatrixTrialResult(" "params=cirq.ParamResolver({'s': 1}), " - "measurements={'m': array([[1]])}, " - "final_simulator_state=cirq.DensityMatrixSimulatorState(" - "density_matrix=np.array([[0.5, 0.5], [0.5, 0.5]]), " - "qubit_map={cirq.LineQubit(0): 0}))" - "" + "measurements={'m': np.array([[1]], dtype=np.int32)}, " + "final_step_result=cirq.DensityMatrixStepResult(" + "sim_state=cirq.ActOnDensityMatrixArgs(" + "target_tensor=np.array([[0.5, 0.5], [0.5, 0.5]], dtype=np.float64), " + "available_buffer=[], " + "qid_shape=(2,), " + "qubits=(cirq.LineQubit(0),), " + "log_of_measurement_results={}), " + "dtype=np.complex64))" ) + assert repr(trial_result) == expected_repr + assert eval(expected_repr) == trial_result class XAsOp(cirq.Operation): @@ -1192,10 +1201,15 @@ def test_works_on_pauli_string(): def test_density_matrix_trial_result_str(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + args = cirq.ActOnDensityMatrixArgs( + target_tensor=np.ones((2, 2)) * 0.5, + available_buffer=[], + qid_shape=(2,), + prng=np.random.RandomState(0), + log_of_measurement_results={}, + qubits=[q0], ) + final_step_result = cirq.DensityMatrixStepResult(args, cirq.DensityMatrixSimulator()) result = cirq.DensityMatrixTrialResult( params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result ) @@ -1204,14 +1218,24 @@ def test_density_matrix_trial_result_str(): # Eliminate whitespace to harden tests against this variation result_no_whitespace = str(result).replace('\n', '').replace(' ', '') assert result_no_whitespace == ( - 'measurements:(nomeasurements)finaldensitymatrix:[[0.50.5][0.50.5]]' + 'measurements:(nomeasurements)' + 'qubits:(cirq.LineQubit(0),)' + 'finaldensitymatrix:[[0.50.5][0.50.5]]' ) def test_density_matrix_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState( + args = cirq.ActOnDensityMatrixArgs( + target_tensor=np.ones((2, 2)) * 0.5, + available_buffer=[], + qid_shape=(2,), + prng=np.random.RandomState(0), + log_of_measurement_results={}, + qubits=[q0], + ) + final_step_result = cirq.DensityMatrixStepResult(args, cirq.DensityMatrixSimulator()) + final_step_result._simulator_state = cirq.DensityMatrixSimulatorState( density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} ) result = cirq.DensityMatrixTrialResult( @@ -1224,7 +1248,9 @@ def test_density_matrix_trial_result_repr_pretty(): # Eliminate whitespace to harden tests against this variation result_no_whitespace = fake_printer.text_pretty.replace('\n', '').replace(' ', '') assert result_no_whitespace == ( - 'measurements:(nomeasurements)finaldensitymatrix:[[0.50.5][0.50.5]]' + 'measurements:(nomeasurements)' + 'qubits:(cirq.LineQubit(0),)' + 'finaldensitymatrix:[[0.50.5][0.50.5]]' ) cirq.testing.assert_repr_pretty(result, "cirq.DensityMatrixTrialResult(...)", cycle=True) @@ -1573,7 +1599,7 @@ def test_density_matrices_same_with_or_without_split_untangled_states(): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.CX.on(q0, q1), cirq.reset(q1)) result1 = sim.simulate(circuit).final_density_matrix - sim = cirq.DensityMatrixSimulator(split_untangled_states=True) + sim = cirq.DensityMatrixSimulator() result2 = sim.simulate(circuit).final_density_matrix assert np.allclose(result1, result2) @@ -1590,17 +1616,68 @@ def test_large_untangled_okay(): _ = cirq.DensityMatrixSimulator(split_untangled_states=False).simulate(circuit) # Validate a simulation run - result = cirq.DensityMatrixSimulator(split_untangled_states=True).simulate(circuit) + result = cirq.DensityMatrixSimulator().simulate(circuit) assert set(result._final_step_result._qubits) == set(cirq.LineQubit.range(59)) # _ = result.final_density_matrix hangs (as expected) # Validate a trial run and sampling - result = cirq.DensityMatrixSimulator(split_untangled_states=True).run(circuit, repetitions=1000) + result = cirq.DensityMatrixSimulator().run(circuit, repetitions=1000) assert len(result.measurements) == 59 assert len(result.measurements['0']) == 1000 assert (result.measurements['0'] == np.full(1000, 1)).all() +def test_separated_states_str_does_not_merge(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0), + cirq.measure(q1), + cirq.X(q0), + ) + + result = cirq.DensityMatrixSimulator().simulate(circuit) + assert ( + str(result) + == """measurements: 0=0 1=0 + +qubits: (cirq.LineQubit(0),) +final density matrix: +[[0.+0.j 0.+0.j] + [0.+0.j 1.+0.j]] + +qubits: (cirq.LineQubit(1),) +final density matrix: +[[1.+0.j 0.+0.j] + [0.+0.j 0.+0.j]] + +phase: +final density matrix: +[[1.+0.j]]""" + ) + + +def test_unseparated_states_str(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0), + cirq.measure(q1), + cirq.X(q0), + ) + + result = cirq.DensityMatrixSimulator(split_untangled_states=False).simulate(circuit) + assert ( + str(result) + == """measurements: 0=0 1=0 + +qubits: (cirq.LineQubit(0), cirq.LineQubit(1)) +final density matrix: +[[0.+0.j 0.+0.j 0.+0.j 0.+0.j] + [0.+0.j 0.+0.j 0.+0.j 0.+0.j] + [0.+0.j 0.+0.j 1.+0.j 0.+0.j] + [0.+0.j 0.+0.j 0.+0.j 0.+0.j]]""" + ) + + def test_sweep_unparameterized_prefix_not_repeated_even_non_unitaries(): q = cirq.LineQubit(0) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 4750f621ce4..a9ac1d07759 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -889,7 +889,7 @@ def _repr_pretty_(self, p: Any, cycle: bool) -> None: def _value_equality_values_(self) -> Any: measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())} - return (self.params, measurements, self._final_simulator_state) + return self.params, measurements, self._final_simulator_state @property def qubit_map(self) -> Dict[ops.Qid, int]: diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 7d017b20a79..201eeca70b9 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -43,6 +43,7 @@ SimulatesIntermediateState, SimulatesSamples, StepResult, + SimulationTrialResult, check_all_resolved, split_into_matching_protocol_then_general, ) @@ -404,3 +405,49 @@ def sample( seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: return self._sim_state.sample(qubits, repetitions, seed) + + +class SimulationTrialResultBase( + Generic[TSimulatorState, TActOnArgs], SimulationTrialResult, abc.ABC +): + """A base class for trial results.""" + + def __init__( + self, + params: study.ParamResolver, + measurements: Dict[str, np.ndarray], + final_step_result: StepResultBase[TSimulatorState, TActOnArgs], + ) -> None: + """Initializes the `SimulationTrialResultBase` class. + + Args: + params: A ParamResolver of settings used for this result. + measurements: A dictionary from measurement gate key to measurement + results. Measurement results are a numpy ndarray of actual + boolean measurement results (ordered by the qubits acted on by + the measurement gate.) + final_step_result: The step result coming from the simulation, that + can be used to get the final simulator state. + """ + super().__init__(params, measurements, final_step_result=final_step_result) + self._final_step_result_typed = final_step_result + + def get_state_containing_qubit(self, qubit: 'cirq.Qid') -> TActOnArgs: + """Returns the independent state space containing the qubit. + + Args: + qubit: The qubit whose state space is required. + + Returns: + The state space containing the qubit.""" + return self._final_step_result_typed._sim_state[qubit] + + def _get_substates(self) -> Sequence[TActOnArgs]: + state = self._final_step_result_typed._sim_state + if isinstance(state, ActOnArgsContainer): + substates = dict() # type: Dict[TActOnArgs, int] + for q in state.qubits: + substates[self.get_state_containing_qubit(q)] = 0 + substates[state[None]] = 0 + return tuple(substates.keys()) + return [state.create_merged_state()] diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index d98e31d1cd0..6509b9a1caf 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -92,7 +92,7 @@ def _simulator_state(self) -> CountingActOnArgs: return self._merged_sim_state -class CountingTrialResult(cirq.SimulationTrialResult): +class CountingTrialResult(cirq.SimulationTrialResultBase[CountingActOnArgs, CountingActOnArgs]): pass diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 7a72f6e2299..7c4fa5edea6 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -252,7 +252,7 @@ class SparseSimulatorStep( def __init__( self, sim_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', - simulator: Simulator, + simulator: Simulator = None, dtype: 'DTypeLike' = np.complex64, ): """Results of a step of the simulator. @@ -329,4 +329,11 @@ def set_state_vector(self, state: 'cirq.STATE_VECTOR_LIKE'): corresponding to a computational basis state. If a numpy array this is the full state vector. """ - self._sim_state = self._simulator._create_act_on_args(state, self._qubits) + if self._simulator: + self._sim_state = self._simulator._create_act_on_args(state, self._qubits) + + def __repr__(self) -> str: + return ( + f'cirq.SparseSimulatorStep(sim_state={self._sim_state!r},' + f' dtype=np.{self._dtype.__name__})' + ) diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index 7414fdfbadc..9d2d1369a05 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -1330,3 +1330,57 @@ def test_noise_model(): result = simulator.run(circuit, repetitions=100) assert 20 <= sum(result.measurements['0'])[0] < 80 + + +def test_separated_states_str_does_not_merge(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0), + cirq.measure(q1), + cirq.H(q0), + cirq.global_phase_operation(0 + 1j), + ) + + result = cirq.Simulator().simulate(circuit) + assert ( + str(result) + == """measurements: 0=0 1=0 + +qubits: (cirq.LineQubit(0),) +output vector: 0.707|0⟩ + 0.707|1⟩ + +qubits: (cirq.LineQubit(1),) +output vector: |0⟩ + +phase: +output vector: 1j|⟩""" + ) + + +def test_separable_non_dirac_str(): + circuit = cirq.Circuit() + for i in range(4): + circuit.append(cirq.H(cirq.LineQubit(i))) + circuit.append(cirq.CX(cirq.LineQubit(0), cirq.LineQubit(i + 1))) + + result = cirq.Simulator().simulate(circuit) + assert '+0.j' in str(result) + + +def test_unseparated_states_str(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0), + cirq.measure(q1), + cirq.H(q0), + cirq.global_phase_operation(0 + 1j), + ) + + result = cirq.Simulator(split_untangled_states=False).simulate(circuit) + assert ( + str(result) + == """measurements: 0=0 1=0 + +qubits: (cirq.LineQubit(0), cirq.LineQubit(1)) +output vector: 0.707j|00⟩ + 0.707j|10⟩""" + ) diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index 5ea4649a421..d0fcb64c8f8 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -29,7 +29,8 @@ import numpy as np -from cirq import ops, study, value +from cirq import ops, study, value, qis +from cirq._compat import proper_repr from cirq.sim import simulator, state_vector, simulator_base from cirq.sim.act_on_state_vector_args import ActOnStateVectorArgs @@ -140,11 +141,16 @@ def __repr__(self) -> str: ) def _value_equality_values_(self) -> Any: - return (self.state_vector.tolist(), self.qubit_map) + return self.state_vector.tolist(), self.qubit_map @value.value_equality(unhashable=True) -class StateVectorTrialResult(state_vector.StateVectorMixin, simulator.SimulationTrialResult): +class StateVectorTrialResult( + state_vector.StateVectorMixin, + simulator_base.SimulationTrialResultBase[ + StateVectorSimulatorState, 'cirq.ActOnStateVectorArgs' + ], +): """A `SimulationTrialResult` that includes the `StateVectorMixin` methods. Attributes: @@ -201,16 +207,23 @@ def state_vector(self): def _value_equality_values_(self): measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())} - return (self.params, measurements, self._final_simulator_state) + return self.params, measurements, self._final_simulator_state def __str__(self) -> str: samples = super().__str__() - final = self.state_vector() - if len([1 for e in final if abs(e) > 0.001]) < 16: - state_vector = self.dirac_notation(3) - else: - state_vector = str(final) - return f'measurements: {samples}\noutput vector: {state_vector}' + ret = f'measurements: {samples}' + for substate in self._get_substates(): + final = substate.target_tensor + shape = final.shape + size = np.prod(shape, dtype=np.int64) + final = final.reshape(size) + if len([1 for e in final if abs(e) > 0.001]) < 16: + state_vector = qis.dirac_notation(final, 3) + else: + state_vector = str(final) + label = f'qubits: {substate.qubits}' if substate.qubits else 'phase:' + ret += f'\n\n{label}\noutput vector: {state_vector}' + return ret def _repr_pretty_(self, p: Any, cycle: bool): """iPython (Jupyter) pretty print.""" @@ -222,7 +235,7 @@ def _repr_pretty_(self, p: Any, cycle: bool): def __repr__(self) -> str: return ( - f'cirq.StateVectorTrialResult(params={self.params!r}, ' - f'measurements={self.measurements!r}, ' - f'final_simulator_state={self._final_simulator_state!r})' + 'cirq.StateVectorTrialResult(' + f'params={self.params!r}, measurements={proper_repr(self.measurements)}, ' + f'final_step_result={self._final_step_result!r})' ) diff --git a/cirq-core/cirq/sim/state_vector_simulator_test.py b/cirq-core/cirq/sim/state_vector_simulator_test.py index d533919c2fb..128d54ef7ca 100644 --- a/cirq-core/cirq/sim/state_vector_simulator_test.py +++ b/cirq-core/cirq/sim/state_vector_simulator_test.py @@ -21,24 +21,34 @@ def test_state_vector_trial_result_repr(): - final_step_result = mock.Mock(cirq.StateVectorStepResult) - final_step_result._qubit_mapping = {} - final_step_result._simulator_state.return_value = cirq.StateVectorSimulatorState( - qubit_map={cirq.NamedQubit('a'): 0}, state_vector=np.array([0, 1]) - ) + q0 = cirq.NamedQubit('a') + args = cirq.ActOnStateVectorArgs( + target_tensor=np.array([0, 1], dtype=np.int32), + available_buffer=np.array([0, 1], dtype=np.int32), + prng=np.random.RandomState(0), + log_of_measurement_results={}, + qubits=[q0], + ) + final_step_result = cirq.SparseSimulatorStep(args, cirq.Simulator()) trial_result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), - measurements={'m': np.array([[1]])}, + measurements={'m': np.array([[1]], dtype=np.int32)}, final_step_result=final_step_result, ) - assert repr(trial_result) == ( + expected_repr = ( "cirq.StateVectorTrialResult(" "params=cirq.ParamResolver({'s': 1}), " - "measurements={'m': array([[1]])}, " - "final_simulator_state=cirq.StateVectorSimulatorState(" - "state_vector=np.array([0, 1]), " - "qubit_map={cirq.NamedQubit('a'): 0}))" + "measurements={'m': np.array([[1]], dtype=np.int32)}, " + "final_step_result=cirq.SparseSimulatorStep(" + "sim_state=cirq.ActOnStateVectorArgs(" + "target_tensor=np.array([0, 1], dtype=np.int32), " + "available_buffer=np.array([0, 1], dtype=np.int32), " + "qubits=(cirq.NamedQubit('a'),), " + "log_of_measurement_results={}), " + "dtype=np.complex64))" ) + assert repr(trial_result) == expected_repr + assert eval(expected_repr) == trial_result def test_state_vector_simulator_state_repr(): @@ -162,25 +172,31 @@ def test_state_vector_trial_state_vector_is_copy(): def test_str_big(): qs = cirq.LineQubit.range(20) - final_step_result = mock.Mock(cirq.StateVectorStepResult) - final_step_result._qubit_mapping = {} - final_step_result._simulator_state.return_value = cirq.StateVectorSimulatorState( - np.array([1] * 2 ** 10), {q: q.x for q in qs} - ) + args = cirq.ActOnStateVectorArgs( + target_tensor=np.array([1] * 2 ** 10), + available_buffer=np.array([1] * 2 ** 10), + prng=np.random.RandomState(0), + log_of_measurement_results={}, + qubits=qs, + ) + final_step_result = cirq.SparseSimulatorStep(args, cirq.Simulator()) result = cirq.StateVectorTrialResult( cirq.ParamResolver(), {}, final_step_result, ) - assert str(result).startswith('measurements: (no measurements)\noutput vector: [1 1 1 ..') + assert 'output vector: [1 1 1 ..' in str(result) def test_pretty_print(): - final_step_result = mock.Mock(cirq.StateVectorStepResult) - final_step_result._qubit_mapping = {} - final_step_result._simulator_state.return_value = cirq.StateVectorSimulatorState( - np.array([1]), {} - ) + args = cirq.ActOnStateVectorArgs( + target_tensor=np.array([1]), + available_buffer=np.array([1]), + prng=np.random.RandomState(0), + log_of_measurement_results={}, + qubits=[], + ) + final_step_result = cirq.SparseSimulatorStep(args, cirq.Simulator()) result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_step_result) # Test Jupyter console output from @@ -193,7 +209,7 @@ def text(self, to_print): p = FakePrinter() result._repr_pretty_(p, False) - assert p.text_pretty == 'measurements: (no measurements)\noutput vector: |⟩' + assert p.text_pretty == 'measurements: (no measurements)\n\nphase:\noutput vector: |⟩' # Test cycle handling p = FakePrinter()