diff --git a/cirq-google/cirq_google/engine/abstract_job.py b/cirq-google/cirq_google/engine/abstract_job.py index abbc21e70ed..122d27d34ab 100644 --- a/cirq-google/cirq_google/engine/abstract_job.py +++ b/cirq-google/cirq_google/engine/abstract_job.py @@ -181,7 +181,7 @@ def calibration_results(self) -> Sequence['calibration_result.CalibrationResult' """ def __iter__(self) -> Iterator[cirq.Result]: - return iter(self.results()) + yield from self.results() # pylint: disable=function-redefined @overload diff --git a/cirq-google/cirq_google/engine/abstract_job_test.py b/cirq-google/cirq_google/engine/abstract_job_test.py index fe723d6a655..0dcb5e17f83 100644 --- a/cirq-google/cirq_google/engine/abstract_job_test.py +++ b/cirq-google/cirq_google/engine/abstract_job_test.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, TYPE_CHECKING +import pytest +import numpy as np +import cirq from cirq_google.engine.abstract_job import AbstractJob if TYPE_CHECKING: @@ -83,7 +86,9 @@ def batched_results(self): pass def results(self): - return list(range(5)) + return list( + cirq.ResultDict(params={}, measurements={'a': np.asarray([t])}) for t in range(5) + ) def calibration_results(self): pass @@ -91,9 +96,30 @@ def calibration_results(self): def test_instantiation_and_iteration(): job = MockJob() + + # Test length assert len(job) == 5 - assert job[3] == 3 + + # Test direct indexing + assert job[3].measurements['a'][0] == 3 + + # Test iterating through for loop count = 0 - for num in job: - assert num == count + for result in job: + assert result.measurements['a'][0] == count count += 1 + + # Test iterator using iterator + iterator = iter(job) + result = next(iterator) + assert result.measurements['a'][0] == 0 + result = next(iterator) + assert result.measurements['a'][0] == 1 + result = next(iterator) + assert result.measurements['a'][0] == 2 + result = next(iterator) + assert result.measurements['a'][0] == 3 + result = next(iterator) + assert result.measurements['a'][0] == 4 + with pytest.raises(StopIteration): + next(iterator)