Skip to content

Commit

Permalink
Change Sampler methods to return Sequence instead of List (quantumlib…
Browse files Browse the repository at this point in the history
…#4807)

Sequence is a covariant type (Sequence[A] is a subclass of Sequence[B]
when A is a subclass of B) while List is invariant (List[A] is not a
subclass of List[B] even in A is a subclass of B). Returning
Sequence[cirq.Result] thus allows returning subtypes of cirq.Result,
while List[cirq.Result] does not.

Review: @viathor
  • Loading branch information
maffoo authored and MichaelBroughton committed Jan 22, 2022
1 parent dd65104 commit 3759795
Show file tree
Hide file tree
Showing 18 changed files with 76 additions and 74 deletions.
4 changes: 2 additions & 2 deletions cirq-aqt/cirq_aqt/aqt_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import json
import time
import uuid
from typing import List, Union, Tuple, Dict, cast
from typing import cast, Dict, List, Sequence, Tuple, Union

import numpy as np
from requests import put
Expand Down Expand Up @@ -179,7 +179,7 @@ def _send_json(

def run_sweep(
self, program: cirq.AbstractCircuit, params: cirq.Sweepable, repetitions: int = 1
) -> List[cirq.Result]:
) -> Sequence[cirq.Result]:
"""Samples from the given Circuit.
In contrast to run, this allows for sweeping over different parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Sequence
import pytest

import numpy as np
Expand Down Expand Up @@ -48,7 +48,7 @@ def run_sweep(
program: 'cirq.AbstractCircuit',
params: cirq.Sweepable,
repetitions: int = 1,
) -> List[cirq.Result]:
) -> Sequence[cirq.Result]:
results = self.simulator.run_sweep(program, params, repetitions)
for result in results:
for bits in result.measurements.values():
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/clifford/stabilizer_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Sequence

import numpy as np

Expand Down Expand Up @@ -40,7 +40,7 @@ def run_sweep(
program: 'cirq.AbstractCircuit',
params: 'cirq.Sweepable',
repetitions: int = 1,
) -> List['cirq.Result']:
) -> Sequence['cirq.Result']:
results: List[cirq.Result] = []
for param_resolver in cirq.to_resolvers(params):
resolved_circuit = cirq.resolve_parameters(program, param_resolver)
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Filename is a reference to multiplexing.
"""

from typing import List, Optional, Type, Union, cast, TYPE_CHECKING
from typing import cast, List, Optional, Sequence, Type, TYPE_CHECKING, Union

import numpy as np

Expand Down Expand Up @@ -169,7 +169,7 @@ def sample_sweep(
repetitions: int = 1,
dtype: Type[np.number] = np.complex64,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> List['cirq.Result']:
) -> Sequence['cirq.Result']:
"""Runs the supplied Circuit, mimicking quantum hardware.
In contrast to run, this allows for sweeping over different parameter
Expand Down
14 changes: 7 additions & 7 deletions cirq-core/cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@
import collections
from typing import (
Any,
Callable,
cast,
Dict,
Generic,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
Optional,
TYPE_CHECKING,
Set,
cast,
Callable,
TypeVar,
Generic,
Union,
)

import numpy as np
Expand Down Expand Up @@ -73,7 +73,7 @@ def run_sweep(
program: 'cirq.AbstractCircuit',
params: 'cirq.Sweepable',
repetitions: int = 1,
) -> List['cirq.Result']:
) -> Sequence['cirq.Result']:
return list(self.run_sweep_iter(program, params, repetitions))

def run_sweep_iter(
Expand Down
12 changes: 6 additions & 6 deletions cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
"""Abstract base class for things sampling quantum circuits."""

import abc
from typing import List, Optional, TYPE_CHECKING, Union, Dict, FrozenSet, Tuple
from typing import Sequence
from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import pandas as pd

from cirq import study, ops
from cirq.work.observable_measurement import (
measure_observables,
Expand Down Expand Up @@ -148,7 +148,7 @@ def run_sweep(
program: 'cirq.AbstractCircuit',
params: 'cirq.Sweepable',
repetitions: int = 1,
) -> List['cirq.Result']:
) -> Sequence['cirq.Result']:
"""Samples from the given Circuit.
In contrast to run, this allows for sweeping over different parameter
Expand Down Expand Up @@ -187,7 +187,7 @@ async def run_sweep_async(
program: 'cirq.AbstractCircuit',
params: 'cirq.Sweepable',
repetitions: int = 1,
) -> List['cirq.Result']:
) -> Sequence['cirq.Result']:
"""Asynchronously sweeps and samples from the given Circuit.
By default, this method invokes `run_sweep` synchronously and simply
Expand All @@ -211,7 +211,7 @@ def run_batch(
programs: Sequence['cirq.AbstractCircuit'],
params_list: Optional[List['cirq.Sweepable']] = None,
repetitions: Union[int, List[int]] = 1,
) -> List[List['cirq.Result']]:
) -> Sequence[Sequence['cirq.Result']]:
"""Runs the supplied circuits.
Each circuit provided in `programs` will pair with the optional
Expand Down Expand Up @@ -277,7 +277,7 @@ def sample_expectation_values(
num_samples: int,
params: 'cirq.Sweepable' = None,
permit_terminal_measurements: bool = False,
) -> List[List[float]]:
) -> Sequence[Sequence[float]]:
"""Calculates estimated expectation values from samples of a circuit.
Please see also `cirq.work.measure_observables` for more control over how to measure
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for cirq.Sampler."""
from typing import List
from typing import Sequence

