diff --git a/cirq-core/cirq/work/observable_measurement.py b/cirq-core/cirq/work/observable_measurement.py index d0acd99a4e7..b0888110a6e 100644 --- a/cirq-core/cirq/work/observable_measurement.py +++ b/cirq-core/cirq/work/observable_measurement.py @@ -62,7 +62,7 @@ def _with_parameterized_layers( - circuit: 'cirq.Circuit', + circuit: 'cirq.AbstractCircuit', qubits: Sequence['cirq.Qid'], needs_init_layer: bool, ) -> 'cirq.Circuit': @@ -84,9 +84,9 @@ def _with_parameterized_layers( meas_mom = ops.Moment([ops.measure(*qubits, key='z')]) if needs_init_layer: total_circuit = circuits.Circuit([x_beg_mom, y_beg_mom]) - total_circuit += circuit.copy() + total_circuit += circuit.unfreeze() else: - total_circuit = circuit.copy() + total_circuit = circuit.unfreeze() total_circuit.append([x_end_mom, y_end_mom, meas_mom]) return total_circuit @@ -445,7 +445,7 @@ def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting def measure_grouped_settings( - circuit: 'cirq.Circuit', + circuit: 'cirq.AbstractCircuit', grouped_settings: Dict[InitObsSetting, List[InitObsSetting]], sampler: 'cirq.Sampler', stopping_criteria: StoppingCriteria, @@ -523,10 +523,7 @@ def measure_grouped_settings( for max_setting, circuit_params in itertools.product( grouped_settings.keys(), circuit_sweep.param_tuples() ): - # The type annotation for Param is just `Iterable`. - # We make sure that it's truly a tuple. circuit_params = dict(circuit_params) - meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params) accumulator = BitstringAccumulator( meas_spec=meas_spec, @@ -616,8 +613,8 @@ def _parse_grouper(grouper: Union[str, GROUPER_T] = group_settings_greedy) -> GR def _get_all_qubits( - circuit: circuits.Circuit, - observables: Iterable[ops.PauliString], + circuit: 'cirq.AbstractCircuit', + observables: Iterable['cirq.PauliString'], ) -> List['cirq.Qid']: """Helper function for `measure_observables` to get all qubits from a circuit and a collection of observables.""" @@ -629,8 +626,8 @@ def _get_all_qubits( def measure_observables( - circuit: circuits.Circuit, - observables: Iterable[ops.PauliString], + circuit: 'cirq.AbstractCircuit', + observables: Iterable['cirq.PauliString'], sampler: Union['cirq.Simulator', 'cirq.Sampler'], stopping_criteria: Union[str, StoppingCriteria], stopping_criteria_val: Optional[float] = None, @@ -642,7 +639,7 @@ def measure_observables( checkpoint: bool = False, checkpoint_fn: Optional[str] = None, checkpoint_other_fn: Optional[str] = None, -): +) -> List[BitstringAccumulator]: """Measure a collection of PauliString observables for a state prepared by a Circuit. If you need more control over the process, please see `measure_grouped_settings` for a @@ -708,8 +705,8 @@ def measure_observables( def measure_observables_df( - circuit: circuits.Circuit, - observables: Iterable[ops.PauliString], + circuit: 'cirq.AbstractCircuit', + observables: Iterable['cirq.PauliString'], sampler: Union['cirq.Simulator', 'cirq.Sampler'], stopping_criteria: Union[str, StoppingCriteria], stopping_criteria_val: Optional[float] = None, diff --git a/cirq-core/cirq/work/observable_settings.py b/cirq-core/cirq/work/observable_settings.py index c5378342001..3101ccef6fc 100644 --- a/cirq-core/cirq/work/observable_settings.py +++ b/cirq-core/cirq/work/observable_settings.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Iterable, Dict, TYPE_CHECKING, Tuple +from typing import Union, Iterable, Dict, TYPE_CHECKING, ItemsView from cirq import ops, value @@ -135,7 +135,7 @@ def _fix_precision(val: float, precision) -> int: return int(val * precision) -def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7): +def _hashable_param(param_tuples: ItemsView[str, float], precision=1e7): """Hash circuit parameters using fixed precision. Circuit parameters can be floats but we also need to use them as diff --git a/cirq-core/cirq/work/sampler.py b/cirq-core/cirq/work/sampler.py index 425a1a1c7ff..6bfef3cef41 100644 --- a/cirq-core/cirq/work/sampler.py +++ b/cirq-core/cirq/work/sampler.py @@ -13,12 +13,13 @@ # limitations under the License. """Abstract base class for things sampling quantum circuits.""" -from typing import List, Optional, TYPE_CHECKING, Union import abc +from typing import List, Optional, TYPE_CHECKING, Union, Dict, FrozenSet import pandas as pd - -from cirq import study +from cirq import study, ops +from cirq.work.observable_measurement import measure_observables, RepetitionsStoppingCriteria +from cirq.work.observable_settings import _hashable_param if TYPE_CHECKING: import cirq @@ -253,3 +254,107 @@ def run_batch( self.run_sweep(circuit, params=params, repetitions=repetitions) for circuit, params, repetitions in zip(programs, params_list, repetitions) ] + + def sample_expectation_values( + self, + program: 'cirq.AbstractCircuit', + observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], + *, + num_samples: int, + params: 'cirq.Sweepable' = None, + permit_terminal_measurements: bool = False, + ) -> List[List[float]]: + """Calculates estimated expectation values from samples of a circuit. + + This is a minimal implementation for measuring observables, and is best reserved for + simple use cases. For more complex use cases, consider upgrading to + `cirq.work.observable_measurement`. Additional features provided by that toolkit include: + - Chunking of submissions to support more than (max_shots) from + Quantum Engine + - Checkpointing so you don't lose your work halfway through a job + - Measuring to a variance tolerance rather than a pre-specified + number of repetitions + - Readout error symmetrization and mitigation + + This method can be run on any device or simulator that supports circuit sampling. Compare + with `simulate_expectation_values` in simulator.py, which is limited to simulators + but provides exact results. + + Args: + program: The circuit which prepares a state from which we sample expectation values. + observables: A list of observables for which to calculate expectation values. + num_samples: The number of samples to take. Increasing this value increases the + statistical accuracy of the estimate. + params: Parameters to run with the program. + permit_terminal_measurements: If the provided circuit ends in a measurement, this + method will generate an error unless this is set to True. This is meant to + prevent measurements from ruining expectation value calculations. + + Returns: + A list of expectation-value lists. The outer index determines the sweep, and the inner + index determines the observable. For instance, results[1][3] would select the fourth + observable measured in the second sweep. + """ + if num_samples <= 0: + raise ValueError( + f'Expectation values require at least one sample. Received: {num_samples}.' + ) + if not observables: + raise ValueError('At least one observable must be provided.') + if not permit_terminal_measurements and program.are_any_measurements_terminal(): + raise ValueError( + 'Provided circuit has terminal measurements, which may ' + 'skew expectation values. If this is intentional, set ' + 'permit_terminal_measurements=True.' + ) + + # Wrap input into a list of pauli sum + pauli_sums: List['cirq.PauliSum'] = ( + [ops.PauliSum.wrap(o) for o in observables] + if isinstance(observables, List) + else [ops.PauliSum.wrap(observables)] + ) + del observables + + # Flatten Pauli Sum into one big list of Pauli String + # Keep track of which Pauli Sum each one was from. + flat_pstrings: List['cirq.PauliString'] = [] + pstring_to_psum_i: Dict['cirq.PauliString', int] = {} + for psum_i, pauli_sum in enumerate(pauli_sums): + for pstring in pauli_sum: + flat_pstrings.append(pstring) + pstring_to_psum_i[pstring] = psum_i + + # Flatten Circuit Sweep into one big list of Params. + # Keep track of their indices so we can map back. + circuit_sweep = study.UnitSweep if params is None else study.to_sweep(params) + all_circuit_params: List[Dict[str, float]] = [ + dict(circuit_params) for circuit_params in circuit_sweep.param_tuples() + ] + circuit_param_to_sweep_i: Dict[FrozenSet[str, float], int] = { + _hashable_param(param.items()): i for i, param in enumerate(all_circuit_params) + } + del params + + accumulators = measure_observables( + circuit=program, + observables=flat_pstrings, + sampler=self, + stopping_criteria=RepetitionsStoppingCriteria(total_repetitions=num_samples), + readout_symmetrization=False, + circuit_sweep=circuit_sweep, + checkpoint=False, + ) + + # Results are ordered by how they're grouped. Since we want the (circuit_sweep, pauli_sum) + # nesting structure, we place the measured values according to the back-mappings we set up + # above. We also do the sum operation to aggregate multiple PauliString measured values + # for a given PauliSum. + results: List[List[float]] = [[0] * len(pauli_sums) for _ in range(len(all_circuit_params))] + for acc in accumulators: + for res in acc.results: + param_i = circuit_param_to_sweep_i[_hashable_param(res.circuit_params.items())] + psum_i = pstring_to_psum_i[res.setting.observable] + results[param_i][psum_i] += res.mean + + return results diff --git a/cirq-core/cirq/work/sampler_test.py b/cirq-core/cirq/work/sampler_test.py index 84119beb397..5bf45d35570 100644 --- a/cirq-core/cirq/work/sampler_test.py +++ b/cirq-core/cirq/work/sampler_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for cirq.Sampler.""" +from typing import List + import pytest import numpy as np @@ -198,3 +200,124 @@ def test_sampler_run_batch_bad_input_lengths(): _ = sampler.run_batch( [circuit1, circuit2], params_list=[params1, params2], repetitions=[1, 2, 3] ) + + +def test_sampler_simple_sample_expectation_values(): + a = cirq.LineQubit(0) + sampler = cirq.Simulator() + circuit = cirq.Circuit(cirq.H(a)) + obs = cirq.X(a) + results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000) + + assert np.allclose(results, [[1]]) + + +def test_sampler_sample_expectation_values_calculation(): + class DeterministicImbalancedStateSampler(cirq.Sampler): + """A simple, deterministic mock sampler. + Pretends to sample from a state vector with a 3:1 balance between the + probabilities of the |0) and |1) state. + """ + + def run_sweep( + self, + program: 'cirq.Circuit', + params: 'cirq.Sweepable', + repetitions: int = 1, + ) -> List['cirq.Result']: + results = np.zeros((repetitions, 1), dtype=bool) + for idx in range(repetitions // 4): + results[idx][0] = 1 + return [ + cirq.Result(params=pr, measurements={'z': results}) + for pr in cirq.study.to_resolvers(params) + ] + + a = cirq.LineQubit(0) + sampler = DeterministicImbalancedStateSampler() + # This circuit is not actually sampled, but the mock sampler above gives + # a reasonable approximation of it. + circuit = cirq.Circuit(cirq.X(a) ** (1 / 3)) + obs = cirq.Z(a) + results = sampler.sample_expectation_values(circuit, [obs], num_samples=1000) + + # (0.75 * 1) + (0.25 * -1) = 0.5 + assert np.allclose(results, [[0.5]]) + + +def test_sampler_sample_expectation_values_multi_param(): + a = cirq.LineQubit(0) + t = sympy.Symbol('t') + sampler = cirq.Simulator(seed=1) + circuit = cirq.Circuit(cirq.X(a) ** t) + obs = cirq.Z(a) + results = sampler.sample_expectation_values( + circuit, [obs], num_samples=5, params=cirq.Linspace('t', 0, 2, 3) + ) + + assert np.allclose(results, [[1], [-1], [1]]) + + +def test_sampler_sample_expectation_values_multi_qubit(): + q = cirq.LineQubit.range(3) + sampler = cirq.Simulator(seed=1) + circuit = cirq.Circuit(cirq.X(q[0]), cirq.X(q[1]), cirq.X(q[2])) + obs = cirq.Z(q[0]) + cirq.Z(q[1]) + cirq.Z(q[2]) + results = sampler.sample_expectation_values(circuit, [obs], num_samples=5) + + assert np.allclose(results, [[-3]]) + + +def test_sampler_sample_expectation_values_composite(): + # Tests multi-{param,qubit} sampling together in one circuit. + q = cirq.LineQubit.range(3) + t = [sympy.Symbol(f't{x}') for x in range(3)] + + sampler = cirq.Simulator(seed=1) + circuit = cirq.Circuit( + cirq.X(q[0]) ** t[0], + cirq.X(q[1]) ** t[1], + cirq.X(q[2]) ** t[2], + ) + + obs = [cirq.Z(q[x]) for x in range(3)] + # t0 is in the inner loop to make bit-ordering easier below. + params = ([{'t0': t0, 't1': t1, 't2': t2} for t2 in [0, 1] for t1 in [0, 1] for t0 in [0, 1]],) + results = sampler.sample_expectation_values( + circuit, + obs, + num_samples=5, + params=params, + ) + print('\n'.join(str(r) for r in results)) + + assert len(results) == 8 + assert np.allclose( + results, + [ + [+1, +1, +1], + [-1, +1, +1], + [+1, -1, +1], + [-1, -1, +1], + [+1, +1, -1], + [-1, +1, -1], + [+1, -1, -1], + [-1, -1, -1], + ], + ) + + +def test_sampler_simple_sample_expectation_requirements(): + a = cirq.LineQubit(0) + sampler = cirq.Simulator(seed=1) + circuit = cirq.Circuit(cirq.H(a)) + obs = cirq.X(a) + with pytest.raises(ValueError, match='at least one sample'): + _ = sampler.sample_expectation_values(circuit, [obs], num_samples=0) + + with pytest.raises(ValueError, match='At least one observable'): + _ = sampler.sample_expectation_values(circuit, [], num_samples=1) + + circuit.append(cirq.measure(a, key='out')) + with pytest.raises(ValueError, match='permit_terminal_measurements'): + _ = sampler.sample_expectation_values(circuit, [obs], num_samples=1)