Skip to content

Commit

Permalink
Base class for quantum states (quantumlib#5065)
Browse files Browse the repository at this point in the history
Creates a base class for all the quantum state classes created in quantumlib#4979, and uses the inheritance to push the implementation of `ActOn<State>Args.kron`, `factor`, etc into the base class.

Closes quantumlib#4827
Resolves quantumlib#3841 (comment) that's been bugging me for a year.
  • Loading branch information
daxfohl authored and rht committed May 1, 2023
1 parent 3b3a087 commit 72e554e
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 244 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@
operation_to_superoperator,
QUANTUM_STATE_LIKE,
QuantumState,
QuantumStateRepresentation,
quantum_state,
STATE_VECTOR_LIKE,
StabilizerState,
Expand Down
40 changes: 14 additions & 26 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import quimb.tensor as qtn

from cirq import devices, protocols, value
from cirq import devices, protocols, qis, value
from cirq._compat import deprecated
from cirq.sim import simulator_base
from cirq.sim.act_on_args import ActOnArgs
Expand Down Expand Up @@ -220,7 +220,7 @@ def _simulator_state(self):


@value.value_equality
class _MPSHandler:
class _MPSHandler(qis.QuantumStateRepresentation):
"""Quantum state of the MPS simulation."""

def __init__(
Expand Down Expand Up @@ -604,21 +604,24 @@ def __init__(
Raises:
ValueError: If the grouping does not cover the qubits.
"""
qubit_map = {q: i for i, q in enumerate(qubits)}
final_grouping = qubit_map if grouping is None else grouping
if final_grouping.keys() != qubit_map.keys():
raise ValueError('Grouping must cover exactly the qubits.')
state = _MPSHandler.create(
initial_state=initial_state,
qid_shape=tuple(q.dimension for q in qubits),
simulation_options=simulation_options,
grouping={qubit_map[k]: v for k, v in final_grouping.items()},
)
super().__init__(
state=state,
prng=prng,
qubits=qubits,
log_of_measurement_results=log_of_measurement_results,
classical_data=classical_data,
)
final_grouping = self.qubit_map if grouping is None else grouping
if final_grouping.keys() != self.qubit_map.keys():
raise ValueError('Grouping must cover exactly the qubits.')
self._state = _MPSHandler.create(
initial_state=initial_state,
qid_shape=tuple(q.dimension for q in qubits),
simulation_options=simulation_options,
grouping={self.qubit_map[k]: v for k, v in final_grouping.items()},
)
self._state: _MPSHandler = state

def i_str(self, i: int) -> str:
# Returns the index name for the i'th qid.
Expand All @@ -636,9 +639,6 @@ def __str__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.qubits, self._state

def _on_copy(self, target: 'MPSState', deep_copy_buffers: bool = True):
target._state = self._state.copy(deep_copy_buffers)

def state_vector(self) -> np.ndarray:
"""Returns the full state vector.
Expand Down Expand Up @@ -709,15 +709,3 @@ def perform_measurement(
tolerance specified in simulation options.
"""
return self._state._measure(self.get_axes(qubits), prng, collapse_state_vector)

def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Measures the axes specified by the simulator."""
return self._state.measure(self.get_axes(qubits), self.prng)

def sample(
self,
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
return self._state.sample(self.get_axes(qubits), repetitions, seed)
2 changes: 1 addition & 1 deletion cirq-core/cirq/qis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
superoperator_to_kraus,
)

from cirq.qis.clifford_tableau import CliffordTableau, StabilizerState
from cirq.qis.clifford_tableau import CliffordTableau, QuantumStateRepresentation, StabilizerState

from cirq.qis.measures import (
entanglement_fidelity,
Expand Down
93 changes: 89 additions & 4 deletions cirq-core/cirq/qis/clifford_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,97 @@
# limitations under the License.

import abc
from typing import Any, Dict, List, TYPE_CHECKING
from typing import Any, Dict, List, Sequence, Tuple, TYPE_CHECKING, TypeVar
import numpy as np

from cirq import protocols
from cirq import protocols, value
from cirq.value import big_endian_int_to_digits, linear_dict

if TYPE_CHECKING:
import cirq

TSelf = TypeVar('TSelf', bound='QuantumStateRepresentation')

class StabilizerState(metaclass=abc.ABCMeta):

class QuantumStateRepresentation(metaclass=abc.ABCMeta):
@abc.abstractmethod
def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
"""Creates a copy of the object.
Args:
deep_copy_buffers: If True, buffers will also be deep-copied.
Otherwise the copy will share a reference to the original object's
buffers.
Returns:
A copied instance.
"""

@abc.abstractmethod
def measure(
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
) -> List[int]:
"""Measures the state.
Args:
axes: The axes to measure.
seed: The random number seed to use.
Returns:
The measurements in order.
"""

def sample(
self,
axes: Sequence[int],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
"""Samples the state. Subclasses can override with more performant method.
Args:
axes: The axes to sample.
repetitions: The number of samples to make.
seed: The random number seed to use.
Returns:
The samples in order.
"""
prng = value.parse_random_state(seed)
measurements = []
for _ in range(repetitions):
state = self.copy()
measurements.append(state.measure(axes, prng))
return np.array(measurements, dtype=bool)

def kron(self: TSelf, other: TSelf) -> TSelf:
"""Joins two state spaces together."""
raise NotImplementedError()

def factor(
self: TSelf, axes: Sequence[int], *, validate=True, atol=1e-07
) -> Tuple[TSelf, TSelf]:
"""Splits two state spaces after a measurement or reset."""
raise NotImplementedError()

def reindex(self: TSelf, axes: Sequence[int]) -> TSelf:
"""Physically reindexes the state by the new basis.
Args:
axes: The desired axis order.
Returns:
The state with qubit order transposed and underlying representation
updated.
"""
raise NotImplementedError()

@property
def supports_factor(self) -> bool:
"""Subclasses that allow factorization should override this."""
return False

@property
def can_represent_mixed_states(self) -> bool:
"""Subclasses that can represent mixed states should override this."""
return False


class StabilizerState(QuantumStateRepresentation, metaclass=abc.ABCMeta):
"""Interface for quantum stabilizer state representations.
This interface is used for CliffordTableau and StabilizerChForm quantum
Expand Down Expand Up @@ -222,7 +302,7 @@ def __eq__(self, other):
def __copy__(self) -> 'CliffordTableau':
return self.copy()

def copy(self) -> 'CliffordTableau':
def copy(self, deep_copy_buffers: bool = True) -> 'CliffordTableau':
state = CliffordTableau(self.n)
state.rs = self.rs.copy()
state.xs = self.xs.copy()
Expand Down Expand Up @@ -578,3 +658,8 @@ def apply_cx(

def apply_global_phase(self, coefficient: linear_dict.Scalar):
pass

def measure(
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
) -> List[int]:
return [self._measure(axis, seed) for axis in axes]
49 changes: 38 additions & 11 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objects and methods for acting efficiently on a state tensor."""
import abc
import copy
import inspect
import warnings
from typing import (
Any,
cast,
Expand All @@ -28,7 +28,6 @@
TYPE_CHECKING,
Tuple,
)
import warnings

import numpy as np

Expand Down Expand Up @@ -59,6 +58,7 @@ def __init__(
log_of_measurement_results: Optional[Dict[str, List[int]]] = None,
ignore_measurement_results: bool = False,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
state: Optional['cirq.QuantumStateRepresentation'] = None,
):
"""Inits ActOnArgs.
Expand All @@ -76,6 +76,7 @@ def __init__(
simulators that can represent mixed states.
classical_data: The shared classical data container for this
simulation.
state: The underlying quantum state of the simulation.
"""
if prng is None:
prng = cast(np.random.RandomState, np.random)
Expand All @@ -90,6 +91,7 @@ def __init__(
}
)
self._ignore_measurement_results = ignore_measurement_results
self._state = state

@property
def prng(self) -> np.random.RandomState:
Expand Down Expand Up @@ -148,10 +150,21 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
return [self.qubit_map[q] for q in qubits]

@abc.abstractmethod
def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Child classes that perform measurements should implement this with
the implementation."""
"""Delegates the call to measure the density matrix."""
if self._state is not None:
return self._state.measure(self.get_axes(qubits), self.prng)
raise NotImplementedError()

def sample(
self,
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
if self._state is not None:
return self._state.sample(self.get_axes(qubits), repetitions, seed)
raise NotImplementedError()

def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
"""Creates a copy of the object.
Expand All @@ -165,6 +178,10 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
A copied instance.
"""
args = copy.copy(self)
args._classical_data = self._classical_data.copy()
if self._state is not None:
args._state = self._state.copy(deep_copy_buffers=deep_copy_buffers)
return args
if 'deep_copy_buffers' in inspect.signature(self._on_copy).parameters:
self._on_copy(args, deep_copy_buffers)
else:
Expand All @@ -176,7 +193,6 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
DeprecationWarning,
)
self._on_copy(args)
args._classical_data = self._classical_data.copy()
return args

def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True):
Expand All @@ -190,7 +206,10 @@ def create_merged_state(self: TSelf) -> TSelf:
def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf:
"""Joins two state spaces together."""
args = self if inplace else copy.copy(self)
self._on_kronecker_product(other, args)
if self._state is not None and other._state is not None:
args._state = self._state.kron(other._state)
else:
self._on_kronecker_product(other, args)
args._set_qubits(self.qubits + other.qubits)
return args

Expand Down Expand Up @@ -225,15 +244,20 @@ def factor(
"""Splits two state spaces after a measurement or reset."""
extracted = copy.copy(self)
remainder = self if inplace else copy.copy(self)
self._on_factor(qubits, extracted, remainder, validate, atol)
if self._state is not None:
e, r = self._state.factor(self.get_axes(qubits), validate=validate, atol=atol)
extracted._state = e
remainder._state = r
else:
self._on_factor(qubits, extracted, remainder, validate, atol)
extracted._set_qubits(qubits)
remainder._set_qubits([q for q in self.qubits if q not in qubits])
return extracted, remainder

@property
def allows_factoring(self):
"""Subclasses that allow factorization should override this."""
return False
return self._state.supports_factor if self._state is not None else False

def _on_factor(
self: TSelf,
Expand Down Expand Up @@ -265,7 +289,10 @@ def transpose_to_qubit_order(
if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits):
raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}')
args = self if inplace else copy.copy(self)
self._on_transpose_to_qubit_order(qubits, args)
if self._state is not None:
args._state = self._state.reindex(self.get_axes(qubits))
else:
self._on_transpose_to_qubit_order(qubits, args)
args._set_qubits(qubits)
return args

Expand Down Expand Up @@ -356,7 +383,7 @@ def __iter__(self) -> Iterator[Optional['cirq.Qid']]:

@property
def can_represent_mixed_states(self) -> bool:
return False
return self._state.can_represent_mixed_states if self._state is not None else False


def strat_act_on_from_apply_decompose(
Expand Down
15 changes: 0 additions & 15 deletions cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,10 @@ def _act_on_fallback_(
) -> bool:
return True

def _on_copy(self, args):
pass

def _on_kronecker_product(self, other, target):
pass

def _on_transpose_to_qubit_order(self, qubits, target):
pass

def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07):
pass

@property
def allows_factoring(self):
return True

def sample(self, qubits, repetitions=1, seed=None):
pass


q0, q1, q2 = qs3 = cirq.LineQubit.range(3)
qs2 = cirq.LineQubit.range(2)
Expand Down
Loading

0 comments on commit 72e554e

Please sign in to comment.