Skip to content

Commit

Permalink
[cirqflow] Factor out PrintLogger (#4613)
Browse files Browse the repository at this point in the history
In anticipation of making this more configurable, define an abstract base class for logging/progress. This factors out the existing lightweight "print" based logging. In the future, we can accept objects satisfying this interface in the signature for `execute`.
  • Loading branch information
mpharrigan authored Nov 9, 2021
1 parent 7c36cef commit cf06adf
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 5 deletions.
64 changes: 64 additions & 0 deletions cirq-google/cirq_google/workflow/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Progress and logging facilities for the quantum runtime."""

import abc
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import cirq_google as cg


class _WorkflowLogger(abc.ABC):
"""Implementers of this class can provide logging and progress information
for execution loops."""

def initialize(self):
"""Initialization logic at the start of an execution loop."""

def consume_result(
self, exe_result: 'cg.ExecutableResult', shared_rt_info: 'cg.SharedRuntimeInfo'
):
"""Consume executable results as they are completed.
Args:
exe_result: The completed `cg.ExecutableResult`.
shared_rt_info: A reference to the `cg.SharedRuntimeInfo` for this
execution at this point.
"""

def finalize(self):
"""Finalization logic at the end of an execution loop."""


class _PrintLogger(_WorkflowLogger):
def __init__(self, n_total: int):
self.n_total = n_total
self.i = 0

def initialize(self):
"""Write a newline at the start of an execution loop."""
print()

def consume_result(
self, exe_result: 'cg.ExecutableResult', shared_rt_info: 'cg.SharedRuntimeInfo'
):
"""Print a simple count of completed executables."""
print(f'\r{self.i + 1} / {self.n_total}', end='', flush=True)
self.i += 1

def finalize(self):
"""Write a newline at the end of an execution loop."""
print()
43 changes: 43 additions & 0 deletions cirq-google/cirq_google/workflow/progress_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq
import cirq_google as cg
from cirq_google.workflow.progress import _PrintLogger


def test_print_logger(capsys):
pl = _PrintLogger(n_total=10)
shared_rt_info = cg.SharedRuntimeInfo(run_id='hi mom')
pl.initialize()
for i in range(10):
exe_result = cg.ExecutableResult(
spec=None,
runtime_info=cg.RuntimeInfo(execution_index=i),
raw_data=cirq.Result(params=cirq.ParamResolver({}), measurements={}),
)
pl.consume_result(exe_result, shared_rt_info)
pl.finalize()
assert capsys.readouterr().out == (
'\n\r1 / 10'
'\r2 / 10'
'\r3 / 10'
'\r4 / 10'
'\r5 / 10'
'\r6 / 10'
'\r7 / 10'
'\r8 / 10'
'\r9 / 10'
'\r10 / 10\n'
)
10 changes: 5 additions & 5 deletions cirq-google/cirq_google/workflow/quantum_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from cirq import _compat
from cirq.protocols import dataclass_json_dict
from cirq_google.workflow._abstract_engine_processor_shim import AbstractEngineProcessorShim
from cirq_google.workflow.progress import _PrintLogger
from cirq_google.workflow.quantum_executable import (
ExecutableSpec,
QuantumExecutableGroup,
Expand Down Expand Up @@ -276,10 +277,9 @@ def execute(
_update_updatable_files(egr_record, shared_rt_info, data_dir)
executable_results = []

# Loop over executables.
sampler = rt_config.processor.get_sampler()
n_executables = len(executable_group)
print()
logger = _PrintLogger(n_total=len(executable_group))
logger.initialize()
for i, exe in enumerate(executable_group):
runtime_info = RuntimeInfo(execution_index=i)

Expand All @@ -305,8 +305,8 @@ def execute(
egr_record.executable_result_paths.append(exe_result_path)

_update_updatable_files(egr_record, shared_rt_info, data_dir)
print(f'\r{i + 1} / {n_executables}', end='', flush=True)
print()
logger.consume_result(exe_result, shared_rt_info)
logger.finalize()

return ExecutableGroupResult(
runtime_configuration=rt_config,
Expand Down

0 comments on commit cf06adf

Please sign in to comment.