import pytest

Expand Down Expand Up @@ -225,7 +225,7 @@ def run_sweep(
program: 'cirq.AbstractCircuit',
params: 'cirq.Sweepable',
repetitions: int = 1,
) -> List['cirq.Result']:
) -> Sequence['cirq.Result']:
results = np.zeros((repetitions, 1), dtype=bool)
for idx in range(repetitions // 4):
results[idx][0] = 1
Expand Down
5 changes: 3 additions & 2 deletions cirq-google/cirq_google/api/v2/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Iterator,
List,
Optional,
Sequence,
Set,
)
from collections import OrderedDict
Expand Down Expand Up @@ -162,7 +163,7 @@ def results_to_proto(
def results_from_proto(
msg: result_pb2.Result,
measurements: List[MeasureInfo] = None,
) -> List[List[cirq.Result]]:
) -> Sequence[Sequence[cirq.Result]]:
"""Converts a v2 result proto into List of list of trial results.
Args:
Expand All @@ -185,7 +186,7 @@ def results_from_proto(
def _trial_sweep_from_proto(
msg: result_pb2.SweepResult,
measure_map: Dict[str, MeasureInfo] = None,
) -> List[cirq.Result]:
) -> Sequence[cirq.Result]:
"""Converts a SweepResult proto into List of list of trial results.
Args:
Expand Down
10 changes: 5 additions & 5 deletions cirq-google/cirq_google/engine/abstract_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""A helper for jobs that have been created on the Quantum Engine."""

import abc
from typing import Dict, Iterator, List, Optional, overload, Tuple, TYPE_CHECKING
from typing import Dict, Iterator, List, Optional, overload, Sequence, Tuple, TYPE_CHECKING

import cirq
from cirq_google.engine.client import quantum
Expand Down Expand Up @@ -161,7 +161,7 @@ def delete(self) -> Optional[bool]:
"""Deletes the job and result, if any."""

@abc.abstractmethod
def batched_results(self) -> List[List[cirq.Result]]:
def batched_results(self) -> Sequence[Sequence[cirq.Result]]:
"""Returns the job results, blocking until the job is complete.
This method is intended for batched jobs. Instead of flattening
Expand All @@ -170,11 +170,11 @@ def batched_results(self) -> List[List[cirq.Result]]:
"""

@abc.abstractmethod
def results(self) -> List[cirq.Result]:
def results(self) -> Sequence[cirq.Result]:
"""Returns the job results, blocking until the job is complete."""

@abc.abstractmethod
def calibration_results(self) -> List['calibration_result.CalibrationResult']:
def calibration_results(self) -> Sequence['calibration_result.CalibrationResult']:
"""Returns the results of a run_calibration() call.
This function will fail if any other type of results were returned.
Expand All @@ -189,7 +189,7 @@ def __getitem__(self, item: int) -> cirq.Result:
pass

@overload
def __getitem__(self, item: slice) -> List[cirq.Result]:
def __getitem__(self, item: slice) -> Sequence[cirq.Result]:
pass

