Skip to content

Commit

Permalink
Improve type safety with generics on simulators
Browse files Browse the repository at this point in the history
  • Loading branch information
daxfohl committed Feb 17, 2021
1 parent 6aa46d7 commit 90c6c99
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 28 deletions.
9 changes: 7 additions & 2 deletions cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
from cirq.sim import simulator


class MPSSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
class MPSSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateStateBase[
'MPSSimulatorStepResult', 'MPSTrialResult', 'MPSState'
],
):
"""An efficient simulator for MPS circuits."""

def __init__(
Expand Down Expand Up @@ -239,7 +244,7 @@ def __str__(self) -> str:
return f'measurements: {samples}\noutput state: {final}'


class MPSSimulatorStepResult(simulator.StepResult):
class MPSSimulatorStepResult(simulator.StepResult['MPSState']):
"""A `StepResult` that can perform measurements."""

def __init__(self, state, measurements):
Expand Down
4 changes: 3 additions & 1 deletion cirq/experiments/xeb_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def simulate_2q_xeb_circuits(
# Need an actual object; not np.random or else multiprocessing will
# fail to pickle the closure object:
# https://github.com/quantumlib/Cirq/issues/3717
simulator = sim.Simulator(seed=np.random.RandomState())
simulator = cast(
'cirq.SimulatesIntermediateState', sim.Simulator(seed=np.random.RandomState())
)
_simulate_2q_xeb_circuit = _Simulate_2q_XEB_Circuit(simulator=simulator)

tasks = tuple(
Expand Down
4 changes: 2 additions & 2 deletions cirq/google/calibration/engine_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SimulatesSamples,
SimulatesIntermediateStateVector,
SparseSimulatorStep,
StepResult,
StateVectorStepResult,
)
from cirq.study import ParamResolver
from cirq.value import RANDOM_STATE_OR_SEED_LIKE, parse_random_state
Expand Down Expand Up @@ -402,7 +402,7 @@ def _base_iterator(
circuit: Circuit,
qubit_order: QubitOrderOrList,
initial_state: Any,
) -> Iterator[StepResult]:
) -> Iterator[StateVectorStepResult]:
converted = _convert_to_circuit_with_drift(self, circuit)
return self._simulator._base_iterator(converted, qubit_order, initial_state)

Expand Down
9 changes: 7 additions & 2 deletions cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
from cirq.sim.simulator import check_all_resolved


class CliffordSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
class CliffordSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateStateBase[
'CliffordSimulatorStepResult', 'CliffordTrialResult', 'CliffordState'
],
):
"""An efficient simulator for Clifford circuits."""

def __init__(self, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None):
Expand Down Expand Up @@ -168,7 +173,7 @@ def __str__(self) -> str:
return f'measurements: {samples}\noutput state: {final}'


class CliffordSimulatorStepResult(simulator.StepResult):
class CliffordSimulatorStepResult(simulator.StepResult['CliffordState']):
"""A `StepResult` that includes `StateVectorMixin` methods."""

def __init__(self, state, measurements):
Expand Down
9 changes: 7 additions & 2 deletions cirq/sim/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def __init__(self, num_qubits: int, tensor: np.ndarray):
self.buffers = [np.empty_like(tensor) for _ in range(3)]


class DensityMatrixSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
class DensityMatrixSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateStateBase[
'DensityMatrixStepResult', 'DensityMatrixTrialResult', 'DensityMatrixSimulatorState'
],
):
"""A simulator for density matrices and noisy quantum circuits.
This simulator can be applied on circuits that are made up of operations
Expand Down Expand Up @@ -320,7 +325,7 @@ def _create_simulator_trial_result(
)


class DensityMatrixStepResult(simulator.StepResult):
class DensityMatrixStepResult(simulator.StepResult['DensityMatrixSimulatorState']):
"""A single step in the simulation of the DensityMatrixSimulator.
Attributes:
Expand Down
8 changes: 4 additions & 4 deletions cirq/sim/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,16 @@ def final_density_matrix(

if can_do_unitary_simulation:
# pure case: use SparseSimulator
result = sparse_simulator.Simulator(dtype=dtype, seed=seed).simulate(
sparse_result = sparse_simulator.Simulator(dtype=dtype, seed=seed).simulate(
program=circuit_like,
initial_state=initial_state,
qubit_order=qubit_order,
param_resolver=param_resolver,
)
return cast(state_vector_simulator.StateVectorTrialResult, result).density_matrix_of()
return sparse_result.density_matrix_of()
else:
# noisy case: use DensityMatrixSimulator with dephasing
result = density_matrix_simulator.DensityMatrixSimulator(
density_result = density_matrix_simulator.DensityMatrixSimulator(
dtype=dtype,
noise=noise,
seed=seed,
Expand All @@ -293,4 +293,4 @@ def final_density_matrix(
qubit_order=qubit_order,
param_resolver=param_resolver,
)
return cast(density_matrix_simulator.DensityMatrixTrialResult, result).final_density_matrix
return density_result.final_density_matrix
61 changes: 50 additions & 11 deletions cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
TYPE_CHECKING,
Set,
cast,
TypeVar,
Generic,
)

import abc
Expand All @@ -53,6 +55,11 @@
import cirq


TStepResult = TypeVar('TStepResult', bound='StepResult')
TSimulationTrialResult = TypeVar('TSimulationTrialResult', bound='SimulationTrialResult')
TSimulatorState = TypeVar('TSimulatorState')


class SimulatesSamples(work.Sampler, metaclass=abc.ABCMeta):
"""Simulator that mimics running on quantum hardware.
Expand Down Expand Up @@ -288,7 +295,7 @@ def simulate_expectation_values_sweep(
"""


class SimulatesFinalState(metaclass=abc.ABCMeta):
class SimulatesFinalState(Generic[TSimulationTrialResult], metaclass=abc.ABCMeta):
"""Simulator that allows access to the simulator's final state.
Implementors of this interface should implement the simulate_sweep
Expand All @@ -305,7 +312,7 @@ def simulate(
param_resolver: 'study.ParamResolverOrSimilarType' = None,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
initial_state: Any = None,
) -> 'SimulationTrialResult':
) -> TSimulationTrialResult:
"""Simulates the supplied Circuit.
This method returns a result which allows access to the entire
Expand Down Expand Up @@ -335,7 +342,7 @@ def simulate_sweep(
params: study.Sweepable,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
initial_state: Any = None,
) -> List['SimulationTrialResult']:
) -> List[TSimulationTrialResult]:
"""Simulates the supplied Circuit.
This method returns a result which allows access to the entire final
Expand All @@ -359,7 +366,11 @@ def simulate_sweep(
raise NotImplementedError()


class SimulatesIntermediateState(SimulatesFinalState, metaclass=abc.ABCMeta):
class SimulatesIntermediateStateBase(
Generic[TStepResult, TSimulationTrialResult, TSimulatorState],
SimulatesFinalState[TSimulationTrialResult],
metaclass=abc.ABCMeta,
):
"""A SimulatesFinalState that simulates a circuit by moments.
Whereas a general SimulatesFinalState may return the entire simulator
Expand All @@ -379,7 +390,7 @@ def simulate_sweep(
params: study.Sweepable,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
initial_state: Any = None,
) -> List['SimulationTrialResult']:
) -> List[TSimulationTrialResult]:
"""Simulates the supplied Circuit.
This method returns a result which allows access to the entire
Expand Down Expand Up @@ -425,7 +436,7 @@ def simulate_moment_steps(
param_resolver: 'study.ParamResolverOrSimilarType' = None,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
initial_state: Any = None,
) -> Iterator:
) -> Iterator[TStepResult]:
"""Returns an iterator of StepResults for each moment simulated.
If the circuit being simulated is empty, a single step result should
Expand Down Expand Up @@ -456,7 +467,7 @@ def _simulator_iterator(
param_resolver: study.ParamResolver,
qubit_order: ops.QubitOrderOrList,
initial_state: Any,
) -> Iterator:
) -> Iterator[TStepResult]:
"""Iterator over StepResult from Moments of a Circuit.
If the initial state is an int, the state is set to the computational
Expand Down Expand Up @@ -493,7 +504,7 @@ def _base_iterator(
circuit: circuits.Circuit,
qubit_order: ops.QubitOrderOrList,
initial_state: Any,
) -> Iterator['StepResult']:
) -> Iterator[TStepResult]:
"""Iterator over StepResult from Moments of a Circuit.
Args:
Expand All @@ -512,13 +523,41 @@ def _base_iterator(
"""
raise NotImplementedError()

@abc.abstractmethod
def _create_simulator_trial_result(
self,
params: study.ParamResolver,
measurements: Dict[str, np.ndarray],
final_simulator_state: TSimulatorState,
) -> TSimulationTrialResult:
"""This method can be implemented to create a trial result.
Args:
params: The ParamResolver for this trial.
measurements: The measurement results for this trial.
final_simulator_state: The final state of the simulator for the
StepResult.
Returns:
The SimulationTrialResult.
"""
raise NotImplementedError()


class SimulatesIntermediateState(
Generic[TStepResult, TSimulatorState],
SimulatesIntermediateStateBase[TStepResult, 'SimulationTrialResult', TSimulatorState],
metaclass=abc.ABCMeta,
):
"""A SimulatesIntermediateState that uses the default SimulationTrialResult type."""

def _create_simulator_trial_result(
self,
params: study.ParamResolver,
measurements: Dict[str, np.ndarray],
final_simulator_state: Any,
) -> 'SimulationTrialResult':
"""This method can be overridden to creation of a trial result.
"""This method creates a default trial result.
Args:
params: The ParamResolver for this trial.
Expand All @@ -534,7 +573,7 @@ def _create_simulator_trial_result(
)


class StepResult(metaclass=abc.ABCMeta):
class StepResult(Generic[TSimulatorState], metaclass=abc.ABCMeta):
"""Results of a step of a SimulatesIntermediateState.
Attributes:
Expand All @@ -546,7 +585,7 @@ def __init__(self, measurements: Optional[Dict[str, List[int]]] = None) -> None:
self.measurements = measurements or collections.defaultdict(list)

@abc.abstractmethod
def _simulator_state(self) -> Any:
def _simulator_state(self) -> TSimulatorState:
"""Returns the simulator state of the simulator after this step.
This method starts with an underscore to indicate that it is private.
Expand Down
2 changes: 1 addition & 1 deletion cirq/sim/sparse_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

class Simulator(
simulator.SimulatesSamples,
state_vector_simulator.SimulatesIntermediateStateVector,
state_vector_simulator.SimulatesIntermediateStateVector['SparseSimulatorStep'],
simulator.SimulatesExpectationValues,
):
"""A sparse matrix state vector simulator that uses numpy.
Expand Down
16 changes: 13 additions & 3 deletions cirq/sim/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import abc

from typing import Any, cast, Dict, Sequence, TYPE_CHECKING, Tuple
from typing import Any, cast, Dict, Sequence, TYPE_CHECKING, Tuple, Generic, TypeVar

import numpy as np

Expand All @@ -27,8 +27,16 @@
import cirq


TStateVectorStepResult = TypeVar('TStateVectorStepResult', bound='StateVectorStepResult')


class SimulatesIntermediateStateVector(
simulator.SimulatesAmplitudes, simulator.SimulatesIntermediateState, metaclass=abc.ABCMeta
Generic[TStateVectorStepResult],
simulator.SimulatesAmplitudes,
simulator.SimulatesIntermediateStateBase[
TStateVectorStepResult, 'StateVectorTrialResult', 'StateVectorSimulatorState'
],
metaclass=abc.ABCMeta,
):
"""A simulator that accesses its state vector as it does its simulation.
Expand Down Expand Up @@ -87,7 +95,9 @@ def __new__(cls, *args, **kwargs):
return SimulatesIntermediateStateVector.__new__(cls)


class StateVectorStepResult(simulator.StepResult, metaclass=abc.ABCMeta):
class StateVectorStepResult(
simulator.StepResult['StateVectorSimulatorState'], metaclass=abc.ABCMeta
):
@abc.abstractmethod
def _simulator_state(self) -> 'StateVectorSimulatorState':
"""Returns the simulator_state of the simulator after this step.
Expand Down

0 comments on commit 90c6c99

Please sign in to comment.