diff --git a/cirq/contrib/quimb/mps_simulator.py b/cirq/contrib/quimb/mps_simulator.py index 0e433fcc5cf..b75c82a9499 100644 --- a/cirq/contrib/quimb/mps_simulator.py +++ b/cirq/contrib/quimb/mps_simulator.py @@ -29,7 +29,12 @@ from cirq.sim import simulator -class MPSSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState): +class MPSSimulator( + simulator.SimulatesSamples, + simulator.SimulatesIntermediateStateBase[ + 'MPSSimulatorStepResult', 'MPSTrialResult', 'MPSState' + ], +): """An efficient simulator for MPS circuits.""" def __init__( @@ -239,7 +244,7 @@ def __str__(self) -> str: return f'measurements: {samples}\noutput state: {final}' -class MPSSimulatorStepResult(simulator.StepResult): +class MPSSimulatorStepResult(simulator.StepResult['MPSState']): """A `StepResult` that can perform measurements.""" def __init__(self, state, measurements): diff --git a/cirq/experiments/xeb_simulation.py b/cirq/experiments/xeb_simulation.py index 37a347a84ca..cc4a89ed5c7 100644 --- a/cirq/experiments/xeb_simulation.py +++ b/cirq/experiments/xeb_simulation.py @@ -120,7 +120,9 @@ def simulate_2q_xeb_circuits( # Need an actual object; not np.random or else multiprocessing will # fail to pickle the closure object: # https://github.com/quantumlib/Cirq/issues/3717 - simulator = sim.Simulator(seed=np.random.RandomState()) + simulator = cast( + 'cirq.SimulatesIntermediateState', sim.Simulator(seed=np.random.RandomState()) + ) _simulate_2q_xeb_circuit = _Simulate_2q_XEB_Circuit(simulator=simulator) tasks = tuple( diff --git a/cirq/google/calibration/engine_simulator.py b/cirq/google/calibration/engine_simulator.py index 4e812501438..b94b5c63326 100644 --- a/cirq/google/calibration/engine_simulator.py +++ b/cirq/google/calibration/engine_simulator.py @@ -31,7 +31,7 @@ SimulatesSamples, SimulatesIntermediateStateVector, SparseSimulatorStep, - StepResult, + StateVectorStepResult, ) from cirq.study import ParamResolver from cirq.value import RANDOM_STATE_OR_SEED_LIKE, parse_random_state @@ -402,7 +402,7 @@ def _base_iterator( circuit: Circuit, qubit_order: QubitOrderOrList, initial_state: Any, - ) -> Iterator[StepResult]: + ) -> Iterator[StateVectorStepResult]: converted = _convert_to_circuit_with_drift(self, circuit) return self._simulator._base_iterator(converted, qubit_order, initial_state) diff --git a/cirq/sim/clifford/clifford_simulator.py b/cirq/sim/clifford/clifford_simulator.py index a407f5bae19..a08759f9afc 100644 --- a/cirq/sim/clifford/clifford_simulator.py +++ b/cirq/sim/clifford/clifford_simulator.py @@ -42,7 +42,12 @@ from cirq.sim.simulator import check_all_resolved -class CliffordSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState): +class CliffordSimulator( + simulator.SimulatesSamples, + simulator.SimulatesIntermediateStateBase[ + 'CliffordSimulatorStepResult', 'CliffordTrialResult', 'CliffordState' + ], +): """An efficient simulator for Clifford circuits.""" def __init__(self, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None): @@ -168,7 +173,7 @@ def __str__(self) -> str: return f'measurements: {samples}\noutput state: {final}' -class CliffordSimulatorStepResult(simulator.StepResult): +class CliffordSimulatorStepResult(simulator.StepResult['CliffordState']): """A `StepResult` that includes `StateVectorMixin` methods.""" def __init__(self, state, measurements): diff --git a/cirq/sim/density_matrix_simulator.py b/cirq/sim/density_matrix_simulator.py index fc9bd568bc1..d378bd0d7c4 100644 --- a/cirq/sim/density_matrix_simulator.py +++ b/cirq/sim/density_matrix_simulator.py @@ -35,7 +35,12 @@ def __init__(self, num_qubits: int, tensor: np.ndarray): self.buffers = [np.empty_like(tensor) for _ in range(3)] -class DensityMatrixSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState): +class DensityMatrixSimulator( + simulator.SimulatesSamples, + simulator.SimulatesIntermediateStateBase[ + 'DensityMatrixStepResult', 'DensityMatrixTrialResult', 'DensityMatrixSimulatorState' + ], +): """A simulator for density matrices and noisy quantum circuits. This simulator can be applied on circuits that are made up of operations @@ -320,7 +325,7 @@ def _create_simulator_trial_result( ) -class DensityMatrixStepResult(simulator.StepResult): +class DensityMatrixStepResult(simulator.StepResult['DensityMatrixSimulatorState']): """A single step in the simulation of the DensityMatrixSimulator. Attributes: diff --git a/cirq/sim/mux.py b/cirq/sim/mux.py index ad91f7429c6..b764acf3bf5 100644 --- a/cirq/sim/mux.py +++ b/cirq/sim/mux.py @@ -273,16 +273,16 @@ def final_density_matrix( if can_do_unitary_simulation: # pure case: use SparseSimulator - result = sparse_simulator.Simulator(dtype=dtype, seed=seed).simulate( + sparse_result = sparse_simulator.Simulator(dtype=dtype, seed=seed).simulate( program=circuit_like, initial_state=initial_state, qubit_order=qubit_order, param_resolver=param_resolver, ) - return cast(state_vector_simulator.StateVectorTrialResult, result).density_matrix_of() + return sparse_result.density_matrix_of() else: # noisy case: use DensityMatrixSimulator with dephasing - result = density_matrix_simulator.DensityMatrixSimulator( + density_result = density_matrix_simulator.DensityMatrixSimulator( dtype=dtype, noise=noise, seed=seed, @@ -293,4 +293,4 @@ def final_density_matrix( qubit_order=qubit_order, param_resolver=param_resolver, ) - return cast(density_matrix_simulator.DensityMatrixTrialResult, result).final_density_matrix + return density_result.final_density_matrix diff --git a/cirq/sim/simulator.py b/cirq/sim/simulator.py index e7245adf111..ebf51e528f0 100644 --- a/cirq/sim/simulator.py +++ b/cirq/sim/simulator.py @@ -39,6 +39,8 @@ TYPE_CHECKING, Set, cast, + TypeVar, + Generic, ) import abc @@ -53,6 +55,11 @@ import cirq +TStepResult = TypeVar('TStepResult', bound='StepResult') +TSimulationTrialResult = TypeVar('TSimulationTrialResult', bound='SimulationTrialResult') +TSimulatorState = TypeVar('TSimulatorState') + + class SimulatesSamples(work.Sampler, metaclass=abc.ABCMeta): """Simulator that mimics running on quantum hardware. @@ -288,7 +295,7 @@ def simulate_expectation_values_sweep( """ -class SimulatesFinalState(metaclass=abc.ABCMeta): +class SimulatesFinalState(Generic[TSimulationTrialResult], metaclass=abc.ABCMeta): """Simulator that allows access to the simulator's final state. Implementors of this interface should implement the simulate_sweep @@ -305,7 +312,7 @@ def simulate( param_resolver: 'study.ParamResolverOrSimilarType' = None, qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, initial_state: Any = None, - ) -> 'SimulationTrialResult': + ) -> TSimulationTrialResult: """Simulates the supplied Circuit. This method returns a result which allows access to the entire @@ -335,7 +342,7 @@ def simulate_sweep( params: study.Sweepable, qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, initial_state: Any = None, - ) -> List['SimulationTrialResult']: + ) -> List[TSimulationTrialResult]: """Simulates the supplied Circuit. This method returns a result which allows access to the entire final @@ -359,7 +366,11 @@ def simulate_sweep( raise NotImplementedError() -class SimulatesIntermediateState(SimulatesFinalState, metaclass=abc.ABCMeta): +class SimulatesIntermediateStateBase( + Generic[TStepResult, TSimulationTrialResult, TSimulatorState], + SimulatesFinalState[TSimulationTrialResult], + metaclass=abc.ABCMeta, +): """A SimulatesFinalState that simulates a circuit by moments. Whereas a general SimulatesFinalState may return the entire simulator @@ -379,7 +390,7 @@ def simulate_sweep( params: study.Sweepable, qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, initial_state: Any = None, - ) -> List['SimulationTrialResult']: + ) -> List[TSimulationTrialResult]: """Simulates the supplied Circuit. This method returns a result which allows access to the entire @@ -425,7 +436,7 @@ def simulate_moment_steps( param_resolver: 'study.ParamResolverOrSimilarType' = None, qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, initial_state: Any = None, - ) -> Iterator: + ) -> Iterator[TStepResult]: """Returns an iterator of StepResults for each moment simulated. If the circuit being simulated is empty, a single step result should @@ -456,7 +467,7 @@ def _simulator_iterator( param_resolver: study.ParamResolver, qubit_order: ops.QubitOrderOrList, initial_state: Any, - ) -> Iterator: + ) -> Iterator[TStepResult]: """Iterator over StepResult from Moments of a Circuit. If the initial state is an int, the state is set to the computational @@ -493,7 +504,7 @@ def _base_iterator( circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList, initial_state: Any, - ) -> Iterator['StepResult']: + ) -> Iterator[TStepResult]: """Iterator over StepResult from Moments of a Circuit. Args: @@ -512,13 +523,41 @@ def _base_iterator( """ raise NotImplementedError() + @abc.abstractmethod + def _create_simulator_trial_result( + self, + params: study.ParamResolver, + measurements: Dict[str, np.ndarray], + final_simulator_state: TSimulatorState, + ) -> TSimulationTrialResult: + """This method can be implemented to create a trial result. + + Args: + params: The ParamResolver for this trial. + measurements: The measurement results for this trial. + final_simulator_state: The final state of the simulator for the + StepResult. + + Returns: + The SimulationTrialResult. + """ + raise NotImplementedError() + + +class SimulatesIntermediateState( + Generic[TStepResult, TSimulatorState], + SimulatesIntermediateStateBase[TStepResult, 'SimulationTrialResult', TSimulatorState], + metaclass=abc.ABCMeta, +): + """A SimulatesIntermediateState that uses the default SimulationTrialResult type.""" + def _create_simulator_trial_result( self, params: study.ParamResolver, measurements: Dict[str, np.ndarray], final_simulator_state: Any, ) -> 'SimulationTrialResult': - """This method can be overridden to creation of a trial result. + """This method creates a default trial result. Args: params: The ParamResolver for this trial. @@ -534,7 +573,7 @@ def _create_simulator_trial_result( ) -class StepResult(metaclass=abc.ABCMeta): +class StepResult(Generic[TSimulatorState], metaclass=abc.ABCMeta): """Results of a step of a SimulatesIntermediateState. Attributes: @@ -546,7 +585,7 @@ def __init__(self, measurements: Optional[Dict[str, List[int]]] = None) -> None: self.measurements = measurements or collections.defaultdict(list) @abc.abstractmethod - def _simulator_state(self) -> Any: + def _simulator_state(self) -> TSimulatorState: """Returns the simulator state of the simulator after this step. This method starts with an underscore to indicate that it is private. diff --git a/cirq/sim/sparse_simulator.py b/cirq/sim/sparse_simulator.py index 85f99609472..913d42c11ca 100644 --- a/cirq/sim/sparse_simulator.py +++ b/cirq/sim/sparse_simulator.py @@ -46,7 +46,7 @@ class Simulator( simulator.SimulatesSamples, - state_vector_simulator.SimulatesIntermediateStateVector, + state_vector_simulator.SimulatesIntermediateStateVector['SparseSimulatorStep'], simulator.SimulatesExpectationValues, ): """A sparse matrix state vector simulator that uses numpy. diff --git a/cirq/sim/state_vector_simulator.py b/cirq/sim/state_vector_simulator.py index 6a863cf28c7..79fea7799cf 100644 --- a/cirq/sim/state_vector_simulator.py +++ b/cirq/sim/state_vector_simulator.py @@ -15,7 +15,7 @@ import abc -from typing import Any, cast, Dict, Sequence, TYPE_CHECKING, Tuple +from typing import Any, cast, Dict, Sequence, TYPE_CHECKING, Tuple, Generic, TypeVar import numpy as np @@ -27,8 +27,16 @@ import cirq +TStateVectorStepResult = TypeVar('TStateVectorStepResult', bound='StateVectorStepResult') + + class SimulatesIntermediateStateVector( - simulator.SimulatesAmplitudes, simulator.SimulatesIntermediateState, metaclass=abc.ABCMeta + Generic[TStateVectorStepResult], + simulator.SimulatesAmplitudes, + simulator.SimulatesIntermediateStateBase[ + TStateVectorStepResult, 'StateVectorTrialResult', 'StateVectorSimulatorState' + ], + metaclass=abc.ABCMeta, ): """A simulator that accesses its state vector as it does its simulation. @@ -87,7 +95,9 @@ def __new__(cls, *args, **kwargs): return SimulatesIntermediateStateVector.__new__(cls) -class StateVectorStepResult(simulator.StepResult, metaclass=abc.ABCMeta): +class StateVectorStepResult( + simulator.StepResult['StateVectorSimulatorState'], metaclass=abc.ABCMeta +): @abc.abstractmethod def _simulator_state(self) -> 'StateVectorSimulatorState': """Returns the simulator_state of the simulator after this step.