From f62f9e3ed36c6801e338c788c4afa03c1afdf4a7 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Wed, 13 Apr 2022 13:36:56 -0700 Subject: [PATCH] Make the quantum state generic (#5255) --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 3 +- cirq-core/cirq/sim/act_on_args.py | 8 +++-- .../cirq/sim/act_on_density_matrix_args.py | 3 +- .../cirq/sim/act_on_state_vector_args.py | 3 +- .../sim/clifford/act_on_stabilizer_args.py | 5 +-- cirq-core/cirq/sim/simulator_base_test.py | 32 ++++--------------- 6 files changed, 18 insertions(+), 36 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 547640e5097..56276170210 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -562,7 +562,7 @@ def sample( @value.value_equality -class MPSState(ActOnArgs): +class MPSState(ActOnArgs[_MPSHandler]): """A state of the MPS simulation.""" @deprecated_parameter( @@ -626,7 +626,6 @@ def __init__( ) else: super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - self._state: _MPSHandler = state def i_str(self, i: int) -> str: # Returns the index name for the i'th qid. diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 44ce3474a12..c442aa3ff2c 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -18,6 +18,7 @@ Any, cast, Dict, + Generic, Iterator, List, Mapping, @@ -36,12 +37,13 @@ from cirq.sim.operation_target import OperationTarget TSelf = TypeVar('TSelf', bound='ActOnArgs') +TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation') if TYPE_CHECKING: import cirq -class ActOnArgs(OperationTarget[TSelf], metaclass=abc.ABCMeta): +class ActOnArgs(OperationTarget, Generic[TState], metaclass=abc.ABCMeta): """State and context for an operation acting on a state tensor.""" @deprecated_parameter( @@ -63,7 +65,7 @@ def __init__( qubits: Optional[Sequence['cirq.Qid']] = None, log_of_measurement_results: Optional[Dict[str, List[int]]] = None, classical_data: Optional['cirq.ClassicalDataStore'] = None, - state: Optional['cirq.QuantumStateRepresentation'] = None, + state: Optional[TState] = None, ): """Inits ActOnArgs. @@ -91,7 +93,7 @@ def __init__( for k, v in (log_of_measurement_results or {}).items() } ) - self._state = state + self._state = cast(TState, state) if state is None: _warn_or_error('This function will require a valid `state` input in cirq v0.16.') diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 93d6c13c4a7..c596cbe5631 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -236,7 +236,7 @@ def can_represent_mixed_states(self) -> bool: return True -class ActOnDensityMatrixArgs(ActOnArgs): +class ActOnDensityMatrixArgs(ActOnArgs[_BufferedDensityMatrix]): """State and context for an operation acting on a density matrix. To act on this object, directly edit the `target_tensor` property, which is @@ -286,7 +286,6 @@ def __init__( buffer=available_buffer, ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - self._state: _BufferedDensityMatrix = state def _act_on_fallback_( self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 1aba792a292..93c258de2ba 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -308,7 +308,7 @@ def supports_factor(self) -> bool: return True -class ActOnStateVectorArgs(ActOnArgs): +class ActOnStateVectorArgs(ActOnArgs[_BufferedStateVector]): """State and context for an operation acting on a state vector. There are two common ways to act on this object: @@ -357,7 +357,6 @@ def __init__( buffer=available_buffer, ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - self._state: _BufferedStateVector = state @_compat.deprecated( deadline='v0.16', fix='None, this function was unintentionally made public.' diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py index 9a8425aeb51..d95b072eac1 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py @@ -32,7 +32,9 @@ TStabilizerState = TypeVar('TStabilizerState', bound='cirq.StabilizerState') -class ActOnStabilizerArgs(ActOnArgs, Generic[TStabilizerState], metaclass=abc.ABCMeta): +class ActOnStabilizerArgs( + ActOnArgs[TStabilizerState], Generic[TStabilizerState], metaclass=abc.ABCMeta +): """Abstract wrapper around a stabilizer state for the act_on protocol.""" @deprecated_parameter( @@ -81,7 +83,6 @@ def __init__( ) else: super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - self._state: TStabilizerState = state @property def state(self) -> TStabilizerState: diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 7d42328ac65..5773d7deb62 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -56,21 +56,13 @@ def copy(self, deep_copy_buffers: bool = True) -> 'CountingState': ) -class CountingActOnArgs(cirq.ActOnArgs): +class CountingActOnArgs(cirq.ActOnArgs[CountingState]): def __init__(self, state, qubits, classical_data): state_obj = CountingState(state) - super().__init__( - state=state_obj, - qubits=qubits, - classical_data=classical_data, - ) - self._state: CountingState = state_obj + super().__init__(state=state_obj, qubits=qubits, classical_data=classical_data) def _act_on_fallback_( - self, - action: Any, - qubits: Sequence['cirq.Qid'], - allow_decompose: bool = True, + self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True ) -> bool: self._state.gate_count += 1 return True @@ -120,10 +112,7 @@ class CountingSimulator( ] ): def __init__(self, noise=None, split_untangled_states=False): - super().__init__( - noise=noise, - split_untangled_states=split_untangled_states, - ) + super().__init__(noise=noise, split_untangled_states=split_untangled_states) def _create_partial_act_on_args( self, @@ -142,18 +131,14 @@ def _create_simulator_trial_result( return CountingTrialResult(params, measurements, final_step_result=final_step_result) def _create_step_result( - self, - sim_state: cirq.OperationTarget[CountingActOnArgs], + self, sim_state: cirq.OperationTarget[CountingActOnArgs] ) -> CountingStepResult: return CountingStepResult(sim_state) class SplittableCountingSimulator(CountingSimulator): def __init__(self, noise=None, split_untangled_states=True): - super().__init__( - noise=noise, - split_untangled_states=split_untangled_states, - ) + super().__init__(noise=noise, split_untangled_states=split_untangled_states) def _create_partial_act_on_args( self, @@ -390,10 +375,7 @@ def _has_unitary_(self): return self.has_unitary simulator = CountingSimulator() - params = [ - cirq.ParamResolver({'a': 0}), - cirq.ParamResolver({'a': 1}), - ] + params = [cirq.ParamResolver({'a': 0}), cirq.ParamResolver({'a': 1})] op1 = TestOp(has_unitary=True) op2 = TestOp(has_unitary=True)