Skip to content

Commit

Permalink
Sampler API
Browse files Browse the repository at this point in the history
  • Loading branch information
mpharrigan committed Aug 17, 2021
1 parent dcb59f3 commit f92fc65
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 19 deletions.
25 changes: 11 additions & 14 deletions cirq-core/cirq/work/observable_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@


def _with_parameterized_layers(
circuit: 'cirq.Circuit',
circuit: 'cirq.AbstractCircuit',
qubits: Sequence['cirq.Qid'],
needs_init_layer: bool,
) -> 'cirq.Circuit':
Expand All @@ -84,9 +84,9 @@ def _with_parameterized_layers(
meas_mom = ops.Moment([ops.measure(*qubits, key='z')])
if needs_init_layer:
total_circuit = circuits.Circuit([x_beg_mom, y_beg_mom])
total_circuit += circuit.copy()
total_circuit += circuit.unfreeze()
else:
total_circuit = circuit.copy()
total_circuit = circuit.unfreeze()
total_circuit.append([x_end_mom, y_end_mom, meas_mom])
return total_circuit

Expand Down Expand Up @@ -445,7 +445,7 @@ def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting


def measure_grouped_settings(
circuit: 'cirq.Circuit',
circuit: 'cirq.AbstractCircuit',
grouped_settings: Dict[InitObsSetting, List[InitObsSetting]],
sampler: 'cirq.Sampler',
stopping_criteria: StoppingCriteria,
Expand Down Expand Up @@ -523,10 +523,7 @@ def measure_grouped_settings(
for max_setting, circuit_params in itertools.product(
grouped_settings.keys(), circuit_sweep.param_tuples()
):
# The type annotation for Param is just `Iterable`.
# We make sure that it's truly a tuple.
circuit_params = dict(circuit_params)

meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
accumulator = BitstringAccumulator(
meas_spec=meas_spec,
Expand Down Expand Up @@ -616,8 +613,8 @@ def _parse_grouper(grouper: Union[str, GROUPER_T] = group_settings_greedy) -> GR


def _get_all_qubits(
circuit: circuits.Circuit,
observables: Iterable[ops.PauliString],
circuit: 'cirq.AbstractCircuit',
observables: Iterable['cirq.PauliString'],
) -> List['cirq.Qid']:
"""Helper function for `measure_observables` to get all qubits from a circuit and a
collection of observables."""
Expand All @@ -629,8 +626,8 @@ def _get_all_qubits(


def measure_observables(
circuit: circuits.Circuit,
observables: Iterable[ops.PauliString],
circuit: 'cirq.AbstractCircuit',
observables: Iterable['cirq.PauliString'],
sampler: Union['cirq.Simulator', 'cirq.Sampler'],
stopping_criteria: Union[str, StoppingCriteria],
stopping_criteria_val: Optional[float] = None,
Expand All @@ -642,7 +639,7 @@ def measure_observables(
checkpoint: bool = False,
checkpoint_fn: Optional[str] = None,
checkpoint_other_fn: Optional[str] = None,
):
) -> List[BitstringAccumulator]:
"""Measure a collection of PauliString observables for a state prepared by a Circuit.
If you need more control over the process, please see `measure_grouped_settings` for a
Expand Down Expand Up @@ -708,8 +705,8 @@ def measure_observables(


def measure_observables_df(
circuit: circuits.Circuit,
observables: Iterable[ops.PauliString],
circuit: 'cirq.AbstractCircuit',
observables: Iterable['cirq.PauliString'],
sampler: Union['cirq.Simulator', 'cirq.Sampler'],
stopping_criteria: Union[str, StoppingCriteria],
stopping_criteria_val: Optional[float] = None,
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/work/observable_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, Iterable, Dict, TYPE_CHECKING, Tuple
from typing import Union, Iterable, Dict, TYPE_CHECKING, ItemsView

from cirq import ops, value

Expand Down Expand Up @@ -135,7 +135,7 @@ def _fix_precision(val: float, precision) -> int:
return int(val * precision)


def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7):
def _hashable_param(param_tuples: ItemsView[str, float], precision=1e7):
"""Hash circuit parameters using fixed precision.
Circuit parameters can be floats but we also need to use them as
Expand Down
111 changes: 108 additions & 3 deletions cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.
"""Abstract base class for things sampling quantum circuits."""

from typing import List, Optional, TYPE_CHECKING, Union
import abc
from typing import List, Optional, TYPE_CHECKING, Union, Dict, FrozenSet

import pandas as pd

from cirq import study
from cirq import study, ops
from cirq.work.observable_measurement import measure_observables, RepetitionsStoppingCriteria
from cirq.work.observable_settings import _hashable_param

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -253,3 +254,107 @@ def run_batch(
self.run_sweep(circuit, params=params, repetitions=repetitions)
for circuit, params, repetitions in zip(programs, params_list, repetitions)
]

def sample_expectation_values(
self,
program: 'cirq.AbstractCircuit',
observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']],
*,
num_samples: int,
params: 'cirq.Sweepable' = None,
permit_terminal_measurements: bool = False,
) -> List[List[float]]:
"""Calculates estimated expectation values from samples of a circuit.
This is a minimal implementation for measuring observables, and is best reserved for
simple use cases. For more complex use cases, consider upgrading to
`cirq.work.observable_measurement`. Additional features provided by that toolkit include:
- Chunking of submissions to support more than (max_shots) from
Quantum Engine
- Checkpointing so you don't lose your work halfway through a job
- Measuring to a variance tolerance rather than a pre-specified
number of repetitions
- Readout error symmetrization and mitigation
This method can be run on any device or simulator that supports circuit sampling. Compare
with `simulate_expectation_values` in simulator.py, which is limited to simulators
but provides exact results.
Args:
program: The circuit which prepares a state from which we sample expectation values.
observables: A list of observables for which to calculate expectation values.
num_samples: The number of samples to take. Increasing this value increases the
statistical accuracy of the estimate.
params: Parameters to run with the program.
permit_terminal_measurements: If the provided circuit ends in a measurement, this
method will generate an error unless this is set to True. This is meant to
prevent measurements from ruining expectation value calculations.
Returns:
A list of expectation-value lists. The outer index determines the sweep, and the inner
index determines the observable. For instance, results[1][3] would select the fourth
observable measured in the second sweep.
"""
if num_samples <= 0:
raise ValueError(
f'Expectation values require at least one sample. Received: {num_samples}.'
)
if not observables:
raise ValueError('At least one observable must be provided.')
if not permit_terminal_measurements and program.are_any_measurements_terminal():
raise ValueError(
'Provided circuit has terminal measurements, which may '
'skew expectation values. If this is intentional, set '
'permit_terminal_measurements=True.'
)

# Wrap input into a list of pauli sum
pauli_sums: List['cirq.PauliSum'] = (
[ops.PauliSum.wrap(o) for o in observables]
if isinstance(observables, List)
else [ops.PauliSum.wrap(observables)]
)
del observables

# Flatten Pauli Sum into one big list of Pauli String
# Keep track of which Pauli Sum each one was from.
flat_pstrings: List['cirq.PauliString'] = []
pstring_to_psum_i: Dict['cirq.PauliString', int] = {}
for psum_i, pauli_sum in enumerate(pauli_sums):
for pstring in pauli_sum:
flat_pstrings.append(pstring)
pstring_to_psum_i[pstring] = psum_i

# Flatten Circuit Sweep into one big list of Params.
# Keep track of their indices so we can map back.
circuit_sweep = study.UnitSweep if params is None else study.to_sweep(params)
all_circuit_params: List[Dict[str, float]] = [
dict(circuit_params) for circuit_params in circuit_sweep.param_tuples()
]
circuit_param_to_sweep_i: Dict[FrozenSet[str, float], int] = {
_hashable_param(param.items()): i for i, param in enumerate(all_circuit_params)
}
del params

accumulators = measure_observables(
circuit=program,
observables=flat_pstrings,
sampler=self,
stopping_criteria=RepetitionsStoppingCriteria(total_repetitions=num_samples),
readout_symmetrization=False,
circuit_sweep=circuit_sweep,
checkpoint=False,
)

# Results are ordered by how they're grouped. Since we want the (circuit_sweep, pauli_sum)
# nesting structure, we place the measured values according to the back-mappings we set up
# above. We also do the sum operation to aggregate multiple PauliString measured values
# for a given PauliSum.
results: List[List[float]] = [[0] * len(pauli_sums) for _ in range(len(all_circuit_params))]
for acc in accumulators:
for res in acc.results:
param_i = circuit_param_to_sweep_i[_hashable_param(res.circuit_params.items())]
psum_i = pstring_to_psum_i[res.setting.observable]
results[param_i][psum_i] += res.mean

return results
123 changes: 123 additions & 0 deletions cirq-core/cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for cirq.Sampler."""
from typing import List

import pytest

import numpy as np
Expand Down Expand Up @@ -198,3 +200,124 @@ def test_sampler_run_batch_bad_input_lengths():
_ = sampler.run_batch(
[circuit1, circuit2], params_list=[params1, params2], repetitions=[1, 2, 3]
)


def test_sampler_simple_sample_expectation_values():
a = cirq.LineQubit(0)
sampler = cirq.Simulator()
circuit = cirq.Circuit(cirq.H(a))
obs = cirq.X(a)
results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000)

assert np.allclose(results, [[1]])


def test_sampler_sample_expectation_values_calculation():
class DeterministicImbalancedStateSampler(cirq.Sampler):
"""A simple, deterministic mock sampler.
Pretends to sample from a state vector with a 3:1 balance between the
probabilities of the |0) and |1) state.
"""

def run_sweep(
self,
program: 'cirq.Circuit',
params: 'cirq.Sweepable',
repetitions: int = 1,
) -> List['cirq.Result']:
results = np.zeros((repetitions, 1), dtype=bool)
for idx in range(repetitions // 4):
results[idx][0] = 1
return [
cirq.Result(params=pr, measurements={'z': results})
for pr in cirq.study.to_resolvers(params)
]

a = cirq.LineQubit(0)
sampler = DeterministicImbalancedStateSampler()
# This circuit is not actually sampled, but the mock sampler above gives
# a reasonable approximation of it.
circuit = cirq.Circuit(cirq.X(a) ** (1 / 3))
obs = cirq.Z(a)
results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000)

# (0.75 * 1) + (0.25 * -1) = 0.5
assert np.allclose(results, [[0.5]])


def test_sampler_sample_expectation_values_multi_param():
a = cirq.LineQubit(0)
t = sympy.Symbol('t')
sampler = cirq.Simulator(seed=1)
circuit = cirq.Circuit(cirq.X(a) ** t)
obs = cirq.Z(a)
results = sampler.sample_expectation_values(
circuit, [obs], num_samples=5, params=cirq.Linspace('t', 0, 2, 3)
)

assert np.allclose(results, [[1], [-1], [1]])


def test_sampler_sample_expectation_values_multi_qubit():
q = cirq.LineQubit.range(3)
sampler = cirq.Simulator(seed=1)
circuit = cirq.Circuit(cirq.X(q[0]), cirq.X(q[1]), cirq.X(q[2]))
obs = cirq.Z(q[0]) + cirq.Z(q[1]) + cirq.Z(q[2])
results = sampler.sample_expectation_values(circuit, [obs], num_samples=5)

assert np.allclose(results, [[-3]])


def test_sampler_sample_expectation_values_composite():
# Tests multi-{param,qubit} sampling together in one circuit.
q = cirq.LineQubit.range(3)
t = [sympy.Symbol(f't{x}') for x in range(3)]

sampler = cirq.Simulator(seed=1)
circuit = cirq.Circuit(
cirq.X(q[0]) ** t[0],
cirq.X(q[1]) ** t[1],
cirq.X(q[2]) ** t[2],
)

obs = [cirq.Z(q[x]) for x in range(3)]
# t0 is in the inner loop to make bit-ordering easier below.
params = ([{'t0': t0, 't1': t1, 't2': t2} for t2 in [0, 1] for t1 in [0, 1] for t0 in [0, 1]],)
results = sampler.sample_expectation_values(
circuit,
obs,
num_samples=5,
params=params,
)
print('\n'.join(str(r) for r in results))

assert len(results) == 8
assert np.allclose(
results,
[
[+1, +1, +1],
[-1, +1, +1],
[+1, -1, +1],
[-1, -1, +1],
[+1, +1, -1],
[-1, +1, -1],
[+1, -1, -1],
[-1, -1, -1],
],
)


def test_sampler_simple_sample_expectation_requirements():
a = cirq.LineQubit(0)
sampler = cirq.Simulator(seed=1)
circuit = cirq.Circuit(cirq.H(a))
obs = cirq.X(a)
with pytest.raises(ValueError, match='at least one sample'):
_ = sampler.sample_expectation_values(circuit, [obs], num_samples=0)

with pytest.raises(ValueError, match='At least one observable'):
_ = sampler.sample_expectation_values(circuit, [], num_samples=1)

circuit.append(cirq.measure(a, key='out'))
with pytest.raises(ValueError, match='permit_terminal_measurements'):
_ = sampler.sample_expectation_values(circuit, [obs], num_samples=1)

0 comments on commit f92fc65

Please sign in to comment.