From 6937e417e372e4157ed70a6a38ec1d38964201ab Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Mon, 7 Feb 2022 12:08:30 -0800 Subject: [PATCH] Add ClassicalDataStore class to keep track of qubits measured (#4781) Adds a `ClassicalDataStore` class so we can keep track of which qubits are associated to which measurements. Closes #3232. Initially this was created as part 14 (of 14) of https://tinyurl.com/cirq-feedforward to enable qudits in classical conditions, by storing and using dimensions of the measured qubits when calculating the integer value of each measurement when resolving sympy expressions. However it may have broader applicability. This approach also sets us up to more easily add different types of measurements (#3233, #4274). It will also ease the path to #3002 and #4449., as we can eventually pass this into `Result` rather than the raw `log_of_measurement_results` dictionary. (The return type of `_run` will have to be changed to `Sequence[C;assicalDataStoreReader]`. Related: #887, #3231 (open question @95-martin-orion whether this closes those or not) This PR contains a `ClassicalDataStoreReader` and `ClassicalDataStoreBase` parent "interface" for the `ClassicalDataStore` class as well. This will allow us to swap in different representations that may have different performance characteristics. See #3808 for an example use case. This could be done by adding an optional `ClassicalDataStore` factory method argument to the `SimulatorBase` initializer, or separately to sampler classes. (Note this is an alternative to #4778 for supporting qudits in sympy classical control expressions, as discussed here: https://github.com/quantumlib/Cirq/pull/4778/files#r774816995. The other PR was simpler and less invasive, but a bit hacky. I felt even though bigger, this seemed like the better approach and especially fits better with our future direction, and closed the other one). **Breaking Changes**: 1. The abstract method `SimulatorBase._create_partial_act_on_args` argument `log_of_measurement_results: Dict` has been changed to `classical_data: ClassicalData`. Any third-party simulators that inherit `SimulatorBase` will need to update their implementation accordingly. 2. The abstract base class `ActOnArgs.__init__` argument `log_of_measurement_results: Dict` is now copied before use. For users that depend on the pass-by-reference semantics (this should be rare), they can use the new `classical_data: ClassicalData` argument instead, which is pass-by-reference. --- cirq-core/cirq/__init__.py | 4 + cirq-core/cirq/contrib/quimb/mps_simulator.py | 17 +- .../cirq/contrib/quimb/mps_simulator_test.py | 4 +- cirq-core/cirq/json_resolver_cache.py | 2 + .../ops/classically_controlled_operation.py | 3 +- .../classically_controlled_operation_test.py | 36 +++ .../ClassicalDataDictionaryStore.json | 60 +++++ .../ClassicalDataDictionaryStore.repr | 1 + .../json_test_data/MeasurementType.json | 1 + .../json_test_data/MeasurementType.repr | 1 + .../protocols/measurement_key_protocol.py | 3 - cirq-core/cirq/sim/act_on_args.py | 26 +- cirq-core/cirq/sim/act_on_args_container.py | 39 +-- .../cirq/sim/act_on_density_matrix_args.py | 13 +- .../cirq/sim/act_on_state_vector_args.py | 16 +- .../clifford/act_on_clifford_tableau_args.py | 18 +- .../act_on_stabilizer_ch_form_args.py | 18 +- .../cirq/sim/clifford/clifford_simulator.py | 11 +- .../clifford/stabilizer_state_ch_form_test.py | 5 +- .../cirq/sim/density_matrix_simulator.py | 7 +- cirq-core/cirq/sim/operation_target.py | 19 +- cirq-core/cirq/sim/simulator_base.py | 22 +- cirq-core/cirq/sim/simulator_base_test.py | 22 +- cirq-core/cirq/sim/sparse_simulator.py | 8 +- cirq-core/cirq/value/__init__.py | 7 + cirq-core/cirq/value/classical_data.py | 245 ++++++++++++++++++ cirq-core/cirq/value/classical_data_test.py | 136 ++++++++++ cirq-core/cirq/value/condition.py | 33 ++- cirq-core/cirq/value/condition_test.py | 48 ++-- cirq-core/cirq/value/measurement_key.py | 10 + cirq-core/cirq/value/measurement_key_test.py | 19 ++ .../calibration/engine_simulator.py | 2 +- 32 files changed, 730 insertions(+), 126 deletions(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json create mode 100644 cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/MeasurementType.json create mode 100644 cirq-core/cirq/protocols/json_test_data/MeasurementType.repr create mode 100644 cirq-core/cirq/value/classical_data.py create mode 100644 cirq-core/cirq/value/classical_data_test.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 3d3aade6008..206a94611ef 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -507,6 +507,9 @@ canonicalize_half_turns, chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, + ClassicalDataDictionaryStore, + ClassicalDataStore, + ClassicalDataStoreReader, Condition, Duration, DURATION_LIKE, @@ -515,6 +518,7 @@ LinearDict, MEASUREMENT_KEY_SEPARATOR, MeasurementKey, + MeasurementType, PeriodicValue, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 8ad99b0b002..34daf9e10e3 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -91,7 +91,7 @@ def _create_partial_act_on_args( self, initial_state: Union[int, 'MPSState'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ) -> 'MPSState': """Creates MPSState args for simulating the Circuit. @@ -101,7 +101,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: A mutable object that measurements are recorded into. + classical_data: The shared classical data container for this + simulation. Returns: MPSState args for simulating the Circuit. @@ -115,7 +116,7 @@ def _create_partial_act_on_args( simulation_options=self.simulation_options, grouping=self.grouping, initial_state=initial_state, - log_of_measurement_results=logs, + classical_data=classical_data, ) def _create_step_result( @@ -229,6 +230,7 @@ def __init__( grouping: Optional[Dict['cirq.Qid', int]] = None, initial_state: int = 0, log_of_measurement_results: Dict[str, Any] = None, + classical_data: 'cirq.ClassicalDataStore' = None, ): """Creates and MPSState @@ -242,11 +244,18 @@ def __init__( initial_state: An integer representing the initial state. log_of_measurement_results: A mutable object that measurements are being recorded into. + classical_data: The shared classical data container for this + simulation. Raises: ValueError: If the grouping does not cover the qubits. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) qubit_map = self.qubit_map self.grouping = qubit_map if grouping is None else grouping if self.grouping.keys() != self.qubit_map.keys(): diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 1025f05e944..53b778d9fb3 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -550,10 +550,10 @@ def test_state_act_on_args_initializer(): s = ccq.mps_simulator.MPSState( qubits=(cirq.LineQubit(0),), prng=np.random.RandomState(0), - log_of_measurement_results={'test': 4}, + log_of_measurement_results={'test': [4]}, ) assert s.qubits == (cirq.LineQubit(0),) - assert s.log_of_measurement_results == {'test': 4} + assert s.log_of_measurement_results == {'test': [4]} def test_act_on_gate(): diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 19fc564df2a..8218e19e0e9 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -65,6 +65,7 @@ def _parallel_gate_op(gate, qubits): 'Circuit': cirq.Circuit, 'CircuitOperation': cirq.CircuitOperation, 'ClassicallyControlledOperation': cirq.ClassicallyControlledOperation, + 'ClassicalDataDictionaryStore': cirq.ClassicalDataDictionaryStore, 'CliffordState': cirq.CliffordState, 'CliffordTableau': cirq.CliffordTableau, 'CNotPowGate': cirq.CNotPowGate, @@ -107,6 +108,7 @@ def _parallel_gate_op(gate, qubits): 'MixedUnitaryChannel': cirq.MixedUnitaryChannel, 'MeasurementKey': cirq.MeasurementKey, 'MeasurementGate': cirq.MeasurementGate, + 'MeasurementType': cirq.MeasurementType, '_MeasurementSpec': cirq.work._MeasurementSpec, 'Moment': cirq.Moment, 'MutableDensePauliString': cirq.MutableDensePauliString, diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 3b8d051fb54..10fa65977f2 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -148,7 +148,6 @@ def _circuit_diagram_info_( sub_info = protocols.circuit_diagram_info(self._sub_operation, sub_args, None) if sub_info is None: return NotImplemented # coverage: ignore - control_count = len({k for c in self._conditions for k in c.keys}) wire_symbols = sub_info.wire_symbols + ('^',) * control_count if any(not isinstance(c, value.KeyCondition) for c in self._conditions): @@ -176,7 +175,7 @@ def _json_dict_(self) -> Dict[str, Any]: } def _act_on_(self, args: 'cirq.OperationTarget') -> bool: - if all(c.resolve(args.log_of_measurement_results) for c in self._conditions): + if all(c.resolve(args.classical_data) for c in self._conditions): protocols.act_on(self._sub_operation, args) return True diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 40b8e98f836..e3cf0f7b1c0 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import numpy as np import pytest import sympy from sympy.parsing import sympy_parser @@ -702,6 +704,40 @@ def test_sympy(): assert result.measurements['m_result'][0][0] == (j > i) +def test_sympy_qudits(): + q0 = cirq.LineQid(0, 3) + q1 = cirq.LineQid(1, 5) + q_result = cirq.LineQubit(2) + + class PlusGate(cirq.Gate): + def __init__(self, dimension, increment=1): + self.dimension = dimension + self.increment = increment % dimension + + def _qid_shape_(self): + return (self.dimension,) + + def _unitary_(self): + inc = (self.increment - 1) % self.dimension + 1 + u = np.empty((self.dimension, self.dimension)) + u[inc:] = np.eye(self.dimension)[:-inc] + u[:inc] = np.eye(self.dimension)[-inc:] + return u + + for i in range(15): + digits = cirq.big_endian_int_to_digits(i, digit_count=2, base=(3, 5)) + circuit = cirq.Circuit( + PlusGate(3, digits[0]).on(q0), + PlusGate(5, digits[1]).on(q1), + cirq.measure(q0, q1, key='m'), + cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m % 4 <= 1')), + cirq.measure(q_result, key='m_result'), + ) + + result = cirq.Simulator().run(circuit) + assert result.measurements['m_result'][0][0] == (i % 4 <= 1) + + def test_sympy_path_prefix(): q = cirq.LineQubit(0) op = cirq.X(q).with_classical_controls(sympy.Symbol('b')) diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json new file mode 100644 index 00000000000..d5c51d5839c --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json @@ -0,0 +1,60 @@ +{ + "cirq_type": "ClassicalDataDictionaryStore", + "measurements": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] + }, + [0, 1] + ] + ], + "measured_qubits": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] + }, + [ + { + "cirq_type": "LineQubit", + "x": 0 + }, + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + ] + ], + "channel_measurements": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "c", + "path": [] + }, + 3 + ] + ], + "measurement_types": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] + }, + 1 + ], + [ + { + "cirq_type": "MeasurementKey", + "name": "c", + "path": [] + }, + 2 + ] + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr new file mode 100644 index 00000000000..c19b8190bfb --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr @@ -0,0 +1 @@ +cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/MeasurementType.json b/cirq-core/cirq/protocols/json_test_data/MeasurementType.json new file mode 100644 index 00000000000..fd8ef095787 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/MeasurementType.json @@ -0,0 +1 @@ +[1, 2] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr b/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr new file mode 100644 index 00000000000..edeebfddc51 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr @@ -0,0 +1 @@ +[cirq.MeasurementType.MEASUREMENT, cirq.MeasurementType.CHANNEL] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 333540b62ca..20b6fa29695 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -20,9 +20,6 @@ from cirq import value from cirq._doc import doc_private -if TYPE_CHECKING: - import cirq - if TYPE_CHECKING: import cirq diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index f8d1d955b8d..8694067131c 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -31,7 +31,7 @@ import numpy as np -from cirq import protocols, ops +from cirq import ops, protocols, value from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits from cirq.sim.operation_target import OperationTarget @@ -50,6 +50,7 @@ def __init__( qubits: Optional[Sequence['cirq.Qid']] = None, log_of_measurement_results: Optional[Dict[str, List[int]]] = None, ignore_measurement_results: bool = False, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits ActOnArgs. @@ -65,16 +66,21 @@ def __init__( will treat measurement as dephasing instead of collapsing process, and not log the result. This is only applicable to simulators that can represent mixed states. + classical_data: The shared classical data container for this + simulation. """ if prng is None: prng = cast(np.random.RandomState, np.random) if qubits is None: qubits = () - if log_of_measurement_results is None: - log_of_measurement_results = {} self._set_qubits(qubits) self.prng = prng - self._log_of_measurement_results = log_of_measurement_results + self._classical_data = classical_data or value.ClassicalDataDictionaryStore( + _measurements={ + value.MeasurementKey.parse_serialized(k): tuple(v) + for k, v in (log_of_measurement_results or {}).items() + } + ) self._ignore_measurement_results = ignore_measurement_results def _set_qubits(self, qubits: Sequence['cirq.Qid']): @@ -103,9 +109,9 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[ return bits = self._perform_measurement(qubits) corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)] - if key in self._log_of_measurement_results: - raise ValueError(f"Measurement already logged to key {key!r}") - self._log_of_measurement_results[key] = corrected + self._classical_data.record_measurement( + value.MeasurementKey.parse_serialized(key), corrected, qubits + ) def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: return [self.qubit_map[q] for q in qubits] @@ -138,7 +144,7 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: DeprecationWarning, ) self._on_copy(args) - args._log_of_measurement_results = self.log_of_measurement_results.copy() + args._classical_data = self._classical_data.copy() return args def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True): @@ -236,8 +242,8 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ functionality, if supported.""" @property - def log_of_measurement_results(self) -> Dict[str, List[int]]: - return self._log_of_measurement_results + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': + return self._classical_data @property def ignore_measurement_results(self) -> bool: diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index d5ee3d6dbc4..b00960b2b49 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -12,25 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import abc import inspect +import warnings +from collections import abc from typing import ( Dict, - TYPE_CHECKING, Generic, - Sequence, - Optional, Iterator, - Any, - Tuple, List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, Union, ) -import warnings import numpy as np -from cirq import ops, protocols +from cirq import ops, protocols, value from cirq.sim.operation_target import OperationTarget from cirq.sim.simulator import ( TActOnArgs, @@ -52,7 +51,8 @@ def __init__( args: Dict[Optional['cirq.Qid'], TActOnArgs], qubits: Sequence['cirq.Qid'], split_untangled_states: bool, - log_of_measurement_results: Dict[str, Any], + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Initializes the class. @@ -65,11 +65,18 @@ def __init__( at the end. log_of_measurement_results: A mutable object that measurements are being recorded into. + classical_data: The shared classical data container for this + simulation. """ self.args = args self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states - self._log_of_measurement_results = log_of_measurement_results + self._classical_data = classical_data or value.ClassicalDataDictionaryStore( + _measurements={ + value.MeasurementKey.parse_serialized(k): tuple(v) + for k, v in (log_of_measurement_results or {}).items() + } + ) def create_merged_state(self) -> TActOnArgs: if not self.split_untangled_states: @@ -135,7 +142,7 @@ def _act_on_fallback_( return True def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActOnArgs]': - logs = self.log_of_measurement_results.copy() + classical_data = self._classical_data.copy() copies = {} for act_on_args in set(self.args.values()): if 'deep_copy_buffers' in inspect.signature(act_on_args.copy).parameters: @@ -150,17 +157,19 @@ def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActO ) copies[act_on_args] = act_on_args.copy() for copy in copies.values(): - copy._log_of_measurement_results = logs + copy._classical_data = classical_data args = {q: copies[a] for q, a in self.args.items()} - return ActOnArgsContainer(args, self.qubits, self.split_untangled_states, logs) + return ActOnArgsContainer( + args, self.qubits, self.split_untangled_states, classical_data=classical_data + ) @property def qubits(self) -> Tuple['cirq.Qid', ...]: return self._qubits @property - def log_of_measurement_results(self) -> Dict[str, Any]: - return self._log_of_measurement_results + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': + return self._classical_data def sample( self, diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 1660b2ca652..35ed79de68d 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -45,11 +45,12 @@ def __init__( available_buffer: Optional[List[np.ndarray]] = None, qid_shape: Optional[Tuple[int, ...]] = None, prng: Optional[np.random.RandomState] = None, - log_of_measurement_results: Optional[Dict[str, Any]] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, ignore_measurement_results: bool = False, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, dtype: Type[np.number] = np.complex64, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits ActOnDensityMatrixArgs. @@ -78,12 +79,20 @@ def __init__( dtype: The `numpy.dtype` of the inferred state vector. One of `numpy.complex64` or `numpy.complex128`. Only used when `target_tenson` is None. + classical_data: The shared classical data container for this + simulation. Raises: ValueError: The dimension of `target_tensor` is not divisible by 2 and `qid_shape` is not provided. """ - super().__init__(prng, qubits, log_of_measurement_results, ignore_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + ignore_measurement_results=ignore_measurement_results, + classical_data=classical_data, + ) if target_tensor is None: qubits_qid_shape = protocols.qid_shape(self.qubits) initial_matrix = qis.to_valid_density_matrix( diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 06f07ed6e37..a1c2618e66f 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -48,10 +48,11 @@ def __init__( target_tensor: Optional[np.ndarray] = None, available_buffer: Optional[np.ndarray] = None, prng: Optional[np.random.RandomState] = None, - log_of_measurement_results: Optional[Dict[str, Any]] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, dtype: Type[np.number] = np.complex64, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits ActOnStateVectorArgs. @@ -76,8 +77,15 @@ def __init__( dtype: The `numpy.dtype` of the inferred state vector. One of `numpy.complex64` or `numpy.complex128`. Only used when `target_tenson` is None. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) if target_tensor is None: qid_shape = protocols.qid_shape(self.qubits) state = qis.to_valid_state_vector( @@ -304,7 +312,7 @@ def _strat_act_on_state_vector_from_mixture( args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.log_of_measurement_results[key] = [index] + args._classical_data.record_channel_measurement(key, index) return True @@ -353,5 +361,5 @@ def prepare_into_buffer(k: int): args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.log_of_measurement_results[key] = [index] + args._classical_data.record_channel_measurement(key, index) return True diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 9629e7a7a9b..f2be23883ad 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -14,7 +14,7 @@ """A protocol for implementing high performance clifford tableau evolutions for Clifford Simulator.""" -from typing import Any, Dict, TYPE_CHECKING, List, Sequence +from typing import Dict, List, Optional, Sequence, TYPE_CHECKING import numpy as np @@ -32,9 +32,10 @@ class ActOnCliffordTableauArgs(ActOnStabilizerArgs): def __init__( self, tableau: 'cirq.CliffordTableau', - prng: np.random.RandomState, - log_of_measurement_results: Dict[str, Any], - qubits: Sequence['cirq.Qid'] = None, + prng: Optional[np.random.RandomState] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, + qubits: Optional[Sequence['cirq.Qid']] = None, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits ActOnCliffordTableauArgs. @@ -48,8 +49,15 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) self.tableau = tableau def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 1c9e2c25b26..c84f0020406 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, TYPE_CHECKING, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING, Union import numpy as np @@ -42,6 +42,7 @@ def __init__( log_of_measurement_results: Optional[Dict[str, Any]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, initial_state: Union[int, 'cirq.StabilizerStateChForm'] = 0, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Initializes with the given state and the axes for the operation. @@ -58,8 +59,15 @@ def __init__( initial_state: The initial state for the simulation. This can be a full CH form passed by reference which will be modified inplace, or a big-endian int in the computational basis. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) initial_state = state or initial_state if isinstance(initial_state, int): qubit_map = {q: i for i, q in enumerate(self.qubits)} @@ -92,19 +100,19 @@ def sample( repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: - measurements: Dict[str, List[np.ndarray]] = {} + measurements = value.ClassicalDataDictionaryStore() prng = value.parse_random_state(seed) for i in range(repetitions): op = ops.measure(*qubits, key=str(i)) state = self.state.copy() ch_form_args = ActOnStabilizerCHFormArgs( + classical_data=measurements, prng=prng, - log_of_measurement_results=measurements, qubits=self.qubits, initial_state=state, ) protocols.act_on(op, ch_form_args) - return np.array(list(measurements.values()), dtype=bool) + return np.array(list(measurements.measurements.values()), dtype=bool) def _x(self, g: common_gates.XPowGate, axis: int): exponent = g.exponent diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 746758cfb95..e04559272d7 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -77,7 +77,7 @@ def _create_partial_act_on_args( self, initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ) -> 'cirq.ActOnStabilizerCHFormArgs': """Creates the ActOnStabilizerChFormArgs for a circuit. @@ -88,6 +88,8 @@ def _create_partial_act_on_args( is often used in specifying the initial state, i.e. the ordering of the computational basis states. logs: A log of the results of measurement that is added to. + classical_data: The shared classical data container for this + simulation. Returns: ActOnStabilizerChFormArgs for the circuit. @@ -97,7 +99,7 @@ def _create_partial_act_on_args( return clifford.ActOnStabilizerCHFormArgs( prng=self._prng, - log_of_measurement_results=logs, + classical_data=classical_data, qubits=qubits, initial_state=initial_state, ) @@ -254,7 +256,6 @@ def state_vector(self): def apply_unitary(self, op: 'cirq.Operation'): ch_form_args = clifford.ActOnStabilizerCHFormArgs( prng=np.random.RandomState(), - log_of_measurement_results={}, qubits=self.qubit_map.keys(), initial_state=self.ch_form, ) @@ -284,10 +285,12 @@ def apply_measurement( else: state = self.copy() + classical_data = value.ClassicalDataDictionaryStore() ch_form_args = clifford.ActOnStabilizerCHFormArgs( prng=prng, - log_of_measurement_results=measurements, + classical_data=classical_data, qubits=self.qubit_map.keys(), initial_state=state.ch_form, ) act_on(op, ch_form_args) + measurements.update({str(k): list(v) for k, v in classical_data.measurements.items()}) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py index d5c2d72fc0d..6f1dc38b4bd 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py @@ -64,14 +64,15 @@ def test_run(): ) for _ in range(10): state = cirq.StabilizerStateChForm(num_qubits=3) - measurements = {} + classical_data = cirq.ClassicalDataDictionaryStore() for op in circuit.all_operations(): args = cirq.ActOnStabilizerCHFormArgs( qubits=list(circuit.all_qubits()), prng=np.random.RandomState(), - log_of_measurement_results=measurements, + classical_data=classical_data, initial_state=state, ) cirq.act_on(op, args) + measurements = {str(k): list(v) for k, v in classical_data.measurements.items()} assert measurements['1'] == [1] assert measurements['0'] != measurements['2'] diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index b10875f2529..9408016e9f4 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -176,7 +176,7 @@ def _create_partial_act_on_args( self, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ) -> 'cirq.ActOnDensityMatrixArgs': """Creates the ActOnDensityMatrixArgs for a circuit. @@ -186,7 +186,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: The log of measurement results that is added into. + classical_data: The shared classical data container for this + simulation. Returns: ActOnDensityMatrixArgs for the circuit. @@ -197,7 +198,7 @@ def _create_partial_act_on_args( return act_on_density_matrix_args.ActOnDensityMatrixArgs( qubits=qubits, prng=self._prng, - log_of_measurement_results=logs, + classical_data=classical_data, ignore_measurement_results=self._ignore_measurement_results, initial_state=initial_state, dtype=self._dtype, diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index 3b208c3e33a..e54916303cd 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -14,16 +14,16 @@ """An interface for quantum states as targets for operations.""" import abc from typing import ( - TypeVar, - TYPE_CHECKING, - Generic, - Dict, Any, - Tuple, - Optional, + Dict, + Generic, Iterator, List, + Optional, Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, Union, ) @@ -86,9 +86,14 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: """Gets the qubit order maintained by this target.""" @property - @abc.abstractmethod def log_of_measurement_results(self) -> Dict[str, Any]: """Gets the log of measurement results.""" + return {str(k): list(self.classical_data.get_digits(k)) for k in self.classical_data.keys()} + + @property + @abc.abstractmethod + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': + """The shared classical data container for this simulation..""" @abc.abstractmethod def sample( diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 29a24d2dfb3..610e369611c 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -17,6 +17,7 @@ import abc import collections import inspect +import warnings from typing import ( Any, Dict, @@ -31,7 +32,6 @@ Optional, TypeVar, ) -import warnings import numpy as np @@ -126,7 +126,7 @@ def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ) -> TActOnArgs: """Creates an instance of the TActOnArgs class for the simulator. @@ -137,8 +137,8 @@ def _create_partial_act_on_args( understood to be a pure state. Other state representations are simulator-dependent. qubits: The sequence of qubits to represent. - logs: The structure to hold measurement logs. A single instance - should be shared among all ActOnArgs within the simulation. + classical_data: The shared classical data container for this + simulation. """ @abc.abstractmethod @@ -352,7 +352,7 @@ def _create_act_on_args( if isinstance(initial_state, OperationTarget): return initial_state - log: Dict[str, Any] = {} + classical_data = value.ClassicalDataDictionaryStore() if self._split_untangled_states: args_map: Dict[Optional['cirq.Qid'], TActOnArgs] = {} if isinstance(initial_state, int): @@ -360,24 +360,26 @@ def _create_act_on_args( args_map[q] = self._create_partial_act_on_args( initial_state=initial_state % q.dimension, qubits=[q], - logs=log, + classical_data=classical_data, ) initial_state = int(initial_state / q.dimension) else: args = self._create_partial_act_on_args( initial_state=initial_state, qubits=qubits, - logs=log, + classical_data=classical_data, ) for q in qubits: args_map[q] = args - args_map[None] = self._create_partial_act_on_args(0, (), log) - return ActOnArgsContainer(args_map, qubits, self._split_untangled_states, log) + args_map[None] = self._create_partial_act_on_args(0, (), classical_data) + return ActOnArgsContainer( + args_map, qubits, self._split_untangled_states, classical_data=classical_data + ) else: return self._create_partial_act_on_args( initial_state=initial_state, qubits=qubits, - logs=log, + classical_data=classical_data, ) diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index d848e2012d3..a99527f3722 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -25,10 +25,10 @@ class CountingActOnArgs(cirq.ActOnArgs): gate_count = 0 measurement_count = 0 - def __init__(self, state, qubits, logs): + def __init__(self, state, qubits, classical_data): super().__init__( qubits=qubits, - log_of_measurement_results=logs, + classical_data=classical_data, ) self.state = state @@ -39,7 +39,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: def copy(self, deep_copy_buffers: bool = True) -> 'CountingActOnArgs': args = CountingActOnArgs( qubits=self.qubits, - logs=self.log_of_measurement_results.copy(), + classical_data=self.classical_data.copy(), state=self.state, ) args.gate_count = self.gate_count @@ -115,9 +115,9 @@ def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> CountingActOnArgs: - return CountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return CountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) def _create_simulator_trial_result( self, @@ -145,9 +145,11 @@ def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> CountingActOnArgs: - return SplittableCountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return SplittableCountingActOnArgs( + qubits=qubits, state=initial_state, classical_data=classical_data + ) q0, q1 = cirq.LineQubit.range(2) @@ -270,9 +272,11 @@ def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> MockCountingActOnArgs: - return MockCountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return MockCountingActOnArgs( + qubits=qubits, state=initial_state, classical_data=classical_data + ) def _create_simulator_trial_result( self, diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index d722970509d..7da4d31bab5 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -16,7 +16,6 @@ from typing import ( Any, - Dict, Iterator, List, Type, @@ -175,7 +174,7 @@ def _create_partial_act_on_args( self, initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ): """Creates the ActOnStateVectorArgs for a circuit. @@ -185,7 +184,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: Log of the measurement results. + classical_data: The shared classical data container for this + simulation. Returns: ActOnStateVectorArgs for the circuit. @@ -196,7 +196,7 @@ def _create_partial_act_on_args( return act_on_state_vector_args.ActOnStateVectorArgs( qubits=qubits, prng=self._prng, - log_of_measurement_results=logs, + classical_data=classical_data, initial_state=initial_state, dtype=self._dtype, ) diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index bbf81d71817..bd34876530c 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -25,6 +25,13 @@ chosen_angle_to_half_turns, ) +from cirq.value.classical_data import ( + ClassicalDataDictionaryStore, + ClassicalDataStore, + ClassicalDataStoreReader, + MeasurementType, +) + from cirq.value.condition import ( Condition, KeyCondition, diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py new file mode 100644 index 00000000000..5596be02efd --- /dev/null +++ b/cirq-core/cirq/value/classical_data.py @@ -0,0 +1,245 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import enum +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, TypeVar + +from cirq.value import digits, value_equality_attr + +if TYPE_CHECKING: + import cirq + + +class MeasurementType(enum.IntEnum): + MEASUREMENT = 1 + CHANNEL = 2 + + def __repr__(self): + return f'cirq.{str(self)}' + + +TSelf = TypeVar('TSelf', bound='ClassicalDataStoreReader') + + +class ClassicalDataStoreReader(abc.ABC): + @abc.abstractmethod + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + """Gets the measurement keys in the order they were stored.""" + + @abc.abstractmethod + def get_int(self, key: 'cirq.MeasurementKey') -> int: + """Gets the integer corresponding to the measurement. + + The integer is determined by summing the qubit-dimensional basis value + of each measured qubit. For example, if the measurement of qubits + [q1, q0] produces [1, 0], then the corresponding integer is 2, the big- + endian equivalent. If they are qutrits and the measurement is [2, 1], + then the integer is 2 * 3 + 1 = 7. + + Args: + key: The measurement key. + + Raises: + KeyError: If the key has not been used. + """ + + @abc.abstractmethod + def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + """Gets the values of the qubits that were measured into this key. + + For example, if the measurement of qubits [q0, q1] produces [0, 1], + this function will return (0, 1). + + Args: + key: The measurement key. + + Raises: + KeyError: If the key has not been used. + """ + + @abc.abstractmethod + def copy(self: TSelf) -> TSelf: + """Creates a copy of the object.""" + + +class ClassicalDataStore(ClassicalDataStoreReader, abc.ABC): + @abc.abstractmethod + def record_measurement( + self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] + ): + """Records a measurement. + + Args: + key: The measurement key to hold the measurement. + measurement: The measurement result. + qubits: The qubits that were measured. + + Raises: + ValueError: If the measurement shape does not match the qubits + measured or if the measurement key was already used. + """ + + @abc.abstractmethod + def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): + """Records a channel measurement. + + Args: + key: The measurement key to hold the measurement. + measurement: The measurement result. + + Raises: + ValueError: If the measurement key was already used. + """ + + +@value_equality_attr.value_equality(unhashable=True) +class ClassicalDataDictionaryStore(ClassicalDataStore): + """Classical data representing measurements and metadata.""" + + def __init__( + self, + *, + _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, + _measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, + _channel_measurements: Dict['cirq.MeasurementKey', int] = None, + _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, + ): + """Initializes a `ClassicalDataDictionaryStore` object.""" + if not _measurement_types: + _measurement_types = {} + if _measurements: + _measurement_types.update( + {k: MeasurementType.MEASUREMENT for k, v in _measurements.items()} + ) + if _channel_measurements: + _measurement_types.update( + {k: MeasurementType.CHANNEL for k, v in _channel_measurements.items()} + ) + if _measurements is None: + _measurements = {} + if _measured_qubits is None: + _measured_qubits = {} + if _channel_measurements is None: + _channel_measurements = {} + self._measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = _measurements + self._measured_qubits: Dict[ + 'cirq.MeasurementKey', Tuple['cirq.Qid', ...] + ] = _measured_qubits + self._channel_measurements: Dict['cirq.MeasurementKey', int] = _channel_measurements + self._measurement_types: Dict[ + 'cirq.MeasurementKey', 'cirq.MeasurementType' + ] = _measurement_types + + @property + def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: + """Gets the a mapping from measurement key to measurement.""" + return self._measurements + + @property + def channel_measurements(self) -> Mapping['cirq.MeasurementKey', int]: + """Gets the a mapping from measurement key to channel measurement.""" + return self._channel_measurements + + @property + def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', ...]]: + """Gets the a mapping from measurement key to the qubits measured.""" + return self._measured_qubits + + @property + def measurement_types(self) -> Mapping['cirq.MeasurementKey', 'cirq.MeasurementType']: + """Gets the a mapping from measurement key to the measurement type.""" + return self._measurement_types + + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + return tuple(self._measurement_types.keys()) + + def record_measurement( + self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] + ): + if len(measurement) != len(qubits): + raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') + if key in self._measurement_types: + raise ValueError(f"Measurement already logged to key {key}") + self._measurement_types[key] = MeasurementType.MEASUREMENT + self._measurements[key] = tuple(measurement) + self._measured_qubits[key] = tuple(qubits) + + def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): + if key in self._measurement_types: + raise ValueError(f"Measurement already logged to key {key}") + self._measurement_types[key] = MeasurementType.CHANNEL + self._channel_measurements[key] = measurement + + def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + return ( + self._measurements[key] + if self._measurement_types[key] == MeasurementType.MEASUREMENT + else (self._channel_measurements[key],) + ) + + def get_int(self, key: 'cirq.MeasurementKey') -> int: + if key not in self._measurement_types: + raise KeyError(f'The measurement key {key} is not in {self._measurements}') + measurement_type = self._measurement_types[key] + if measurement_type == MeasurementType.CHANNEL: + return self._channel_measurements[key] + if key not in self._measured_qubits: + return digits.big_endian_bits_to_int(self._measurements[key]) + return digits.big_endian_digits_to_int( + self._measurements[key], base=[q.dimension for q in self._measured_qubits[key]] + ) + + def copy(self): + return ClassicalDataDictionaryStore( + _measurements=self._measurements.copy(), + _measured_qubits=self._measured_qubits.copy(), + _channel_measurements=self._channel_measurements.copy(), + _measurement_types=self._measurement_types.copy(), + ) + + def _json_dict_(self): + return { + 'measurements': list(self.measurements.items()), + 'measured_qubits': list(self.measured_qubits.items()), + 'channel_measurements': list(self.channel_measurements.items()), + 'measurement_types': list(self.measurement_types.items()), + } + + @classmethod + def _from_json_dict_( + cls, measurements, measured_qubits, channel_measurements, measurement_types, **kwargs + ): + return cls( + _measurements=dict(measurements), + _measured_qubits=dict(measured_qubits), + _channel_measurements=dict(channel_measurements), + _measurement_types=dict(measurement_types), + ) + + def __repr__(self): + return ( + f'cirq.ClassicalDataDictionaryStore(_measurements={self.measurements!r},' + f' _measured_qubits={self.measured_qubits!r},' + f' _channel_measurements={self.channel_measurements!r},' + f' _measurement_types={self.measurement_types!r})' + ) + + def _value_equality_values_(self): + return ( + self._measurements, + self._channel_measurements, + self._measurement_types, + self._measured_qubits, + ) diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py new file mode 100644 index 00000000000..00cfe475d0e --- /dev/null +++ b/cirq-core/cirq/value/classical_data_test.py @@ -0,0 +1,136 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import cirq + +mkey_m = cirq.MeasurementKey('m') +mkey_c = cirq.MeasurementKey('c') +two_qubits = tuple(cirq.LineQubit.range(2)) + + +def test_init(): + cd = cirq.ClassicalDataDictionaryStore() + assert cd.measurements == {} + assert cd.keys() == () + assert cd.measured_qubits == {} + assert cd.channel_measurements == {} + assert cd.measurement_types == {} + cd = cirq.ClassicalDataDictionaryStore( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + ) + assert cd.measurements == {mkey_m: (0, 1)} + assert cd.keys() == (mkey_m, mkey_c) + assert cd.measured_qubits == {mkey_m: two_qubits} + assert cd.channel_measurements == {mkey_c: 3} + assert cd.measurement_types == { + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + } + + +def test_record_measurement(): + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + assert cd.measurements == {mkey_m: (0, 1)} + assert cd.keys() == (mkey_m,) + assert cd.measured_qubits == {mkey_m: two_qubits} + + +def test_record_measurement_errors(): + cd = cirq.ClassicalDataDictionaryStore() + with pytest.raises(ValueError, match='3 measurements but 2 qubits'): + cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) + cd.record_measurement(mkey_m, (0, 1), two_qubits) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) + + +def test_record_channel_measurement(): + cd = cirq.ClassicalDataDictionaryStore() + cd.record_channel_measurement(mkey_m, 1) + assert cd.channel_measurements == {mkey_m: 1} + assert cd.keys() == (mkey_m,) + + +def test_record_channel_measurement_errors(): + cd = cirq.ClassicalDataDictionaryStore() + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) + + +def test_get_int(): + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + assert cd.get_int(mkey_m) == 1 + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (1, 1), two_qubits) + assert cd.get_int(mkey_m) == 3 + cd = cirq.ClassicalDataDictionaryStore() + cd.record_channel_measurement(mkey_m, 1) + assert cd.get_int(mkey_m) == 1 + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (1, 1), (cirq.LineQid.range(2, dimension=3))) + assert cd.get_int(mkey_m) == 4 + cd = cirq.ClassicalDataDictionaryStore() + with pytest.raises(KeyError, match='The measurement key m is not in {}'): + cd.get_int(mkey_m) + + +def test_copy(): + cd = cirq.ClassicalDataDictionaryStore( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + _measurement_types={ + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + }, + ) + cd1 = cd.copy() + assert cd1 is not cd + assert cd1 == cd + assert cd1.measurements is not cd.measurements + assert cd1.measurements == cd.measurements + assert cd1.measured_qubits is not cd.measured_qubits + assert cd1.measured_qubits == cd.measured_qubits + assert cd1.channel_measurements is not cd.channel_measurements + assert cd1.channel_measurements == cd.channel_measurements + assert cd1.measurement_types is not cd.measurement_types + assert cd1.measurement_types == cd.measurement_types + + +def test_repr(): + cd = cirq.ClassicalDataDictionaryStore( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + _measurement_types={ + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + }, + ) + cirq.testing.assert_equivalent_repr(cd) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index ef432b7506f..7c594eb2d95 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -14,13 +14,13 @@ import abc import dataclasses -from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, FrozenSet +from typing import Dict, Tuple, TYPE_CHECKING, FrozenSet import sympy from cirq._compat import proper_repr from cirq.protocols import json_serialization, measurement_key_protocol as mkp -from cirq.value import digits, measurement_key +from cirq.value import measurement_key if TYPE_CHECKING: import cirq @@ -39,7 +39,10 @@ def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.Measure """Replaces the control keys.""" @abc.abstractmethod - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + def resolve( + self, + classical_data: 'cirq.ClassicalDataStoreReader', + ) -> bool: """Resolves the condition based on the measurements.""" @property @@ -98,11 +101,13 @@ def __str__(self): def __repr__(self): return f'cirq.KeyCondition({self.key!r})' - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: - key = str(self.key) - if key not in measurements: - raise ValueError(f'Measurement key {key} missing when testing classical control') - return any(measurements[key]) + def resolve( + self, + classical_data: 'cirq.ClassicalDataStoreReader', + ) -> bool: + if self.key not in classical_data.keys(): + raise ValueError(f'Measurement key {self.key} missing when testing classical control') + return classical_data.get_int(self.key) != 0 def _json_dict_(self): return json_serialization.dataclass_json_dict(self) @@ -143,15 +148,15 @@ def __str__(self): def __repr__(self): return f'cirq.SympyCondition({proper_repr(self.expr)})' - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: - missing = [str(k) for k in self.keys if str(k) not in measurements] + def resolve( + self, + classical_data: 'cirq.ClassicalDataStoreReader', + ) -> bool: + missing = [str(k) for k in self.keys if k not in classical_data.keys()] if missing: raise ValueError(f'Measurement keys {missing} missing when testing classical control') - def value(k): - return digits.big_endian_bits_to_int(measurements[str(k)]) - - replacements = {str(k): value(k) for k in self.keys} + replacements = {str(k): classical_data.get_int(k) for k in self.keys} return bool(self.expr.subs(replacements)) def _json_dict_(self): diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index fd80033a29a..e92029b1bfb 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -42,22 +42,26 @@ def test_key_condition_repr(): def test_key_condition_resolve(): - assert init_key_condition.resolve({'0:a': [1]}) - assert init_key_condition.resolve({'0:a': [2]}) - assert init_key_condition.resolve({'0:a': [0, 1]}) - assert init_key_condition.resolve({'0:a': [1, 0]}) - assert not init_key_condition.resolve({'0:a': [0]}) - assert not init_key_condition.resolve({'0:a': [0, 0]}) - assert not init_key_condition.resolve({'0:a': []}) - assert not init_key_condition.resolve({'0:a': [0], 'b': [1]}) + def resolve(measurements): + classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) + return init_key_condition.resolve(classical_data) + + assert resolve({'0:a': [1]}) + assert resolve({'0:a': [2]}) + assert resolve({'0:a': [0, 1]}) + assert resolve({'0:a': [1, 0]}) + assert not resolve({'0:a': [0]}) + assert not resolve({'0:a': [0, 0]}) + assert not resolve({'0:a': []}) + assert not resolve({'0:a': [0], 'b': [1]}) with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = init_key_condition.resolve({}) + _ = resolve({}) with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = init_key_condition.resolve({'0:b': [1]}) + _ = resolve({'0:b': [1]}) def test_key_condition_qasm(): @@ -80,24 +84,28 @@ def test_sympy_condition_repr(): def test_sympy_condition_resolve(): - assert init_sympy_condition.resolve({'0:a': [1]}) - assert init_sympy_condition.resolve({'0:a': [2]}) - assert init_sympy_condition.resolve({'0:a': [0, 1]}) - assert init_sympy_condition.resolve({'0:a': [1, 0]}) - assert not init_sympy_condition.resolve({'0:a': [0]}) - assert not init_sympy_condition.resolve({'0:a': [0, 0]}) - assert not init_sympy_condition.resolve({'0:a': []}) - assert not init_sympy_condition.resolve({'0:a': [0], 'b': [1]}) + def resolve(measurements): + classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) + return init_sympy_condition.resolve(classical_data) + + assert resolve({'0:a': [1]}) + assert resolve({'0:a': [2]}) + assert resolve({'0:a': [0, 1]}) + assert resolve({'0:a': [1, 0]}) + assert not resolve({'0:a': [0]}) + assert not resolve({'0:a': [0, 0]}) + assert not resolve({'0:a': []}) + assert not resolve({'0:a': [0], 'b': [1]}) with pytest.raises( ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = init_sympy_condition.resolve({}) + _ = resolve({}) with pytest.raises( ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = init_sympy_condition.resolve({'0:b': [1]}) + _ = resolve({'0:b': [1]}) def test_sympy_condition_qasm(): diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index ee4c12bb051..e53eac47fde 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -77,6 +77,16 @@ def __hash__(self): object.__setattr__(self, '_hash', hash(str(self))) return self._hash + def __lt__(self, other): + if isinstance(other, MeasurementKey): + if self.path != other.path: + return self.path < other.path + return self.name < other.name + return NotImplemented + + def __le__(self, other): + return self == other or self < other + def _json_dict_(self): return { 'name': self.name, diff --git a/cirq-core/cirq/value/measurement_key_test.py b/cirq-core/cirq/value/measurement_key_test.py index c7f01de7d9a..e04a8be9c62 100644 --- a/cirq-core/cirq/value/measurement_key_test.py +++ b/cirq-core/cirq/value/measurement_key_test.py @@ -98,3 +98,22 @@ def test_with_measurement_key_mapping(): mkey3 = cirq.with_measurement_key_mapping(mkey3, {'new_key': 'newer_key'}) assert mkey3.name == 'newer_key' assert mkey3.path == ('a',) + + +def test_compare(): + assert cirq.MeasurementKey('a') < cirq.MeasurementKey('b') + assert cirq.MeasurementKey('a') <= cirq.MeasurementKey('b') + assert cirq.MeasurementKey('a') <= cirq.MeasurementKey('a') + assert cirq.MeasurementKey('b') > cirq.MeasurementKey('a') + assert cirq.MeasurementKey('b') >= cirq.MeasurementKey('a') + assert cirq.MeasurementKey('a') >= cirq.MeasurementKey('a') + assert not cirq.MeasurementKey('a') > cirq.MeasurementKey('b') + assert not cirq.MeasurementKey('a') >= cirq.MeasurementKey('b') + assert not cirq.MeasurementKey('b') < cirq.MeasurementKey('a') + assert not cirq.MeasurementKey('b') <= cirq.MeasurementKey('a') + assert cirq.MeasurementKey(path=(), name='b') < cirq.MeasurementKey(path=('0',), name='a') + assert cirq.MeasurementKey(path=('0',), name='n') < cirq.MeasurementKey(path=('1',), name='a') + with pytest.raises(TypeError): + _ = cirq.MeasurementKey('a') < 'b' + with pytest.raises(TypeError): + _ = cirq.MeasurementKey('a') <= 'b' diff --git a/cirq-google/cirq_google/calibration/engine_simulator.py b/cirq-google/cirq_google/calibration/engine_simulator.py index d2c383f059b..29bd90f0913 100644 --- a/cirq-google/cirq_google/calibration/engine_simulator.py +++ b/cirq-google/cirq_google/calibration/engine_simulator.py @@ -474,7 +474,7 @@ def _create_partial_act_on_args( self, initial_state: Union[int, cirq.ActOnStateVectorArgs], qubits: Sequence[cirq.Qid], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> cirq.ActOnStateVectorArgs: # Needs an implementation since it's abstract but will never actually be called. raise NotImplementedError()