diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 8f58f3c35..6127c3745 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -4,6 +4,7 @@ import atexit import functools import logging +import pickle import signal import subprocess import sys @@ -312,7 +313,15 @@ def submit( return task.task_id def status(self, task_id: uuid.UUID) -> CompilationStatus: - """Retrieve the status of the specified task.""" + """ + Retrieve the status of the specified task. + + Args: + task_id (uuid.UUID): The ID of the task to check. + + Returns: + CompilationStatus: The status of the task. + """ msg, payload = self._send_recv(RuntimeMessage.STATUS, task_id) if msg != RuntimeMessage.STATUS: raise RuntimeError(f'Unexpected message type: {msg}.') @@ -439,9 +448,15 @@ def _recv_handle_log_error(self) -> tuple[RuntimeMessage, Any]: msg, payload = self.conn.recv() if msg == RuntimeMessage.LOG: - logger = logging.getLogger(payload.name) - if logger.isEnabledFor(payload.levelno): - logger.handle(payload) + record = pickle.loads(payload) + if isinstance(record, logging.LogRecord): + logger = logging.getLogger(record.name) + if logger.isEnabledFor(record.levelno): + logger.handle(record) + else: + name, levelno, msg = record + logger = logging.getLogger(name) + logger.log(levelno, msg) elif msg == RuntimeMessage.ERROR: raise RuntimeError(payload) diff --git a/bqskit/passes/__init__.py b/bqskit/passes/__init__.py index 9a386f6fb..389d42ec5 100644 --- a/bqskit/passes/__init__.py +++ b/bqskit/passes/__init__.py @@ -28,7 +28,7 @@ :toctree: autogen :recursive: - DiagonalSynthesisPass + WalshDiagonalSynthesisPass LEAPSynthesisPass QSearchSynthesisPass QFASTDecompositionPass @@ -138,6 +138,10 @@ These passes either perform upper-bound error analysis of the PAM process. +.. autosummary:: + :toctree: autogen + :recursive: + TagPAMBlockDataPass CalculatePAMErrorsPass UnTagPAMBlockDataPass @@ -285,6 +289,7 @@ from bqskit.passes.search.heuristics.astar import AStarHeuristic from bqskit.passes.search.heuristics.dijkstra import DijkstraHeuristic from bqskit.passes.search.heuristics.greedy import GreedyHeuristic +from bqskit.passes.synthesis.diagonal import WalshDiagonalSynthesisPass from bqskit.passes.synthesis.leap import LEAPSynthesisPass from bqskit.passes.synthesis.pas import PermutationAwareSynthesisPass from bqskit.passes.synthesis.qfast import QFASTDecompositionPass @@ -322,7 +327,7 @@ 'ScanPartitioner', 'QuickPartitioner', 'SynthesisPass', - 'DiagonalSynthesisPass', + 'WalshDiagonalSynthesisPass', 'LEAPSynthesisPass', 'QSearchSynthesisPass', 'QFASTDecompositionPass', diff --git a/bqskit/passes/synthesis/leap.py b/bqskit/passes/synthesis/leap.py index da7475fae..f05300eef 100644 --- a/bqskit/passes/synthesis/leap.py +++ b/bqskit/passes/synthesis/leap.py @@ -196,7 +196,7 @@ async def synthesize( # Evalute initial layer if best_dist < self.success_threshold: - _logger.debug('Successful synthesis.') + _logger.debug('Successful synthesis with 0 layers.') return initial_layer # Main loop @@ -222,7 +222,9 @@ async def synthesize( dist = self.cost.calc_cost(circuit, utry) if dist < self.success_threshold: - _logger.debug('Successful synthesis.') + _logger.debug( + f'Successful synthesis with {layer + 1} layers.', + ) if self.store_partial_solutions: data['psols'] = psols return circuit diff --git a/bqskit/passes/synthesis/qfast.py b/bqskit/passes/synthesis/qfast.py index 9a72bd79c..e4e036fb1 100644 --- a/bqskit/passes/synthesis/qfast.py +++ b/bqskit/passes/synthesis/qfast.py @@ -164,7 +164,7 @@ async def synthesize( if dist < self.success_threshold: self.finalize(circuit, utry, instantiate_options) - _logger.info('Successful synthesis.') + _logger.info(f'Successful synthesis with {depth} layers.') return circuit # Expand or restrict head diff --git a/bqskit/passes/synthesis/qsearch.py b/bqskit/passes/synthesis/qsearch.py index 13276ad82..9cad4fc44 100644 --- a/bqskit/passes/synthesis/qsearch.py +++ b/bqskit/passes/synthesis/qsearch.py @@ -171,7 +171,7 @@ async def synthesize( # Evalute initial layer if best_dist < self.success_threshold: - _logger.debug('Successful synthesis.') + _logger.debug('Successful synthesis with 0 layers.') return initial_layer # Main loop @@ -197,7 +197,9 @@ async def synthesize( dist = self.cost.calc_cost(circuit, utry) if dist < self.success_threshold: - _logger.debug('Successful synthesis.') + _logger.debug( + f'Successful synthesis with {layer + 1} layers.', + ) if self.store_partial_solutions: data['psols'] = psols return circuit @@ -210,7 +212,7 @@ async def synthesize( ) best_dist = dist best_circ = circuit - best_layer = layer + best_layer = layer + 1 if self.store_partial_solutions: if layer not in psols: diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index d10996e49..bcc9f03fc 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -410,7 +410,13 @@ def send_outgoing(self) -> None: if outgoing[0].closed: continue - outgoing[0].send((outgoing[1], outgoing[2])) + try: + outgoing[0].send((outgoing[1], outgoing[2])) + except (EOFError, ConnectionResetError): + self.handle_disconnect(outgoing[0]) + _logger.warning('Connection reset while sending message.') + continue + if _logger.isEnabledFor(logging.DEBUG): to = self.get_to_string(outgoing[0]) _logger.debug(f'Sent message {outgoing[1].name} to {to}.') diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 8740c7170..ea32afbd6 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -8,7 +8,6 @@ import time import uuid from dataclasses import dataclass -from logging import LogRecord from multiprocessing.connection import Connection from multiprocessing.connection import Listener from threading import Thread @@ -256,8 +255,19 @@ def handle_disconnect(self, conn: Connection) -> None: """Disconnect a client connection from the runtime.""" super().handle_disconnect(conn) tasks = self.clients.pop(conn) + for task_id in tasks: self.handle_cancel_comp_task(task_id) + + tasks_to_pop = [] + for (task, (tid, other_conn)) in self.tasks.items(): + if other_conn == conn: + tasks_to_pop.append((task_id, tid)) + + for task_id, tid in tasks_to_pop: + self.tasks.pop(task_id) + self.mailbox_to_task_dict.pop(tid) + _logger.info('Unregistered client.') def handle_new_comp_task( @@ -386,6 +396,9 @@ def handle_error(self, error_payload: tuple[int, str]) -> None: raise RuntimeError(error_payload) tid = error_payload[0] + if tid not in self.mailbox_to_task_dict: + return # Silently discard errors from cancelled tasks + conn = self.tasks[self.mailbox_to_task_dict[tid]][1] self.outgoing.put((conn, RuntimeMessage.ERROR, error_payload[1])) # TODO: Broadcast cancel to all tasks with compilation task id tid @@ -395,9 +408,12 @@ def handle_error(self, error_payload: tuple[int, str]) -> None: # still cancel here incase the client catches the error and # resubmits a job. - def handle_log(self, log_payload: tuple[int, LogRecord]) -> None: + def handle_log(self, log_payload: tuple[int, bytes]) -> None: """Forward logs to appropriate client.""" tid = log_payload[0] + if tid not in self.mailbox_to_task_dict: + return # Silently discard logs from cancelled tasks + conn = self.tasks[self.mailbox_to_task_dict[tid]][1] self.outgoing.put((conn, RuntimeMessage.LOG, log_payload[1])) diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 1684f7dbc..e61b13009 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -4,6 +4,7 @@ import argparse import logging import os +import pickle import signal import sys import time @@ -225,7 +226,15 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: record.msg += con_str record.msg += ']' tid = active_task.comp_task_id - self._conn.send((RuntimeMessage.LOG, (tid, record))) + try: + serial = pickle.dumps(record) + except (pickle.PicklingError, TypeError): + serial = pickle.dumps(( + record.name, + record.levelno, + record.getMessage(), + )) + self._conn.send((RuntimeMessage.LOG, (tid, serial))) return record logging.setLogRecordFactory(record_factory) diff --git a/docs/conf.py b/docs/conf.py index f48262b93..45036da82 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,6 +38,7 @@ 'myst_parser', 'jupyter_sphinx', 'nbsphinx', + 'sphinx_autodoc_typehints', ] # Add any paths that contain templates here, relative to this directory. diff --git a/docs/requirements.txt b/docs/requirements.txt index 45bb175e7..a88d4eaa7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,6 +2,7 @@ Sphinx>=4.5.0 sphinx-autodoc-typehints>=1.12.0 sphinx-rtd-theme>=1.0.0 sphinx-togglebutton>=0.2.3 +sphinx-autodoc-typehints>=2.3.0 sphinxcontrib-applehelp>=1.0.2 sphinxcontrib-devhelp>=1.0.2 sphinxcontrib-htmlhelp>=2.0.0 diff --git a/tests/runtime/test_logging.py b/tests/runtime/test_logging.py index bee4bfb0a..4c8694439 100644 --- a/tests/runtime/test_logging.py +++ b/tests/runtime/test_logging.py @@ -2,7 +2,9 @@ from __future__ import annotations import logging +import pickle from io import StringIO +from typing import Any import pytest @@ -143,6 +145,55 @@ def test_using_external_logging(server_compiler: Compiler) -> None: logger.setLevel(logging.WARNING) +class ExternalWithArgsPass(BasePass): + async def run(self, circuit: Circuit, data: PassData) -> None: + logging.getLogger('dummy2').debug('int %d', 1) + + +def test_external_logging_with_args(server_compiler: Compiler) -> None: + logger = logging.getLogger('dummy2') + logger.setLevel(logging.DEBUG) + handler = logging.StreamHandler(StringIO()) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + server_compiler.compile(Circuit(1), [ExternalWithArgsPass()]) + log = handler.stream.getvalue() + assert 'int 1' in log + logger.removeHandler(handler) + logger.setLevel(logging.WARNING) + + +class NonSerializable: + def __reduce__(self) -> str | tuple[Any, ...]: + raise pickle.PicklingError('This class is not serializable') + + def __str__(self) -> str: + return 'NonSerializable' + + +class ExternalWithNonSerializableArgsPass(BasePass): + async def run(self, circuit: Circuit, data: PassData) -> None: + logging.getLogger('dummy2').debug( + 'NonSerializable %s', + NonSerializable(), + ) + + +def test_external_logging_with_nonserializable_args( + server_compiler: Compiler, +) -> None: + logger = logging.getLogger('dummy2') + logger.setLevel(logging.DEBUG) + handler = logging.StreamHandler(StringIO()) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + server_compiler.compile(Circuit(1), [ExternalWithNonSerializableArgsPass()]) + log = handler.stream.getvalue() + assert 'NonSerializable NonSerializable' in log + logger.removeHandler(handler) + logger.setLevel(logging.WARNING) + + @pytest.mark.parametrize('level', [-1, 0, 1, 2, 3, 4]) def test_limiting_nested_calls_enable_logging( server_compiler: Compiler,