From 32a21a38e0965c21199df3b60acde17cfba63b16 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Mon, 28 Mar 2022 08:07:25 -0700 Subject: [PATCH] Base class for quantum states (#5065) Creates a base class for all the quantum state classes created in #4979, and uses the inheritance to push the implementation of `ActOnArgs.kron`, `factor`, etc into the base class. Closes #4827 Resolves https://github.com/quantumlib/Cirq/pull/3841#discussion_r580705505 that's been bugging me for a year. --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/contrib/quimb/mps_simulator.py | 40 +++----- cirq-core/cirq/qis/__init__.py | 2 +- cirq-core/cirq/qis/clifford_tableau.py | 93 ++++++++++++++++++- cirq-core/cirq/sim/act_on_args.py | 49 +++++++--- .../cirq/sim/act_on_args_container_test.py | 15 --- .../cirq/sim/act_on_density_matrix_args.py | 69 ++++---------- .../cirq/sim/act_on_state_vector_args.py | 62 +++---------- .../clifford/act_on_clifford_tableau_args.py | 16 ---- .../sim/clifford/act_on_stabilizer_args.py | 3 +- .../act_on_stabilizer_ch_form_args.py | 34 +------ .../sim/clifford/stabilizer_state_ch_form.py | 9 +- cirq-core/cirq/sim/simulator_base_test.py | 86 ++++++++++------- 13 files changed, 235 insertions(+), 244 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 8a8d695c887..92997de79e5 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -433,6 +433,7 @@ operation_to_superoperator, QUANTUM_STATE_LIKE, QuantumState, + QuantumStateRepresentation, quantum_state, STATE_VECTOR_LIKE, StabilizerState, diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index cfa48ab651a..890ad503d81 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -24,7 +24,7 @@ import numpy as np import quimb.tensor as qtn -from cirq import devices, protocols, value +from cirq import devices, protocols, qis, value from cirq._compat import deprecated from cirq.sim import simulator_base from cirq.sim.act_on_args import ActOnArgs @@ -220,7 +220,7 @@ def _simulator_state(self): @value.value_equality -class _MPSHandler: +class _MPSHandler(qis.QuantumStateRepresentation): """Quantum state of the MPS simulation.""" def __init__( @@ -604,21 +604,24 @@ def __init__( Raises: ValueError: If the grouping does not cover the qubits. """ + qubit_map = {q: i for i, q in enumerate(qubits)} + final_grouping = qubit_map if grouping is None else grouping + if final_grouping.keys() != qubit_map.keys(): + raise ValueError('Grouping must cover exactly the qubits.') + state = _MPSHandler.create( + initial_state=initial_state, + qid_shape=tuple(q.dimension for q in qubits), + simulation_options=simulation_options, + grouping={qubit_map[k]: v for k, v in final_grouping.items()}, + ) super().__init__( + state=state, prng=prng, qubits=qubits, log_of_measurement_results=log_of_measurement_results, classical_data=classical_data, ) - final_grouping = self.qubit_map if grouping is None else grouping - if final_grouping.keys() != self.qubit_map.keys(): - raise ValueError('Grouping must cover exactly the qubits.') - self._state = _MPSHandler.create( - initial_state=initial_state, - qid_shape=tuple(q.dimension for q in qubits), - simulation_options=simulation_options, - grouping={self.qubit_map[k]: v for k, v in final_grouping.items()}, - ) + self._state: _MPSHandler = state def i_str(self, i: int) -> str: # Returns the index name for the i'th qid. @@ -636,9 +639,6 @@ def __str__(self) -> str: def _value_equality_values_(self) -> Any: return self.qubits, self._state - def _on_copy(self, target: 'MPSState', deep_copy_buffers: bool = True): - target._state = self._state.copy(deep_copy_buffers) - def state_vector(self) -> np.ndarray: """Returns the full state vector. @@ -709,15 +709,3 @@ def perform_measurement( tolerance specified in simulation options. """ return self._state._measure(self.get_axes(qubits), prng, collapse_state_vector) - - def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: - """Measures the axes specified by the simulator.""" - return self._state.measure(self.get_axes(qubits), self.prng) - - def sample( - self, - qubits: Sequence['cirq.Qid'], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - return self._state.sample(self.get_axes(qubits), repetitions, seed) diff --git a/cirq-core/cirq/qis/__init__.py b/cirq-core/cirq/qis/__init__.py index 9134f0b8c33..3faccde034e 100644 --- a/cirq-core/cirq/qis/__init__.py +++ b/cirq-core/cirq/qis/__init__.py @@ -25,7 +25,7 @@ superoperator_to_kraus, ) -from cirq.qis.clifford_tableau import CliffordTableau, StabilizerState +from cirq.qis.clifford_tableau import CliffordTableau, QuantumStateRepresentation, StabilizerState from cirq.qis.measures import ( entanglement_fidelity, diff --git a/cirq-core/cirq/qis/clifford_tableau.py b/cirq-core/cirq/qis/clifford_tableau.py index 017082961f0..60db36bb53c 100644 --- a/cirq-core/cirq/qis/clifford_tableau.py +++ b/cirq-core/cirq/qis/clifford_tableau.py @@ -13,17 +13,97 @@ # limitations under the License. import abc -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, Dict, List, Sequence, Tuple, TYPE_CHECKING, TypeVar import numpy as np -from cirq import protocols +from cirq import protocols, value from cirq.value import big_endian_int_to_digits, linear_dict if TYPE_CHECKING: import cirq +TSelf = TypeVar('TSelf', bound='QuantumStateRepresentation') -class StabilizerState(metaclass=abc.ABCMeta): + +class QuantumStateRepresentation(metaclass=abc.ABCMeta): + @abc.abstractmethod + def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: + """Creates a copy of the object. + Args: + deep_copy_buffers: If True, buffers will also be deep-copied. + Otherwise the copy will share a reference to the original object's + buffers. + Returns: + A copied instance. + """ + + @abc.abstractmethod + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + """Measures the state. + + Args: + axes: The axes to measure. + seed: The random number seed to use. + Returns: + The measurements in order. + """ + + def sample( + self, + axes: Sequence[int], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + """Samples the state. Subclasses can override with more performant method. + + Args: + axes: The axes to sample. + repetitions: The number of samples to make. + seed: The random number seed to use. + Returns: + The samples in order. + """ + prng = value.parse_random_state(seed) + measurements = [] + for _ in range(repetitions): + state = self.copy() + measurements.append(state.measure(axes, prng)) + return np.array(measurements, dtype=bool) + + def kron(self: TSelf, other: TSelf) -> TSelf: + """Joins two state spaces together.""" + raise NotImplementedError() + + def factor( + self: TSelf, axes: Sequence[int], *, validate=True, atol=1e-07 + ) -> Tuple[TSelf, TSelf]: + """Splits two state spaces after a measurement or reset.""" + raise NotImplementedError() + + def reindex(self: TSelf, axes: Sequence[int]) -> TSelf: + """Physically reindexes the state by the new basis. + Args: + axes: The desired axis order. + Returns: + The state with qubit order transposed and underlying representation + updated. + """ + raise NotImplementedError() + + @property + def supports_factor(self) -> bool: + """Subclasses that allow factorization should override this.""" + return False + + @property + def can_represent_mixed_states(self) -> bool: + """Subclasses that can represent mixed states should override this.""" + return False + + +class StabilizerState(QuantumStateRepresentation, metaclass=abc.ABCMeta): """Interface for quantum stabilizer state representations. This interface is used for CliffordTableau and StabilizerChForm quantum @@ -222,7 +302,7 @@ def __eq__(self, other): def __copy__(self) -> 'CliffordTableau': return self.copy() - def copy(self) -> 'CliffordTableau': + def copy(self, deep_copy_buffers: bool = True) -> 'CliffordTableau': state = CliffordTableau(self.n) state.rs = self.rs.copy() state.xs = self.xs.copy() @@ -578,3 +658,8 @@ def apply_cx( def apply_global_phase(self, coefficient: linear_dict.Scalar): pass + + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + return [self._measure(axis, seed) for axis in axes] diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 6a1e0bbbb1b..e305a9741fd 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Objects and methods for acting efficiently on a state tensor.""" -import abc import copy import inspect +import warnings from typing import ( Any, cast, @@ -28,7 +28,6 @@ TYPE_CHECKING, Tuple, ) -import warnings import numpy as np @@ -59,6 +58,7 @@ def __init__( log_of_measurement_results: Optional[Dict[str, List[int]]] = None, ignore_measurement_results: bool = False, classical_data: Optional['cirq.ClassicalDataStore'] = None, + state: Optional['cirq.QuantumStateRepresentation'] = None, ): """Inits ActOnArgs. @@ -76,6 +76,7 @@ def __init__( simulators that can represent mixed states. classical_data: The shared classical data container for this simulation. + state: The underlying quantum state of the simulation. """ if prng is None: prng = cast(np.random.RandomState, np.random) @@ -90,6 +91,7 @@ def __init__( } ) self._ignore_measurement_results = ignore_measurement_results + self._state = state @property def prng(self) -> np.random.RandomState: @@ -148,10 +150,21 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[ def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: return [self.qubit_map[q] for q in qubits] - @abc.abstractmethod def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: - """Child classes that perform measurements should implement this with - the implementation.""" + """Delegates the call to measure the density matrix.""" + if self._state is not None: + return self._state.measure(self.get_axes(qubits), self.prng) + raise NotImplementedError() + + def sample( + self, + qubits: Sequence['cirq.Qid'], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + if self._state is not None: + return self._state.sample(self.get_axes(qubits), repetitions, seed) + raise NotImplementedError() def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: """Creates a copy of the object. @@ -165,6 +178,10 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: A copied instance. """ args = copy.copy(self) + args._classical_data = self._classical_data.copy() + if self._state is not None: + args._state = self._state.copy(deep_copy_buffers=deep_copy_buffers) + return args if 'deep_copy_buffers' in inspect.signature(self._on_copy).parameters: self._on_copy(args, deep_copy_buffers) else: @@ -176,7 +193,6 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: DeprecationWarning, ) self._on_copy(args) - args._classical_data = self._classical_data.copy() return args def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True): @@ -190,7 +206,10 @@ def create_merged_state(self: TSelf) -> TSelf: def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf: """Joins two state spaces together.""" args = self if inplace else copy.copy(self) - self._on_kronecker_product(other, args) + if self._state is not None and other._state is not None: + args._state = self._state.kron(other._state) + else: + self._on_kronecker_product(other, args) args._set_qubits(self.qubits + other.qubits) return args @@ -225,7 +244,12 @@ def factor( """Splits two state spaces after a measurement or reset.""" extracted = copy.copy(self) remainder = self if inplace else copy.copy(self) - self._on_factor(qubits, extracted, remainder, validate, atol) + if self._state is not None: + e, r = self._state.factor(self.get_axes(qubits), validate=validate, atol=atol) + extracted._state = e + remainder._state = r + else: + self._on_factor(qubits, extracted, remainder, validate, atol) extracted._set_qubits(qubits) remainder._set_qubits([q for q in self.qubits if q not in qubits]) return extracted, remainder @@ -233,7 +257,7 @@ def factor( @property def allows_factoring(self): """Subclasses that allow factorization should override this.""" - return False + return self._state.supports_factor if self._state is not None else False def _on_factor( self: TSelf, @@ -265,7 +289,10 @@ def transpose_to_qubit_order( if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits): raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}') args = self if inplace else copy.copy(self) - self._on_transpose_to_qubit_order(qubits, args) + if self._state is not None: + args._state = self._state.reindex(self.get_axes(qubits)) + else: + self._on_transpose_to_qubit_order(qubits, args) args._set_qubits(qubits) return args @@ -356,7 +383,7 @@ def __iter__(self) -> Iterator[Optional['cirq.Qid']]: @property def can_represent_mixed_states(self) -> bool: - return False + return self._state.can_represent_mixed_states if self._state is not None else False def strat_act_on_from_apply_decompose( diff --git a/cirq-core/cirq/sim/act_on_args_container_test.py b/cirq-core/cirq/sim/act_on_args_container_test.py index 8afba6fc738..48ae94850d0 100644 --- a/cirq-core/cirq/sim/act_on_args_container_test.py +++ b/cirq-core/cirq/sim/act_on_args_container_test.py @@ -41,25 +41,10 @@ def _act_on_fallback_( ) -> bool: return True - def _on_copy(self, args): - pass - - def _on_kronecker_product(self, other, target): - pass - - def _on_transpose_to_qubit_order(self, qubits, target): - pass - - def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07): - pass - @property def allows_factoring(self): return True - def sample(self, qubits, repetitions=1, seed=None): - pass - q0, q1, q2 = qs3 = cirq.LineQubit.range(3) qs2 = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 015e373f8c5..51ba9f9069f 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -27,7 +27,7 @@ from numpy.typing import DTypeLike -class _BufferedDensityMatrix: +class _BufferedDensityMatrix(qis.QuantumStateRepresentation): """Contains the density matrix and buffers for efficient state evolution.""" def __init__(self, density_matrix: np.ndarray, buffer: Optional[List[np.ndarray]] = None): @@ -223,6 +223,14 @@ def sample( seed=seed, ) + @property + def supports_factor(self) -> bool: + return True + + @property + def can_represent_mixed_states(self) -> bool: + return True + class ActOnDensityMatrixArgs(ActOnArgs): """State and context for an operation acting on a density matrix. @@ -296,8 +304,15 @@ def __init__( ValueError: The dimension of `target_tensor` is not divisible by 2 and `qid_shape` is not provided. """ + state = _BufferedDensityMatrix.create( + initial_state=target_tensor if target_tensor is not None else initial_state, + qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, + dtype=dtype, + buffer=available_buffer, + ) if ignore_measurement_results: super().__init__( + state=state, prng=prng, qubits=qubits, log_of_measurement_results=log_of_measurement_results, @@ -306,17 +321,13 @@ def __init__( ) else: super().__init__( + state=state, prng=prng, qubits=qubits, log_of_measurement_results=log_of_measurement_results, classical_data=classical_data, ) - self._state = _BufferedDensityMatrix.create( - initial_state=target_tensor if target_tensor is not None else initial_state, - qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, - dtype=dtype, - buffer=available_buffer, - ) + self._state: _BufferedDensityMatrix = state def _act_on_fallback_( self, @@ -344,50 +355,6 @@ def _act_on_fallback_( "SupportsMixture or SupportsKraus or is a measurement: {!r}".format(action) ) - def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: - """Delegates the call to measure the density matrix.""" - return self._state.measure(self.get_axes(qubits), self.prng) - - def _on_copy(self, target: 'cirq.ActOnDensityMatrixArgs', deep_copy_buffers: bool = True): - target._state = self._state.copy(deep_copy_buffers) - - def _on_kronecker_product( - self, other: 'cirq.ActOnDensityMatrixArgs', target: 'cirq.ActOnDensityMatrixArgs' - ): - target._state = self._state.kron(other._state) - - def _on_factor( - self, - qubits: Sequence['cirq.Qid'], - extracted: 'cirq.ActOnDensityMatrixArgs', - remainder: 'cirq.ActOnDensityMatrixArgs', - validate=True, - atol=1e-07, - ): - axes = self.get_axes(qubits) - extracted._state, remainder._state = self._state.factor(axes, validate=validate, atol=atol) - - @property - def allows_factoring(self): - return True - - def _on_transpose_to_qubit_order( - self, qubits: Sequence['cirq.Qid'], target: 'cirq.ActOnDensityMatrixArgs' - ): - target._state = self._state.reindex(self.get_axes(qubits)) - - def sample( - self, - qubits: Sequence['cirq.Qid'], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - return self._state.sample(self.get_axes(qubits), repetitions, seed) - - @property - def can_represent_mixed_states(self) -> bool: - return True - def __repr__(self) -> str: return ( 'cirq.ActOnDensityMatrixArgs(' diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 7d9aeeb14c4..e29ea8d5503 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -27,7 +27,7 @@ from numpy.typing import DTypeLike -class _BufferedStateVector: +class _BufferedStateVector(qis.QuantumStateRepresentation): """Contains the state vector and buffer for efficient state evolution.""" def __init__(self, state_vector: np.ndarray, buffer: Optional[np.ndarray] = None): @@ -321,6 +321,10 @@ def _swap_target_tensor_for(self, new_target_tensor: np.ndarray): self._buffer = self._state_vector self._state_vector = new_target_tensor + @property + def supports_factor(self) -> bool: + return True + class ActOnStateVectorArgs(ActOnArgs): """State and context for an operation acting on a state vector. @@ -382,18 +386,20 @@ def __init__( classical_data: The shared classical data container for this simulation. """ + state = _BufferedStateVector.create( + initial_state=target_tensor if target_tensor is not None else initial_state, + qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, + dtype=dtype, + buffer=available_buffer, + ) super().__init__( + state=state, prng=prng, qubits=qubits, log_of_measurement_results=log_of_measurement_results, classical_data=classical_data, ) - self._state = _BufferedStateVector.create( - initial_state=target_tensor if target_tensor is not None else initial_state, - qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, - dtype=dtype, - buffer=available_buffer, - ) + self._state: _BufferedStateVector = state @_compat.deprecated( deadline='v0.16', @@ -480,7 +486,7 @@ def _act_on_fallback_( _strat_act_on_state_vector_from_channel, ] if allow_decompose: - strats.append(strat_act_on_from_apply_decompose) + strats.append(strat_act_on_from_apply_decompose) # type: ignore # Try each strategy, stopping if one works. for strat in strats: @@ -496,46 +502,6 @@ def _act_on_fallback_( "SupportsMixture or is a measurement: {!r}".format(action) ) - def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: - """Delegates the call to measure the state vector.""" - return self._state.measure(self.get_axes(qubits), self.prng) - - def _on_copy(self, target: 'cirq.ActOnStateVectorArgs', deep_copy_buffers: bool = True): - target._state = self._state.copy(deep_copy_buffers) - - def _on_kronecker_product( - self, other: 'cirq.ActOnStateVectorArgs', target: 'cirq.ActOnStateVectorArgs' - ): - target._state = self._state.kron(other._state) - - def _on_factor( - self, - qubits: Sequence['cirq.Qid'], - extracted: 'cirq.ActOnStateVectorArgs', - remainder: 'cirq.ActOnStateVectorArgs', - validate=True, - atol=1e-07, - ): - axes = self.get_axes(qubits) - extracted._state, remainder._state = self._state.factor(axes, validate=validate, atol=atol) - - @property - def allows_factoring(self): - return True - - def _on_transpose_to_qubit_order( - self, qubits: Sequence['cirq.Qid'], target: 'cirq.ActOnStateVectorArgs' - ): - target._state = self._state.reindex(self.get_axes(qubits)) - - def sample( - self, - qubits: Sequence['cirq.Qid'], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - return self._state.sample(self.get_axes(qubits), repetitions, seed) - def __repr__(self) -> str: return ( 'cirq.ActOnStateVectorArgs(' diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 0f76546a20f..4c20479fea4 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -69,19 +69,3 @@ def __init__( @property def tableau(self) -> 'cirq.CliffordTableau': return self.state - - def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: - """Returns the measurement from the tableau.""" - return [self.state._measure(self.qubit_map[q], self.prng) for q in qubits] - - def _on_copy(self, target: 'ActOnCliffordTableauArgs', deep_copy_buffers: bool = True): - target._state = self.state.copy() - - def sample( - self, - qubits: Sequence['cirq.Qid'], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - # Unnecessary for now but can be added later if there is a use case. - raise NotImplementedError() diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py index 5cc45ef6297..0f01f19f55b 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py @@ -58,12 +58,13 @@ def __init__( simulation. """ super().__init__( + state=state, prng=prng, qubits=qubits, log_of_measurement_results=log_of_measurement_results, classical_data=classical_data, ) - self._state = state + self._state: TStabilizerState = state @property def state(self) -> TStabilizerState: diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index d7a67d8e198..0cbb88c7ff3 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -16,7 +16,7 @@ import numpy as np -from cirq import _compat, value +from cirq import _compat from cirq.sim.clifford import stabilizer_state_ch_form from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs @@ -88,35 +88,3 @@ def __init__( log_of_measurement_results=log_of_measurement_results, classical_data=classical_data, ) - - def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: - """Returns the measurement from the stabilizer state form.""" - return [self.state._measure(self.qubit_map[q], self.prng) for q in qubits] - - def _on_copy(self, target: 'ActOnStabilizerCHFormArgs', deep_copy_buffers: bool = True): - target._state = self.state.copy() - - def _on_kronecker_product( - self, other: 'cirq.ActOnStabilizerCHFormArgs', target: 'cirq.ActOnStabilizerCHFormArgs' - ): - target._state = self.state.kron(other.state) - - def _on_transpose_to_qubit_order( - self, qubits: Sequence['cirq.Qid'], target: 'cirq.ActOnStabilizerCHFormArgs' - ): - axes = [self.qubit_map[q] for q in qubits] - target._state = self.state.reindex(axes) - - def sample( - self, - qubits: Sequence['cirq.Qid'], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - prng = value.parse_random_state(seed) - axes = self.get_axes(qubits) - measurements = [] - for _ in range(repetitions): - state = self.state.copy() - measurements.append([state._measure(i, prng) for i in axes]) - return np.array(measurements, dtype=bool) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py index 8a53ef5b3e2..ad6c3c63945 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Sequence, Union +from typing import Any, Dict, List, Sequence, Union import numpy as np import cirq @@ -80,7 +80,7 @@ def _from_json_dict_(cls, n, G, F, M, gamma, v, s, omega, **kwargs): def _value_equality_values_(self) -> Any: return (self.n, self.G, self.F, self.M, self.gamma, self.v, self.s, self.omega) - def copy(self) -> 'cirq.StabilizerStateChForm': + def copy(self, deep_copy_buffers: bool = True) -> 'cirq.StabilizerStateChForm': copy = StabilizerStateChForm(self.n) copy.G = self.G.copy() @@ -385,6 +385,11 @@ def apply_cx( def apply_global_phase(self, coefficient: value.Scalar): self.omega *= coefficient + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + return [self._measure(axis, seed) for axis in axes] + def _phase(exponent, global_shift): return np.exp(1j * np.pi * global_shift * exponent) diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index a99527f3722..add1db86fd5 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Dict, List, Sequence, Tuple, Union import numpy as np import pytest @@ -21,30 +21,50 @@ import cirq -class CountingActOnArgs(cirq.ActOnArgs): - gate_count = 0 - measurement_count = 0 - - def __init__(self, state, qubits, classical_data): - super().__init__( - qubits=qubits, - classical_data=classical_data, - ) +class CountingState(cirq.qis.QuantumStateRepresentation): + def __init__(self, state, gate_count=0, measurement_count=0): self.state = state + self.gate_count = gate_count + self.measurement_count = measurement_count - def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: self.measurement_count += 1 return [self.gate_count] - def copy(self, deep_copy_buffers: bool = True) -> 'CountingActOnArgs': - args = CountingActOnArgs( - qubits=self.qubits, - classical_data=self.classical_data.copy(), - state=self.state, + def kron(self: 'CountingState', other: 'CountingState') -> 'CountingState': + return CountingState( + self.state, + self.gate_count + other.gate_count, + self.measurement_count + other.measurement_count, + ) + + def factor( + self: 'CountingState', axes: Sequence[int], *, validate=True, atol=1e-07 + ) -> Tuple['CountingState', 'CountingState']: + return CountingState(self.state, self.gate_count, self.measurement_count), CountingState( + self.state + ) + + def reindex(self: 'CountingState', axes: Sequence[int]) -> 'CountingState': + return self.copy() + + def copy(self, deep_copy_buffers: bool = True) -> 'CountingState': + return CountingState( + state=self.state, gate_count=self.gate_count, measurement_count=self.measurement_count + ) + + +class CountingActOnArgs(cirq.ActOnArgs): + def __init__(self, state, qubits, classical_data): + state_obj = CountingState(state) + super().__init__( + state=state_obj, + qubits=qubits, + classical_data=classical_data, ) - args.gate_count = self.gate_count - args.measurement_count = self.measurement_count - return args + self._state: CountingState = state_obj def _act_on_fallback_( self, @@ -52,33 +72,27 @@ def _act_on_fallback_( qubits: Sequence['cirq.Qid'], allow_decompose: bool = True, ) -> bool: - self.gate_count += 1 + self._state.gate_count += 1 return True - def sample(self, qubits, repetitions=1, seed=None): - pass + @property + def state(self): + return self._state.state + @property + def gate_count(self): + return self._state.gate_count -class SplittableCountingActOnArgs(CountingActOnArgs): - def _on_kronecker_product( - self, other: 'SplittableCountingActOnArgs', target: 'SplittableCountingActOnArgs' - ): - target.gate_count = self.gate_count + other.gate_count - target.measurement_count = self.measurement_count + other.measurement_count + @property + def measurement_count(self): + return self._state.measurement_count - def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07): - remainder.gate_count = 0 - remainder.measurement_count = 0 +class SplittableCountingActOnArgs(CountingActOnArgs): @property def allows_factoring(self): return True - def _on_transpose_to_qubit_order( - self, qubits: Sequence['cirq.Qid'], target: 'SplittableCountingActOnArgs' - ): - pass - class CountingStepResult(cirq.StepResultBase[CountingActOnArgs, CountingActOnArgs]): def sample(