diff --git a/cirq-core/cirq/work/observable_grouping.py b/cirq-core/cirq/work/observable_grouping.py index 2b6ab0e8b01..74f8e392477 100644 --- a/cirq-core/cirq/work/observable_grouping.py +++ b/cirq-core/cirq/work/observable_grouping.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Dict, List, TYPE_CHECKING, cast +from typing import Iterable, Dict, List, TYPE_CHECKING, cast, Callable from cirq import ops, value from cirq.work.observable_settings import InitObsSetting, _max_weight_state, _max_weight_observable @@ -20,6 +20,8 @@ if TYPE_CHECKING: pass +GROUPER_T = Callable[[Iterable[InitObsSetting]], Dict[InitObsSetting, List[InitObsSetting]]] + def group_settings_greedy( settings: Iterable[InitObsSetting], diff --git a/cirq-core/cirq/work/observable_measurement.py b/cirq-core/cirq/work/observable_measurement.py index 6df8dcc5ed2..d880cef70a0 100644 --- a/cirq-core/cirq/work/observable_measurement.py +++ b/cirq-core/cirq/work/observable_measurement.py @@ -18,16 +18,33 @@ import os import tempfile import warnings -from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence +from typing import ( + Optional, + Union, + Iterable, + Dict, + List, + Tuple, + TYPE_CHECKING, + Set, + Sequence, + Any, +) import numpy as np +import pandas as pd import sympy -from cirq import circuits, study, ops, value +from cirq import circuits, study, ops, value, protocols from cirq._doc import document -from cirq.protocols import json_serializable_dataclass, to_json -from cirq.work.observable_measurement_data import BitstringAccumulator +from cirq.work.observable_grouping import group_settings_greedy, GROUPER_T +from cirq.work.observable_measurement_data import ( + BitstringAccumulator, + ObservableMeasuredResult, + flatten_grouped_results, +) from cirq.work.observable_settings import ( InitObsSetting, + observables_to_settings, _MeasurementSpec, ) @@ -47,7 +64,7 @@ def _with_parameterized_layers( - circuit: 'cirq.Circuit', + circuit: 'cirq.AbstractCircuit', qubits: Sequence['cirq.Qid'], needs_init_layer: bool, ) -> 'cirq.Circuit': @@ -69,9 +86,9 @@ def _with_parameterized_layers( meas_mom = ops.Moment([ops.measure(*qubits, key='z')]) if needs_init_layer: total_circuit = circuits.Circuit([x_beg_mom, y_beg_mom]) - total_circuit += circuit.copy() + total_circuit += circuit.unfreeze() else: - total_circuit = circuit.copy() + total_circuit = circuit.unfreeze() total_circuit.append([x_end_mom, y_end_mom, meas_mom]) return total_circuit @@ -89,7 +106,7 @@ def more_repetitions(self, accumulator: BitstringAccumulator) -> int: """ -@json_serializable_dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class VarianceStoppingCriteria(StoppingCriteria): """Stop sampling when average variance per term drops below a variance bound.""" @@ -111,8 +128,11 @@ def more_repetitions(self, accumulator: BitstringAccumulator) -> int: return 0 return self.repetitions_per_chunk + def _json_dict_(self): + return protocols.dataclass_json_dict(self) -@json_serializable_dataclass(frozen=True) + +@dataclasses.dataclass(frozen=True) class RepetitionsStoppingCriteria(StoppingCriteria): """Stop sampling when the number of repetitions has been reached.""" @@ -128,6 +148,9 @@ def more_repetitions(self, accumulator: BitstringAccumulator) -> int: to_do_next = min(self.repetitions_per_chunk, todo) return to_do_next + def _json_dict_(self): + return protocols.dataclass_json_dict(self) + _OBS_TO_PARAM_VAL: Dict[Tuple['cirq.Pauli', bool], Tuple[float, float]] = { (ops.X, False): (0, -1 / 2), @@ -170,6 +193,7 @@ def _get_params_for_setting( if we know that `setting.init_state` is the all-zeros state and `needs_init_layer` is False. """ + setting = _pad_setting(setting, qubits) params = {} for qubit, flip in itertools.zip_longest(qubits, flips): if qubit is None or flip is None: @@ -196,7 +220,7 @@ def _get_params_for_setting( def _pad_setting( max_setting: InitObsSetting, - qubits: List['cirq.Qid'], + qubits: Sequence['cirq.Qid'], pad_init_state_with=value.KET_ZERO, pad_obs_with: 'cirq.Gate' = ops.Z, ) -> InitObsSetting: @@ -411,6 +435,52 @@ def _parse_checkpoint_options( return checkpoint_fn, checkpoint_other_fn +@dataclasses.dataclass(frozen=True) +class CheckpointFileOptions: + """Options to configure "checkpointing" to save intermediate results. + + Args: + checkpoint: If set to True, save cumulative raw results at the end + of each iteration of the sampling loop. Load in these results + with `cirq.read_json`. + checkpoint_fn: The filename for the checkpoint file. If `checkpoint` + is set to True and this is not specified, a file in a temporary + directory will be used. + checkpoint_other_fn: The filename for another checkpoint file, which + contains the previous checkpoint. This lets us avoid losing data if + a failure occurs during checkpoint writing. If `checkpoint` + is set to True and this is not specified, a file in a temporary + directory will be used. If `checkpoint` is set to True and + `checkpoint_fn` is specified but this argument is *not* specified, + "{checkpoint_fn}.prev.json" will be used. + """ + + checkpoint: bool = False + checkpoint_fn: Optional[str] = None + checkpoint_other_fn: Optional[str] = None + + def __post_init__(self): + fn, other_fn = _parse_checkpoint_options( + self.checkpoint, self.checkpoint_fn, self.checkpoint_other_fn + ) + object.__setattr__(self, 'checkpoint_fn', fn) + object.__setattr__(self, 'checkpoint_other_fn', other_fn) + + def maybe_to_json(self, obj: Any): + """Call `cirq.to_json with `value` according to the configuration options in this class. + + If `checkpoint=False`, nothing will happen. Otherwise, we will use `checkpoint_fn` and + `checkpoint_other_fn` as the destination JSON file as described in the class docstring. + """ + if not self.checkpoint: + return + assert self.checkpoint_fn is not None, 'mypy' + assert self.checkpoint_other_fn is not None, 'mypy' + if os.path.exists(self.checkpoint_fn): + os.replace(self.checkpoint_fn, self.checkpoint_other_fn) + protocols.to_json(obj, self.checkpoint_fn) + + # pylint: enable=missing-raises-doc def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool: """Helper function to go through init_states and determine if any of them need an @@ -424,17 +494,15 @@ def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting # TODO(#3388) Add documentation for Raises. # pylint: disable=missing-raises-doc def measure_grouped_settings( - circuit: 'cirq.Circuit', + circuit: 'cirq.AbstractCircuit', grouped_settings: Dict[InitObsSetting, List[InitObsSetting]], sampler: 'cirq.Sampler', stopping_criteria: StoppingCriteria, *, readout_symmetrization: bool = False, - circuit_sweep: 'cirq.study.sweepable.SweepLike' = None, + circuit_sweep: 'cirq.Sweepable' = None, readout_calibrations: Optional[BitstringAccumulator] = None, - checkpoint: bool = False, - checkpoint_fn: Optional[str] = None, - checkpoint_other_fn: Optional[str] = None, + checkpoint: CheckpointFileOptions = CheckpointFileOptions(), ) -> List[BitstringAccumulator]: """Measure a suite of grouped InitObsSetting settings. @@ -463,49 +531,29 @@ def measure_grouped_settings( in `circuit`. The total sweep is the product of the circuit sweep with parameter settings for the single-qubit basis-change rotations. readout_calibrations: The result of `calibrate_readout_error`. - checkpoint: If set to True, save cumulative raw results at the end - of each iteration of the sampling loop. Load in these results - with `cirq.read_json`. - checkpoint_fn: The filename for the checkpoint file. If `checkpoint` - is set to True and this is not specified, a file in a temporary - directory will be used. - checkpoint_other_fn: The filename for another checkpoint file, which - contains the previous checkpoint. This lets us avoid losing data if - a failure occurs during checkpoint writing. If `checkpoint` - is set to True and this is not specified, a file in a temporary - directory will be used. If `checkpoint` is set to True and - `checkpoint_fn` is specified but this argument is *not* specified, - "{checkpoint_fn}.prev.json" will be used. + checkpoint: Options to set up optional checkpointing of intermediate + data for each iteration of the sampling loop. See the documentation + for `CheckpointFileOptions` for more. Load in these results with + `cirq.read_json`. """ if readout_calibrations is not None and not readout_symmetrization: raise ValueError("Readout calibration only works if `readout_symmetrization` is enabled.") - checkpoint_fn, checkpoint_other_fn = _parse_checkpoint_options( - checkpoint=checkpoint, checkpoint_fn=checkpoint_fn, checkpoint_other_fn=checkpoint_other_fn - ) qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits}) qubit_to_index = {q: i for i, q in enumerate(qubits)} needs_init_layer = _needs_init_layer(grouped_settings) measurement_param_circuit = _with_parameterized_layers(circuit, qubits, needs_init_layer) - grouped_settings = { - _pad_setting(max_setting, qubits): settings - for max_setting, settings in grouped_settings.items() - } - circuit_sweep = study.UnitSweep if circuit_sweep is None else study.to_sweep(circuit_sweep) # meas_spec provides a key for accumulators. # meas_specs_todo is a mutable list. We will pop things from it as various # specs are measured to the satisfaction of the stopping criteria accumulators = {} meas_specs_todo = [] - for max_setting, circuit_params in itertools.product( - grouped_settings.keys(), circuit_sweep.param_tuples() + for max_setting, param_resolver in itertools.product( + grouped_settings.keys(), study.to_resolvers(circuit_sweep) ): - # The type annotation for Param is just `Iterable`. - # We make sure that it's truly a tuple. - circuit_params = dict(circuit_params) - + circuit_params = param_resolver.param_dict meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params) accumulator = BitstringAccumulator( meas_spec=meas_spec, @@ -551,14 +599,131 @@ def measure_grouped_settings( bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z']) accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe')) - if checkpoint: - assert checkpoint_fn is not None, 'mypy' - assert checkpoint_other_fn is not None, 'mypy' - if os.path.exists(checkpoint_fn): - os.replace(checkpoint_fn, checkpoint_other_fn) - to_json(list(accumulators.values()), checkpoint_fn) + checkpoint.maybe_to_json(list(accumulators.values())) return list(accumulators.values()) # pylint: enable=missing-raises-doc + + +_GROUPING_FUNCS: Dict[str, GROUPER_T] = { + 'greedy': group_settings_greedy, +} + + +def _parse_grouper(grouper: Union[str, GROUPER_T] = group_settings_greedy) -> GROUPER_T: + """Logic for turning a named grouper into one of the build-in groupers in support of the + high-level `measure_observables` API.""" + if isinstance(grouper, str): + try: + grouper = _GROUPING_FUNCS[grouper.lower()] + except KeyError: + raise ValueError(f"Unknown grouping function {grouper}") + return grouper + + +def _get_all_qubits( + circuit: 'cirq.AbstractCircuit', + observables: Iterable['cirq.PauliString'], +) -> List['cirq.Qid']: + """Helper function for `measure_observables` to get all qubits from a circuit and a + collection of observables.""" + qubit_set = set() + for obs in observables: + qubit_set |= set(obs.qubits) + qubit_set |= circuit.all_qubits() + return sorted(qubit_set) + + +def measure_observables( + circuit: 'cirq.AbstractCircuit', + observables: Iterable['cirq.PauliString'], + sampler: Union['cirq.Simulator', 'cirq.Sampler'], + stopping_criteria: StoppingCriteria, + *, + readout_symmetrization: bool = False, + circuit_sweep: Optional['cirq.Sweepable'] = None, + grouper: Union[str, GROUPER_T] = group_settings_greedy, + readout_calibrations: Optional[BitstringAccumulator] = None, + checkpoint: CheckpointFileOptions = CheckpointFileOptions(), +) -> List[ObservableMeasuredResult]: + """Measure a collection of PauliString observables for a state prepared by a Circuit. + + If you need more control over the process, please see `measure_grouped_settings` for a + lower-level API. If you would like your results returned as a pandas DataFrame, + please see `measure_observables_df`. + + Args: + circuit: The circuit used to prepare the state to measure. This can contain parameters, + in which case you should also specify `circuit_sweep`. + observables: A collection of PauliString observables to measure. These will be grouped + into simultaneously-measurable groups, see `grouper` argument. + sampler: The sampler. + stopping_criteria: A StoppingCriteria object to indicate how precisely to sample + measurements for estimating observables. + readout_symmetrization: If set to True, each run will be split into two: one normal and + one where a bit flip is incorporated prior to measurement. In the latter case, the + measured bit will be flipped back classically and accumulated together. This causes + readout error to appear symmetric, p(0|0) = p(1|1). + circuit_sweep: Additional parameter sweeps for parameters contained in `circuit`. The + total sweep is the product of the circuit sweep with parameter settings for the + single-qubit basis-change rotations. + grouper: Either "greedy" or a function that groups lists of `InitObsSetting`. See the + documentation for the `grouped_settings` argument of `measure_grouped_settings` for + full details. + readout_calibrations: The result of `calibrate_readout_error`. + checkpoint: Options to set up optional checkpointing of intermediate data for each + iteration of the sampling loop. See the documentation for `CheckpointFileOptions` for + more. Load in these results with `cirq.read_json`. + + Returns: + A list of ObservableMeasuredResult; one for each input PauliString. + """ + qubits = _get_all_qubits(circuit, observables) + settings = list(observables_to_settings(observables, qubits)) + actual_grouper = _parse_grouper(grouper) + grouped_settings = actual_grouper(settings) + + accumulators = measure_grouped_settings( + circuit=circuit, + grouped_settings=grouped_settings, + sampler=sampler, + stopping_criteria=stopping_criteria, + circuit_sweep=circuit_sweep, + readout_symmetrization=readout_symmetrization, + readout_calibrations=readout_calibrations, + checkpoint=checkpoint, + ) + return flatten_grouped_results(accumulators) + + +def measure_observables_df( + circuit: 'cirq.AbstractCircuit', + observables: Iterable['cirq.PauliString'], + sampler: Union['cirq.Simulator', 'cirq.Sampler'], + stopping_criteria: StoppingCriteria, + *, + readout_symmetrization: bool = False, + circuit_sweep: Optional['cirq.Sweepable'] = None, + grouper: Union[str, GROUPER_T] = group_settings_greedy, + readout_calibrations: Optional[BitstringAccumulator] = None, + checkpoint: CheckpointFileOptions = CheckpointFileOptions(), +): + """Measure observables and return resulting data as a Pandas dataframe. + + Please see `measure_observables` for argument documentation. + """ + results = measure_observables( + circuit=circuit, + observables=observables, + sampler=sampler, + stopping_criteria=stopping_criteria, + readout_symmetrization=readout_symmetrization, + circuit_sweep=circuit_sweep, + grouper=grouper, + readout_calibrations=readout_calibrations, + checkpoint=checkpoint, + ) + df = pd.DataFrame(res.as_dict() for res in results) + return df diff --git a/cirq-core/cirq/work/observable_measurement_data.py b/cirq-core/cirq/work/observable_measurement_data.py index fa9f431e1b1..61f96f9faf7 100644 --- a/cirq-core/cirq/work/observable_measurement_data.py +++ b/cirq-core/cirq/work/observable_measurement_data.py @@ -14,11 +14,10 @@ import dataclasses import datetime -from typing import Dict, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, TYPE_CHECKING, Iterable, Any import numpy as np - -from cirq import protocols, ops +from cirq import ops, protocols from cirq._compat import proper_repr from cirq.work.observable_settings import ( InitObsSetting, @@ -81,12 +80,12 @@ def _stats_from_measurements( return obs_mean.item(), obs_err.item() -@protocols.json_serializable_dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ObservableMeasuredResult: """The result of an observable measurement. - Please see `flatten_grouped_results` or `BitstringAccumulator.results` for information on how - to get these from `measure_observables` return values. + A list of these is returned by `measure_observables`, or see `flatten_grouped_results` for + transformation of `measure_grouped_settings` BitstringAccumulators into these objects. This is a flattened form of the contents of a `BitstringAccumulator` which may group many simultaneously-observable settings into one object. As such, `BitstringAccumulator` has more @@ -110,7 +109,7 @@ class ObservableMeasuredResult: def __repr__(self): # I wish we could use the default dataclass __repr__ but - # we need to prefix our class name with `cirq.work.`A + # we need to prefix our class name with `cirq.work.` return ( f'cirq.work.ObservableMeasuredResult(' f'setting={self.setting!r}, ' @@ -132,6 +131,25 @@ def observable(self): def stddev(self): return np.sqrt(self.variance) + def as_dict(self) -> Dict[str, Any]: + """Return the contents of this class as a dictionary. + + This makes records suitable for construction of a Pandas dataframe. The circuit parameters + are flattened into the top-level of this dictionary. + """ + record = dataclasses.asdict(self) + del record['circuit_params'] + del record['setting'] + record['init_state'] = self.init_state + record['observable'] = self.observable + + circuit_param_dict = {f'param.{k}': v for k, v in self.circuit_params.items()} + record.update(**circuit_param_dict) + return record + + def _json_dict_(self): + return protocols.dataclass_json_dict(self) + def _setting_to_z_observable(setting: InitObsSetting): qubits = setting.observable.qubits @@ -271,7 +289,7 @@ def n_repetitions(self): return len(self.bitstrings) @property - def results(self): + def results(self) -> Iterable[ObservableMeasuredResult]: """Yield individual setting results as `ObservableMeasuredResult` objects.""" for setting in self._simul_settings: @@ -291,10 +309,7 @@ def records(self): after chaining these results with those from other BitstringAccumulators. """ for result in self.results: - record = dataclasses.asdict(result) - del record['circuit_params'] - record.update(**self._meas_spec.circuit_params) - yield record + yield result.as_dict() def _json_dict_(self): from cirq.study.result import _pack_digits diff --git a/cirq-core/cirq/work/observable_measurement_data_test.py b/cirq-core/cirq/work/observable_measurement_data_test.py index cc947f443a6..9eb344e5f8b 100644 --- a/cirq-core/cirq/work/observable_measurement_data_test.py +++ b/cirq-core/cirq/work/observable_measurement_data_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import datetime import time @@ -90,7 +91,7 @@ def test_observable_measured_result(): mean=0, variance=5 ** 2, repetitions=4, - circuit_params={}, + circuit_params={'phi': 52}, ) assert omr.stddev == 5 assert omr.observable == cirq.Y(a) * cirq.Y(b) @@ -98,6 +99,33 @@ def test_observable_measured_result(): cirq.testing.assert_equivalent_repr(omr) + assert omr.as_dict() == { + 'init_state': cirq.Z(a) * cirq.Z(b), + 'observable': cirq.Y(a) * cirq.Y(b), + 'mean': 0, + 'variance': 25, + 'repetitions': 4, + 'param.phi': 52, + } + omr2 = dataclasses.replace( + omr, + circuit_params={ + 'phi': 52, + 'observable': 3.14, # this would be a bad but legal parameter name + 'param.phi': -1, + }, + ) + assert omr2.as_dict() == { + 'init_state': cirq.Z(a) * cirq.Z(b), + 'observable': cirq.Y(a) * cirq.Y(b), + 'mean': 0, + 'variance': 25, + 'repetitions': 4, + 'param.phi': 52, + 'param.observable': 3.14, + 'param.param.phi': -1, + } + @pytest.fixture() def example_bsa() -> 'cw.BitstringAccumulator': diff --git a/cirq-core/cirq/work/observable_measurement_test.py b/cirq-core/cirq/work/observable_measurement_test.py index 51c19093522..d76eb7a1f89 100644 --- a/cirq-core/cirq/work/observable_measurement_test.py +++ b/cirq-core/cirq/work/observable_measurement_test.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import tempfile +from typing import Iterable, Dict, List import numpy as np import pytest import cirq import cirq.work as cw -from cirq.work import _MeasurementSpec, BitstringAccumulator +from cirq.work import _MeasurementSpec, BitstringAccumulator, group_settings_greedy, InitObsSetting from cirq.work.observable_measurement import ( _with_parameterized_layers, _get_params_for_setting, @@ -28,6 +29,11 @@ _check_meas_specs_still_todo, StoppingCriteria, _parse_checkpoint_options, + measure_observables_df, + CheckpointFileOptions, + VarianceStoppingCriteria, + measure_observables, + RepetitionsStoppingCriteria, ) @@ -448,8 +454,7 @@ def test_measure_grouped_settings(with_circuit_sweep, checkpoint, tmpdir): sampler=cirq.Simulator(), stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500), circuit_sweep=ss, - checkpoint=checkpoint, - checkpoint_fn=checkpoint_fn, + checkpoint=CheckpointFileOptions(checkpoint=checkpoint, checkpoint_fn=checkpoint_fn), ) if with_circuit_sweep: for result in results: @@ -504,20 +509,91 @@ def test_measure_grouped_settings_read_checkpoint(tmpdir): grouped_settings=grouped_settings, sampler=cirq.Simulator(), stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500), - checkpoint=True, - checkpoint_fn=f'{tmpdir}/obs.json', - checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename + checkpoint=CheckpointFileOptions( + checkpoint=True, + checkpoint_fn=f'{tmpdir}/obs.json', + checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename + ), ) _ = cw.measure_grouped_settings( circuit=circuit, grouped_settings=grouped_settings, sampler=cirq.Simulator(), stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500), - checkpoint=True, - checkpoint_fn=f'{tmpdir}/obs.json', - checkpoint_other_fn=f'{tmpdir}/obs.prev.json', + checkpoint=CheckpointFileOptions( + checkpoint=True, + checkpoint_fn=f'{tmpdir}/obs.json', + checkpoint_other_fn=f'{tmpdir}/obs.prev.json', + ), ) results = cirq.read_json(f'{tmpdir}/obs.json') (result,) = results # one group assert result.n_repetitions == 1_000 assert result.means() == [1.0] + + +Q = cirq.NamedQubit('q') + + +@pytest.mark.parametrize( + ['circuit', 'observable'], + [ + (cirq.Circuit(cirq.X(Q) ** 0.2), cirq.Z(Q)), + (cirq.Circuit(cirq.X(Q) ** -0.5, cirq.Z(Q) ** 0.2), cirq.Y(Q)), + (cirq.Circuit(cirq.Y(Q) ** 0.5, cirq.Z(Q) ** 0.2), cirq.X(Q)), + ], +) +def test_XYZ_point8(circuit, observable): + # each circuit, observable combination should result in the observable value of 0.8 + df = measure_observables_df( + circuit, + [observable], + cirq.Simulator(seed=52), + stopping_criteria=VarianceStoppingCriteria(1e-3 ** 2), + ) + assert len(df) == 1, 'one observable' + mean = df.loc[0]['mean'] + np.testing.assert_allclose(0.8, mean, atol=1e-2) + + +def _each_in_its_own_group_grouper( + settings: Iterable[InitObsSetting], +) -> Dict[InitObsSetting, List[InitObsSetting]]: + return {setting: [setting] for setting in settings} + + +@pytest.mark.parametrize( + 'grouper', ['greedy', group_settings_greedy, _each_in_its_own_group_grouper] +) +def test_measure_observable_grouper(grouper): + circuit = cirq.Circuit(cirq.X(Q) ** 0.2) + observables = [ + cirq.Z(Q), + cirq.Z(cirq.NamedQubit('q2')), + ] + results = measure_observables( + circuit, + observables, + cirq.Simulator(seed=52), + stopping_criteria=RepetitionsStoppingCriteria(50_000), + grouper=grouper, + ) + assert len(results) == 2, 'two observables' + np.testing.assert_allclose(0.8, results[0].mean, atol=0.05) + np.testing.assert_allclose(1, results[1].mean, atol=1e-9) + + +def test_measure_observable_bad_grouper(): + circuit = cirq.Circuit(cirq.X(Q) ** 0.2) + observables = [ + cirq.Z(Q), + cirq.Z(cirq.NamedQubit('q2')), + ] + with pytest.raises(ValueError, match=r'Unknown grouping function'): + _ = measure_observables( + circuit, + observables, + cirq.Simulator(seed=52), + stopping_criteria=RepetitionsStoppingCriteria(50_000), + grouper='super fancy grouper', + ) diff --git a/cirq-core/cirq/work/observable_settings.py b/cirq-core/cirq/work/observable_settings.py index c5378342001..e135f336c68 100644 --- a/cirq-core/cirq/work/observable_settings.py +++ b/cirq-core/cirq/work/observable_settings.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Iterable, Dict, TYPE_CHECKING, Tuple +import dataclasses +from typing import Union, Iterable, Dict, TYPE_CHECKING, ItemsView, Tuple, FrozenSet -from cirq import ops, value +from cirq import ops, value, protocols if TYPE_CHECKING: import cirq from cirq.value.product_state import _NamedOneQubitState - # Workaround for mypy custom dataclasses - from dataclasses import dataclass as json_serializable_dataclass -else: - from cirq.protocols import json_serializable_dataclass - -@json_serializable_dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class InitObsSetting: """A pair of initial state and observable. @@ -59,6 +55,9 @@ def __repr__(self): f'observable={self.observable!r})' ) + def _json_dict_(self): + return protocols.dataclass_json_dict(self) + def _max_weight_observable(observables: Iterable[ops.PauliString]) -> Union[None, ops.PauliString]: """Create a new observable that is compatible with all input observables @@ -135,7 +134,9 @@ def _fix_precision(val: float, precision) -> int: return int(val * precision) -def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7): +def _hashable_param( + param_tuples: ItemsView[str, float], precision=1e7 +) -> FrozenSet[Tuple[str, float]]: """Hash circuit parameters using fixed precision. Circuit parameters can be floats but we also need to use them as @@ -144,7 +145,7 @@ def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7): return frozenset((k, _fix_precision(v, precision)) for k, v in param_tuples) -@json_serializable_dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class _MeasurementSpec: """An encapsulation of all the specifications for one run of a quantum processor. @@ -165,3 +166,6 @@ def __repr__(self): f'cirq.work._MeasurementSpec(max_setting={self.max_setting!r}, ' f'circuit_params={self.circuit_params!r})' ) + + def _json_dict_(self): + return protocols.dataclass_json_dict(self) diff --git a/cirq-core/cirq/work/sampler.py b/cirq-core/cirq/work/sampler.py index e718cd75471..6df64b9aaa9 100644 --- a/cirq-core/cirq/work/sampler.py +++ b/cirq-core/cirq/work/sampler.py @@ -13,12 +13,18 @@ # limitations under the License. """Abstract base class for things sampling quantum circuits.""" -from typing import List, Optional, Sequence, TYPE_CHECKING, Union import abc +from typing import List, Optional, TYPE_CHECKING, Union, Dict, FrozenSet, Tuple +from typing import Sequence import pandas as pd - -from cirq import study +from cirq import study, ops +from cirq.work.observable_measurement import ( + measure_observables, + RepetitionsStoppingCriteria, + CheckpointFileOptions, +) +from cirq.work.observable_settings import _hashable_param if TYPE_CHECKING: import cirq @@ -261,4 +267,94 @@ def run_batch( for circuit, params, repetitions in zip(programs, params_list, repetitions) ] - # pylint: enable=missing-raises-doc + def sample_expectation_values( + self, + program: 'cirq.AbstractCircuit', + observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], + *, + num_samples: int, + params: 'cirq.Sweepable' = None, + permit_terminal_measurements: bool = False, + ) -> List[List[float]]: + """Calculates estimated expectation values from samples of a circuit. + + Please see also `cirq.work.measure_observables` for more control over how to measure + a suite of observables. + + This method can be run on any device or simulator that supports circuit sampling. Compare + with `simulate_expectation_values` in simulator.py, which is limited to simulators + but provides exact results. + + Args: + program: The circuit which prepares a state from which we sample expectation values. + observables: A list of observables for which to calculate expectation values. + num_samples: The number of samples to take. Increasing this value increases the + statistical accuracy of the estimate. + params: Parameters to run with the program. + permit_terminal_measurements: If the provided circuit ends in a measurement, this + method will generate an error unless this is set to True. This is meant to + prevent measurements from ruining expectation value calculations. + + Returns: + A list of expectation-value lists. The outer index determines the sweep, and the inner + index determines the observable. For instance, results[1][3] would select the fourth + observable measured in the second sweep. + """ + if num_samples <= 0: + raise ValueError( + f'Expectation values require at least one sample. Received: {num_samples}.' + ) + if not observables: + raise ValueError('At least one observable must be provided.') + if not permit_terminal_measurements and program.are_any_measurements_terminal(): + raise ValueError( + 'Provided circuit has terminal measurements, which may ' + 'skew expectation values. If this is intentional, set ' + 'permit_terminal_measurements=True.' + ) + + # Wrap input into a list of pauli sum + pauli_sums: List['cirq.PauliSum'] = ( + [ops.PauliSum.wrap(o) for o in observables] + if isinstance(observables, List) + else [ops.PauliSum.wrap(observables)] + ) + del observables + + # Flatten Pauli Sum into one big list of Pauli String + # Keep track of which Pauli Sum each one was from. + flat_pstrings: List['cirq.PauliString'] = [] + pstring_to_psum_i: Dict['cirq.PauliString', int] = {} + for psum_i, pauli_sum in enumerate(pauli_sums): + for pstring in pauli_sum: + flat_pstrings.append(pstring) + pstring_to_psum_i[pstring] = psum_i + + # Flatten Circuit Sweep into one big list of Params. + # Keep track of their indices so we can map back. + flat_params: List[Dict[str, float]] = [pr.param_dict for pr in study.to_resolvers(params)] + circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, float]], int] = { + _hashable_param(param.items()): i for i, param in enumerate(flat_params) + } + + obs_meas_results = measure_observables( + circuit=program, + observables=flat_pstrings, + sampler=self, + stopping_criteria=RepetitionsStoppingCriteria(total_repetitions=num_samples), + readout_symmetrization=False, + circuit_sweep=params, + checkpoint=CheckpointFileOptions(checkpoint=False), + ) + + # Results are ordered by how they're grouped. Since we want the (circuit_sweep, pauli_sum) + # nesting structure, we place the measured values according to the back-mappings we set up + # above. We also do the sum operation to aggregate multiple PauliString measured values + # for a given PauliSum. + nested_results: List[List[float]] = [[0] * len(pauli_sums) for _ in range(len(flat_params))] + for res in obs_meas_results: + param_i = circuit_param_to_sweep_i[_hashable_param(res.circuit_params.items())] + psum_i = pstring_to_psum_i[res.setting.observable] + nested_results[param_i][psum_i] += res.mean + + return nested_results diff --git a/cirq-core/cirq/work/sampler_test.py b/cirq-core/cirq/work/sampler_test.py index 3531bcff9f3..334c0b7a46a 100644 --- a/cirq-core/cirq/work/sampler_test.py +++ b/cirq-core/cirq/work/sampler_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for cirq.Sampler.""" +from typing import List + import pytest import duet @@ -199,3 +201,123 @@ def test_sampler_run_batch_bad_input_lengths(): _ = sampler.run_batch( [circuit1, circuit2], params_list=[params1, params2], repetitions=[1, 2, 3] ) + + +def test_sampler_simple_sample_expectation_values(): + a = cirq.LineQubit(0) + sampler = cirq.Simulator() + circuit = cirq.Circuit(cirq.H(a)) + obs = cirq.X(a) + results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000) + + assert np.allclose(results, [[1]]) + + +def test_sampler_sample_expectation_values_calculation(): + class DeterministicImbalancedStateSampler(cirq.Sampler): + """A simple, deterministic mock sampler. + Pretends to sample from a state vector with a 3:1 balance between the + probabilities of the |0) and |1) state. + """ + + def run_sweep( + self, + program: 'cirq.AbstractCircuit', + params: 'cirq.Sweepable', + repetitions: int = 1, + ) -> List['cirq.Result']: + results = np.zeros((repetitions, 1), dtype=bool) + for idx in range(repetitions // 4): + results[idx][0] = 1 + return [ + cirq.Result(params=pr, measurements={'z': results}) + for pr in cirq.study.to_resolvers(params) + ] + + a = cirq.LineQubit(0) + sampler = DeterministicImbalancedStateSampler() + # This circuit is not actually sampled, but the mock sampler above gives + # a reasonable approximation of it. + circuit = cirq.Circuit(cirq.X(a) ** (1 / 3)) + obs = cirq.Z(a) + results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000) + + # (0.75 * 1) + (0.25 * -1) = 0.5 + assert np.allclose(results, [[0.5]]) + + +def test_sampler_sample_expectation_values_multi_param(): + a = cirq.LineQubit(0) + t = sympy.Symbol('t') + sampler = cirq.Simulator(seed=1) + circuit = cirq.Circuit(cirq.X(a) ** t) + obs = cirq.Z(a) + results = sampler.sample_expectation_values( + circuit, [obs], num_samples=5, params=cirq.Linspace('t', 0, 2, 3) + ) + + assert np.allclose(results, [[1], [-1], [1]]) + + +def test_sampler_sample_expectation_values_multi_qubit(): + q = cirq.LineQubit.range(3) + sampler = cirq.Simulator(seed=1) + circuit = cirq.Circuit(cirq.X(q[0]), cirq.X(q[1]), cirq.X(q[2])) + obs = cirq.Z(q[0]) + cirq.Z(q[1]) + cirq.Z(q[2]) + results = sampler.sample_expectation_values(circuit, [obs], num_samples=5) + + assert np.allclose(results, [[-3]]) + + +def test_sampler_sample_expectation_values_composite(): + # Tests multi-{param,qubit} sampling together in one circuit. + q = cirq.LineQubit.range(3) + t = [sympy.Symbol(f't{x}') for x in range(3)] + + sampler = cirq.Simulator(seed=1) + circuit = cirq.Circuit( + cirq.X(q[0]) ** t[0], + cirq.X(q[1]) ** t[1], + cirq.X(q[2]) ** t[2], + ) + + obs = [cirq.Z(q[x]) for x in range(3)] + # t0 is in the inner loop to make bit-ordering easier below. + params = ([{'t0': t0, 't1': t1, 't2': t2} for t2 in [0, 1] for t1 in [0, 1] for t0 in [0, 1]],) + results = sampler.sample_expectation_values( + circuit, + obs, + num_samples=5, + params=params, + ) + + assert len(results) == 8 + assert np.allclose( + results, + [ + [+1, +1, +1], + [-1, +1, +1], + [+1, -1, +1], + [-1, -1, +1], + [+1, +1, -1], + [-1, +1, -1], + [+1, -1, -1], + [-1, -1, -1], + ], + ) + + +def test_sampler_simple_sample_expectation_requirements(): + a = cirq.LineQubit(0) + sampler = cirq.Simulator(seed=1) + circuit = cirq.Circuit(cirq.H(a)) + obs = cirq.X(a) + with pytest.raises(ValueError, match='at least one sample'): + _ = sampler.sample_expectation_values(circuit, [obs], num_samples=0) + + with pytest.raises(ValueError, match='At least one observable'): + _ = sampler.sample_expectation_values(circuit, [], num_samples=1) + + circuit.append(cirq.measure(a, key='out')) + with pytest.raises(ValueError, match='permit_terminal_measurements'): + _ = sampler.sample_expectation_values(circuit, [obs], num_samples=1)