Skip to content

Commit

Permalink
Make the quantum state generic (#5255)
Browse files Browse the repository at this point in the history
  • Loading branch information
daxfohl authored Apr 13, 2022
1 parent f85e731 commit f62f9e3
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 36 deletions.
3 changes: 1 addition & 2 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def sample(


@value.value_equality
class MPSState(ActOnArgs):
class MPSState(ActOnArgs[_MPSHandler]):
"""A state of the MPS simulation."""

@deprecated_parameter(
Expand Down Expand Up @@ -626,7 +626,6 @@ def __init__(
)
else:
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
self._state: _MPSHandler = state

def i_str(self, i: int) -> str:
# Returns the index name for the i'th qid.
Expand Down
8 changes: 5 additions & 3 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Any,
cast,
Dict,
Generic,
Iterator,
List,
Mapping,
Expand All @@ -36,12 +37,13 @@
from cirq.sim.operation_target import OperationTarget

TSelf = TypeVar('TSelf', bound='ActOnArgs')
TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation')

if TYPE_CHECKING:
import cirq


class ActOnArgs(OperationTarget[TSelf], metaclass=abc.ABCMeta):
class ActOnArgs(OperationTarget, Generic[TState], metaclass=abc.ABCMeta):
"""State and context for an operation acting on a state tensor."""

@deprecated_parameter(
Expand All @@ -63,7 +65,7 @@ def __init__(
qubits: Optional[Sequence['cirq.Qid']] = None,
log_of_measurement_results: Optional[Dict[str, List[int]]] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
state: Optional['cirq.QuantumStateRepresentation'] = None,
state: Optional[TState] = None,
):
"""Inits ActOnArgs.
Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(
for k, v in (log_of_measurement_results or {}).items()
}
)
self._state = state
self._state = cast(TState, state)
if state is None:
_warn_or_error('This function will require a valid `state` input in cirq v0.16.')

Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def can_represent_mixed_states(self) -> bool:
return True


class ActOnDensityMatrixArgs(ActOnArgs):
class ActOnDensityMatrixArgs(ActOnArgs[_BufferedDensityMatrix]):
"""State and context for an operation acting on a density matrix.
To act on this object, directly edit the `target_tensor` property, which is
Expand Down Expand Up @@ -286,7 +286,6 @@ def __init__(
buffer=available_buffer,
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
self._state: _BufferedDensityMatrix = state

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def supports_factor(self) -> bool:
return True


class ActOnStateVectorArgs(ActOnArgs):
class ActOnStateVectorArgs(ActOnArgs[_BufferedStateVector]):
"""State and context for an operation acting on a state vector.
There are two common ways to act on this object:
Expand Down Expand Up @@ -357,7 +357,6 @@ def __init__(
buffer=available_buffer,
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
self._state: _BufferedStateVector = state

@_compat.deprecated(
deadline='v0.16', fix='None, this function was unintentionally made public.'
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
TStabilizerState = TypeVar('TStabilizerState', bound='cirq.StabilizerState')


class ActOnStabilizerArgs(ActOnArgs, Generic[TStabilizerState], metaclass=abc.ABCMeta):
class ActOnStabilizerArgs(
ActOnArgs[TStabilizerState], Generic[TStabilizerState], metaclass=abc.ABCMeta
):
"""Abstract wrapper around a stabilizer state for the act_on protocol."""

@deprecated_parameter(
Expand Down Expand Up @@ -81,7 +83,6 @@ def __init__(
)
else:
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
self._state: TStabilizerState = state

@property
def state(self) -> TStabilizerState:
Expand Down
32 changes: 7 additions & 25 deletions cirq-core/cirq/sim/simulator_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,13 @@ def copy(self, deep_copy_buffers: bool = True) -> 'CountingState':
)


class CountingActOnArgs(cirq.ActOnArgs):
class CountingActOnArgs(cirq.ActOnArgs[CountingState]):
def __init__(self, state, qubits, classical_data):
state_obj = CountingState(state)
super().__init__(
state=state_obj,
qubits=qubits,
classical_data=classical_data,
)
self._state: CountingState = state_obj
super().__init__(state=state_obj, qubits=qubits, classical_data=classical_data)

def _act_on_fallback_(
self,
action: Any,
qubits: Sequence['cirq.Qid'],
allow_decompose: bool = True,
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
self._state.gate_count += 1
return True
Expand Down Expand Up @@ -120,10 +112,7 @@ class CountingSimulator(
]
):
def __init__(self, noise=None, split_untangled_states=False):
super().__init__(
noise=noise,
split_untangled_states=split_untangled_states,
)
super().__init__(noise=noise, split_untangled_states=split_untangled_states)

def _create_partial_act_on_args(
self,
Expand All @@ -142,18 +131,14 @@ def _create_simulator_trial_result(
return CountingTrialResult(params, measurements, final_step_result=final_step_result)

def _create_step_result(
self,
sim_state: cirq.OperationTarget[CountingActOnArgs],
self, sim_state: cirq.OperationTarget[CountingActOnArgs]
) -> CountingStepResult:
return CountingStepResult(sim_state)


class SplittableCountingSimulator(CountingSimulator):
def __init__(self, noise=None, split_untangled_states=True):
super().__init__(
noise=noise,
split_untangled_states=split_untangled_states,
)
super().__init__(noise=noise, split_untangled_states=split_untangled_states)

def _create_partial_act_on_args(
self,
Expand Down Expand Up @@ -390,10 +375,7 @@ def _has_unitary_(self):
return self.has_unitary

simulator = CountingSimulator()
params = [
cirq.ParamResolver({'a': 0}),
cirq.ParamResolver({'a': 1}),
]
params = [cirq.ParamResolver({'a': 0}), cirq.ParamResolver({'a': 1})]

op1 = TestOp(has_unitary=True)
op2 = TestOp(has_unitary=True)
Expand Down

0 comments on commit f62f9e3

Please sign in to comment.