Skip to content

Commit

Permalink
Add ClassicalDataStore class to keep track of qubits measured (#4781)
Browse files Browse the repository at this point in the history
Adds a `ClassicalDataStore` class so we can keep track of which qubits are associated to which measurements.

Closes #3232. Initially this was created as part 14 (of 14) of https://tinyurl.com/cirq-feedforward to enable qudits in classical conditions, by storing and using dimensions of the measured qubits when calculating the integer value of each measurement when resolving sympy expressions. However it may have broader applicability.

This approach also sets us up to more easily add different types of measurements (#3233, #4274). It will also ease the path to #3002 and #4449., as we can eventually pass this into `Result` rather than the raw `log_of_measurement_results` dictionary. (The return type of `_run` will have to be changed to `Sequence[C;assicalDataStoreReader]`.

Related: #887, #3231 (open question @95-martin-orion whether this closes those or not)

This PR contains a `ClassicalDataStoreReader` and `ClassicalDataStoreBase` parent "interface" for the `ClassicalDataStore` class as well. This will allow us to swap in different representations that may have different performance characteristics. See #3808 for an example use case. This could be done by adding an optional `ClassicalDataStore` factory method argument to the `SimulatorBase` initializer, or separately to sampler classes.

(Note this is an alternative to #4778 for supporting qudits in sympy classical control expressions, as discussed here: https://github.com/quantumlib/Cirq/pull/4778/files#r774816995. The other PR was simpler and less invasive, but a bit hacky. I felt even though bigger, this seemed like the better approach and especially fits better with our future direction, and closed the other one).

**Breaking Changes**:
1. The abstract method `SimulatorBase._create_partial_act_on_args` argument `log_of_measurement_results: Dict` has been changed to `classical_data: ClassicalData`. Any third-party simulators that inherit `SimulatorBase` will need to update their implementation accordingly.
2. The abstract base class `ActOnArgs.__init__` argument `log_of_measurement_results: Dict` is now copied before use. For users that depend on the pass-by-reference semantics (this should be rare), they can use the new `classical_data: ClassicalData` argument instead, which is pass-by-reference.
  • Loading branch information
daxfohl authored Feb 7, 2022
1 parent 467c68d commit 6937e41
Show file tree
Hide file tree
Showing 32 changed files with 730 additions and 126 deletions.
4 changes: 4 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@
canonicalize_half_turns,
chosen_angle_to_canonical_half_turns,
chosen_angle_to_half_turns,
ClassicalDataDictionaryStore,
ClassicalDataStore,
ClassicalDataStoreReader,
Condition,
Duration,
DURATION_LIKE,
Expand All @@ -515,6 +518,7 @@
LinearDict,
MEASUREMENT_KEY_SEPARATOR,
MeasurementKey,
MeasurementType,
PeriodicValue,
RANDOM_STATE_OR_SEED_LIKE,
state_vector_to_probabilities,
Expand Down
17 changes: 13 additions & 4 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _create_partial_act_on_args(
self,
initial_state: Union[int, 'MPSState'],
qubits: Sequence['cirq.Qid'],
logs: Dict[str, Any],
classical_data: 'cirq.ClassicalDataStore',
) -> 'MPSState':
"""Creates MPSState args for simulating the Circuit.
Expand All @@ -101,7 +101,8 @@ def _create_partial_act_on_args(
qubits: Determines the canonical ordering of the qubits. This
is often used in specifying the initial state, i.e. the
ordering of the computational basis states.
logs: A mutable object that measurements are recorded into.
classical_data: The shared classical data container for this
simulation.
Returns:
MPSState args for simulating the Circuit.
Expand All @@ -115,7 +116,7 @@ def _create_partial_act_on_args(
simulation_options=self.simulation_options,
grouping=self.grouping,
initial_state=initial_state,
log_of_measurement_results=logs,
classical_data=classical_data,
)

def _create_step_result(
Expand Down Expand Up @@ -229,6 +230,7 @@ def __init__(
grouping: Optional[Dict['cirq.Qid', int]] = None,
initial_state: int = 0,
log_of_measurement_results: Dict[str, Any] = None,
classical_data: 'cirq.ClassicalDataStore' = None,
):
"""Creates and MPSState
Expand All @@ -242,11 +244,18 @@ def __init__(
initial_state: An integer representing the initial state.
log_of_measurement_results: A mutable object that measurements are
being recorded into.
classical_data: The shared classical data container for this
simulation.
Raises:
ValueError: If the grouping does not cover the qubits.
"""
super().__init__(prng, qubits, log_of_measurement_results)
super().__init__(
prng=prng,
qubits=qubits,
log_of_measurement_results=log_of_measurement_results,
classical_data=classical_data,
)
qubit_map = self.qubit_map
self.grouping = qubit_map if grouping is None else grouping
if self.grouping.keys() != self.qubit_map.keys():
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,10 @@ def test_state_act_on_args_initializer():
s = ccq.mps_simulator.MPSState(
qubits=(cirq.LineQubit(0),),
prng=np.random.RandomState(0),
log_of_measurement_results={'test': 4},
log_of_measurement_results={'test': [4]},
)
assert s.qubits == (cirq.LineQubit(0),)
assert s.log_of_measurement_results == {'test': 4}
assert s.log_of_measurement_results == {'test': [4]}


def test_act_on_gate():
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _parallel_gate_op(gate, qubits):
'Circuit': cirq.Circuit,
'CircuitOperation': cirq.CircuitOperation,
'ClassicallyControlledOperation': cirq.ClassicallyControlledOperation,
'ClassicalDataDictionaryStore': cirq.ClassicalDataDictionaryStore,
'CliffordState': cirq.CliffordState,
'CliffordTableau': cirq.CliffordTableau,
'CNotPowGate': cirq.CNotPowGate,
Expand Down Expand Up @@ -107,6 +108,7 @@ def _parallel_gate_op(gate, qubits):
'MixedUnitaryChannel': cirq.MixedUnitaryChannel,
'MeasurementKey': cirq.MeasurementKey,
'MeasurementGate': cirq.MeasurementGate,
'MeasurementType': cirq.MeasurementType,
'_MeasurementSpec': cirq.work._MeasurementSpec,
'Moment': cirq.Moment,
'MutableDensePauliString': cirq.MutableDensePauliString,
Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def _circuit_diagram_info_(
sub_info = protocols.circuit_diagram_info(self._sub_operation, sub_args, None)
if sub_info is None:
return NotImplemented # coverage: ignore

control_count = len({k for c in self._conditions for k in c.keys})
wire_symbols = sub_info.wire_symbols + ('^',) * control_count
if any(not isinstance(c, value.KeyCondition) for c in self._conditions):
Expand Down Expand Up @@ -176,7 +175,7 @@ def _json_dict_(self) -> Dict[str, Any]:
}

def _act_on_(self, args: 'cirq.OperationTarget') -> bool:
if all(c.resolve(args.log_of_measurement_results) for c in self._conditions):
if all(c.resolve(args.classical_data) for c in self._conditions):
protocols.act_on(self._sub_operation, args)
return True

Expand Down
36 changes: 36 additions & 0 deletions cirq-core/cirq/ops/classically_controlled_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
import sympy
from sympy.parsing import sympy_parser
Expand Down Expand Up @@ -702,6 +704,40 @@ def test_sympy():
assert result.measurements['m_result'][0][0] == (j > i)


def test_sympy_qudits():
q0 = cirq.LineQid(0, 3)
q1 = cirq.LineQid(1, 5)
q_result = cirq.LineQubit(2)

class PlusGate(cirq.Gate):
def __init__(self, dimension, increment=1):
self.dimension = dimension
self.increment = increment % dimension

def _qid_shape_(self):
return (self.dimension,)

def _unitary_(self):
inc = (self.increment - 1) % self.dimension + 1
u = np.empty((self.dimension, self.dimension))
u[inc:] = np.eye(self.dimension)[:-inc]
u[:inc] = np.eye(self.dimension)[-inc:]
return u

for i in range(15):
digits = cirq.big_endian_int_to_digits(i, digit_count=2, base=(3, 5))
circuit = cirq.Circuit(
PlusGate(3, digits[0]).on(q0),
PlusGate(5, digits[1]).on(q1),
cirq.measure(q0, q1, key='m'),
cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m % 4 <= 1')),
cirq.measure(q_result, key='m_result'),
)

result = cirq.Simulator().run(circuit)
assert result.measurements['m_result'][0][0] == (i % 4 <= 1)


def test_sympy_path_prefix():
q = cirq.LineQubit(0)
op = cirq.X(q).with_classical_controls(sympy.Symbol('b'))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"cirq_type": "ClassicalDataDictionaryStore",
"measurements": [
[
{
"cirq_type": "MeasurementKey",
"name": "m",
"path": []
},
[0, 1]
]
],
"measured_qubits": [
[
{
"cirq_type": "MeasurementKey",
"name": "m",
"path": []
},
[
{
"cirq_type": "LineQubit",
"x": 0
},
{
"cirq_type": "LineQubit",
"x": 1
}
]
]
],
"channel_measurements": [
[
{
"cirq_type": "MeasurementKey",
"name": "c",
"path": []
},
3
]
],
"measurement_types": [
[
{
"cirq_type": "MeasurementKey",
"name": "m",
"path": []
},
1
],
[
{
"cirq_type": "MeasurementKey",
"name": "c",
"path": []
},
2
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL})
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[1, 2]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[cirq.MeasurementType.MEASUREMENT, cirq.MeasurementType.CHANNEL]
3 changes: 0 additions & 3 deletions cirq-core/cirq/protocols/measurement_key_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from cirq import value
from cirq._doc import doc_private

if TYPE_CHECKING:
import cirq

if TYPE_CHECKING:
import cirq

Expand Down
26 changes: 16 additions & 10 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import numpy as np

from cirq import protocols, ops
from cirq import ops, protocols, value
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq.sim.operation_target import OperationTarget

Expand All @@ -50,6 +50,7 @@ def __init__(
qubits: Optional[Sequence['cirq.Qid']] = None,
log_of_measurement_results: Optional[Dict[str, List[int]]] = None,
ignore_measurement_results: bool = False,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
):
"""Inits ActOnArgs.
Expand All @@ -65,16 +66,21 @@ def __init__(
will treat measurement as dephasing instead of collapsing
process, and not log the result. This is only applicable to
simulators that can represent mixed states.
classical_data: The shared classical data container for this
simulation.
"""
if prng is None:
prng = cast(np.random.RandomState, np.random)
if qubits is None:
qubits = ()
if log_of_measurement_results is None:
log_of_measurement_results = {}
self._set_qubits(qubits)
self.prng = prng
self._log_of_measurement_results = log_of_measurement_results
self._classical_data = classical_data or value.ClassicalDataDictionaryStore(
_measurements={
value.MeasurementKey.parse_serialized(k): tuple(v)
for k, v in (log_of_measurement_results or {}).items()
}
)
self._ignore_measurement_results = ignore_measurement_results

def _set_qubits(self, qubits: Sequence['cirq.Qid']):
Expand Down Expand Up @@ -103,9 +109,9 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
return
bits = self._perform_measurement(qubits)
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
if key in self._log_of_measurement_results:
raise ValueError(f"Measurement already logged to key {key!r}")
self._log_of_measurement_results[key] = corrected
self._classical_data.record_measurement(
value.MeasurementKey.parse_serialized(key), corrected, qubits
)

def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
return [self.qubit_map[q] for q in qubits]
Expand Down Expand Up @@ -138,7 +144,7 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
DeprecationWarning,
)
self._on_copy(args)
args._log_of_measurement_results = self.log_of_measurement_results.copy()
args._classical_data = self._classical_data.copy()
return args

def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True):
Expand Down Expand Up @@ -236,8 +242,8 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ
functionality, if supported."""

@property
def log_of_measurement_results(self) -> Dict[str, List[int]]:
return self._log_of_measurement_results
def classical_data(self) -> 'cirq.ClassicalDataStoreReader':
return self._classical_data

@property
def ignore_measurement_results(self) -> bool:
Expand Down
Loading

0 comments on commit 6937e41

Please sign in to comment.