Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the quantum state generic #5255

Merged
merged 2 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def sample(


@value.value_equality
class MPSState(ActOnArgs):
class MPSState(ActOnArgs[_MPSHandler]):
"""A state of the MPS simulation."""

def __init__(
Expand Down Expand Up @@ -621,7 +621,6 @@ def __init__(
log_of_measurement_results=log_of_measurement_results,
classical_data=classical_data,
)
self._state: _MPSHandler = state

def i_str(self, i: int) -> str:
# Returns the index name for the i'th qid.
Expand Down
7 changes: 5 additions & 2 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objects and methods for acting efficiently on a state tensor."""
import abc
import copy
from typing import (
Any,
cast,
Dict,
Generic,
Iterator,
List,
Mapping,
Expand All @@ -35,12 +37,13 @@
from cirq.sim.operation_target import OperationTarget

TSelf = TypeVar('TSelf', bound='ActOnArgs')
TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation')

if TYPE_CHECKING:
import cirq


class ActOnArgs(OperationTarget[TSelf]):
class ActOnArgs(OperationTarget, Generic[TState], metaclass=abc.ABCMeta):
"""State and context for an operation acting on a state tensor."""

def __init__(
Expand Down Expand Up @@ -77,7 +80,7 @@ def __init__(
for k, v in (log_of_measurement_results or {}).items()
}
)
self._state = state
self._state = cast(TState, state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we specify state: TState in the arg list instead of casting? Or does that break the callers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it has to be a cast because it's an optional parameter. Ideally we could make it non optional now, but it occurs after some other optional parameters in the method signature. That's part of why I deprecated positional arguments in the other PR.

if state is None:
_warn_or_error('This function will require a valid `state` input in cirq v0.16.')

Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def can_represent_mixed_states(self) -> bool:
return True


class ActOnDensityMatrixArgs(ActOnArgs):
class ActOnDensityMatrixArgs(ActOnArgs[_BufferedDensityMatrix]):
"""State and context for an operation acting on a density matrix.

To act on this object, directly edit the `target_tensor` property, which is
Expand Down Expand Up @@ -301,7 +301,6 @@ def __init__(
log_of_measurement_results=log_of_measurement_results,
classical_data=classical_data,
)
self._state: _BufferedDensityMatrix = state

def _act_on_fallback_(
self,
Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def supports_factor(self) -> bool:
return True


class ActOnStateVectorArgs(ActOnArgs):
class ActOnStateVectorArgs(ActOnArgs[_BufferedStateVector]):
"""State and context for an operation acting on a state vector.

There are two common ways to act on this object:
Expand Down Expand Up @@ -390,7 +390,6 @@ def __init__(
log_of_measurement_results=log_of_measurement_results,
classical_data=classical_data,
)
self._state: _BufferedStateVector = state

@_compat.deprecated(
deadline='v0.16',
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
TStabilizerState = TypeVar('TStabilizerState', bound='cirq.StabilizerState')


class ActOnStabilizerArgs(ActOnArgs, Generic[TStabilizerState], metaclass=abc.ABCMeta):
class ActOnStabilizerArgs(
ActOnArgs[TStabilizerState], Generic[TStabilizerState], metaclass=abc.ABCMeta
):
"""Abstract wrapper around a stabilizer state for the act_on protocol."""

def __init__(
Expand Down Expand Up @@ -64,7 +66,6 @@ def __init__(
log_of_measurement_results=log_of_measurement_results,
classical_data=classical_data,
)
self._state: TStabilizerState = state

@property
def state(self) -> TStabilizerState:
Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/sim/simulator_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@ def copy(self, deep_copy_buffers: bool = True) -> 'CountingState':
)


class CountingActOnArgs(cirq.ActOnArgs):
class CountingActOnArgs(cirq.ActOnArgs[CountingState]):
def __init__(self, state, qubits, classical_data):
state_obj = CountingState(state)
super().__init__(
state=state_obj,
qubits=qubits,
classical_data=classical_data,
)
self._state: CountingState = state_obj

def _act_on_fallback_(
self,
Expand Down