Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Obs] 4.5 - High-level API #4392

Merged
merged 18 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cirq-core/cirq/work/observable_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Dict, List, TYPE_CHECKING, cast
from typing import Iterable, Dict, List, TYPE_CHECKING, cast, Callable

from cirq import ops, value
from cirq.work.observable_settings import InitObsSetting, _max_weight_state, _max_weight_observable

if TYPE_CHECKING:
pass

GROUPER_T = Callable[[Iterable[InitObsSetting]], Dict[InitObsSetting, List[InitObsSetting]]]


def group_settings_greedy(
settings: Iterable[InitObsSetting],
Expand Down
254 changes: 208 additions & 46 deletions cirq-core/cirq/work/observable_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,42 @@
import os
import tempfile
import warnings
from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence
from typing import (
Optional,
Union,
Iterable,
Dict,
List,
Tuple,
TYPE_CHECKING,
Set,
Sequence,
Any,
)

import numpy as np
import pandas as pd
import sympy
from cirq import circuits, study, ops, value
from cirq import circuits, study, ops, value, protocols
from cirq._doc import document
from cirq.protocols import json_serializable_dataclass, to_json
from cirq.work.observable_measurement_data import BitstringAccumulator
from cirq.work.observable_grouping import group_settings_greedy, GROUPER_T
from cirq.work.observable_measurement_data import (
BitstringAccumulator,
ObservableMeasuredResult,
flatten_grouped_results,
)
from cirq.work.observable_settings import (
InitObsSetting,
observables_to_settings,
_MeasurementSpec,
)

if TYPE_CHECKING:
import cirq
from cirq.value.product_state import _NamedOneQubitState
from dataclasses import dataclass as json_serializable_dataclass
else:
from cirq.protocols import json_serializable_dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the Cirq json_serializable_dataclass break type checking? Using two different classes under the same name feels really unstable to me.

Copy link
Collaborator Author

@mpharrigan mpharrigan Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It breaks type checking. I introduced #4391 to fix this, but it isn't merged yet.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I stand by my comment, though - this should wait until #4391 is merged and use the fix it provides.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. We definitely have this pattern in Cirq already as the linked pr demonstrates. It's really a drawback of mypy more than anything else; the decorator works fine


MAX_REPETITIONS_PER_JOB = 3_000_000
document(
Expand All @@ -47,7 +67,7 @@


def _with_parameterized_layers(
circuit: 'cirq.Circuit',
circuit: 'cirq.AbstractCircuit',
qubits: Sequence['cirq.Qid'],
needs_init_layer: bool,
) -> 'cirq.Circuit':
Expand All @@ -69,9 +89,9 @@ def _with_parameterized_layers(
meas_mom = ops.Moment([ops.measure(*qubits, key='z')])
if needs_init_layer:
total_circuit = circuits.Circuit([x_beg_mom, y_beg_mom])
total_circuit += circuit.copy()
total_circuit += circuit.unfreeze()
else:
total_circuit = circuit.copy()
total_circuit = circuit.unfreeze()
total_circuit.append([x_end_mom, y_end_mom, meas_mom])
return total_circuit

Expand Down Expand Up @@ -170,6 +190,7 @@ def _get_params_for_setting(
if we know that `setting.init_state` is the all-zeros state and
`needs_init_layer` is False.
"""
setting = _pad_setting(setting, qubits)
params = {}
for qubit, flip in itertools.zip_longest(qubits, flips):
if qubit is None or flip is None:
Expand All @@ -196,7 +217,7 @@ def _get_params_for_setting(

def _pad_setting(
max_setting: InitObsSetting,
qubits: List['cirq.Qid'],
qubits: Sequence['cirq.Qid'],
pad_init_state_with=value.KET_ZERO,
pad_obs_with: 'cirq.Gate' = ops.Z,
) -> InitObsSetting:
Expand Down Expand Up @@ -411,6 +432,52 @@ def _parse_checkpoint_options(
return checkpoint_fn, checkpoint_other_fn


@dataclasses.dataclass(frozen=True)
class CheckpointFileOptions:
"""Options to configure "checkpointing" to save intermediate results.

Args:
checkpoint: If set to True, save cumulative raw results at the end
of each iteration of the sampling loop. Load in these results
with `cirq.read_json`.
checkpoint_fn: The filename for the checkpoint file. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used.
checkpoint_other_fn: The filename for another checkpoint file, which
contains the previous checkpoint. This lets us avoid losing data if
a failure occurs during checkpoint writing. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used. If `checkpoint` is set to True and
`checkpoint_fn` is specified but this argument is *not* specified,
"{checkpoint_fn}.prev.json" will be used.
"""

checkpoint: bool = False
checkpoint_fn: Optional[str] = None
checkpoint_other_fn: Optional[str] = None

def __post_init__(self):
fn, other_fn = _parse_checkpoint_options(
self.checkpoint, self.checkpoint_fn, self.checkpoint_other_fn
)
object.__setattr__(self, 'checkpoint_fn', fn)
object.__setattr__(self, 'checkpoint_other_fn', other_fn)

def maybe_to_json(self, obj: Any):
"""Call `cirq.to_json with `value` according to the configuration options in this class.

If `checkpoint=False`, nothing will happen. Otherwise, we will use `checkpoint_fn` and
`checkpoint_other_fn` as the destination JSON file as described in the class docstring.
"""
if not self.checkpoint:
return
assert self.checkpoint_fn is not None, 'mypy'
assert self.checkpoint_other_fn is not None, 'mypy'
if os.path.exists(self.checkpoint_fn):
os.replace(self.checkpoint_fn, self.checkpoint_other_fn)
protocols.to_json(obj, self.checkpoint_fn)


# pylint: enable=missing-raises-doc
def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool:
"""Helper function to go through init_states and determine if any of them need an
Expand All @@ -424,17 +491,15 @@ def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting
# TODO(#3388) Add documentation for Raises.
# pylint: disable=missing-raises-doc
def measure_grouped_settings(
circuit: 'cirq.Circuit',
circuit: 'cirq.AbstractCircuit',
grouped_settings: Dict[InitObsSetting, List[InitObsSetting]],
sampler: 'cirq.Sampler',
stopping_criteria: StoppingCriteria,
*,
readout_symmetrization: bool = False,
circuit_sweep: 'cirq.study.sweepable.SweepLike' = None,
circuit_sweep: 'cirq.Sweepable' = None,
readout_calibrations: Optional[BitstringAccumulator] = None,
checkpoint: bool = False,
checkpoint_fn: Optional[str] = None,
checkpoint_other_fn: Optional[str] = None,
checkpoint: CheckpointFileOptions = CheckpointFileOptions(),
) -> List[BitstringAccumulator]:
"""Measure a suite of grouped InitObsSetting settings.

Expand Down Expand Up @@ -463,49 +528,29 @@ def measure_grouped_settings(
in `circuit`. The total sweep is the product of the circuit sweep
with parameter settings for the single-qubit basis-change rotations.
readout_calibrations: The result of `calibrate_readout_error`.
checkpoint: If set to True, save cumulative raw results at the end
of each iteration of the sampling loop. Load in these results
with `cirq.read_json`.
checkpoint_fn: The filename for the checkpoint file. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used.
checkpoint_other_fn: The filename for another checkpoint file, which
contains the previous checkpoint. This lets us avoid losing data if
a failure occurs during checkpoint writing. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used. If `checkpoint` is set to True and
`checkpoint_fn` is specified but this argument is *not* specified,
"{checkpoint_fn}.prev.json" will be used.
checkpoint: Options to set up optional checkpointing of intermediate
data for each iteration of the sampling loop. See the documentation
for `CheckpointFileOptions` for more. Load in these results with
`cirq.read_json`.
"""
if readout_calibrations is not None and not readout_symmetrization:
raise ValueError("Readout calibration only works if `readout_symmetrization` is enabled.")

checkpoint_fn, checkpoint_other_fn = _parse_checkpoint_options(
checkpoint=checkpoint, checkpoint_fn=checkpoint_fn, checkpoint_other_fn=checkpoint_other_fn
)
qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits})
qubit_to_index = {q: i for i, q in enumerate(qubits)}

needs_init_layer = _needs_init_layer(grouped_settings)
measurement_param_circuit = _with_parameterized_layers(circuit, qubits, needs_init_layer)
grouped_settings = {
_pad_setting(max_setting, qubits): settings
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the call to _pad_setting has been moved to right near where it's actually needed in _get_params_for_setting. This was causing a subtle bug with the new test I added that uses a particularly bad grouper: it puts each observable in its own group. However, if you pad the keys in this bad grouping you start to get collisions. We'll keep the user's input max_settings whenever we're using it as a key anywhere and just "pad" when we need the actual circuit parameters.

for max_setting, settings in grouped_settings.items()
}
circuit_sweep = study.UnitSweep if circuit_sweep is None else study.to_sweep(circuit_sweep)

# meas_spec provides a key for accumulators.
# meas_specs_todo is a mutable list. We will pop things from it as various
# specs are measured to the satisfaction of the stopping criteria
accumulators = {}
meas_specs_todo = []
for max_setting, circuit_params in itertools.product(
grouped_settings.keys(), circuit_sweep.param_tuples()
for max_setting, param_resolver in itertools.product(
grouped_settings.keys(), study.to_resolvers(circuit_sweep)
):
# The type annotation for Param is just `Iterable`.
# We make sure that it's truly a tuple.
circuit_params = dict(circuit_params)

circuit_params = param_resolver.param_dict
meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
accumulator = BitstringAccumulator(
meas_spec=meas_spec,
Expand Down Expand Up @@ -551,14 +596,131 @@ def measure_grouped_settings(
bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z'])
accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe'))

if checkpoint:
assert checkpoint_fn is not None, 'mypy'
assert checkpoint_other_fn is not None, 'mypy'
if os.path.exists(checkpoint_fn):
os.replace(checkpoint_fn, checkpoint_other_fn)
to_json(list(accumulators.values()), checkpoint_fn)
checkpoint.maybe_to_json(list(accumulators.values()))

return list(accumulators.values())


# pylint: enable=missing-raises-doc


_GROUPING_FUNCS: Dict[str, GROUPER_T] = {
'greedy': group_settings_greedy,
}


def _parse_grouper(grouper: Union[str, GROUPER_T] = group_settings_greedy) -> GROUPER_T:
"""Logic for turning a named grouper into one of the build-in groupers in support of the
high-level `measure_observables` API."""
if isinstance(grouper, str):
try:
grouper = _GROUPING_FUNCS[grouper.lower()]
except KeyError:
raise ValueError(f"Unknown grouping function {grouper}")
return grouper


def _get_all_qubits(
circuit: 'cirq.AbstractCircuit',
observables: Iterable['cirq.PauliString'],
) -> List['cirq.Qid']:
"""Helper function for `measure_observables` to get all qubits from a circuit and a
collection of observables."""
qubit_set = set()
for obs in observables:
qubit_set |= set(obs.qubits)
qubit_set |= circuit.all_qubits()
return sorted(qubit_set)


def measure_observables(
mpharrigan marked this conversation as resolved.
Show resolved Hide resolved
circuit: 'cirq.AbstractCircuit',
observables: Iterable['cirq.PauliString'],
sampler: Union['cirq.Simulator', 'cirq.Sampler'],
stopping_criteria: StoppingCriteria,
*,
readout_symmetrization: bool = False,
circuit_sweep: Optional['cirq.Sweepable'] = None,
grouper: Union[str, GROUPER_T] = group_settings_greedy,
readout_calibrations: Optional[BitstringAccumulator] = None,
checkpoint: CheckpointFileOptions = CheckpointFileOptions(),
) -> List[ObservableMeasuredResult]:
"""Measure a collection of PauliString observables for a state prepared by a Circuit.

If you need more control over the process, please see `measure_grouped_settings` for a
lower-level API. If you would like your results returned as a pandas DataFrame,
please see `measure_observables_df`.

Args:
circuit: The circuit used to prepare the state to measure. This can contain parameters,
in which case you should also specify `circuit_sweep`.
observables: A collection of PauliString observables to measure. These will be grouped
into simultaneously-measurable groups, see `grouper` argument.
sampler: The sampler.
stopping_criteria: A StoppingCriteria object to indicate how precisely to sample
measurements for estimating observables.
readout_symmetrization: If set to True, each run will be split into two: one normal and
one where a bit flip is incorporated prior to measurement. In the latter case, the
measured bit will be flipped back classically and accumulated together. This causes
readout error to appear symmetric, p(0|0) = p(1|1).
circuit_sweep: Additional parameter sweeps for parameters contained in `circuit`. The
total sweep is the product of the circuit sweep with parameter settings for the
single-qubit basis-change rotations.
grouper: Either "greedy" or a function that groups lists of `InitObsSetting`. See the
documentation for the `grouped_settings` argument of `measure_grouped_settings` for
full details.
readout_calibrations: The result of `calibrate_readout_error`.
checkpoint: Options to set up optional checkpointing of intermediate data for each
iteration of the sampling loop. See the documentation for `CheckpointFileOptions` for
more. Load in these results with `cirq.read_json`.

Returns:
A list of ObservableMeasuredResult; one for each input PauliString.
"""
qubits = _get_all_qubits(circuit, observables)
settings = list(observables_to_settings(observables, qubits))
actual_grouper = _parse_grouper(grouper)
grouped_settings = actual_grouper(settings)

accumulators = measure_grouped_settings(
circuit=circuit,
grouped_settings=grouped_settings,
sampler=sampler,
stopping_criteria=stopping_criteria,
circuit_sweep=circuit_sweep,
readout_symmetrization=readout_symmetrization,
readout_calibrations=readout_calibrations,
checkpoint=checkpoint,
)
return flatten_grouped_results(accumulators)


def measure_observables_df(
circuit: 'cirq.AbstractCircuit',
observables: Iterable['cirq.PauliString'],
sampler: Union['cirq.Simulator', 'cirq.Sampler'],
stopping_criteria: StoppingCriteria,
*,
readout_symmetrization: bool = False,
circuit_sweep: Optional['cirq.Sweepable'] = None,
grouper: Union[str, GROUPER_T] = group_settings_greedy,
readout_calibrations: Optional[BitstringAccumulator] = None,
checkpoint: CheckpointFileOptions = CheckpointFileOptions(),
):
"""Measure observables and return resulting data as a Pandas dataframe.

Please see `measure_observables` for argument documentation.
"""
results = measure_observables(
circuit=circuit,
observables=observables,
sampler=sampler,
stopping_criteria=stopping_criteria,
readout_symmetrization=readout_symmetrization,
circuit_sweep=circuit_sweep,
grouper=grouper,
readout_calibrations=readout_calibrations,
checkpoint=checkpoint,
)
df = pd.DataFrame(res.as_dict() for res in results)
return df
Loading