Skip to content

Commit

Permalink
Merge pull request #271 from BQSKit/todo-cleanup
Browse files Browse the repository at this point in the history
Todo cleanup
  • Loading branch information
edyounis authored Sep 8, 2024
2 parents 45293ae + f061bef commit e05f947
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 16 deletions.
23 changes: 19 additions & 4 deletions bqskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import atexit
import functools
import logging
import pickle
import signal
import subprocess
import sys
Expand Down Expand Up @@ -312,7 +313,15 @@ def submit(
return task.task_id

def status(self, task_id: uuid.UUID) -> CompilationStatus:
"""Retrieve the status of the specified task."""
"""
Retrieve the status of the specified task.
Args:
task_id (uuid.UUID): The ID of the task to check.
Returns:
CompilationStatus: The status of the task.
"""
msg, payload = self._send_recv(RuntimeMessage.STATUS, task_id)
if msg != RuntimeMessage.STATUS:
raise RuntimeError(f'Unexpected message type: {msg}.')
Expand Down Expand Up @@ -439,9 +448,15 @@ def _recv_handle_log_error(self) -> tuple[RuntimeMessage, Any]:
msg, payload = self.conn.recv()

if msg == RuntimeMessage.LOG:
logger = logging.getLogger(payload.name)
if logger.isEnabledFor(payload.levelno):
logger.handle(payload)
record = pickle.loads(payload)
if isinstance(record, logging.LogRecord):
logger = logging.getLogger(record.name)
if logger.isEnabledFor(record.levelno):
logger.handle(record)
else:
name, levelno, msg = record
logger = logging.getLogger(name)
logger.log(levelno, msg)

elif msg == RuntimeMessage.ERROR:
raise RuntimeError(payload)
Expand Down
9 changes: 7 additions & 2 deletions bqskit/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
:toctree: autogen
:recursive:
DiagonalSynthesisPass
WalshDiagonalSynthesisPass
LEAPSynthesisPass
QSearchSynthesisPass
QFASTDecompositionPass
Expand Down Expand Up @@ -138,6 +138,10 @@
These passes either perform upper-bound error analysis of the PAM process.
.. autosummary::
:toctree: autogen
:recursive:
TagPAMBlockDataPass
CalculatePAMErrorsPass
UnTagPAMBlockDataPass
Expand Down Expand Up @@ -285,6 +289,7 @@
from bqskit.passes.search.heuristics.astar import AStarHeuristic
from bqskit.passes.search.heuristics.dijkstra import DijkstraHeuristic
from bqskit.passes.search.heuristics.greedy import GreedyHeuristic
from bqskit.passes.synthesis.diagonal import WalshDiagonalSynthesisPass
from bqskit.passes.synthesis.leap import LEAPSynthesisPass
from bqskit.passes.synthesis.pas import PermutationAwareSynthesisPass
from bqskit.passes.synthesis.qfast import QFASTDecompositionPass
Expand Down Expand Up @@ -322,7 +327,7 @@
'ScanPartitioner',
'QuickPartitioner',
'SynthesisPass',
'DiagonalSynthesisPass',
'WalshDiagonalSynthesisPass',
'LEAPSynthesisPass',
'QSearchSynthesisPass',
'QFASTDecompositionPass',
Expand Down
6 changes: 4 additions & 2 deletions bqskit/passes/synthesis/leap.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def synthesize(

# Evalute initial layer
if best_dist < self.success_threshold:
_logger.debug('Successful synthesis.')
_logger.debug('Successful synthesis with 0 layers.')
return initial_layer

# Main loop
Expand All @@ -222,7 +222,9 @@ async def synthesize(
dist = self.cost.calc_cost(circuit, utry)

if dist < self.success_threshold:
_logger.debug('Successful synthesis.')
_logger.debug(
f'Successful synthesis with {layer + 1} layers.',
)
if self.store_partial_solutions:
data['psols'] = psols
return circuit
Expand Down
2 changes: 1 addition & 1 deletion bqskit/passes/synthesis/qfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def synthesize(

if dist < self.success_threshold:
self.finalize(circuit, utry, instantiate_options)
_logger.info('Successful synthesis.')
_logger.info(f'Successful synthesis with {depth} layers.')
return circuit

# Expand or restrict head
Expand Down
8 changes: 5 additions & 3 deletions bqskit/passes/synthesis/qsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def synthesize(

# Evalute initial layer
if best_dist < self.success_threshold:
_logger.debug('Successful synthesis.')
_logger.debug('Successful synthesis with 0 layers.')
return initial_layer

# Main loop
Expand All @@ -197,7 +197,9 @@ async def synthesize(
dist = self.cost.calc_cost(circuit, utry)

if dist < self.success_threshold:
_logger.debug('Successful synthesis.')
_logger.debug(
f'Successful synthesis with {layer + 1} layers.',
)
if self.store_partial_solutions:
data['psols'] = psols
return circuit
Expand All @@ -210,7 +212,7 @@ async def synthesize(
)
best_dist = dist
best_circ = circuit
best_layer = layer
best_layer = layer + 1

if self.store_partial_solutions:
if layer not in psols:
Expand Down
8 changes: 7 additions & 1 deletion bqskit/runtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,13 @@ def send_outgoing(self) -> None:
if outgoing[0].closed:
continue

outgoing[0].send((outgoing[1], outgoing[2]))
try:
outgoing[0].send((outgoing[1], outgoing[2]))
except (EOFError, ConnectionResetError):
self.handle_disconnect(outgoing[0])
_logger.warning('Connection reset while sending message.')
continue

if _logger.isEnabledFor(logging.DEBUG):
to = self.get_to_string(outgoing[0])
_logger.debug(f'Sent message {outgoing[1].name} to {to}.')
Expand Down
20 changes: 18 additions & 2 deletions bqskit/runtime/detached.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import time
import uuid
from dataclasses import dataclass
from logging import LogRecord
from multiprocessing.connection import Connection
from multiprocessing.connection import Listener
from threading import Thread
Expand Down Expand Up @@ -256,8 +255,19 @@ def handle_disconnect(self, conn: Connection) -> None:
"""Disconnect a client connection from the runtime."""
super().handle_disconnect(conn)
tasks = self.clients.pop(conn)

for task_id in tasks:
self.handle_cancel_comp_task(task_id)

tasks_to_pop = []
for (task, (tid, other_conn)) in self.tasks.items():
if other_conn == conn:
tasks_to_pop.append((task_id, tid))

for task_id, tid in tasks_to_pop:
self.tasks.pop(task_id)
self.mailbox_to_task_dict.pop(tid)

_logger.info('Unregistered client.')

def handle_new_comp_task(
Expand Down Expand Up @@ -386,6 +396,9 @@ def handle_error(self, error_payload: tuple[int, str]) -> None:
raise RuntimeError(error_payload)

tid = error_payload[0]
if tid not in self.mailbox_to_task_dict:
return # Silently discard errors from cancelled tasks

conn = self.tasks[self.mailbox_to_task_dict[tid]][1]
self.outgoing.put((conn, RuntimeMessage.ERROR, error_payload[1]))
# TODO: Broadcast cancel to all tasks with compilation task id tid
Expand All @@ -395,9 +408,12 @@ def handle_error(self, error_payload: tuple[int, str]) -> None:
# still cancel here incase the client catches the error and
# resubmits a job.

def handle_log(self, log_payload: tuple[int, LogRecord]) -> None:
def handle_log(self, log_payload: tuple[int, bytes]) -> None:
"""Forward logs to appropriate client."""
tid = log_payload[0]
if tid not in self.mailbox_to_task_dict:
return # Silently discard logs from cancelled tasks

conn = self.tasks[self.mailbox_to_task_dict[tid]][1]
self.outgoing.put((conn, RuntimeMessage.LOG, log_payload[1]))

Expand Down
11 changes: 10 additions & 1 deletion bqskit/runtime/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import argparse
import logging
import os
import pickle
import signal
import sys
import time
Expand Down Expand Up @@ -225,7 +226,15 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
record.msg += con_str
record.msg += ']'
tid = active_task.comp_task_id
self._conn.send((RuntimeMessage.LOG, (tid, record)))
try:
serial = pickle.dumps(record)
except (pickle.PicklingError, TypeError):
serial = pickle.dumps((
record.name,
record.levelno,
record.getMessage(),
))
self._conn.send((RuntimeMessage.LOG, (tid, serial)))
return record

logging.setLogRecordFactory(record_factory)
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'myst_parser',
'jupyter_sphinx',
'nbsphinx',
'sphinx_autodoc_typehints',
]

# Add any paths that contain templates here, relative to this directory.
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Sphinx>=4.5.0
sphinx-autodoc-typehints>=1.12.0
sphinx-rtd-theme>=1.0.0
sphinx-togglebutton>=0.2.3
sphinx-autodoc-typehints>=2.3.0
sphinxcontrib-applehelp>=1.0.2
sphinxcontrib-devhelp>=1.0.2
sphinxcontrib-htmlhelp>=2.0.0
Expand Down
51 changes: 51 additions & 0 deletions tests/runtime/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from __future__ import annotations

import logging
import pickle
from io import StringIO
from typing import Any

import pytest

Expand Down Expand Up @@ -143,6 +145,55 @@ def test_using_external_logging(server_compiler: Compiler) -> None:
logger.setLevel(logging.WARNING)


class ExternalWithArgsPass(BasePass):
async def run(self, circuit: Circuit, data: PassData) -> None:
logging.getLogger('dummy2').debug('int %d', 1)


def test_external_logging_with_args(server_compiler: Compiler) -> None:
logger = logging.getLogger('dummy2')
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(StringIO())
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)
server_compiler.compile(Circuit(1), [ExternalWithArgsPass()])
log = handler.stream.getvalue()
assert 'int 1' in log
logger.removeHandler(handler)
logger.setLevel(logging.WARNING)


class NonSerializable:
def __reduce__(self) -> str | tuple[Any, ...]:
raise pickle.PicklingError('This class is not serializable')

def __str__(self) -> str:
return 'NonSerializable'


class ExternalWithNonSerializableArgsPass(BasePass):
async def run(self, circuit: Circuit, data: PassData) -> None:
logging.getLogger('dummy2').debug(
'NonSerializable %s',
NonSerializable(),
)


def test_external_logging_with_nonserializable_args(
server_compiler: Compiler,
) -> None:
logger = logging.getLogger('dummy2')
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(StringIO())
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)
server_compiler.compile(Circuit(1), [ExternalWithNonSerializableArgsPass()])
log = handler.stream.getvalue()
assert 'NonSerializable NonSerializable' in log
logger.removeHandler(handler)
logger.setLevel(logging.WARNING)


@pytest.mark.parametrize('level', [-1, 0, 1, 2, 3, 4])
def test_limiting_nested_calls_enable_logging(
server_compiler: Compiler,
Expand Down

0 comments on commit e05f947

Please sign in to comment.