def __getitem__(self, item):
Expand Down
8 changes: 4 additions & 4 deletions cirq-google/cirq_google/engine/abstract_local_job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A helper for jobs that have been created on the Quantum Engine."""
from typing import List, Optional, Tuple
from typing import Optional, Sequence, Tuple
import datetime
import cirq

Expand Down Expand Up @@ -40,13 +40,13 @@ def cancel(self) -> None:
def delete(self) -> None:
pass

def batched_results(self) -> List[List[cirq.Result]]:
def batched_results(self) -> Sequence[Sequence[cirq.Result]]:
return [] # coverage: ignore

def results(self) -> List[cirq.Result]:
def results(self) -> Sequence[cirq.Result]:
return [] # coverage: ignore

def calibration_results(self) -> List[CalibrationResult]:
def calibration_results(self) -> Sequence[CalibrationResult]:
return [] # coverage: ignore


Expand Down
31 changes: 16 additions & 15 deletions cirq-google/cirq_google/engine/engine_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import datetime
import time

from typing import Dict, Iterator, List, Optional, overload, Tuple, TYPE_CHECKING
from typing import Dict, Iterator, List, Optional, overload, Sequence, Tuple, TYPE_CHECKING

import cirq
from cirq_google.engine import abstract_job, calibration, engine_client
Expand Down Expand Up @@ -79,9 +79,9 @@ def __init__(
self.job_id = job_id
self.context = context
self._job = _job
self._results: Optional[List[cirq.Result]] = None
self._calibration_results: Optional[CalibrationResult] = None
self._batched_results: Optional[List[List[cirq.Result]]] = None
self._results: Optional[Sequence[cirq.Result]] = None
self._calibration_results: Optional[Sequence[CalibrationResult]] = None
self._batched_results: Optional[Sequence[Sequence[cirq.Result]]] = None
self.result_type = result_type

def id(self) -> str:
Expand Down Expand Up @@ -258,11 +258,11 @@ def delete(self) -> None:
"""Deletes the job and result, if any."""
self.context.client.delete_job(self.project_id, self.program_id, self.job_id)

def batched_results(self) -> List[List[cirq.Result]]:
def batched_results(self) -> Sequence[Sequence[cirq.Result]]:
"""Returns the job results, blocking until the job is complete.
This method is intended for batched jobs. Instead of flattening
results into a single list, this will return a List[Result]
results into a single list, this will return a Sequence[Result]
for each circuit in the batch.
"""
self.results()
Expand All @@ -288,7 +288,7 @@ def _wait_for_result(self):
)
return response.result

def results(self) -> List[cirq.Result]:
def results(self) -> Sequence[cirq.Result]:
"""Returns the job results, blocking until the job is complete."""
import cirq_google.engine.engine as engine_base

Expand All @@ -315,7 +315,7 @@ def results(self) -> List[cirq.Result]:
raise ValueError(f'invalid result proto version: {result_type}')
return self._results

def calibration_results(self):
def calibration_results(self) -> Sequence[CalibrationResult]:
"""Returns the results of a run_calibration() call.
This function will fail if any other type of results were returned
Expand All @@ -334,24 +334,25 @@ def calibration_results(self):
metrics = calibration.Calibration(layer.metrics)
message = layer.error_message or None
token = layer.token or None
ts: Optional[datetime.datetime] = None
if layer.valid_until_ms > 0:
ts = datetime.datetime.fromtimestamp(layer.valid_until_ms / 1000)
else:
ts = None
cal_results.append(CalibrationResult(layer.code, message, token, ts, metrics))
self._calibration_results = cal_results
return self._calibration_results

@classmethod
def _get_batch_results_v2(cls, results: v2.batch_pb2.BatchResult) -> List[List[cirq.Result]]:
def _get_batch_results_v2(
cls, results: v2.batch_pb2.BatchResult
) -> Sequence[Sequence[cirq.Result]]:
trial_results = []
for result in results.results:
# Add a new list for the result
trial_results.append(_get_job_results_v2(result))
return trial_results

@classmethod
def _flatten(cls, result) -> List[cirq.Result]:
def _flatten(cls, result) -> Sequence[cirq.Result]:
return [res for result_list in result for res in result_list]

def __iter__(self) -> Iterator[cirq.Result]:
Expand All @@ -363,7 +364,7 @@ def __getitem__(self, item: int) -> cirq.Result:
pass

@overload
def __getitem__(self, item: slice) -> List[cirq.Result]:
def __getitem__(self, item: slice) -> Sequence[cirq.Result]:
pass

def __getitem__(self, item):
Expand Down Expand Up @@ -403,7 +404,7 @@ def _deserialize_run_context(
raise ValueError(f'unsupported run_context type: {run_context_type}')


def _get_job_results_v1(result: v1.program_pb2.Result) -> List[cirq.Result]:
def _get_job_results_v1(result: v1.program_pb2.Result) -> Sequence[cirq.Result]:
trial_results = []
for sweep_result in result.sweep_results:
sweep_repetitions = sweep_result.repetitions
Expand All @@ -421,7 +422,7 @@ def _get_job_results_v1(result: v1.program_pb2.Result) -> List[cirq.Result]:
return trial_results


def _get_job_results_v2(result: v2.result_pb2.Result) -> List[cirq.Result]:
def _get_job_results_v2(result: v2.result_pb2.Result) -> Sequence[cirq.Result]:
sweep_results = v2.results_from_proto(result)
# Flatten to single list to match to sampler api.
return [trial_result for sweep_result in sweep_results for trial_result in sweep_result]
Expand Down
6 changes: 3 additions & 3 deletions cirq-google/cirq_google/engine/engine_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run_sweep(
program: Union[cirq.AbstractCircuit, 'cirq_google.EngineProgram'],
params: cirq.Sweepable,
repetitions: int = 1,
) -> List[cirq.Result]:
) -> Sequence[cirq.Result]:
if isinstance(program, engine.EngineProgram):
job = program.run_sweep(
params=params, repetitions=repetitions, processor_ids=self._processor_ids
Expand All @@ -70,10 +70,10 @@ def run_sweep(

def run_batch(
self,
programs: Sequence['cirq.AbstractCircuit'],
programs: Sequence[cirq.AbstractCircuit],
params_list: Optional[List[cirq.Sweepable]] = None,
repetitions: Union[int, List[int]] = 1,
) -> List[List[cirq.Result]]:
) -> Sequence[Sequence[cirq.Result]]:
"""Runs the supplied circuits.
In order to gain a speedup from using this method instead of other run
Expand Down
Loading

0 comments on commit 3759795

Please sign in to comment.