Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the quantum state generic #5255

Merged
merged 2 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def sample(


@value.value_equality
class MPSState(ActOnArgs):
class MPSState(ActOnArgs[_MPSHandler]):
"""A state of the MPS simulation."""

@deprecated_parameter(
Expand Down Expand Up @@ -626,7 +626,6 @@ def __init__(
)
else:
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
self._state: _MPSHandler = state

def i_str(self, i: int) -> str:
# Returns the index name for the i'th qid.
Expand Down
8 changes: 5 additions & 3 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Any,
cast,
Dict,
Generic,
Iterator,
List,
Mapping,
Expand All @@ -36,12 +37,13 @@
from cirq.sim.operation_target import OperationTarget

TSelf = TypeVar('TSelf', bound='ActOnArgs')
TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation')

if TYPE_CHECKING:
import cirq


class ActOnArgs(OperationTarget[TSelf], metaclass=abc.ABCMeta):
class ActOnArgs(OperationTarget, Generic[TState], metaclass=abc.ABCMeta):
"""State and context for an operation acting on a state tensor."""

@deprecated_parameter(
Expand All @@ -63,7 +65,7 @@ def __init__(
qubits: Optional[Sequence['cirq.Qid']] = None,
log_of_measurement_results: Optional[Dict[str, List[int]]] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
state: Optional['cirq.QuantumStateRepresentation'] = None,
state: Optional[TState] = None,
):
"""Inits ActOnArgs.
Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(
for k, v in (log_of_measurement_results or {}).items()
}
)
self._state = state
self._state = cast(TState, state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we specify state: TState in the arg list instead of casting? Or does that break the callers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it has to be a cast because it's an optional parameter. Ideally we could make it non optional now, but it occurs after some other optional parameters in the method signature. That's part of why I deprecated positional arguments in the other PR.

if state is None:
_warn_or_error('This function will require a valid `state` input in cirq v0.16.')

Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def can_represent_mixed_states(self) -> bool:
return True


class ActOnDensityMatrixArgs(ActOnArgs):
class ActOnDensityMatrixArgs(ActOnArgs[_BufferedDensityMatrix]):
"""State and context for an operation acting on a density matrix.
To act on this object, directly edit the `target_tensor` property, which is
Expand Down Expand Up @@ -286,7 +286,6 @@ def __init__(
buffer=available_buffer,
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
self._state: _BufferedDensityMatrix = state

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def supports_factor(self) -> bool:
return True


class ActOnStateVectorArgs(ActOnArgs):
class ActOnStateVectorArgs(ActOnArgs[_BufferedStateVector]):
"""State and context for an operation acting on a state vector.
There are two common ways to act on this object:
Expand Down Expand Up @@ -357,7 +357,6 @@ def __init__(
buffer=available_buffer,
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
self._state: _BufferedStateVector = state

@_compat.deprecated(
deadline='v0.16', fix='None, this function was unintentionally made public.'
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
TStabilizerState = TypeVar('TStabilizerState', bound='cirq.StabilizerState')


class ActOnStabilizerArgs(ActOnArgs, Generic[TStabilizerState], metaclass=abc.ABCMeta):
class ActOnStabilizerArgs(
ActOnArgs[TStabilizerState], Generic[TStabilizerState], metaclass=abc.ABCMeta
):
"""Abstract wrapper around a stabilizer state for the act_on protocol."""

@deprecated_parameter(
Expand Down Expand Up @@ -81,7 +83,6 @@ def __init__(
)
else:
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
self._state: TStabilizerState = state

@property
def state(self) -> TStabilizerState:
Expand Down
32 changes: 7 additions & 25 deletions cirq-core/cirq/sim/simulator_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,13 @@ def copy(self, deep_copy_buffers: bool = True) -> 'CountingState':
)


class CountingActOnArgs(cirq.ActOnArgs):
class CountingActOnArgs(cirq.ActOnArgs[CountingState]):
def __init__(self, state, qubits, classical_data):
state_obj = CountingState(state)
super().__init__(
state=state_obj,
qubits=qubits,
classical_data=classical_data,
)
self._state: CountingState = state_obj
super().__init__(state=state_obj, qubits=qubits, classical_data=classical_data)

def _act_on_fallback_(
self,
action: Any,
qubits: Sequence['cirq.Qid'],
allow_decompose: bool = True,
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
self._state.gate_count += 1
return True
Expand Down Expand Up @@ -120,10 +112,7 @@ class CountingSimulator(
]
):
def __init__(self, noise=None, split_untangled_states=False):
super().__init__(
noise=noise,
split_untangled_states=split_untangled_states,
)
super().__init__(noise=noise, split_untangled_states=split_untangled_states)

def _create_partial_act_on_args(
self,
Expand All @@ -142,18 +131,14 @@ def _create_simulator_trial_result(
return CountingTrialResult(params, measurements, final_step_result=final_step_result)

def _create_step_result(
self,
sim_state: cirq.OperationTarget[CountingActOnArgs],
self, sim_state: cirq.OperationTarget[CountingActOnArgs]
) -> CountingStepResult:
return CountingStepResult(sim_state)


class SplittableCountingSimulator(CountingSimulator):
def __init__(self, noise=None, split_untangled_states=True):
super().__init__(
noise=noise,
split_untangled_states=split_untangled_states,
)
super().__init__(noise=noise, split_untangled_states=split_untangled_states)

def _create_partial_act_on_args(
self,
Expand Down Expand Up @@ -390,10 +375,7 @@ def _has_unitary_(self):
return self.has_unitary

simulator = CountingSimulator()
params = [
cirq.ParamResolver({'a': 0}),
cirq.ParamResolver({'a': 1}),
]
params = [cirq.ParamResolver({'a': 0}), cirq.ParamResolver({'a': 1})]

op1 = TestOp(has_unitary=True)
op2 = TestOp(has_unitary=True)
Expand Down