diff --git a/cirq-google/cirq_google/workflow/progress.py b/cirq-google/cirq_google/workflow/progress.py new file mode 100644 index 00000000000..07a0830a19b --- /dev/null +++ b/cirq-google/cirq_google/workflow/progress.py @@ -0,0 +1,64 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Progress and logging facilities for the quantum runtime.""" + +import abc +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import cirq_google as cg + + +class _WorkflowLogger(abc.ABC): + """Implementers of this class can provide logging and progress information + for execution loops.""" + + def initialize(self): + """Initialization logic at the start of an execution loop.""" + + def consume_result( + self, exe_result: 'cg.ExecutableResult', shared_rt_info: 'cg.SharedRuntimeInfo' + ): + """Consume executable results as they are completed. + + Args: + exe_result: The completed `cg.ExecutableResult`. + shared_rt_info: A reference to the `cg.SharedRuntimeInfo` for this + execution at this point. + """ + + def finalize(self): + """Finalization logic at the end of an execution loop.""" + + +class _PrintLogger(_WorkflowLogger): + def __init__(self, n_total: int): + self.n_total = n_total + self.i = 0 + + def initialize(self): + """Write a newline at the start of an execution loop.""" + print() + + def consume_result( + self, exe_result: 'cg.ExecutableResult', shared_rt_info: 'cg.SharedRuntimeInfo' + ): + """Print a simple count of completed executables.""" + print(f'\r{self.i + 1} / {self.n_total}', end='', flush=True) + self.i += 1 + + def finalize(self): + """Write a newline at the end of an execution loop.""" + print() diff --git a/cirq-google/cirq_google/workflow/progress_test.py b/cirq-google/cirq_google/workflow/progress_test.py new file mode 100644 index 00000000000..f6159c77603 --- /dev/null +++ b/cirq-google/cirq_google/workflow/progress_test.py @@ -0,0 +1,43 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cirq +import cirq_google as cg +from cirq_google.workflow.progress import _PrintLogger + + +def test_print_logger(capsys): + pl = _PrintLogger(n_total=10) + shared_rt_info = cg.SharedRuntimeInfo(run_id='hi mom') + pl.initialize() + for i in range(10): + exe_result = cg.ExecutableResult( + spec=None, + runtime_info=cg.RuntimeInfo(execution_index=i), + raw_data=cirq.Result(params=cirq.ParamResolver({}), measurements={}), + ) + pl.consume_result(exe_result, shared_rt_info) + pl.finalize() + assert capsys.readouterr().out == ( + '\n\r1 / 10' + '\r2 / 10' + '\r3 / 10' + '\r4 / 10' + '\r5 / 10' + '\r6 / 10' + '\r7 / 10' + '\r8 / 10' + '\r9 / 10' + '\r10 / 10\n' + ) diff --git a/cirq-google/cirq_google/workflow/quantum_runtime.py b/cirq-google/cirq_google/workflow/quantum_runtime.py index f05bd75ab4a..35e69372bf1 100644 --- a/cirq-google/cirq_google/workflow/quantum_runtime.py +++ b/cirq-google/cirq_google/workflow/quantum_runtime.py @@ -23,6 +23,7 @@ from cirq import _compat from cirq.protocols import dataclass_json_dict from cirq_google.workflow._abstract_engine_processor_shim import AbstractEngineProcessorShim +from cirq_google.workflow.progress import _PrintLogger from cirq_google.workflow.quantum_executable import ( ExecutableSpec, QuantumExecutableGroup, @@ -276,10 +277,9 @@ def execute( _update_updatable_files(egr_record, shared_rt_info, data_dir) executable_results = [] - # Loop over executables. sampler = rt_config.processor.get_sampler() - n_executables = len(executable_group) - print() + logger = _PrintLogger(n_total=len(executable_group)) + logger.initialize() for i, exe in enumerate(executable_group): runtime_info = RuntimeInfo(execution_index=i) @@ -305,8 +305,8 @@ def execute( egr_record.executable_result_paths.append(exe_result_path) _update_updatable_files(egr_record, shared_rt_info, data_dir) - print(f'\r{i + 1} / {n_executables}', end='', flush=True) - print() + logger.consume_result(exe_result, shared_rt_info) + logger.finalize() return ExecutableGroupResult( runtime_configuration=rt_config,