-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Obs] 4.5 - High-level API #4392
Changes from 10 commits
dcb59f3
f92fc65
0aa1e51
66fe36a
45e9d74
ea6bbfa
a65e143
500aa75
2b94b7b
b78d7ee
25cdc1c
deb9ea6
7cd0c12
6974b7a
6d42a61
307c58e
9bc0a66
c8e9bc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,22 +18,42 @@ | |
import os | ||
import tempfile | ||
import warnings | ||
from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence | ||
from typing import ( | ||
Optional, | ||
Union, | ||
Iterable, | ||
Dict, | ||
List, | ||
Tuple, | ||
TYPE_CHECKING, | ||
Set, | ||
Sequence, | ||
Any, | ||
) | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import sympy | ||
from cirq import circuits, study, ops, value | ||
from cirq import circuits, study, ops, value, protocols | ||
from cirq._doc import document | ||
from cirq.protocols import json_serializable_dataclass, to_json | ||
from cirq.work.observable_measurement_data import BitstringAccumulator | ||
from cirq.work.observable_grouping import group_settings_greedy, GROUPER_T | ||
from cirq.work.observable_measurement_data import ( | ||
BitstringAccumulator, | ||
ObservableMeasuredResult, | ||
flatten_grouped_results, | ||
) | ||
from cirq.work.observable_settings import ( | ||
InitObsSetting, | ||
observables_to_settings, | ||
_MeasurementSpec, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
import cirq | ||
from cirq.value.product_state import _NamedOneQubitState | ||
from dataclasses import dataclass as json_serializable_dataclass | ||
else: | ||
from cirq.protocols import json_serializable_dataclass | ||
|
||
MAX_REPETITIONS_PER_JOB = 3_000_000 | ||
document( | ||
|
@@ -47,7 +67,7 @@ | |
|
||
|
||
def _with_parameterized_layers( | ||
circuit: 'cirq.Circuit', | ||
circuit: 'cirq.AbstractCircuit', | ||
qubits: Sequence['cirq.Qid'], | ||
needs_init_layer: bool, | ||
) -> 'cirq.Circuit': | ||
|
@@ -69,9 +89,9 @@ def _with_parameterized_layers( | |
meas_mom = ops.Moment([ops.measure(*qubits, key='z')]) | ||
if needs_init_layer: | ||
total_circuit = circuits.Circuit([x_beg_mom, y_beg_mom]) | ||
total_circuit += circuit.copy() | ||
total_circuit += circuit.unfreeze() | ||
else: | ||
total_circuit = circuit.copy() | ||
total_circuit = circuit.unfreeze() | ||
total_circuit.append([x_end_mom, y_end_mom, meas_mom]) | ||
return total_circuit | ||
|
||
|
@@ -170,6 +190,7 @@ def _get_params_for_setting( | |
if we know that `setting.init_state` is the all-zeros state and | ||
`needs_init_layer` is False. | ||
""" | ||
setting = _pad_setting(setting, qubits) | ||
params = {} | ||
for qubit, flip in itertools.zip_longest(qubits, flips): | ||
if qubit is None or flip is None: | ||
|
@@ -196,7 +217,7 @@ def _get_params_for_setting( | |
|
||
def _pad_setting( | ||
max_setting: InitObsSetting, | ||
qubits: List['cirq.Qid'], | ||
qubits: Sequence['cirq.Qid'], | ||
pad_init_state_with=value.KET_ZERO, | ||
pad_obs_with: 'cirq.Gate' = ops.Z, | ||
) -> InitObsSetting: | ||
|
@@ -411,6 +432,52 @@ def _parse_checkpoint_options( | |
return checkpoint_fn, checkpoint_other_fn | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class CheckpointFileOptions: | ||
"""Options to configure "checkpointing" to save intermediate results. | ||
|
||
Args: | ||
checkpoint: If set to True, save cumulative raw results at the end | ||
of each iteration of the sampling loop. Load in these results | ||
with `cirq.read_json`. | ||
checkpoint_fn: The filename for the checkpoint file. If `checkpoint` | ||
is set to True and this is not specified, a file in a temporary | ||
directory will be used. | ||
checkpoint_other_fn: The filename for another checkpoint file, which | ||
contains the previous checkpoint. This lets us avoid losing data if | ||
a failure occurs during checkpoint writing. If `checkpoint` | ||
is set to True and this is not specified, a file in a temporary | ||
directory will be used. If `checkpoint` is set to True and | ||
`checkpoint_fn` is specified but this argument is *not* specified, | ||
"{checkpoint_fn}.prev.json" will be used. | ||
""" | ||
|
||
checkpoint: bool = False | ||
checkpoint_fn: Optional[str] = None | ||
checkpoint_other_fn: Optional[str] = None | ||
|
||
def __post_init__(self): | ||
fn, other_fn = _parse_checkpoint_options( | ||
self.checkpoint, self.checkpoint_fn, self.checkpoint_other_fn | ||
) | ||
object.__setattr__(self, 'checkpoint_fn', fn) | ||
object.__setattr__(self, 'checkpoint_other_fn', other_fn) | ||
|
||
def maybe_to_json(self, obj: Any): | ||
"""Call `cirq.to_json with `value` according to the configuration options in this class. | ||
|
||
If `checkpoint=False`, nothing will happen. Otherwise, we will use `checkpoint_fn` and | ||
`checkpoint_other_fn` as the destination JSON file as described in the class docstring. | ||
""" | ||
if not self.checkpoint: | ||
return | ||
assert self.checkpoint_fn is not None, 'mypy' | ||
assert self.checkpoint_other_fn is not None, 'mypy' | ||
if os.path.exists(self.checkpoint_fn): | ||
os.replace(self.checkpoint_fn, self.checkpoint_other_fn) | ||
protocols.to_json(obj, self.checkpoint_fn) | ||
|
||
|
||
# pylint: enable=missing-raises-doc | ||
def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool: | ||
"""Helper function to go through init_states and determine if any of them need an | ||
|
@@ -424,17 +491,15 @@ def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting | |
# TODO(#3388) Add documentation for Raises. | ||
# pylint: disable=missing-raises-doc | ||
def measure_grouped_settings( | ||
circuit: 'cirq.Circuit', | ||
circuit: 'cirq.AbstractCircuit', | ||
grouped_settings: Dict[InitObsSetting, List[InitObsSetting]], | ||
sampler: 'cirq.Sampler', | ||
stopping_criteria: StoppingCriteria, | ||
*, | ||
readout_symmetrization: bool = False, | ||
circuit_sweep: 'cirq.study.sweepable.SweepLike' = None, | ||
circuit_sweep: 'cirq.Sweepable' = None, | ||
readout_calibrations: Optional[BitstringAccumulator] = None, | ||
checkpoint: bool = False, | ||
checkpoint_fn: Optional[str] = None, | ||
checkpoint_other_fn: Optional[str] = None, | ||
checkpoint: CheckpointFileOptions = CheckpointFileOptions(), | ||
) -> List[BitstringAccumulator]: | ||
"""Measure a suite of grouped InitObsSetting settings. | ||
|
||
|
@@ -463,49 +528,29 @@ def measure_grouped_settings( | |
in `circuit`. The total sweep is the product of the circuit sweep | ||
with parameter settings for the single-qubit basis-change rotations. | ||
readout_calibrations: The result of `calibrate_readout_error`. | ||
checkpoint: If set to True, save cumulative raw results at the end | ||
of each iteration of the sampling loop. Load in these results | ||
with `cirq.read_json`. | ||
checkpoint_fn: The filename for the checkpoint file. If `checkpoint` | ||
is set to True and this is not specified, a file in a temporary | ||
directory will be used. | ||
checkpoint_other_fn: The filename for another checkpoint file, which | ||
contains the previous checkpoint. This lets us avoid losing data if | ||
a failure occurs during checkpoint writing. If `checkpoint` | ||
is set to True and this is not specified, a file in a temporary | ||
directory will be used. If `checkpoint` is set to True and | ||
`checkpoint_fn` is specified but this argument is *not* specified, | ||
"{checkpoint_fn}.prev.json" will be used. | ||
checkpoint: Options to set up optional checkpointing of intermediate | ||
data for each iteration of the sampling loop. See the documentation | ||
for `CheckpointFileOptions` for more. Load in these results with | ||
`cirq.read_json`. | ||
""" | ||
if readout_calibrations is not None and not readout_symmetrization: | ||
raise ValueError("Readout calibration only works if `readout_symmetrization` is enabled.") | ||
|
||
checkpoint_fn, checkpoint_other_fn = _parse_checkpoint_options( | ||
checkpoint=checkpoint, checkpoint_fn=checkpoint_fn, checkpoint_other_fn=checkpoint_other_fn | ||
) | ||
qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits}) | ||
qubit_to_index = {q: i for i, q in enumerate(qubits)} | ||
|
||
needs_init_layer = _needs_init_layer(grouped_settings) | ||
measurement_param_circuit = _with_parameterized_layers(circuit, qubits, needs_init_layer) | ||
grouped_settings = { | ||
_pad_setting(max_setting, qubits): settings | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the call to |
||
for max_setting, settings in grouped_settings.items() | ||
} | ||
circuit_sweep = study.UnitSweep if circuit_sweep is None else study.to_sweep(circuit_sweep) | ||
|
||
# meas_spec provides a key for accumulators. | ||
# meas_specs_todo is a mutable list. We will pop things from it as various | ||
# specs are measured to the satisfaction of the stopping criteria | ||
accumulators = {} | ||
meas_specs_todo = [] | ||
for max_setting, circuit_params in itertools.product( | ||
grouped_settings.keys(), circuit_sweep.param_tuples() | ||
for max_setting, param_resolver in itertools.product( | ||
grouped_settings.keys(), study.to_resolvers(circuit_sweep) | ||
): | ||
# The type annotation for Param is just `Iterable`. | ||
# We make sure that it's truly a tuple. | ||
circuit_params = dict(circuit_params) | ||
|
||
circuit_params = param_resolver.param_dict | ||
meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params) | ||
accumulator = BitstringAccumulator( | ||
meas_spec=meas_spec, | ||
|
@@ -551,14 +596,131 @@ def measure_grouped_settings( | |
bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z']) | ||
accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe')) | ||
|
||
if checkpoint: | ||
assert checkpoint_fn is not None, 'mypy' | ||
assert checkpoint_other_fn is not None, 'mypy' | ||
if os.path.exists(checkpoint_fn): | ||
os.replace(checkpoint_fn, checkpoint_other_fn) | ||
to_json(list(accumulators.values()), checkpoint_fn) | ||
checkpoint.maybe_to_json(list(accumulators.values())) | ||
|
||
return list(accumulators.values()) | ||
|
||
|
||
# pylint: enable=missing-raises-doc | ||
|
||
|
||
_GROUPING_FUNCS: Dict[str, GROUPER_T] = { | ||
'greedy': group_settings_greedy, | ||
} | ||
|
||
|
||
def _parse_grouper(grouper: Union[str, GROUPER_T] = group_settings_greedy) -> GROUPER_T: | ||
"""Logic for turning a named grouper into one of the build-in groupers in support of the | ||
high-level `measure_observables` API.""" | ||
if isinstance(grouper, str): | ||
try: | ||
grouper = _GROUPING_FUNCS[grouper.lower()] | ||
except KeyError: | ||
raise ValueError(f"Unknown grouping function {grouper}") | ||
return grouper | ||
|
||
|
||
def _get_all_qubits( | ||
circuit: 'cirq.AbstractCircuit', | ||
observables: Iterable['cirq.PauliString'], | ||
) -> List['cirq.Qid']: | ||
"""Helper function for `measure_observables` to get all qubits from a circuit and a | ||
collection of observables.""" | ||
qubit_set = set() | ||
for obs in observables: | ||
qubit_set |= set(obs.qubits) | ||
qubit_set |= circuit.all_qubits() | ||
return sorted(qubit_set) | ||
|
||
|
||
def measure_observables( | ||
mpharrigan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
circuit: 'cirq.AbstractCircuit', | ||
observables: Iterable['cirq.PauliString'], | ||
sampler: Union['cirq.Simulator', 'cirq.Sampler'], | ||
stopping_criteria: StoppingCriteria, | ||
*, | ||
readout_symmetrization: bool = False, | ||
circuit_sweep: Optional['cirq.Sweepable'] = None, | ||
grouper: Union[str, GROUPER_T] = group_settings_greedy, | ||
readout_calibrations: Optional[BitstringAccumulator] = None, | ||
checkpoint: CheckpointFileOptions = CheckpointFileOptions(), | ||
) -> List[ObservableMeasuredResult]: | ||
"""Measure a collection of PauliString observables for a state prepared by a Circuit. | ||
|
||
If you need more control over the process, please see `measure_grouped_settings` for a | ||
lower-level API. If you would like your results returned as a pandas DataFrame, | ||
please see `measure_observables_df`. | ||
|
||
Args: | ||
circuit: The circuit used to prepare the state to measure. This can contain parameters, | ||
in which case you should also specify `circuit_sweep`. | ||
observables: A collection of PauliString observables to measure. These will be grouped | ||
into simultaneously-measurable groups, see `grouper` argument. | ||
sampler: The sampler. | ||
stopping_criteria: A StoppingCriteria object to indicate how precisely to sample | ||
measurements for estimating observables. | ||
readout_symmetrization: If set to True, each run will be split into two: one normal and | ||
one where a bit flip is incorporated prior to measurement. In the latter case, the | ||
measured bit will be flipped back classically and accumulated together. This causes | ||
readout error to appear symmetric, p(0|0) = p(1|1). | ||
circuit_sweep: Additional parameter sweeps for parameters contained in `circuit`. The | ||
total sweep is the product of the circuit sweep with parameter settings for the | ||
single-qubit basis-change rotations. | ||
grouper: Either "greedy" or a function that groups lists of `InitObsSetting`. See the | ||
documentation for the `grouped_settings` argument of `measure_grouped_settings` for | ||
full details. | ||
readout_calibrations: The result of `calibrate_readout_error`. | ||
checkpoint: Options to set up optional checkpointing of intermediate data for each | ||
iteration of the sampling loop. See the documentation for `CheckpointFileOptions` for | ||
more. Load in these results with `cirq.read_json`. | ||
|
||
Returns: | ||
A list of ObservableMeasuredResult; one for each input PauliString. | ||
""" | ||
qubits = _get_all_qubits(circuit, observables) | ||
settings = list(observables_to_settings(observables, qubits)) | ||
actual_grouper = _parse_grouper(grouper) | ||
grouped_settings = actual_grouper(settings) | ||
|
||
accumulators = measure_grouped_settings( | ||
circuit=circuit, | ||
grouped_settings=grouped_settings, | ||
sampler=sampler, | ||
stopping_criteria=stopping_criteria, | ||
circuit_sweep=circuit_sweep, | ||
readout_symmetrization=readout_symmetrization, | ||
readout_calibrations=readout_calibrations, | ||
checkpoint=checkpoint, | ||
) | ||
return flatten_grouped_results(accumulators) | ||
|
||
|
||
def measure_observables_df( | ||
circuit: 'cirq.AbstractCircuit', | ||
observables: Iterable['cirq.PauliString'], | ||
sampler: Union['cirq.Simulator', 'cirq.Sampler'], | ||
stopping_criteria: StoppingCriteria, | ||
*, | ||
readout_symmetrization: bool = False, | ||
circuit_sweep: Optional['cirq.Sweepable'] = None, | ||
grouper: Union[str, GROUPER_T] = group_settings_greedy, | ||
readout_calibrations: Optional[BitstringAccumulator] = None, | ||
checkpoint: CheckpointFileOptions = CheckpointFileOptions(), | ||
): | ||
"""Measure observables and return resulting data as a Pandas dataframe. | ||
|
||
Please see `measure_observables` for argument documentation. | ||
""" | ||
results = measure_observables( | ||
circuit=circuit, | ||
observables=observables, | ||
sampler=sampler, | ||
stopping_criteria=stopping_criteria, | ||
readout_symmetrization=readout_symmetrization, | ||
circuit_sweep=circuit_sweep, | ||
grouper=grouper, | ||
readout_calibrations=readout_calibrations, | ||
checkpoint=checkpoint, | ||
) | ||
df = pd.DataFrame(res.as_dict() for res in results) | ||
return df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the Cirq
json_serializable_dataclass
break type checking? Using two different classes under the same name feels really unstable to me.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It breaks type checking. I introduced #4391 to fix this, but it isn't merged yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I stand by my comment, though - this should wait until #4391 is merged and use the fix it provides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. We definitely have this pattern in Cirq already as the linked pr demonstrates. It's really a drawback of mypy more than anything else; the decorator works fine