From c1f00575ca2cb4bb05875e0b300f9108494e2878 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 28 Mar 2024 08:12:40 -0400 Subject: [PATCH 01/44] Split worker into two threads --- bqskit/__init__.py | 1 + bqskit/runtime/__init__.py | 2 +- bqskit/runtime/base.py | 20 ++++- bqskit/runtime/manager.py | 8 +- bqskit/runtime/task.py | 10 ++- bqskit/runtime/worker.py | 155 +++++++++++++++++++++---------------- 6 files changed, 122 insertions(+), 74 deletions(-) diff --git a/bqskit/__init__.py b/bqskit/__init__.py index dd6a87128..da01938ba 100644 --- a/bqskit/__init__.py +++ b/bqskit/__init__.py @@ -4,6 +4,7 @@ import logging from sys import stdout as _stdout +import bqskit.runtime from .version import __version__ # noqa: F401 from .version import __version_info__ # noqa: F401 from bqskit.compiler.compile import compile diff --git a/bqskit/runtime/__init__.py b/bqskit/runtime/__init__.py index 31764cb46..3477c3572 100644 --- a/bqskit/runtime/__init__.py +++ b/bqskit/runtime/__init__.py @@ -111,7 +111,7 @@ os.environ['NUMEXPR_NUM_THREADS'] = '1' os.environ['VECLIB_MAXIMUM_THREADS'] = '1' os.environ['RUST_BACKTRACE'] = '1' - +print("SETTING THREADS TO 1") if TYPE_CHECKING: from bqskit.runtime.future import RuntimeFuture diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 17cdf2747..f46b95187 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -71,6 +71,7 @@ def has_idle_resources(self) -> bool: def send_outgoing(node: ServerBase) -> None: """Outgoing thread forwards messages as they are created.""" while True: + node.logger.debug('Waiting to send outgoing message...') outgoing = node.outgoing.get() if not node.running: @@ -80,9 +81,16 @@ def send_outgoing(node: ServerBase) -> None: # while condition. break + node.logger.debug(f'Sending message {outgoing[1].name}...') outgoing[0].send((outgoing[1], outgoing[2])) node.logger.debug(f'Sent message {outgoing[1].name}.') - node.logger.log(1, f'{outgoing[2]}\n') + + if outgoing[1] == RuntimeMessage.SUBMIT_BATCH: + node.logger.log(1, f'{len(outgoing[2])}\n') + else: + node.logger.log(1, f'{outgoing[2]}\n') + + node.outgoing.task_done() def sigint_handler(signum: int, _: FrameType | None, node: ServerBase) -> None: @@ -347,6 +355,7 @@ def run(self) -> None: try: while self.running: # Wait for messages + self.logger.debug('Waiting for messages...') events = self.sel.select() # Say that 5 times fast for key, _ in events: @@ -368,7 +377,10 @@ def run(self) -> None: continue log = f'Received message {msg.name} from {direction.name}.' self.logger.debug(log) - self.logger.log(1, f'{payload}\n') + if msg == RuntimeMessage.SUBMIT_BATCH: + self.logger.log(1, f'{len(payload)}\n') + else: + self.logger.log(1, f'{payload}\n') # Handle message self.handle_message(msg, direction, conn, payload) @@ -513,9 +525,10 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: """Schedule tasks between this node's employees.""" if len(tasks) == 0: return - + self.logger.info(f'Scheduling {len(tasks)} tasks with {self.num_idle_workers} idle workers.') assignments = self.assign_tasks(tasks) + # for e, assignment in sorted(zip(self.employees, assignments), key=lambda x: x[0].num_idle_workers, reverse=True): for e, assignment in zip(self.employees, assignments): num_tasks = len(assignment) @@ -528,6 +541,7 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: e.num_idle_workers -= min(num_tasks, e.num_idle_workers) self.num_idle_workers = sum(e.num_idle_workers for e in self.employees) + self.logger.info(f'Finished scheduling {len(tasks)} tasks with now {self.num_idle_workers} idle workers.') def send_result_down(self, result: RuntimeResult) -> None: """Send the `result` to the appropriate employee.""" diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 7779f47ca..476f73c32 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -95,14 +95,16 @@ def __init__( MessageDirection.ABOVE, ) - # Case 1: spawn and manage workers + # Case 1: spawn and/or manage workers if ipports is None: if only_connect: self.connect_to_workers(num_workers, worker_port) else: + print('Spawning workers...') + print(f'Number of workers: {num_workers}') self.spawn_workers(num_workers, worker_port) - # Case 2: Connect to managers at ipports + # Case 2: Connect to detached managers at ipports else: self.connect_to_managers(ipports) @@ -122,6 +124,7 @@ def handle_message( payload: Any, ) -> None: """Process the message coming from `direction`.""" + self.logger.debug(f'Manager handling message {msg.name} from {direction.name}.') if direction == MessageDirection.ABOVE: if msg == RuntimeMessage.SUBMIT: @@ -133,6 +136,7 @@ def handle_message( rtasks = cast(List[RuntimeTask], payload) self.schedule_tasks(rtasks) self.update_upstream_idle_workers() + self.logger.debug(f'Finished handling submit batch from above.') elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index b55037a87..7983633ee 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -28,7 +28,7 @@ class RuntimeTask: def __init__( self, - fnargs: tuple[Any, Any, Any], + fnargs: tuple[Any, Any, Any], # TODO: Look into retyping this return_address: RuntimeAddress, comp_task_id: int, breadcrumbs: tuple[RuntimeAddress, ...], @@ -110,3 +110,11 @@ async def run(self) -> Any: def is_descendant_of(self, addr: RuntimeAddress) -> bool: """Return true if `addr` identifies a parent (or this) task.""" return addr == self.return_address or addr in self.breadcrumbs + + def __str__(self) -> str: + """Return a string representation of the task.""" + return f'{self.fnargs[0].__name__}' + + def __repr__(self) -> str: + """Return a string representation of the task.""" + return f'' diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 5c4b3ca18..1c92c5f27 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -14,6 +14,9 @@ from multiprocessing.connection import Client from multiprocessing.connection import Connection from multiprocessing.connection import wait +from threading import Thread +from queue import Queue +from queue import Empty from typing import Any from typing import Callable from typing import cast @@ -122,15 +125,50 @@ def deposit_result(self, result: RuntimeResult) -> None: self.result[slot_id] = result.result +def handle_incoming_comms(worker: Worker) -> None: + """Handle all incoming messages.""" + while True: + # Handle incomming communication + msg, payload = worker._conn.recv() + + # Process message + if msg == RuntimeMessage.SHUTDOWN: + worker._running = False + return + + elif msg == RuntimeMessage.SUBMIT: + task = cast(RuntimeTask, payload) + worker._add_task(task) + + elif msg == RuntimeMessage.SUBMIT_BATCH: + tasks = cast(List[RuntimeTask], payload) + worker._add_task(tasks.pop()) # Submit one task + worker._delayed_tasks.extend(tasks) # Delay rest + # Delayed tasks have no context and are stored (more-or-less) + # as a function pointer together with the arguments. + # When it gets started, it consumes much more memory, + # so we delay the task start until necessary (at no cost) + + elif msg == RuntimeMessage.RESULT: + result = cast(RuntimeResult, payload) + worker._handle_result(result) + + elif msg == RuntimeMessage.CANCEL: + addr = cast(RuntimeAddress, payload) + worker._handle_cancel(addr) + + class Worker: """ BQSKit Runtime's Worker. - BQSKit Runtime utilizes a single-threaded worker to accept, execute, + BQSKit Runtime utilizes a dual-threaded worker to accept, execute, pause, spawn, resume, and complete tasks in a custom event loop built with python's async await mechanisms. Each worker receives and sends tasks and results to the greater system through a single duplex - connection with a runtime server or manager. + connection with a runtime server or manager. One thread performs + work and sends outgoing messages, while the other thread handles + incoming messages. At start-up, the worker receives an ID and waits for its first task. An executing task may use the `submit` and `map` methods to spawn child @@ -178,8 +216,9 @@ def __init__(self, id: int, conn: Connection) -> None: self._id = id self._conn = conn - self._outgoing: list[tuple[RuntimeMessage, Any]] = [] - """Stores outgoing messages to be handled by the event loop.""" + # self._outgoing: list[tuple[RuntimeMessage, Any]] = [] + # self._outgoing: Queue[tuple[RuntimeMessage, Any]] = Queue() + # """Stores outgoing messages to be handled by the event loop.""" self._tasks: dict[RuntimeAddress, RuntimeTask] = {} """Tracks all started, unfinished tasks on this worker.""" @@ -187,7 +226,8 @@ def __init__(self, id: int, conn: Connection) -> None: self._delayed_tasks: list[RuntimeTask] = [] """Store all delayed tasks in LIFO order.""" - self._ready_task_ids: WorkerQueue = WorkerQueue() + # self._ready_task_ids: WorkerQueue = WorkerQueue() + self._ready_task_ids: Queue[RuntimeAddress] = Queue() """Tasks queued up for execution.""" self._cancelled_task_ids: set[RuntimeAddress] = set() @@ -208,7 +248,7 @@ def __init__(self, id: int, conn: Connection) -> None: self._cache: dict[str, Any] = {} """Local worker cache.""" - # Send out every emitted log message upstream + # Send out every client emitted log message upstream old_factory = logging.getLogRecordFactory() def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: @@ -218,11 +258,19 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: lvl = active_task.logging_level if lvl is None or lvl <= record.levelno: tid = active_task.comp_task_id - self._outgoing.append((RuntimeMessage.LOG, (tid, record))) + self._conn.send((RuntimeMessage.LOG, (tid, record))) return record logging.setLogRecordFactory(record_factory) + # Start incoming thread + self.incomming_thread = Thread( + target=handle_incoming_comms, + args=(self,), + ) + self.incomming_thread.start() + # self.logger.info('Started incoming thread.') + # Communicate that this worker is ready self._conn.send((RuntimeMessage.STARTED, self._id)) @@ -231,8 +279,8 @@ def _loop(self) -> None: self._running = True while self._running: self._try_step_next_ready_task() - self._try_idle() - self._handle_comms() + # self._try_idle() + # self._handle_comms() def _try_idle(self) -> None: """If there is nothing to do, wait until we receive a message.""" @@ -244,44 +292,12 @@ def _try_idle(self) -> None: self._conn.send((RuntimeMessage.WAITING, 1)) wait([self._conn]) - def _handle_comms(self) -> None: - """Handle all incoming and outgoing messages.""" - - # Handle outgoing communication + def _flush_outgoing_comms(self) -> None: + """Handle all outgoing messages.""" for out_msg in self._outgoing: self._conn.send(out_msg) self._outgoing.clear() - # Handle incomming communication - while self._conn.poll(): - msg, payload = self._conn.recv() - - # Process message - if msg == RuntimeMessage.SHUTDOWN: - self._running = False - return - - elif msg == RuntimeMessage.SUBMIT: - task = cast(RuntimeTask, payload) - self._add_task(task) - - elif msg == RuntimeMessage.SUBMIT_BATCH: - tasks = cast(List[RuntimeTask], payload) - self._add_task(tasks.pop()) # Submit one task - self._delayed_tasks.extend(tasks) # Delay rest - # Delayed tasks have no context and are stored (more-or-less) - # as a function pointer together with the arguments. - # When it gets started, it consumes much more memory, - # so we delay the task start until necessary (at no cost) - - elif msg == RuntimeMessage.RESULT: - result = cast(RuntimeResult, payload) - self._handle_result(result) - - elif msg == RuntimeMessage.CANCEL: - addr = cast(RuntimeAddress, payload) - self._handle_cancel(addr) - def _add_task(self, task: RuntimeTask) -> None: """Start a task and add it to the loop.""" self._tasks[task.return_address] = task @@ -290,8 +306,9 @@ def _add_task(self, task: RuntimeTask) -> None: def _handle_result(self, result: RuntimeResult) -> None: """Insert result into appropriate mailbox and wake waiting task.""" - mailbox_id = result.return_address.mailbox_index assert result.return_address.worker_id == self._id + + mailbox_id = result.return_address.mailbox_index if mailbox_id not in self._mailboxes: # If the mailbox has been dropped due to a cancel, ignore result return @@ -338,16 +355,23 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: if not t.is_descendant_of(addr) ] - def _get_next_ready_task(self) -> RuntimeTask | None: - """Return the next ready task if one exists, otherwise None.""" + def _get_next_ready_task(self) -> RuntimeTask: + """Return the next ready task if one exists, otherwise block.""" while True: - if self._ready_task_ids.empty(): - if len(self._delayed_tasks) > 0: - self._add_task(self._delayed_tasks.pop()) - continue - return None + if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: + self._add_task(self._delayed_tasks.pop()) + continue - addr = self._ready_task_ids.get() + try: + addr = self._ready_task_ids.get_nowait() + except Empty: + # TODO: evaluate race condition here: + # If the incoming comms thread adds a task to the ready queue + # after this check, then the worker will have incorrectly + # sent a waiting message to the manager. + # TODO: consider some lock mechanism to prevent this? + self._conn.send((RuntimeMessage.WAITING, 1)) + addr = self._ready_task_ids.get() if addr in self._cancelled_task_ids or addr not in self._tasks: # When a task is cancelled on the worker it is not removed @@ -362,6 +386,7 @@ def _get_next_ready_task(self) -> RuntimeTask | None: # then discard this one too. Each breadcrumb (bcb) is a # task address (unique system-wide task id) of an ancestor # task. + # TODO: do I need to manually remove addr from self._tasks? continue return task @@ -370,10 +395,6 @@ def _try_step_next_ready_task(self) -> None: """Select a task to run, and advance it one step.""" task = self._get_next_ready_task() - if task is None: - # Nothing to do - return - try: self._active_task = task @@ -392,7 +413,7 @@ def _try_step_next_ready_task(self) -> None: exc_info = sys.exc_info() error_str = ''.join(traceback.format_exception(*exc_info)) error_payload = (self._active_task.comp_task_id, error_str) - self._outgoing.append((RuntimeMessage.ERROR, error_payload)) + self._conn.send((RuntimeMessage.ERROR, error_payload)) finally: self._active_task = None @@ -428,11 +449,11 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: if task.return_address.worker_id == self._id: self._handle_result(packaged_result) - self._outgoing.append((RuntimeMessage.UPDATE, -1)) + self._conn.send((RuntimeMessage.UPDATE, -1)) # Let manager know this worker has one less task # without sending a result else: - self._outgoing.append((RuntimeMessage.RESULT, packaged_result)) + self._conn.send((RuntimeMessage.RESULT, packaged_result)) # Remove task self._tasks.pop(task.return_address) @@ -448,10 +469,6 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: # Otherwise send a cancel message self.cancel(RuntimeFuture(mailbox_id)) - # Start delayed task - if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: - self._add_task(self._delayed_tasks.pop()) - def _get_desired_result(self, task: RuntimeTask) -> Any: """Retrieve the task's desired result from the mailboxes.""" if task.desired_box_id is None: @@ -501,7 +518,7 @@ def submit( ) # Submit the task (on the next cycle) - self._outgoing.append((RuntimeMessage.SUBMIT, task)) + self._conn.send((RuntimeMessage.SUBMIT, task)) # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) @@ -548,7 +565,7 @@ def map( ] # Submit the tasks - self._outgoing.append((RuntimeMessage.SUBMIT_BATCH, tasks)) + self._conn.send((RuntimeMessage.SUBMIT_BATCH, tasks)) # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) @@ -563,8 +580,8 @@ def cancel(self, future: RuntimeFuture) -> None: RuntimeAddress(self._id, future.mailbox_id, slot_id) for slot_id in range(num_slots) ] - msgs = [(RuntimeMessage.CANCEL, addr) for addr in addrs] - self._outgoing.extend(msgs) + for addr in addrs: + self._conn.send((RuntimeMessage.CANCEL, addr)) def get_cache(self) -> dict[str, Any]: """ @@ -618,11 +635,13 @@ def start_worker(w_id: int | None, port: int, cpu: int | None = None) -> None: logger.handlers.clear() logging.Logger.manager.loggerDict = {} + # Pin worker to cpu if cpu is not None: if sys.platform == 'win32': raise RuntimeError('Cannot pin worker to cpu on windows.') os.sched_setaffinity(0, [cpu]) + # Connect to manager max_retries = 7 wait_time = .1 conn: Connection | None = None @@ -639,10 +658,12 @@ def start_worker(w_id: int | None, port: int, cpu: int | None = None) -> None: if conn is None: raise RuntimeError('Unable to establish connection with manager.') + # If id isn't provided, wait for assignment if w_id is None: msg, w_id = conn.recv() assert msg == RuntimeMessage.STARTED + # Build and start worker global _worker _worker = Worker(w_id, conn) _worker._loop() From 4156b8cb757bb8e6463ba4239780efb9c5638061 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 28 Mar 2024 17:47:00 -0400 Subject: [PATCH 02/44] Removed race condition in idle tracking --- bqskit/runtime/base.py | 63 ++++++++++++++++++++++++++------------ bqskit/runtime/detached.py | 4 +-- bqskit/runtime/manager.py | 25 ++++++++------- bqskit/runtime/task.py | 10 +++++- bqskit/runtime/worker.py | 26 ++++++++++++---- 5 files changed, 89 insertions(+), 39 deletions(-) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index f46b95187..3e6bd47b2 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -41,15 +41,21 @@ def __init__( conn: Connection, total_workers: int, process: Process | None = None, - num_tasks: int = 0, ) -> None: """Construct an employee with all resources idle.""" self.conn: Connection = conn self.total_workers = total_workers self.process = process - self.num_tasks = num_tasks + self.num_tasks = 0 self.num_idle_workers = total_workers + self.submit_cache: list[tuple[RuntimeAddress, int]] = [] + """ + Tracks recently submitted tasks by id and count. + This is used to adjust the idle worker count when + the employee sends a waiting message. + """ + def shutdown(self) -> None: """Shutdown the employee.""" try: @@ -67,11 +73,25 @@ def shutdown(self) -> None: def has_idle_resources(self) -> bool: return self.num_idle_workers > 0 + def get_num_of_tasks_sent_since( + self, + read_receipt: RuntimeAddress | None, + ) -> int: + """Return the number of tasks sent since the read receipt.""" + if read_receipt is None: + return sum(count for _, count in self.submit_cache) + + for i, (addr, _) in enumerate(self.submit_cache): + if addr == read_receipt: + self.submit_cache = self.submit_cache[:i] + return sum(count for _, count in self.submit_cache[1:]) + + raise RuntimeError('Read receipt not found in submit cache.') + def send_outgoing(node: ServerBase) -> None: """Outgoing thread forwards messages as they are created.""" while True: - node.logger.debug('Waiting to send outgoing message...') outgoing = node.outgoing.get() if not node.running: @@ -81,7 +101,6 @@ def send_outgoing(node: ServerBase) -> None: # while condition. break - node.logger.debug(f'Sending message {outgoing[1].name}...') outgoing[0].send((outgoing[1], outgoing[2])) node.logger.debug(f'Sent message {outgoing[1].name}.') @@ -525,7 +544,6 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: """Schedule tasks between this node's employees.""" if len(tasks) == 0: return - self.logger.info(f'Scheduling {len(tasks)} tasks with {self.num_idle_workers} idle workers.') assignments = self.assign_tasks(tasks) # for e, assignment in sorted(zip(self.employees, assignments), key=lambda x: x[0].num_idle_workers, reverse=True): @@ -539,9 +557,9 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: e.num_tasks += num_tasks e.num_idle_workers -= min(num_tasks, e.num_idle_workers) + e.submit_cache.append((assignment[0].unique_id, num_tasks)) self.num_idle_workers = sum(e.num_idle_workers for e in self.employees) - self.logger.info(f'Finished scheduling {len(tasks)} tasks with now {self.num_idle_workers} idle workers.') def send_result_down(self, result: RuntimeResult) -> None: """Send the `result` to the appropriate employee.""" @@ -568,23 +586,30 @@ def broadcast_cancel(self, addr: RuntimeAddress) -> None: for employee in self.employees: self.outgoing.put((employee.conn, RuntimeMessage.CANCEL, addr)) - def handle_waiting(self, conn: Connection, new_idle_count: int) -> None: + def handle_waiting( + self, + conn: Connection, + new_idle_count: int, + read_receipt: RuntimeAddress | None, + ) -> None: """ Record that an employee is idle with nothing to do. - There is a race condition here that is allowed. If an employee - sends a waiting message at the same time that this sends it a - task, it will still be marked waiting even though it is running - a task. We allow this for two reasons. First, the consequences are - minimal: this situation can only lead to one extra task assigned - to the worker that could otherwise go to a truly idle worker. - Second, it is unlikely in the common BQSKit workflows, which have - wide and shallow task graphs and each leaf task can require seconds - of runtime. + There is a race condition that is corrected here. If an employee + sends a waiting message at the same time that its boss sends it a + task, the boss's idle count will eventually be incorrect. To fix + this, every waiting message sent by an employee is accompanied by + a read receipt of the latest batch of tasks it has processed. The + boss can then adjust the idle count by the number of tasks sent + since the read receipt. """ - old_count = self.conn_to_employee_dict[conn].num_idle_workers - self.conn_to_employee_dict[conn].num_idle_workers = new_idle_count - self.num_idle_workers += (new_idle_count - old_count) + employee = self.conn_to_employee_dict[conn] + unaccounted_task = employee.get_num_of_tasks_sent_since(read_receipt) + adjusted_idle_count = max(new_idle_count - unaccounted_task, 0) + + old_count = employee.num_idle_workers + employee.num_idle_workers = adjusted_idle_count + self.num_idle_workers += (adjusted_idle_count - old_count) assert 0 <= self.num_idle_workers <= self.total_workers diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 3b817ef1a..878fb4847 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -182,8 +182,8 @@ def handle_message( self.handle_shutdown() elif msg == RuntimeMessage.WAITING: - num_idle = cast(int, payload) - self.handle_waiting(conn, num_idle) + num_idle, read_receipt = cast(int, payload) + self.handle_waiting(conn, num_idle, read_receipt) elif msg == RuntimeMessage.UPDATE: task_diff = cast(int, payload) diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 476f73c32..0e40a052f 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -100,8 +100,6 @@ def __init__( if only_connect: self.connect_to_workers(num_workers, worker_port) else: - print('Spawning workers...') - print(f'Number of workers: {num_workers}') self.spawn_workers(num_workers, worker_port) # Case 2: Connect to detached managers at ipports @@ -111,6 +109,9 @@ def __init__( # Track info on sent messages to reduce redundant messages: self.last_num_idle_sent_up = self.total_workers + # Track info on received messages to report read receipts: + self.most_recent_read_submit: RuntimeAddress | None = None + # Inform upstream we are starting msg = (self.upstream, RuntimeMessage.STARTED, self.total_workers) self.outgoing.put(msg) @@ -124,19 +125,19 @@ def handle_message( payload: Any, ) -> None: """Process the message coming from `direction`.""" - self.logger.debug(f'Manager handling message {msg.name} from {direction.name}.') if direction == MessageDirection.ABOVE: if msg == RuntimeMessage.SUBMIT: rtask = cast(RuntimeTask, payload) + self.most_recent_read_submit = rtask.unique_id self.schedule_tasks([rtask]) - self.update_upstream_idle_workers() + # self.update_upstream_idle_workers() elif msg == RuntimeMessage.SUBMIT_BATCH: rtasks = cast(List[RuntimeTask], payload) + self.most_recent_read_submit = rtasks[0].unique_id self.schedule_tasks(rtasks) - self.update_upstream_idle_workers() - self.logger.debug(f'Finished handling submit batch from above.') + # self.update_upstream_idle_workers() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) @@ -157,20 +158,20 @@ def handle_message( if msg == RuntimeMessage.SUBMIT: rtask = cast(RuntimeTask, payload) self.send_up_or_schedule_tasks([rtask]) - self.update_upstream_idle_workers() + # self.update_upstream_idle_workers() elif msg == RuntimeMessage.SUBMIT_BATCH: rtasks = cast(List[RuntimeTask], payload) self.send_up_or_schedule_tasks(rtasks) - self.update_upstream_idle_workers() + # self.update_upstream_idle_workers() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) self.handle_result_from_below(result) elif msg == RuntimeMessage.WAITING: - num_idle = cast(int, payload) - self.handle_waiting(conn, num_idle) + num_idle, read_receipt = cast(int, payload) + self.handle_waiting(conn, num_idle, read_receipt) self.update_upstream_idle_workers() elif msg == RuntimeMessage.UPDATE: @@ -221,6 +222,7 @@ def send_up_or_schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: if num_idle != 0: self.outgoing.put((self.upstream, RuntimeMessage.UPDATE, num_idle)) self.schedule_tasks(tasks[:num_idle]) + self.update_upstream_idle_workers() if len(tasks) > num_idle: self.outgoing.put(( @@ -248,7 +250,8 @@ def update_upstream_idle_workers(self) -> None: """Update the total number of idle workers upstream.""" if self.num_idle_workers != self.last_num_idle_sent_up: self.last_num_idle_sent_up = self.num_idle_workers - m = (self.upstream, RuntimeMessage.WAITING, self.num_idle_workers) + payload = (self.num_idle_workers, self.most_recent_read_submit) + m = (self.upstream, RuntimeMessage.WAITING, payload) self.outgoing.put(m) def handle_update(self, conn: Connection, task_diff: int) -> None: diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index 7983633ee..676f5338b 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -43,7 +43,10 @@ def __init__( """Tuple of function pointer, arguments, and keyword arguments.""" self.return_address = return_address - """Where the result of this task should be sent.""" + """ + Where the result of this task should be sent. + This doubles as a unique system-wide id for the task. + """ self.logging_level = logging_level """Logs with levels >= to this get emitted, if None always emit.""" @@ -97,6 +100,11 @@ def step(self, send_val: Any = None) -> Any: return to_return + @property + def unique_id(self) -> RuntimeAddress: + """Return the task's system-wide unique id.""" + return self.return_address + def start(self) -> None: """Initialize the task.""" self.coro = self.run() diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 1c92c5f27..15f02f267 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -15,6 +15,7 @@ from multiprocessing.connection import Connection from multiprocessing.connection import wait from threading import Thread +from threading import Lock from queue import Queue from queue import Empty from typing import Any @@ -137,17 +138,23 @@ def handle_incoming_comms(worker: Worker) -> None: return elif msg == RuntimeMessage.SUBMIT: + worker.read_receipt_mutex.acquire() task = cast(RuntimeTask, payload) + worker._most_recent_read_submit = task.unique_id worker._add_task(task) + worker.read_receipt_mutex.release() elif msg == RuntimeMessage.SUBMIT_BATCH: + worker.read_receipt_mutex.acquire() tasks = cast(List[RuntimeTask], payload) + worker._most_recent_read_submit = tasks[0].unique_id worker._add_task(tasks.pop()) # Submit one task worker._delayed_tasks.extend(tasks) # Delay rest # Delayed tasks have no context and are stored (more-or-less) # as a function pointer together with the arguments. # When it gets started, it consumes much more memory, # so we delay the task start until necessary (at no cost) + worker.read_receipt_mutex.release() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) @@ -248,6 +255,12 @@ def __init__(self, id: int, conn: Connection) -> None: self._cache: dict[str, Any] = {} """Local worker cache.""" + self.most_recent_read_submit: RuntimeAddress | None = None + """Tracks the most recently processed submit message from above.""" + + self.read_receipt_mutex = Lock() + """A lock to ensure waiting messages's read receipt is correct.""" + # Send out every client emitted log message upstream old_factory = logging.getLogRecordFactory() @@ -362,17 +375,18 @@ def _get_next_ready_task(self) -> RuntimeTask: self._add_task(self._delayed_tasks.pop()) continue + self.read_receipt_mutex.acquire() try: addr = self._ready_task_ids.get_nowait() except Empty: - # TODO: evaluate race condition here: - # If the incoming comms thread adds a task to the ready queue - # after this check, then the worker will have incorrectly - # sent a waiting message to the manager. - # TODO: consider some lock mechanism to prevent this? - self._conn.send((RuntimeMessage.WAITING, 1)) + payload = (1, self.most_recent_read_submit) + self._conn.send((RuntimeMessage.WAITING, payload)) + self.read_receipt_mutex.release() addr = self._ready_task_ids.get() + if self.read_receipt_mutex.locked(): + self.read_receipt_mutex.release() + if addr in self._cancelled_task_ids or addr not in self._tasks: # When a task is cancelled on the worker it is not removed # from the ready queue because it is much cheaper to just From 7c2edc7dc3d1433a15bef904633eb3d40123688f Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 28 Mar 2024 18:03:24 -0400 Subject: [PATCH 03/44] pre-commit-ish --- bqskit/runtime/__init__.py | 2 +- bqskit/runtime/base.py | 18 ++++++++--------- bqskit/runtime/detached.py | 4 +++- bqskit/runtime/manager.py | 4 +++- bqskit/runtime/task.py | 1 + bqskit/runtime/worker.py | 41 +++++++++++++++++++------------------- 6 files changed, 37 insertions(+), 33 deletions(-) diff --git a/bqskit/runtime/__init__.py b/bqskit/runtime/__init__.py index 3477c3572..0443671b6 100644 --- a/bqskit/runtime/__init__.py +++ b/bqskit/runtime/__init__.py @@ -111,7 +111,7 @@ os.environ['NUMEXPR_NUM_THREADS'] = '1' os.environ['VECLIB_MAXIMUM_THREADS'] = '1' os.environ['RUST_BACKTRACE'] = '1' -print("SETTING THREADS TO 1") +print('SETTING THREADS TO 1') if TYPE_CHECKING: from bqskit.runtime.future import RuntimeFuture diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 3e6bd47b2..52b50deca 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -52,8 +52,9 @@ def __init__( self.submit_cache: list[tuple[RuntimeAddress, int]] = [] """ Tracks recently submitted tasks by id and count. - This is used to adjust the idle worker count when - the employee sends a waiting message. + + This is used to adjust the idle worker count when the employee sends a + waiting message. """ def shutdown(self) -> None: @@ -595,13 +596,12 @@ def handle_waiting( """ Record that an employee is idle with nothing to do. - There is a race condition that is corrected here. If an employee - sends a waiting message at the same time that its boss sends it a - task, the boss's idle count will eventually be incorrect. To fix - this, every waiting message sent by an employee is accompanied by - a read receipt of the latest batch of tasks it has processed. The - boss can then adjust the idle count by the number of tasks sent - since the read receipt. + There is a race condition that is corrected here. If an employee sends a + waiting message at the same time that its boss sends it a task, the + boss's idle count will eventually be incorrect. To fix this, every + waiting message sent by an employee is accompanied by a read receipt of + the latest batch of tasks it has processed. The boss can then adjust the + idle count by the number of tasks sent since the read receipt. """ employee = self.conn_to_employee_dict[conn] unaccounted_task = employee.get_num_of_tasks_sent_since(read_receipt) diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 878fb4847..2bc015cc1 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -15,6 +15,7 @@ from typing import Any from typing import cast from typing import List +from typing import Optional from typing import Sequence from bqskit.compiler.status import CompilationStatus @@ -182,7 +183,8 @@ def handle_message( self.handle_shutdown() elif msg == RuntimeMessage.WAITING: - num_idle, read_receipt = cast(int, payload) + p = cast(tuple[int, Optional[RuntimeAddress]], payload) + num_idle, read_receipt = p self.handle_waiting(conn, num_idle, read_receipt) elif msg == RuntimeMessage.UPDATE: diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 0e40a052f..4e0f13c76 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -9,6 +9,7 @@ from typing import Any from typing import cast from typing import List +from typing import Optional from typing import Sequence from bqskit.runtime import default_manager_port @@ -170,7 +171,8 @@ def handle_message( self.handle_result_from_below(result) elif msg == RuntimeMessage.WAITING: - num_idle, read_receipt = cast(int, payload) + p = cast(tuple[int, Optional[RuntimeAddress]], payload) + num_idle, read_receipt = p self.handle_waiting(conn, num_idle, read_receipt) self.update_upstream_idle_workers() diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index 676f5338b..c9f582804 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -45,6 +45,7 @@ def __init__( self.return_address = return_address """ Where the result of this task should be sent. + This doubles as a unique system-wide id for the task. """ diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 15f02f267..af7d86b9c 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -13,11 +13,10 @@ from multiprocessing import Process from multiprocessing.connection import Client from multiprocessing.connection import Connection -from multiprocessing.connection import wait -from threading import Thread -from threading import Lock -from queue import Queue from queue import Empty +from queue import Queue +from threading import Lock +from threading import Thread from typing import Any from typing import Callable from typing import cast @@ -140,14 +139,14 @@ def handle_incoming_comms(worker: Worker) -> None: elif msg == RuntimeMessage.SUBMIT: worker.read_receipt_mutex.acquire() task = cast(RuntimeTask, payload) - worker._most_recent_read_submit = task.unique_id + worker.most_recent_read_submit = task.unique_id worker._add_task(task) worker.read_receipt_mutex.release() elif msg == RuntimeMessage.SUBMIT_BATCH: worker.read_receipt_mutex.acquire() tasks = cast(List[RuntimeTask], payload) - worker._most_recent_read_submit = tasks[0].unique_id + worker.most_recent_read_submit = tasks[0].unique_id worker._add_task(tasks.pop()) # Submit one task worker._delayed_tasks.extend(tasks) # Delay rest # Delayed tasks have no context and are stored (more-or-less) @@ -295,21 +294,21 @@ def _loop(self) -> None: # self._try_idle() # self._handle_comms() - def _try_idle(self) -> None: - """If there is nothing to do, wait until we receive a message.""" - empty_outgoing = len(self._outgoing) == 0 - no_ready_tasks = self._ready_task_ids.empty() - no_delayed_tasks = len(self._delayed_tasks) == 0 - - if empty_outgoing and no_ready_tasks and no_delayed_tasks: - self._conn.send((RuntimeMessage.WAITING, 1)) - wait([self._conn]) - - def _flush_outgoing_comms(self) -> None: - """Handle all outgoing messages.""" - for out_msg in self._outgoing: - self._conn.send(out_msg) - self._outgoing.clear() + # def _try_idle(self) -> None: + # """If there is nothing to do, wait until we receive a message.""" + # empty_outgoing = len(self._outgoing) == 0 + # no_ready_tasks = self._ready_task_ids.empty() + # no_delayed_tasks = len(self._delayed_tasks) == 0 + + # if empty_outgoing and no_ready_tasks and no_delayed_tasks: + # self._conn.send((RuntimeMessage.WAITING, 1)) + # wait([self._conn]) + + # def _flush_outgoing_comms(self) -> None: + # """Handle all outgoing messages.""" + # for out_msg in self._outgoing: + # self._conn.send(out_msg) + # self._outgoing.clear() def _add_task(self, task: RuntimeTask) -> None: """Start a task and add it to the loop.""" From ead3663e5b2ac749e6b87d500e743903fd601690 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Sat, 30 Mar 2024 15:54:30 -0400 Subject: [PATCH 04/44] Runtime tests passing local --- bqskit/runtime/base.py | 3 +-- bqskit/runtime/detached.py | 1 + bqskit/runtime/future.py | 4 ++-- bqskit/runtime/worker.py | 31 +++++++++++++++++++++++++++---- tests/runtime/test_attached.py | 4 ++-- tests/runtime/test_next.py | 2 +- 6 files changed, 34 insertions(+), 11 deletions(-) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 52b50deca..0a1fd0a31 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -84,7 +84,7 @@ def get_num_of_tasks_sent_since( for i, (addr, _) in enumerate(self.submit_cache): if addr == read_receipt: - self.submit_cache = self.submit_cache[:i] + self.submit_cache = self.submit_cache[i:] return sum(count for _, count in self.submit_cache[1:]) raise RuntimeError('Read receipt not found in submit cache.') @@ -375,7 +375,6 @@ def run(self) -> None: try: while self.running: # Wait for messages - self.logger.debug('Waiting for messages...') events = self.sel.select() # Say that 5 times fast for key, _ in events: diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 2bc015cc1..d86e94cfa 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -369,6 +369,7 @@ def handle_error(self, error_payload: tuple[int, str]) -> None: tid = error_payload[0] conn = self.tasks[self.mailbox_to_task_dict[tid]][1] self.outgoing.put((conn, RuntimeMessage.ERROR, error_payload[1])) + # TODO: Broadcast cancel to all tasks with compilation task id tid def handle_log(self, log_payload: tuple[int, LogRecord]) -> None: """Forward logs to appropriate client.""" diff --git a/bqskit/runtime/future.py b/bqskit/runtime/future.py index 70f6ac2cc..ab69d5911 100644 --- a/bqskit/runtime/future.py +++ b/bqskit/runtime/future.py @@ -27,8 +27,8 @@ def __await__(self) -> Any: Informs the event loop which mailbox this is waiting on. """ - if self._next_flag: - return (yield self) + # if self._next_flag: + # return (yield self) return (yield self) diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index af7d86b9c..103f9eba6 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -134,6 +134,8 @@ def handle_incoming_comms(worker: Worker) -> None: # Process message if msg == RuntimeMessage.SHUTDOWN: worker._running = False + worker._ready_task_ids.put(RuntimeAddress(-1, -1, -1)) + # TODO: Interupt main, maybe even kill it return elif msg == RuntimeMessage.SUBMIT: @@ -162,6 +164,7 @@ def handle_incoming_comms(worker: Worker) -> None: elif msg == RuntimeMessage.CANCEL: addr = cast(RuntimeAddress, payload) worker._handle_cancel(addr) + # TODO: preempt? class Worker: @@ -280,6 +283,7 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: target=handle_incoming_comms, args=(self,), ) + self.incomming_thread.daemon = True self.incomming_thread.start() # self.logger.info('Started incoming thread.') @@ -334,6 +338,7 @@ def _handle_result(self, result: RuntimeResult) -> None: if task.wake_on_next or box.ready: self._ready_task_ids.put(box.dest_addr) # Wake it + box.dest_addr = None # Prevent double wake def _handle_cancel(self, addr: RuntimeAddress) -> None: """ @@ -349,6 +354,7 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: for themselves using breadcrumbs and the original `addr` cancel message. """ + # TODO: Send update message? self._cancelled_task_ids.add(addr) # Remove all tasks that are children of `addr` from initialized tasks @@ -382,6 +388,8 @@ def _get_next_ready_task(self) -> RuntimeTask: self._conn.send((RuntimeMessage.WAITING, payload)) self.read_receipt_mutex.release() addr = self._ready_task_ids.get() + if addr == RuntimeAddress(-1, -1, -1): + return None if self.read_receipt_mutex.locked(): self.read_receipt_mutex.release() @@ -408,6 +416,9 @@ def _try_step_next_ready_task(self) -> None: """Select a task to run, and advance it one step.""" task = self._get_next_ready_task() + if task is None: + return + try: self._active_task = task @@ -451,8 +462,15 @@ def _process_await(self, task: RuntimeTask, future: RuntimeFuture) -> None: m = 'Cannot wait for next results on a complete task.' raise RuntimeError(m) task.wake_on_next = True - - elif box.ready: + # if future._next_flag: + # # Set from Worker.next, implies the task wants the next result + # # if box.ready: + # # m = 'Cannot wait for next results on a complete task.' + # # raise RuntimeError(m) + # task.wake_on_next = True + task.wake_on_next = future._next_flag + + if box.ready: self._ready_task_ids.put(task.return_address) def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: @@ -460,6 +478,10 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: assert task is self._active_task packaged_result = RuntimeResult(task.return_address, result, self._id) + if task.return_address not in self._tasks: + print(f'Task was cancelled: {task.return_address}, {task.fnargs[0].__name__}') + return + if task.return_address.worker_id == self._id: self._handle_result(packaged_result) self._conn.send((RuntimeMessage.UPDATE, -1)) @@ -491,7 +513,7 @@ def _get_desired_result(self, task: RuntimeTask) -> Any: if task.wake_on_next: fresh_results = box.get_new_results() - assert len(fresh_results) > 0 + # assert len(fresh_results) > 0 return fresh_results assert box.ready @@ -621,7 +643,8 @@ async def next(self, future: RuntimeFuture) -> list[tuple[int, Any]]: returned. Each result is paired with the index of its arguments in the original map call. """ - if future._done: + # if future._done: + if future.mailbox_id not in self._mailboxes: raise RuntimeError('Cannot wait on an already completed result.') future._next_flag = True diff --git a/tests/runtime/test_attached.py b/tests/runtime/test_attached.py index 0c2ecb67e..eb708ff9f 100644 --- a/tests/runtime/test_attached.py +++ b/tests/runtime/test_attached.py @@ -60,7 +60,7 @@ def test_create_workers(num_workers: int) -> None: compiler.close() -def test_one_thread_per_worker() -> None: +def test_two_thread_per_worker() -> None: # On windows we aren't sure how the threads are handeled if sys.platform == 'win32': return @@ -68,7 +68,7 @@ def test_one_thread_per_worker() -> None: compiler = Compiler(num_workers=1) assert compiler.p is not None assert len(psutil.Process(compiler.p.pid).children()) in [1, 2] - assert psutil.Process(compiler.p.pid).children()[0].num_threads() == 1 + assert psutil.Process(compiler.p.pid).children()[0].num_threads() == 2 compiler.close() diff --git a/tests/runtime/test_next.py b/tests/runtime/test_next.py index 30642a7d7..81c45da5a 100644 --- a/tests/runtime/test_next.py +++ b/tests/runtime/test_next.py @@ -29,7 +29,7 @@ async def run(self, circuit: Circuit, data: PassData) -> None: class TestNoDuplicateResultsInTwoNexts(BasePass): async def run(self, circuit: Circuit, data: PassData) -> None: - future = get_runtime().map(sleepi, [0.3, 0.4, 0.1, 0.2]) + future = get_runtime().map(sleepi, [0.3, 0.4, 0.1, 0.2, 5]) seen = [0] int_ids = await get_runtime().next(future) From 9ea8880bf074b1b64f905a29e32b374050e76969 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Mon, 1 Apr 2024 09:33:01 -0400 Subject: [PATCH 05/44] Removed deprecated CompilationTask feature --- bqskit/compiler/compiler.py | 105 +++++++----------------------------- bqskit/ir/circuit.py | 5 +- 2 files changed, 21 insertions(+), 89 deletions(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 55bb50a1d..6a6d233e6 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -229,20 +229,18 @@ def __del__(self) -> None: def submit( self, - task_or_circuit: CompilationTask | Circuit, - workflow: WorkflowLike | None = None, + circuit: Circuit, + workflow: WorkflowLike, request_data: bool = False, logging_level: int | None = None, max_logging_depth: int = -1, + data: dict[str, Any] | None = None, ) -> uuid.UUID: """ Submit a compilation job to the Compiler. Args: - task_or_circuit (CompilationTask | Circuit): The task to compile, - or the input circuit. If a task is specified, no other - argument should be specified. If a task is not specified, - the circuit must be paired with a workflow argument. + circuit (Circuit): The input circuit to be compiled. workflow (WorkflowLike): The compilation job submitted is defined by executing this workflow on the input circuit. @@ -267,86 +265,35 @@ def submit( the result of the task. """ # Build CompilationTask - if isinstance(task_or_circuit, CompilationTask): - if workflow is not None: - raise ValueError( - 'Cannot specify workflow and task.' - ' Either specify a workflow and circuit or a task alone.', - ) - - task = task_or_circuit - - else: - if workflow is None: - m = 'Must specify workflow when providing a circuit to submit.' - raise TypeError(m) - - task = CompilationTask(task_or_circuit, Workflow(workflow)) + task = CompilationTask(circuit, Workflow(workflow)) # Set task configuration task.request_data = request_data task.logging_level = logging_level or self._discover_lowest_log_level() task.max_logging_depth = max_logging_depth + if data is not None: + task.data = data # Submit task to runtime self._send(RuntimeMessage.SUBMIT, task) return task.task_id - def status(self, task_id: CompilationTask | uuid.UUID) -> CompilationStatus: + def status(self, task_id: uuid.UUID) -> CompilationStatus: """Retrieve the status of the specified task.""" - if isinstance(task_id, CompilationTask): - warnings.warn( - 'Request a status from a CompilationTask is deprecated.\n' - ' Instead, pass a task ID to request a status.\n' - ' `compiler.submit` returns a task id, and you can get an\n' - ' ID from a task via `task.task_id`.\n' - ' This warning will turn into an error in a future update.', - DeprecationWarning, - ) - task_id = task_id.task_id - assert isinstance(task_id, uuid.UUID) - msg, payload = self._send_recv(RuntimeMessage.STATUS, task_id) if msg != RuntimeMessage.STATUS: raise RuntimeError(f'Unexpected message type: {msg}.') return payload - def result( - self, - task_id: CompilationTask | uuid.UUID, - ) -> Circuit | tuple[Circuit, PassData]: + def result(self, task_id: uuid.UUID) -> Circuit | tuple[Circuit, PassData]: """Block until the task is finished, return its result.""" - if isinstance(task_id, CompilationTask): - warnings.warn( - 'Request a result from a CompilationTask is deprecated.' - ' Instead, pass a task ID to request a result.\n' - ' `compiler.submit` returns a task id, and you can get an\n' - ' ID from a task via `task.task_id`.\n' - ' This warning will turn into an error in a future update.', - DeprecationWarning, - ) - task_id = task_id.task_id - assert isinstance(task_id, uuid.UUID) - msg, payload = self._send_recv(RuntimeMessage.REQUEST, task_id) if msg != RuntimeMessage.RESULT: raise RuntimeError(f'Unexpected message type: {msg}.') return payload - def cancel(self, task_id: CompilationTask | uuid.UUID) -> bool: + def cancel(self, task_id: uuid.UUID) -> bool: """Cancel the execution of a task in the system.""" - if isinstance(task_id, CompilationTask): - warnings.warn( - 'Cancelling a CompilationTask is deprecated. Instead,' - ' Instead, pass a task ID to cancel a task.\n' - ' `compiler.submit` returns a task id, and you can get an\n' - ' ID from a task via `task.task_id`.\n' - ' This warning will turn into an error in a future update.', - DeprecationWarning, - ) - task_id = task_id.task_id - assert isinstance(task_id, uuid.UUID) - msg, _ = self._send_recv(RuntimeMessage.CANCEL, task_id) if msg != RuntimeMessage.CANCEL: raise RuntimeError(f'Unexpected message type: {msg}.') @@ -355,63 +302,51 @@ def cancel(self, task_id: CompilationTask | uuid.UUID) -> bool: @overload def compile( self, - task_or_circuit: CompilationTask, - ) -> Circuit | tuple[Circuit, PassData]: - ... - - @overload - def compile( - self, - task_or_circuit: Circuit, + circuit: Circuit, workflow: WorkflowLike, request_data: Literal[False] = ..., logging_level: int | None = ..., max_logging_depth: int = ..., + data: dict[str, Any] | None = ..., ) -> Circuit: ... @overload def compile( self, - task_or_circuit: Circuit, + circuit: Circuit, workflow: WorkflowLike, request_data: Literal[True], logging_level: int | None = ..., max_logging_depth: int = ..., + data: dict[str, Any] | None = ..., ) -> tuple[Circuit, PassData]: ... @overload def compile( self, - task_or_circuit: Circuit, + circuit: Circuit, workflow: WorkflowLike, request_data: bool, logging_level: int | None = ..., max_logging_depth: int = ..., + data: dict[str, Any] | None = ..., ) -> Circuit | tuple[Circuit, PassData]: ... def compile( self, - task_or_circuit: CompilationTask | Circuit, - workflow: WorkflowLike | None = None, + circuit: Circuit, + workflow: WorkflowLike, request_data: bool = False, logging_level: int | None = None, max_logging_depth: int = -1, + data: dict[str, Any] | None = None, ) -> Circuit | tuple[Circuit, PassData]: """Submit a task, wait for its results; see :func:`submit` for more.""" - if isinstance(task_or_circuit, CompilationTask): - warnings.warn( - 'Manually constructing and compiling CompilationTasks' - ' is deprecated. Instead, call compile directly with' - ' your input circuit and workflow. This warning will' - ' turn into an error in a future update.', - DeprecationWarning, - ) - task_id = self.submit( - task_or_circuit, + circuit, workflow, request_data, logging_level, diff --git a/bqskit/ir/circuit.py b/bqskit/ir/circuit.py index d58b299a7..49e53cdf5 100644 --- a/bqskit/ir/circuit.py +++ b/bqskit/ir/circuit.py @@ -2718,16 +2718,13 @@ def perform( """ from bqskit.compiler.compiler import Compiler from bqskit.compiler.passdata import PassData - from bqskit.compiler.task import CompilationTask pass_data = PassData(self) if data is not None: pass_data.update(data) with Compiler() as compiler: - task = CompilationTask(self, [compiler_pass]) - task.data = pass_data - task_id = compiler.submit(task) + task_id = compiler.submit(self, [compiler_pass], data=pass_data) self.become(compiler.result(task_id)) # type: ignore def instantiate( From 3dd5a245bf0c70a465c904046f19d55436830fc0 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Mon, 1 Apr 2024 10:15:20 -0400 Subject: [PATCH 06/44] Fix bug in workflow copy constructor --- bqskit/compiler/workflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bqskit/compiler/workflow.py b/bqskit/compiler/workflow.py index 6134d07aa..3c0cdaf4f 100644 --- a/bqskit/compiler/workflow.py +++ b/bqskit/compiler/workflow.py @@ -39,6 +39,7 @@ def __init__(self, passes: WorkflowLike, name: str = '') -> None: """ if isinstance(passes, Workflow): self._passes: list[BasePass] = copy.deepcopy(passes._passes) + self._name = copy.deepcopy(passes._name) if name == '' else name return if isinstance(passes, BasePass): From 4261294e6fc7c678482cf38d56f1a2d4bae31808 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Mon, 1 Apr 2024 10:18:59 -0400 Subject: [PATCH 07/44] Better circuit and workflow serialization with dill --- bqskit/compiler/workflow.py | 7 +++ bqskit/ir/circuit.py | 117 +++++++++++++++++++++++++++--------- setup.py | 1 + 3 files changed, 98 insertions(+), 27 deletions(-) diff --git a/bqskit/compiler/workflow.py b/bqskit/compiler/workflow.py index 3c0cdaf4f..65399ff87 100644 --- a/bqskit/compiler/workflow.py +++ b/bqskit/compiler/workflow.py @@ -3,6 +3,7 @@ import copy import logging +import dill from typing import Iterable from typing import Iterator from typing import overload @@ -119,5 +120,11 @@ def __getitem__(self, _key: slice, /) -> list[BasePass]: def __getitem__(self, _key: int | slice) -> BasePass | list[BasePass]: return self._passes.__getitem__(_key) + def __getstate__(self) -> bytes: + return dill.dumps(self.__dict__, recurse=True) + + def __setstate__(self, state: bytes) -> None: + self.__dict__.update(dill.loads(state)) + WorkflowLike = Union[Workflow, Iterable[BasePass], BasePass] diff --git a/bqskit/ir/circuit.py b/bqskit/ir/circuit.py index 49e53cdf5..2fd69f846 100644 --- a/bqskit/ir/circuit.py +++ b/bqskit/ir/circuit.py @@ -4,6 +4,8 @@ import copy import logging import warnings +import pickle +import dill from typing import Any from typing import cast from typing import Collection @@ -1035,33 +1037,8 @@ def point( raise ValueError('No such operation exists in the circuit.') - def append(self, op: Operation) -> int: - """ - Append `op` to the end of the circuit and return its cycle index. - - Args: - op (Operation): The operation to append. - - Returns: - int: The cycle index of the appended operation. - - Raises: - ValueError: If `op` cannot be placed on the circuit due to - either an invalid location or gate radix mismatch. - - Notes: - Due to the circuit being represented as a matrix, - `circuit.append(op)` does not imply `op` is last in simulation - order but it implies `op` is in the last cycle of circuit. - - Examples: - >>> from bqskit.ir.gates import HGate - >>> circ = Circuit(1) - >>> op = Operation(HGate(), [0]) - >>> circ.append(op) # Appends a Hadamard gate to qudit 0. - """ - self.check_valid_operation(op) - cycle_index = self._find_available_or_append_cycle(op.location) + def _append(self, op: Operation, cycle_index: int) -> None: + """Append the operation to the circuit at the specified cycle.""" point = CircuitPoint(cycle_index, op.location[0]) prevs: dict[int, CircuitPoint | None] = {i: None for i in op.location} @@ -1096,6 +1073,34 @@ def append(self, op: Operation) -> int: self._gate_info[op.gate] = 0 self._gate_info[op.gate] += 1 + def append(self, op: Operation) -> int: + """ + Append `op` to the end of the circuit and return its cycle index. + + Args: + op (Operation): The operation to append. + + Returns: + int: The cycle index of the appended operation. + + Raises: + ValueError: If `op` cannot be placed on the circuit due to + either an invalid location or gate radix mismatch. + + Notes: + Due to the circuit being represented as a matrix, + `circuit.append(op)` does not imply `op` is last in simulation + order but it implies `op` is in the last cycle of circuit. + + Examples: + >>> from bqskit.ir.gates import HGate + >>> circ = Circuit(1) + >>> op = Operation(HGate(), [0]) + >>> circ.append(op) # Appends a Hadamard gate to qudit 0. + """ + self.check_valid_operation(op) + cycle_index = self._find_available_or_append_cycle(op.location) + self._append(op, cycle_index) return cycle_index def append_gate( @@ -3238,4 +3243,62 @@ def from_operation(op: Operation) -> Circuit: circuit.append_gate(op.gate, list(range(circuit.num_qudits)), op.params) return circuit + def __reduce__(self): + """Return the pickle state of the circuit.""" + serialized_gates = [] + gate_table = {} + for gate in self.gate_set: + gate_table[gate] = len(serialized_gates) + if gate.__class__.__module__.startswith('bqskit'): + serialized_gates.append((False, pickle.dumps(gate))) + else: + serialized_gates.append((True, dill.dumps(gate, recurse=True))) + + cycles = [] + last_cycle = -1 + for cycle, op in self.operations_with_cycles(): + + if cycle != last_cycle: + last_cycle = cycle + cycles.append([]) + + marshalled_op = ( + gate_table[op.gate], + op.location._location, + op.params + ) + cycles[-1].append(marshalled_op) + + data = ( + self.num_qudits, + self.radixes, + serialized_gates, + pickle.dumps(cycles), + ) + return (rebuild_circuit, data) + # endregion + + +def rebuild_circuit(num_qudits, radixes, serialized_gates, serialized_cycles) -> Circuit: + """Rebuild a circuit from a pickle state.""" + circuit = Circuit(num_qudits, radixes) + + gate_table = {} + for i, (is_dill, serialized_gate) in enumerate(serialized_gates): + if is_dill: + gate = dill.loads(serialized_gate) + else: + gate = pickle.loads(serialized_gate) + gate_table[i] = gate + + cycles = pickle.loads(serialized_cycles) + for i, cycle in enumerate(cycles): + circuit._append_cycle() + for marshalled_op in cycle: + gate = gate_table[marshalled_op[0]] + location = marshalled_op[1] + params = marshalled_op[2] + circuit._append(Operation(gate, location, params), i) + + return circuit diff --git a/setup.py b/setup.py index 935ba74ad..298874919 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ 'numpy>=1.22.0', 'scipy>=1.8.0', 'typing-extensions>=4.0.0', + 'dill>=0.3.8' ], python_requires='>=3.8, <4', entry_points={ From 99fce350121f3b9c1631d4d9ac7918489b69e658 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Mon, 1 Apr 2024 10:19:48 -0400 Subject: [PATCH 08/44] Faster node shutdown procedure --- bqskit/runtime/base.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 0a1fd0a31..3a7072273 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -57,19 +57,26 @@ def __init__( waiting message. """ - def shutdown(self) -> None: - """Shutdown the employee.""" + def initiate_shutdown(self) -> None: + """Instruct employee to shutdown.""" try: self.conn.send((RuntimeMessage.SHUTDOWN, None)) except Exception: pass + def complete_shutdown(self) -> None: + """Ensure employee is shutdown and clean up resources.""" if self.process is not None: self.process.join() self.process = None self.conn.close() + def shutdown(self) -> None: + """Initiate and complete shutdown.""" + self.initiate_shutdown() + self.complete_shutdown() + @property def has_idle_resources(self) -> bool: return self.num_idle_workers > 0 @@ -269,6 +276,7 @@ def spawn_workers( for i in range(num_workers): w_id = self.lower_id_bound + i procs[w_id] = Process(target=start_worker, args=(w_id, port)) + procs[w_id].daemon = True procs[w_id].start() self.logger.debug(f'Stated worker process {i}.') @@ -451,7 +459,11 @@ def handle_shutdown(self) -> None: # Instruct employees to shutdown for employee in self.employees: - employee.shutdown() + employee.initiate_shutdown() + + for employee in self.employees: + employee.complete_shutdown() + self.employees.clear() self.logger.debug('Shutdown employees.') From d4d6122570d33aaca0b114ecadb2e06fb056137b Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Mon, 1 Apr 2024 10:32:35 -0400 Subject: [PATCH 09/44] Better and lazy de/serialization of RuntimeTasks --- bqskit/runtime/task.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index c9f582804..f36d540dc 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -3,6 +3,7 @@ import inspect import logging +import dill from typing import Any from typing import Coroutine @@ -39,7 +40,9 @@ def __init__( RuntimeTask.task_counter += 1 self.task_id = RuntimeTask.task_counter - self.fnargs = fnargs + self.serialized_fnargs = dill.dumps(fnargs) + self._fnargs = None + self._name = fnargs[0].__name__ """Tuple of function pointer, arguments, and keyword arguments.""" self.return_address = return_address @@ -76,6 +79,13 @@ def __init__( self.wake_on_next: bool = False """Set to true if this task should wake immediately on a result.""" + @property + def fnargs(self) -> tuple[Any, Any, Any]: + """Return the function pointer, arguments, and keyword arguments.""" + if self._fnargs is None: + self._fnargs = dill.loads(self.serialized_fnargs) + return self._fnargs + def step(self, send_val: Any = None) -> Any: """Execute one step of the task.""" if self.coro is None: @@ -122,8 +132,8 @@ def is_descendant_of(self, addr: RuntimeAddress) -> bool: def __str__(self) -> str: """Return a string representation of the task.""" - return f'{self.fnargs[0].__name__}' + return f'{self._name}' def __repr__(self) -> str: """Return a string representation of the task.""" - return f'' + return f'' From 66d38da0008ffa0fc7989413ae4f74e3f398cf1d Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Mon, 1 Apr 2024 10:41:09 -0400 Subject: [PATCH 10/44] Somewhat better error handling in worker --- bqskit/runtime/worker.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 103f9eba6..0b9066b90 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -129,16 +129,24 @@ def handle_incoming_comms(worker: Worker) -> None: """Handle all incoming messages.""" while True: # Handle incomming communication - msg, payload = worker._conn.recv() + try: + msg, payload = worker._conn.recv() + except Exception: + print(f'Worker {worker._id} crashed due to lost connection') + worker._running = False + worker._ready_task_ids.put(RuntimeAddress(-1, -1, -1)) + break # Process message if msg == RuntimeMessage.SHUTDOWN: + print(f'Worker {worker._id} received shutdown message') worker._running = False worker._ready_task_ids.put(RuntimeAddress(-1, -1, -1)) # TODO: Interupt main, maybe even kill it - return + break elif msg == RuntimeMessage.SUBMIT: + # print('Worker received submit message') worker.read_receipt_mutex.acquire() task = cast(RuntimeTask, payload) worker.most_recent_read_submit = task.unique_id @@ -294,7 +302,16 @@ def _loop(self) -> None: """Main worker event loop.""" self._running = True while self._running: - self._try_step_next_ready_task() + try: + self._try_step_next_ready_task() + except Exception: + self._running = False + exc_info = sys.exc_info() + error_str = ''.join(traceback.format_exception(*exc_info)) + try: + self._conn.send((RuntimeMessage.ERROR, error_str)) + except Exception: + pass # self._try_idle() # self._handle_comms() @@ -383,17 +400,20 @@ def _get_next_ready_task(self) -> RuntimeTask: self.read_receipt_mutex.acquire() try: addr = self._ready_task_ids.get_nowait() + except Empty: payload = (1, self.most_recent_read_submit) self._conn.send((RuntimeMessage.WAITING, payload)) self.read_receipt_mutex.release() addr = self._ready_task_ids.get() - if addr == RuntimeAddress(-1, -1, -1): - return None - if self.read_receipt_mutex.locked(): + else: self.read_receipt_mutex.release() + # Handle a shutdown request that occured while waiting + if not self._running: + return None + if addr in self._cancelled_task_ids or addr not in self._tasks: # When a task is cancelled on the worker it is not removed # from the ready queue because it is much cheaper to just @@ -456,12 +476,6 @@ def _process_await(self, task: RuntimeTask, future: RuntimeFuture) -> None: box.dest_addr = task.return_address task.desired_box_id = future.mailbox_id - if future._next_flag: - # Set from Worker.next, implies the task wants the next result - if box.ready: - m = 'Cannot wait for next results on a complete task.' - raise RuntimeError(m) - task.wake_on_next = True # if future._next_flag: # # Set from Worker.next, implies the task wants the next result # # if box.ready: From 72c073dc815d1e112da02eaa969b72955423b033 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Mon, 1 Apr 2024 10:42:31 -0400 Subject: [PATCH 11/44] Export sys.path on client connection --- bqskit/compiler/compiler.py | 3 +-- bqskit/runtime/detached.py | 16 ++++++++++++++-- bqskit/runtime/manager.py | 7 +++++++ bqskit/runtime/message.py | 1 + bqskit/runtime/worker.py | 4 ++++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 6a6d233e6..778521d00 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -9,7 +9,6 @@ import sys import time import uuid -import warnings from multiprocessing.connection import Client from multiprocessing.connection import Connection from subprocess import Popen @@ -149,7 +148,7 @@ def _connect_to_server(self, ip: str, port: int) -> None: self.old_signal = signal.signal(signal.SIGINT, handle) if self.conn is None: raise RuntimeError('Connection unexpectedly none.') - self.conn.send((RuntimeMessage.CONNECT, None)) + self.conn.send((RuntimeMessage.CONNECT, sys.path)) _logger.debug('Successfully connected to runtime server.') return raise RuntimeError('Client connection refused') diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index d86e94cfa..02d2bf25b 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -30,7 +30,6 @@ from bqskit.runtime.result import RuntimeResult from bqskit.runtime.task import RuntimeTask - def listen(server: DetachedServer, port: int) -> None: """Listening thread listens for client connections.""" listener = Listener(('0.0.0.0', port)) @@ -131,7 +130,15 @@ def handle_message( if direction == MessageDirection.CLIENT: if msg == RuntimeMessage.CONNECT: - pass + # paths, serialized_defintions = cast(List[str], payload) + paths = cast(List[str], payload) + import sys + for path in paths: + if path not in sys.path: + sys.path.append(path) + for employee in self.employees: + employee.conn.send((RuntimeMessage.IMPORTPATH, path)) + elif msg == RuntimeMessage.DISCONNECT: self.handle_disconnect(conn) @@ -370,6 +377,11 @@ def handle_error(self, error_payload: tuple[int, str]) -> None: conn = self.tasks[self.mailbox_to_task_dict[tid]][1] self.outgoing.put((conn, RuntimeMessage.ERROR, error_payload[1])) # TODO: Broadcast cancel to all tasks with compilation task id tid + # But avoid double broadcasting it. If the client crashes due to + # this error, which it may not, then we will quickly process + # a handle_disconnect and call the cancel anyways. We should + # still cancel here incase the client catches the error and + # resubmits a job. def handle_log(self, log_payload: tuple[int, LogRecord]) -> None: """Forward logs to appropriate client.""" diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 4e0f13c76..974ddbae0 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -151,6 +151,13 @@ def handle_message( elif msg == RuntimeMessage.SHUTDOWN: self.handle_shutdown() + elif msg == RuntimeMessage.IMPORTPATH: + import_path = cast(str, payload) + import sys + sys.path.append(import_path) + for employee in self.employees: + employee.conn.send((RuntimeMessage.IMPORTPATH, import_path)) + else: raise RuntimeError(f'Unexpected message type: {msg.name}') diff --git a/bqskit/runtime/message.py b/bqskit/runtime/message.py index 63f687048..aed9f6cfa 100644 --- a/bqskit/runtime/message.py +++ b/bqskit/runtime/message.py @@ -20,3 +20,4 @@ class RuntimeMessage(IntEnum): CANCEL = 11 WAITING = 12 UPDATE = 13 + IMPORTPATH = 14 diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 0b9066b90..575d10864 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -174,6 +174,10 @@ def handle_incoming_comms(worker: Worker) -> None: worker._handle_cancel(addr) # TODO: preempt? + elif msg == RuntimeMessage.IMPORTPATH: + import_path = cast(str, payload) + sys.path.append(import_path) + class Worker: """ From 67871008cfe8f5263c9a5d3fae2a67c9d10e56cd Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 2 Apr 2024 08:29:28 -0400 Subject: [PATCH 12/44] pre-commit (ish) --- bqskit/compiler/compiler.py | 13 +++++++------ bqskit/compiler/workflow.py | 5 +++-- bqskit/ir/circuit.py | 26 +++++++++++++++++++------- bqskit/runtime/base.py | 11 +++++++---- bqskit/runtime/detached.py | 6 ++++-- bqskit/runtime/task.py | 6 ++++-- bqskit/runtime/worker.py | 12 +++++++++--- setup.py | 2 +- 8 files changed, 54 insertions(+), 27 deletions(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 778521d00..b306004dd 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -14,6 +14,7 @@ from subprocess import Popen from types import FrameType from typing import Literal +from typing import MutableMapping from typing import overload from typing import TYPE_CHECKING @@ -233,7 +234,7 @@ def submit( request_data: bool = False, logging_level: int | None = None, max_logging_depth: int = -1, - data: dict[str, Any] | None = None, + data: MutableMapping[str, Any] | None = None, ) -> uuid.UUID: """ Submit a compilation job to the Compiler. @@ -271,7 +272,7 @@ def submit( task.logging_level = logging_level or self._discover_lowest_log_level() task.max_logging_depth = max_logging_depth if data is not None: - task.data = data + task.data.update(data) # Submit task to runtime self._send(RuntimeMessage.SUBMIT, task) @@ -306,7 +307,7 @@ def compile( request_data: Literal[False] = ..., logging_level: int | None = ..., max_logging_depth: int = ..., - data: dict[str, Any] | None = ..., + data: MutableMapping[str, Any] | None = ..., ) -> Circuit: ... @@ -318,7 +319,7 @@ def compile( request_data: Literal[True], logging_level: int | None = ..., max_logging_depth: int = ..., - data: dict[str, Any] | None = ..., + data: MutableMapping[str, Any] | None = ..., ) -> tuple[Circuit, PassData]: ... @@ -330,7 +331,7 @@ def compile( request_data: bool, logging_level: int | None = ..., max_logging_depth: int = ..., - data: dict[str, Any] | None = ..., + data: MutableMapping[str, Any] | None = ..., ) -> Circuit | tuple[Circuit, PassData]: ... @@ -341,7 +342,7 @@ def compile( request_data: bool = False, logging_level: int | None = None, max_logging_depth: int = -1, - data: dict[str, Any] | None = None, + data: MutableMapping[str, Any] | None = None, ) -> Circuit | tuple[Circuit, PassData]: """Submit a task, wait for its results; see :func:`submit` for more.""" task_id = self.submit( diff --git a/bqskit/compiler/workflow.py b/bqskit/compiler/workflow.py index 65399ff87..9a0ad2677 100644 --- a/bqskit/compiler/workflow.py +++ b/bqskit/compiler/workflow.py @@ -3,7 +3,6 @@ import copy import logging -import dill from typing import Iterable from typing import Iterator from typing import overload @@ -11,6 +10,8 @@ from typing import TYPE_CHECKING from typing import Union +import dill + from bqskit.compiler.basepass import BasePass from bqskit.utils.random import seed_random_sources from bqskit.utils.typing import is_iterable @@ -40,7 +41,7 @@ def __init__(self, passes: WorkflowLike, name: str = '') -> None: """ if isinstance(passes, Workflow): self._passes: list[BasePass] = copy.deepcopy(passes._passes) - self._name = copy.deepcopy(passes._name) if name == '' else name + self._name: str = name if name else copy.deepcopy(passes._name) return if isinstance(passes, BasePass): diff --git a/bqskit/ir/circuit.py b/bqskit/ir/circuit.py index 2fd69f846..1c628125c 100644 --- a/bqskit/ir/circuit.py +++ b/bqskit/ir/circuit.py @@ -3,10 +3,10 @@ import copy import logging -import warnings import pickle -import dill +import warnings from typing import Any +from typing import Callable from typing import cast from typing import Collection from typing import Dict @@ -20,6 +20,7 @@ from typing import Tuple from typing import TYPE_CHECKING +import dill import numpy as np import numpy.typing as npt @@ -3243,9 +3244,15 @@ def from_operation(op: Operation) -> Circuit: circuit.append_gate(op.gate, list(range(circuit.num_qudits)), op.params) return circuit - def __reduce__(self): + def __reduce__(self) -> tuple[ + Callable[ + [int, tuple[int, ...], list[tuple[bool, bytes]], bytes], + Circuit, + ], + tuple[int, tuple[int, ...], list[tuple[bool, bytes]], bytes], + ]: """Return the pickle state of the circuit.""" - serialized_gates = [] + serialized_gates: list[tuple[bool, bytes]] = [] gate_table = {} for gate in self.gate_set: gate_table[gate] = len(serialized_gates) @@ -3254,7 +3261,7 @@ def __reduce__(self): else: serialized_gates.append((True, dill.dumps(gate, recurse=True))) - cycles = [] + cycles: list[list[tuple[int, tuple[int, ...], list[float]]]] = [] last_cycle = -1 for cycle, op in self.operations_with_cycles(): @@ -3265,7 +3272,7 @@ def __reduce__(self): marshalled_op = ( gate_table[op.gate], op.location._location, - op.params + op.params, ) cycles[-1].append(marshalled_op) @@ -3280,7 +3287,12 @@ def __reduce__(self): # endregion -def rebuild_circuit(num_qudits, radixes, serialized_gates, serialized_cycles) -> Circuit: +def rebuild_circuit( + num_qudits: int, + radixes: tuple[int, ...], + serialized_gates: list[tuple[bool, bytes]], + serialized_cycles: bytes, +) -> Circuit: """Rebuild a circuit from a pickle state.""" circuit = Circuit(num_qudits, radixes) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 3a7072273..f4897b5ec 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -556,10 +556,13 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: """Schedule tasks between this node's employees.""" if len(tasks) == 0: return - assignments = self.assign_tasks(tasks) - - # for e, assignment in sorted(zip(self.employees, assignments), key=lambda x: x[0].num_idle_workers, reverse=True): - for e, assignment in zip(self.employees, assignments): + assignments = zip(self.employees, self.assign_tasks(tasks)) + sorted_assignments = sorted( + assignments, + key=lambda x: x[0].num_idle_workers, + reverse=True, + ) + for e, assignment in sorted_assignments: num_tasks = len(assignment) if num_tasks == 0: diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 02d2bf25b..1c0a9fdf6 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -30,6 +30,7 @@ from bqskit.runtime.result import RuntimeResult from bqskit.runtime.task import RuntimeTask + def listen(server: DetachedServer, port: int) -> None: """Listening thread listens for client connections.""" listener = Listener(('0.0.0.0', port)) @@ -137,8 +138,9 @@ def handle_message( if path not in sys.path: sys.path.append(path) for employee in self.employees: - employee.conn.send((RuntimeMessage.IMPORTPATH, path)) - + employee.conn.send( + (RuntimeMessage.IMPORTPATH, path), + ) elif msg == RuntimeMessage.DISCONNECT: self.handle_disconnect(conn) diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index f36d540dc..ccffa3b9a 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -3,10 +3,11 @@ import inspect import logging -import dill from typing import Any from typing import Coroutine +import dill + from bqskit.runtime.address import RuntimeAddress @@ -41,7 +42,7 @@ def __init__( self.task_id = RuntimeTask.task_counter self.serialized_fnargs = dill.dumps(fnargs) - self._fnargs = None + self._fnargs: tuple[Any, Any, Any] | None = None self._name = fnargs[0].__name__ """Tuple of function pointer, arguments, and keyword arguments.""" @@ -84,6 +85,7 @@ def fnargs(self) -> tuple[Any, Any, Any]: """Return the function pointer, arguments, and keyword arguments.""" if self._fnargs is None: self._fnargs = dill.loads(self.serialized_fnargs) + assert self._fnargs is not None # for type checker return self._fnargs def step(self, send_val: Any = None) -> Any: diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 575d10864..7c4c8434d 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -358,8 +358,11 @@ def _handle_result(self, result: RuntimeResult) -> None: task = self._tasks[box.dest_addr] if task.wake_on_next or box.ready: + # print(f'Worker {self._id} is waking task + # {task.return_address}, with {task.wake_on_next=}, + # {box.ready=}') self._ready_task_ids.put(box.dest_addr) # Wake it - box.dest_addr = None # Prevent double wake + box.dest_addr = None # Prevent double wake def _handle_cancel(self, addr: RuntimeAddress) -> None: """ @@ -394,7 +397,7 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: if not t.is_descendant_of(addr) ] - def _get_next_ready_task(self) -> RuntimeTask: + def _get_next_ready_task(self) -> RuntimeTask | None: """Return the next ready task if one exists, otherwise block.""" while True: if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: @@ -487,6 +490,8 @@ def _process_await(self, task: RuntimeTask, future: RuntimeFuture) -> None: # # raise RuntimeError(m) # task.wake_on_next = True task.wake_on_next = future._next_flag + # print(f'Worker {self._id} is waiting on task + # {task.return_address}, with {task.wake_on_next=}') if box.ready: self._ready_task_ids.put(task.return_address) @@ -497,7 +502,8 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: packaged_result = RuntimeResult(task.return_address, result, self._id) if task.return_address not in self._tasks: - print(f'Task was cancelled: {task.return_address}, {task.fnargs[0].__name__}') + # print(f'Task was cancelled: {task.return_address}, + # {task.fnargs[0].__name__}') return if task.return_address.worker_id == self._id: diff --git a/setup.py b/setup.py index 298874919..a546d6267 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ 'numpy>=1.22.0', 'scipy>=1.8.0', 'typing-extensions>=4.0.0', - 'dill>=0.3.8' + 'dill>=0.3.8', ], python_requires='>=3.8, <4', entry_points={ From 886a034b28851bb154787e5fbb67b298033f70df Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 2 Apr 2024 10:50:06 -0400 Subject: [PATCH 13/44] Avoid unnecessary circuit evals in data update --- bqskit/compiler/passdata.py | 18 ++++++++++++++++++ bqskit/ir/circuit.py | 7 +------ tests/compiler/test_data.py | 6 ++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/bqskit/compiler/passdata.py b/bqskit/compiler/passdata.py index 160d4f44e..23d584712 100644 --- a/bqskit/compiler/passdata.py +++ b/bqskit/compiler/passdata.py @@ -252,6 +252,24 @@ def __contains__(self, _o: object) -> bool: in_data = self._data.__contains__(_o) return in_resv or in_data + def update(self, other: Any = (), /, **kwds: Any) -> None: + """Update the data with key-values pairs from `other` and `kwds`.""" + if isinstance(other, PassData): + for key in other: + # Handle target specially to avoid circuit evaluation + if key == 'target': + self._target = other._target + continue + + self[key] = other[key] + + for key, value in kwds.items(): + self[key] = value + + return + + super().update(other, **kwds) + def copy(self) -> PassData: """Returns a deep copy of the data.""" return copy.deepcopy(self) diff --git a/bqskit/ir/circuit.py b/bqskit/ir/circuit.py index 1c628125c..2f1e8758b 100644 --- a/bqskit/ir/circuit.py +++ b/bqskit/ir/circuit.py @@ -2723,14 +2723,9 @@ def perform( :class:`~bqskit.compiler.compiler.Compiler` directly. """ from bqskit.compiler.compiler import Compiler - from bqskit.compiler.passdata import PassData - - pass_data = PassData(self) - if data is not None: - pass_data.update(data) with Compiler() as compiler: - task_id = compiler.submit(self, [compiler_pass], data=pass_data) + task_id = compiler.submit(self, [compiler_pass], data=data) self.become(compiler.result(task_id)) # type: ignore def instantiate( diff --git a/tests/compiler/test_data.py b/tests/compiler/test_data.py index 075a95c82..934215db6 100644 --- a/tests/compiler/test_data.py +++ b/tests/compiler/test_data.py @@ -26,3 +26,9 @@ def test_update_error_mul() -> None: assert data.error == 0.75 data.update_error_mul(0.5) assert data.error == 0.875 + + +def test_target_doesnt_get_expanded_on_update() -> None: + data = PassData(Circuit(64)) + data2 = PassData(Circuit(64)) + data.update(data2) # Should not crash From 9942a0a57e2e3a8298dfac71e70e789320e4a7c1 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 2 Apr 2024 15:46:02 -0400 Subject: [PATCH 14/44] Mock dill for docs --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index cfd238319..f48262b93 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -100,7 +100,7 @@ 'pytket', 'cirq', 'qutip', - 'qiskit', + 'dill', ] nbsphinx_allow_errors = True nbsphinx_execute = 'never' From 50f3ad8191e9a8f9c14b5c17d3c92af95bf32120 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 2 Apr 2024 15:50:02 -0400 Subject: [PATCH 15/44] Fixed python 3.8 type issue --- bqskit/runtime/detached.py | 3 ++- bqskit/runtime/manager.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 1c0a9fdf6..607af21ed 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -17,6 +17,7 @@ from typing import List from typing import Optional from typing import Sequence +from typing import Tuple from bqskit.compiler.status import CompilationStatus from bqskit.compiler.task import CompilationTask @@ -192,7 +193,7 @@ def handle_message( self.handle_shutdown() elif msg == RuntimeMessage.WAITING: - p = cast(tuple[int, Optional[RuntimeAddress]], payload) + p = cast(Tuple[int, Optional[RuntimeAddress]], payload) num_idle, read_receipt = p self.handle_waiting(conn, num_idle, read_receipt) diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 974ddbae0..c8725bf1c 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -11,6 +11,7 @@ from typing import List from typing import Optional from typing import Sequence +from typing import Tuple from bqskit.runtime import default_manager_port from bqskit.runtime import default_worker_port @@ -178,7 +179,7 @@ def handle_message( self.handle_result_from_below(result) elif msg == RuntimeMessage.WAITING: - p = cast(tuple[int, Optional[RuntimeAddress]], payload) + p = cast(Tuple[int, Optional[RuntimeAddress]], payload) num_idle, read_receipt = p self.handle_waiting(conn, num_idle, read_receipt) self.update_upstream_idle_workers() From a47be9476a6e689eaaeb5d2dd848a5ebf30f3485 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 2 Apr 2024 16:29:15 -0400 Subject: [PATCH 16/44] Test mac through CI :( --- tests/runtime/test_attached.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/runtime/test_attached.py b/tests/runtime/test_attached.py index eb708ff9f..3996007fe 100644 --- a/tests/runtime/test_attached.py +++ b/tests/runtime/test_attached.py @@ -68,6 +68,8 @@ def test_two_thread_per_worker() -> None: compiler = Compiler(num_workers=1) assert compiler.p is not None assert len(psutil.Process(compiler.p.pid).children()) in [1, 2] + if sys.platform == 'darwin': + print(psutil.Process(compiler.p.pid).children()[0].threads()) assert psutil.Process(compiler.p.pid).children()[0].num_threads() == 2 compiler.close() From 363c7e5818cac9ed341b6abf3c1b29ae922621e1 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Wed, 3 Apr 2024 07:22:24 -0400 Subject: [PATCH 17/44] Lazy imports in __init__ for runtime thread ctrl --- bqskit/__init__.py | 100 +++++------------------------ bqskit/_logging.py | 36 +++++++++++ bqskit/{version.py => _version.py} | 0 bqskit/compiler/compile.py | 7 +- bqskit/ir/__init__.py | 7 ++ setup.py | 2 +- 6 files changed, 64 insertions(+), 88 deletions(-) create mode 100644 bqskit/_logging.py rename bqskit/{version.py => _version.py} (100%) diff --git a/bqskit/__init__.py b/bqskit/__init__.py index da01938ba..83ca808bb 100644 --- a/bqskit/__init__.py +++ b/bqskit/__init__.py @@ -1,92 +1,29 @@ """The Berkeley Quantum Synthesis Toolkit Python Package.""" from __future__ import annotations -import logging -from sys import stdout as _stdout +from typing import Any -import bqskit.runtime -from .version import __version__ # noqa: F401 -from .version import __version_info__ # noqa: F401 -from bqskit.compiler.compile import compile -from bqskit.compiler.machine import MachineModel -from bqskit.ir.circuit import Circuit -from bqskit.ir.lang import register_language as _register_language -from bqskit.ir.lang.qasm2 import OPENQASM2Language as _qasm +from bqskit._logging import disable_logging +from bqskit._logging import enable_logging +from bqskit._version import __version__ # noqa: F401 +from bqskit._version import __version_info__ # noqa: F401 -# Initialize Logging -_logging_initialized = False +def __getattr__(name: str) -> Any: + # Lazy imports + if name == 'compile': + from bqskit.compiler.compile import compile + return compile -def enable_logging(verbose: bool = False) -> None: - """ - Enable logging for BQSKit. + if name == 'Circuit': + from bqskit.ir.circuit import Circuit + return Circuit - Args: - verbose (bool): If set to True, will print more verbose messages. - Defaults to False. - """ - global _logging_initialized - if not _logging_initialized: - _logger = logging.getLogger('bqskit') - _handler = logging.StreamHandler(_stdout) - _handler.setLevel(0) - _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' - _fmt_message = ' %(name)s: %(message)s' - _fmt = _fmt_header + _fmt_message - _formatter = logging.Formatter(_fmt, '%H:%M:%S') - _handler.setFormatter(_formatter) - _logger.addHandler(_handler) - _logging_initialized = True + if name == 'MachineModel': + from bqskit.compiler.machine import MachineModel + return MachineModel - level = logging.DEBUG if verbose else logging.INFO - logging.getLogger('bqskit').setLevel(level) - - -def disable_logging() -> None: - """Disable logging for BQSKit.""" - logging.getLogger('bqskit').setLevel(logging.CRITICAL) - - -def enable_dashboard() -> None: - import warnings - warnings.warn( - 'Dask has been removed from BQSKit. As a result, the' - ' enable_dashboard method has been removed.' - 'This warning will turn into an error in a future update.', - DeprecationWarning, - ) - - -def disable_dashboard() -> None: - import warnings - warnings.warn( - 'Dask has been removed from BQSKit. As a result, the' - ' disable_dashboard method has been removed.' - 'This warning will turn into an error in a future update.', - DeprecationWarning, - ) - - -def disable_parallelism() -> None: - import warnings - warnings.warn( - 'The disable_parallelism method has been removed.' - ' Instead, set the "num_workers" parameter to 1 during ' - 'Compiler construction. This warning will turn into' - 'an error in a future update.', - DeprecationWarning, - ) - - -def enable_parallelism() -> None: - import warnings - warnings.warn( - 'The enable_parallelism method has been removed.' - ' Instead, set the "num_workers" parameter to 1 during ' - 'Compiler construction. This warning will turn into' - 'an error in a future update.', - DeprecationWarning, - ) + raise AttributeError(f'module {__name__} has no attribute {name}') __all__ = [ @@ -96,6 +33,3 @@ def enable_parallelism() -> None: 'enable_logging', 'disable_logging', ] - -# Register supported languages -_register_language('qasm', _qasm()) diff --git a/bqskit/_logging.py b/bqskit/_logging.py new file mode 100644 index 000000000..73ec1bef8 --- /dev/null +++ b/bqskit/_logging.py @@ -0,0 +1,36 @@ +"""This module contains the logging configuration and methods for BQSKit.""" +import logging +from sys import stdout as _stdout + + +_logging_initialized = False + + +def enable_logging(verbose: bool = False) -> None: + """ + Enable logging for BQSKit. + + Args: + verbose (bool): If set to True, will print more verbose messages. + Defaults to False. + """ + global _logging_initialized + if not _logging_initialized: + _logger = logging.getLogger('bqskit') + _handler = logging.StreamHandler(_stdout) + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' %(name)s: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + _logger.addHandler(_handler) + _logging_initialized = True + + level = logging.DEBUG if verbose else logging.INFO + logging.getLogger('bqskit').setLevel(level) + + +def disable_logging() -> None: + """Disable logging for BQSKit.""" + logging.getLogger('bqskit').setLevel(logging.CRITICAL) diff --git a/bqskit/version.py b/bqskit/_version.py similarity index 100% rename from bqskit/version.py rename to bqskit/_version.py diff --git a/bqskit/compiler/compile.py b/bqskit/compiler/compile.py index 87db9db93..e8d4aa90d 100644 --- a/bqskit/compiler/compile.py +++ b/bqskit/compiler/compile.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import math import warnings from typing import Any from typing import Literal @@ -10,8 +11,6 @@ from typing import TYPE_CHECKING from typing import Union -import numpy as np - from bqskit.compiler.compiler import Compiler from bqskit.compiler.machine import MachineModel from bqskit.compiler.passdata import PassData @@ -582,7 +581,7 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput: if error_threshold is not None: for i, data in enumerate(datas): error = data.error - nonsq_error = 1 - np.sqrt(max(1 - (error * error), 0)) + nonsq_error = 1 - math.sqrt(max(1 - (error * error), 0)) if nonsq_error > error_threshold: warnings.warn( 'Upper bound on error is greater than set threshold:' @@ -631,7 +630,7 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput: # Log error if necessary if error_threshold is not None: error = data.error - nonsq_error = 1 - np.sqrt(max(1 - (error * error), 0)) + nonsq_error = 1 - math.sqrt(max(1 - (error * error), 0)) if nonsq_error > error_threshold: warnings.warn( 'Upper bound on error is greater than set threshold:' diff --git a/bqskit/ir/__init__.py b/bqskit/ir/__init__.py index 10d0e4342..9959e3f04 100644 --- a/bqskit/ir/__init__.py +++ b/bqskit/ir/__init__.py @@ -62,6 +62,8 @@ from bqskit.ir.interval import CycleInterval from bqskit.ir.interval import IntervalLike from bqskit.ir.iterator import CircuitIterator +from bqskit.ir.lang import register_language as _register_language +from bqskit.ir.lang.qasm2 import OPENQASM2Language as _qasm from bqskit.ir.location import CircuitLocation from bqskit.ir.location import CircuitLocationLike from bqskit.ir.operation import Operation @@ -71,6 +73,11 @@ from bqskit.ir.region import CircuitRegionLike from bqskit.ir.structure import CircuitStructure + +# Register supported languages +_register_language('qasm', _qasm()) + + __all__ = [ 'Operation', 'Circuit', diff --git a/setup.py b/setup.py index a546d6267..0c02a9dec 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ root_dir_path = os.path.abspath(os.path.dirname(__file__)) pkg_dir_path = os.path.join(root_dir_path, 'bqskit') readme_path = os.path.join(root_dir_path, 'README.md') -version_path = os.path.join(pkg_dir_path, 'version.py') +version_path = os.path.join(pkg_dir_path, '_version.py') # Load Version Number with open(version_path) as version_file: From 011142a7bacacaf545781974b9400b19899539e5 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Wed, 3 Apr 2024 07:26:33 -0400 Subject: [PATCH 18/44] Skip thread counting test on macOS --- tests/runtime/test_attached.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/runtime/test_attached.py b/tests/runtime/test_attached.py index 3996007fe..db70ec8bc 100644 --- a/tests/runtime/test_attached.py +++ b/tests/runtime/test_attached.py @@ -61,15 +61,15 @@ def test_create_workers(num_workers: int) -> None: def test_two_thread_per_worker() -> None: - # On windows we aren't sure how the threads are handeled if sys.platform == 'win32': - return + pytest.skip('Not sure how to count threads on Windows.') + + if sys.platform == 'darwin': + pytest.skip('MacOS requires permissions to count threads.') compiler = Compiler(num_workers=1) assert compiler.p is not None assert len(psutil.Process(compiler.p.pid).children()) in [1, 2] - if sys.platform == 'darwin': - print(psutil.Process(compiler.p.pid).children()[0].threads()) assert psutil.Process(compiler.p.pid).children()[0].num_threads() == 2 compiler.close() From 7571d0118ab320e8a4d4ab8da4ec76529b104efa Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Wed, 3 Apr 2024 07:27:39 -0400 Subject: [PATCH 19/44] pre-commit --- bqskit/_logging.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bqskit/_logging.py b/bqskit/_logging.py index 73ec1bef8..59b079156 100644 --- a/bqskit/_logging.py +++ b/bqskit/_logging.py @@ -1,4 +1,6 @@ """This module contains the logging configuration and methods for BQSKit.""" +from __future__ import annotations + import logging from sys import stdout as _stdout From d4ff91726f274961be68799b0be4a183f8d3e8e0 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Wed, 3 Apr 2024 08:31:41 -0400 Subject: [PATCH 20/44] Addresses #211 --- bqskit/compiler/compiler.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index b306004dd..00f2b54ab 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -109,7 +109,7 @@ def __init__( ip = 'localhost' self._start_server(num_workers, runtime_log_level, worker_port) - self._connect_to_server(ip, port) + self._connect_to_server(ip, port, self.p is not None) def _start_server( self, @@ -132,17 +132,36 @@ def _start_server( self.p = Popen([sys.executable, '-c', launch_str], creationflags=flags) _logger.debug('Starting runtime server process.') - def _connect_to_server(self, ip: str, port: int) -> None: + def _connect_to_server(self, ip: str, port: int, attached: bool) -> None: """Connect to a runtime server at `ip` and `port`.""" max_retries = 8 wait_time = .25 - for _ in range(max_retries): + current_retry = 0 + while current_retry < max_retries or attached: try: family = 'AF_INET' if sys.platform == 'win32' else None conn = Client((ip, port), family) except ConnectionRefusedError: + if wait_time > 4: + _logger.warning( + 'Connection refused by runtime server.' + ' Retrying in %s seconds.', wait_time, + ) + if wait_time > 16 and attached: + _logger.warning( + 'Connection is still refused by runtime server.' + ' This may be due to the server not being started.' + ' You may want to check the server logs, by starting' + ' the compiler with "runtime_log_level" set. You' + ' can also try launching the bqskit runtime in' + ' detached mode. See the bqskit runtime documentation' + ' for more information:' + ' https://bqskit.readthedocs.io/en/latest/guides/' + 'distributing.html', + ) time.sleep(wait_time) wait_time *= 2 + current_retry += 1 else: self.conn = conn handle = functools.partial(sigint_handler, compiler=self) From a1b106b738f6b7aab899cb3f3add1f4442c91936 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Wed, 3 Apr 2024 08:47:20 -0400 Subject: [PATCH 21/44] Fixes #181 --- bqskit/runtime/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index f4897b5ec..52fed8556 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -487,10 +487,6 @@ def handle_disconnect(self, conn: Connection) -> None: if conn in self.conn_to_employee_dict: self.handle_shutdown() - def __del__(self) -> None: - """Ensure resources are cleaned up.""" - self.handle_shutdown() - def assign_tasks( self, tasks: Sequence[RuntimeTask], From 1004c05de4e8b53fdce35dcb33261954fba396bb Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Wed, 3 Apr 2024 10:30:59 -0400 Subject: [PATCH 22/44] Attempt to Fix #213 --- bqskit/compiler/compiler.py | 18 +++++++++++++----- tests/runtime/test_attached.py | 20 ++++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 00f2b54ab..019a8fc10 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -103,7 +103,7 @@ def __init__( self.p: Popen | None = None # type: ignore self.conn: Connection | None = None - atexit.register(self.close) + _compiler_instances.add(self) if ip is None: ip = 'localhost' @@ -240,11 +240,10 @@ def close(self) -> None: # Reset interrupt signal handler and remove exit handler if hasattr(self, 'old_signal'): signal.signal(signal.SIGINT, self.old_signal) + del self.old_signal - def __del__(self) -> None: - self.close() - atexit.unregister(self.close) - _logger.debug('Compiler successfully shutdown.') + _compiler_instances.discard(self) + _logger.debug('Compiler has been closed.') def submit( self, @@ -484,3 +483,12 @@ def sigint_handler(signum: int, frame: FrameType, compiler: Compiler) -> None: _logger.critical('Compiler interrupted.') compiler.close() raise KeyboardInterrupt + + +_compiler_instances: set[Compiler] = set() + + +@atexit.register +def _cleanup_compiler_instances() -> None: + for compiler in list(_compiler_instances): + compiler.close() diff --git a/tests/runtime/test_attached.py b/tests/runtime/test_attached.py index db70ec8bc..aa155dc3c 100644 --- a/tests/runtime/test_attached.py +++ b/tests/runtime/test_attached.py @@ -17,16 +17,16 @@ from bqskit.runtime import get_runtime -@pytest.mark.parametrize('num_workers', [1, -1]) -def test_startup_shutdown_transparently(num_workers: int) -> None: - in_num_childs = len(psutil.Process(os.getpid()).children(recursive=True)) - compiler = Compiler(num_workers=num_workers) - assert compiler.p is not None - compiler.__del__() - if sys.platform == 'win32': - time.sleep(1) - out_num_childs = len(psutil.Process(os.getpid()).children(recursive=True)) - assert in_num_childs == out_num_childs +# @pytest.mark.parametrize('num_workers', [1, -1]) +# def test_startup_shutdown_transparently(num_workers: int) -> None: +# in_num_childs = len(psutil.Process(os.getpid()).children(recursive=True)) +# compiler = Compiler(num_workers=num_workers) +# assert compiler.p is not None +# compiler.__del__() +# if sys.platform == 'win32': +# time.sleep(1) +# out_num_childs = len(psutil.Process(os.getpid()).children(recursive=True)) +# assert in_num_childs == out_num_childs @pytest.mark.parametrize('num_workers', [1, -1]) From 20859d022256a4e71184e186ecc768d9ca8a8123 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 4 Apr 2024 08:01:47 -0400 Subject: [PATCH 23/44] More robust startup and a little cleanup --- bqskit/compiler/compiler.py | 4 +- bqskit/runtime/attached.py | 6 +- bqskit/runtime/base.py | 63 ++++++++------ bqskit/runtime/detached.py | 61 +++++++------- bqskit/runtime/manager.py | 10 +-- bqskit/runtime/message.py | 1 + bqskit/runtime/worker.py | 161 ++++++++++++------------------------ 7 files changed, 129 insertions(+), 177 deletions(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 019a8fc10..ba19fbd16 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -168,7 +168,9 @@ def _connect_to_server(self, ip: str, port: int, attached: bool) -> None: self.old_signal = signal.signal(signal.SIGINT, handle) if self.conn is None: raise RuntimeError('Connection unexpectedly none.') - self.conn.send((RuntimeMessage.CONNECT, sys.path)) + msg, payload = self._send_recv(RuntimeMessage.CONNECT, sys.path) + if msg != RuntimeMessage.READY: + raise RuntimeError(f'Unexpected message type: {msg}.') _logger.debug('Successfully connected to runtime server.') return raise RuntimeError('Client connection refused') diff --git a/bqskit/runtime/attached.py b/bqskit/runtime/attached.py index 6cd6a6a98..27a6f7983 100644 --- a/bqskit/runtime/attached.py +++ b/bqskit/runtime/attached.py @@ -59,9 +59,6 @@ def __init__( self.mailboxes: dict[int, ServerMailbox] = {} self.mailbox_counter = 0 - # Start workers - self.spawn_workers(num_workers, worker_port) - # Connect to client client_conn = self.listen_once('localhost', port) self.clients[client_conn] = set() @@ -72,6 +69,9 @@ def __init__( ) self.logger.info('Connected to client.') + # Start workers + self.spawn_workers(num_workers, worker_port, log_level) + def handle_disconnect(self, conn: Connection) -> None: """A client disconnect in attached mode is equal to a shutdown.""" self.handle_shutdown() diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 52fed8556..206a66fc2 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -97,29 +97,6 @@ def get_num_of_tasks_sent_since( raise RuntimeError('Read receipt not found in submit cache.') -def send_outgoing(node: ServerBase) -> None: - """Outgoing thread forwards messages as they are created.""" - while True: - outgoing = node.outgoing.get() - - if not node.running: - # NodeBase's handle_shutdown will put a dummy value in the - # queue to wake the thread up so it can exit safely. - # Hence the node.running check now rather than in the - # while condition. - break - - outgoing[0].send((outgoing[1], outgoing[2])) - node.logger.debug(f'Sent message {outgoing[1].name}.') - - if outgoing[1] == RuntimeMessage.SUBMIT_BATCH: - node.logger.log(1, f'{len(outgoing[2])}\n') - else: - node.logger.log(1, f'{outgoing[2]}\n') - - node.outgoing.task_done() - - def sigint_handler(signum: int, _: FrameType | None, node: ServerBase) -> None: """Interrupt the node.""" if not node.running: @@ -172,7 +149,7 @@ def __init__(self) -> None: # Start outgoing thread self.outgoing: Queue[tuple[Connection, RuntimeMessage, Any]] = Queue() - self.outgoing_thread = Thread(target=send_outgoing, args=(self,)) + self.outgoing_thread = Thread(target=self.send_outgoing, daemon=True) self.outgoing_thread.start() self.logger.info('Started outgoing thread.') @@ -376,9 +353,34 @@ def listen_once(self, ip: str, port: int) -> Connection: listener.close() return conn + def send_outgoing(self) -> None: + """Outgoing thread forwards messages as they are created.""" + while True: + outgoing = self.outgoing.get() + + if not self.running: + # NodeBase's handle_shutdown will put a dummy value in the + # queue to wake the thread up so it can exit safely. + # Hence the node.running check now rather than in the + # while condition. + break + + if outgoing[0].closed: + continue + + outgoing[0].send((outgoing[1], outgoing[2])) + _logger.debug(f'Sent message {outgoing[1].name}.') + + if outgoing[1] == RuntimeMessage.SUBMIT_BATCH: + _logger.log(1, f'[{outgoing[2][0]}] * {len(outgoing[2])}\n') + else: + _logger.log(1, f'{outgoing[2]}\n') + + self.outgoing.task_done() + def run(self) -> None: """Main loop.""" - self.logger.info(f'{self.__class__.__name__} running...') + _logger.info(f'{self.__class__.__name__} running...') try: while self.running: @@ -592,10 +594,17 @@ def get_employee_responsible_for(self, worker_id: int) -> RuntimeEmployee: employee_id = (worker_id - self.lower_id_bound) // self.step_size return self.employees[employee_id] - def broadcast_cancel(self, addr: RuntimeAddress) -> None: + def broadcast(self, msg: RuntimeMessage, payload: Any) -> None: """Broadcast a cancel message to my employees.""" for employee in self.employees: - self.outgoing.put((employee.conn, RuntimeMessage.CANCEL, addr)) + self.outgoing.put((employee.conn, msg, payload)) + + def handle_importpath(self, paths: list[str]) -> None: + """Update the system path with the given paths.""" + for path in paths: + if path not in sys.path: + sys.path.append(path) + self.broadcast(RuntimeMessage.IMPORTPATH, paths) def handle_waiting( self, diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 607af21ed..86d69f322 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -1,6 +1,7 @@ """This module implements the DetachedServer runtime.""" from __future__ import annotations +import sys import argparse import logging import selectors @@ -32,25 +33,6 @@ from bqskit.runtime.task import RuntimeTask -def listen(server: DetachedServer, port: int) -> None: - """Listening thread listens for client connections.""" - listener = Listener(('0.0.0.0', port)) - while server.running: - client = listener.accept() - - if server.running: - # We check again that the server is running before registering - # the client because dummy data is sent to unblock - # listener.accept() during server shutdown - server.clients[client] = set() - server.sel.register( - client, - selectors.EVENT_READ, - MessageDirection.CLIENT, - ) - server.logger.debug('Connected and registered new client.') - - listener.close() @dataclass @@ -117,9 +99,29 @@ def __init__( # Start client listener self.port = port - self.listen_thread = Thread(target=listen, args=(self, port)) + self.listen_thread = Thread(target=self.listen, args=(port,)) + self.listen_thread.daemon = True self.listen_thread.start() self.logger.info(f'Started client listener on port {self.port}.') + def listen(self, port: int) -> None: + """Listening thread listens for client connections.""" + listener = Listener(('0.0.0.0', port)) + while self.running: + client = listener.accept() + + if self.running: + # We check again that the server is running before registering + # the client because dummy data is sent to unblock + # listener.accept() during server shutdown + self.clients[client] = set() + self.sel.register( + client, + selectors.EVENT_READ, + MessageDirection.CLIENT, + ) + _logger.debug('Connected and registered new client.') + + listener.close() def handle_message( self, @@ -132,16 +134,8 @@ def handle_message( if direction == MessageDirection.CLIENT: if msg == RuntimeMessage.CONNECT: - # paths, serialized_defintions = cast(List[str], payload) paths = cast(List[str], payload) - import sys - for path in paths: - if path not in sys.path: - sys.path.append(path) - for employee in self.employees: - employee.conn.send( - (RuntimeMessage.IMPORTPATH, path), - ) + self.handle_connect(conn, paths) elif msg == RuntimeMessage.DISCONNECT: self.handle_disconnect(conn) @@ -187,7 +181,7 @@ def handle_message( self.handle_log(payload) elif msg == RuntimeMessage.CANCEL: - self.broadcast_cancel(payload) + self.broadcast(msg, payload) elif msg == RuntimeMessage.SHUTDOWN: self.handle_shutdown() @@ -207,6 +201,11 @@ def handle_message( else: raise RuntimeError(f'Unexpected message from {direction.name}.') + def handle_connect(self, conn: Connection, paths: list[str]) -> None: + """Handle a client connection request.""" + self.handle_importpath(paths) + self.outgoing.put((conn, RuntimeMessage.READY, None)) + def handle_system_error(self, error_str: str) -> None: """ Handle an error in runtime code as opposed to client code. @@ -331,7 +330,7 @@ def handle_cancel_comp_task(self, request: uuid.UUID) -> None: # Forward internal cancel messages addr = RuntimeAddress(-1, mailbox_id, 0) - self.broadcast_cancel(addr) + self.broadcast(RuntimeMessage.CANCEL, addr) # Acknowledge the client's cancel request if not client_conn.closed: diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index c8725bf1c..72fdabf25 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -146,18 +146,14 @@ def handle_message( self.send_result_down(result) elif msg == RuntimeMessage.CANCEL: - addr = cast(RuntimeAddress, payload) - self.broadcast_cancel(addr) + self.broadcast(RuntimeMessage.CANCEL, payload) elif msg == RuntimeMessage.SHUTDOWN: self.handle_shutdown() elif msg == RuntimeMessage.IMPORTPATH: - import_path = cast(str, payload) - import sys - sys.path.append(import_path) - for employee in self.employees: - employee.conn.send((RuntimeMessage.IMPORTPATH, import_path)) + paths = cast(List[str], payload) + self.handle_importpath(paths) else: raise RuntimeError(f'Unexpected message type: {msg.name}') diff --git a/bqskit/runtime/message.py b/bqskit/runtime/message.py index aed9f6cfa..c975099c8 100644 --- a/bqskit/runtime/message.py +++ b/bqskit/runtime/message.py @@ -21,3 +21,4 @@ class RuntimeMessage(IntEnum): WAITING = 12 UPDATE = 13 IMPORTPATH = 14 + READY = 15 diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 7c4c8434d..8c96f0d6a 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -30,31 +30,6 @@ from bqskit.runtime.task import RuntimeTask -class WorkerQueue(): - """The worker's task FIFO queue.""" - - def __init__(self) -> None: - """ - Initialize the worker queue. - - An OrderedDict is used to internally store the task. This prevents the - same task appearing multiple times in the queue, while also ensuring - O(1) operations. - """ - self._queue: OrderedDict[RuntimeAddress, None] = OrderedDict() - - def put(self, addr: RuntimeAddress) -> None: - """Enqueue a task by its address.""" - if addr not in self._queue: - self._queue[addr] = None - - def get(self) -> RuntimeAddress: - """Get the next task to run.""" - return self._queue.popitem(last=False)[0] - - def empty(self) -> bool: - """Check if the queue is empty.""" - return len(self._queue) == 0 @dataclass @@ -125,59 +100,6 @@ def deposit_result(self, result: RuntimeResult) -> None: self.result[slot_id] = result.result -def handle_incoming_comms(worker: Worker) -> None: - """Handle all incoming messages.""" - while True: - # Handle incomming communication - try: - msg, payload = worker._conn.recv() - except Exception: - print(f'Worker {worker._id} crashed due to lost connection') - worker._running = False - worker._ready_task_ids.put(RuntimeAddress(-1, -1, -1)) - break - - # Process message - if msg == RuntimeMessage.SHUTDOWN: - print(f'Worker {worker._id} received shutdown message') - worker._running = False - worker._ready_task_ids.put(RuntimeAddress(-1, -1, -1)) - # TODO: Interupt main, maybe even kill it - break - - elif msg == RuntimeMessage.SUBMIT: - # print('Worker received submit message') - worker.read_receipt_mutex.acquire() - task = cast(RuntimeTask, payload) - worker.most_recent_read_submit = task.unique_id - worker._add_task(task) - worker.read_receipt_mutex.release() - - elif msg == RuntimeMessage.SUBMIT_BATCH: - worker.read_receipt_mutex.acquire() - tasks = cast(List[RuntimeTask], payload) - worker.most_recent_read_submit = tasks[0].unique_id - worker._add_task(tasks.pop()) # Submit one task - worker._delayed_tasks.extend(tasks) # Delay rest - # Delayed tasks have no context and are stored (more-or-less) - # as a function pointer together with the arguments. - # When it gets started, it consumes much more memory, - # so we delay the task start until necessary (at no cost) - worker.read_receipt_mutex.release() - - elif msg == RuntimeMessage.RESULT: - result = cast(RuntimeResult, payload) - worker._handle_result(result) - - elif msg == RuntimeMessage.CANCEL: - addr = cast(RuntimeAddress, payload) - worker._handle_cancel(addr) - # TODO: preempt? - - elif msg == RuntimeMessage.IMPORTPATH: - import_path = cast(str, payload) - sys.path.append(import_path) - class Worker: """ @@ -237,17 +159,12 @@ def __init__(self, id: int, conn: Connection) -> None: self._id = id self._conn = conn - # self._outgoing: list[tuple[RuntimeMessage, Any]] = [] - # self._outgoing: Queue[tuple[RuntimeMessage, Any]] = Queue() - # """Stores outgoing messages to be handled by the event loop.""" - self._tasks: dict[RuntimeAddress, RuntimeTask] = {} """Tracks all started, unfinished tasks on this worker.""" self._delayed_tasks: list[RuntimeTask] = [] """Store all delayed tasks in LIFO order.""" - # self._ready_task_ids: WorkerQueue = WorkerQueue() self._ready_task_ids: Queue[RuntimeAddress] = Queue() """Tasks queued up for execution.""" @@ -257,7 +174,7 @@ def __init__(self, id: int, conn: Connection) -> None: self._active_task: RuntimeTask | None = None """The currently executing task if one is running.""" - self._running = False + self._running = True """Controls if the event loop is running.""" self._mailboxes: dict[int, WorkerMailbox] = {} @@ -291,20 +208,15 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: logging.setLogRecordFactory(record_factory) # Start incoming thread - self.incomming_thread = Thread( - target=handle_incoming_comms, - args=(self,), - ) + self.incomming_thread = Thread(target=self.recv_incoming) self.incomming_thread.daemon = True self.incomming_thread.start() - # self.logger.info('Started incoming thread.') # Communicate that this worker is ready self._conn.send((RuntimeMessage.STARTED, self._id)) def _loop(self) -> None: """Main worker event loop.""" - self._running = True while self._running: try: self._try_step_next_ready_task() @@ -316,24 +228,57 @@ def _loop(self) -> None: self._conn.send((RuntimeMessage.ERROR, error_str)) except Exception: pass - # self._try_idle() - # self._handle_comms() - - # def _try_idle(self) -> None: - # """If there is nothing to do, wait until we receive a message.""" - # empty_outgoing = len(self._outgoing) == 0 - # no_ready_tasks = self._ready_task_ids.empty() - # no_delayed_tasks = len(self._delayed_tasks) == 0 - - # if empty_outgoing and no_ready_tasks and no_delayed_tasks: - # self._conn.send((RuntimeMessage.WAITING, 1)) - # wait([self._conn]) - - # def _flush_outgoing_comms(self) -> None: - # """Handle all outgoing messages.""" - # for out_msg in self._outgoing: - # self._conn.send(out_msg) - # self._outgoing.clear() + + def recv_incoming(self) -> None: + """Continuously receive all incoming messages.""" + while self._running: + # Receive message + try: + msg, payload = self._conn.recv() + except Exception: + _logger.debug('Crashed due to lost connection') + os.kill(os.getpid(), signal.SIGKILL) + + _logger.debug(f'Received message {msg.name}.') + _logger.log(1, f'Payload: {payload}') + + # Process message + if msg == RuntimeMessage.SHUTDOWN: + os.kill(os.getpid(), signal.SIGKILL) + + elif msg == RuntimeMessage.SUBMIT: + self.read_receipt_mutex.acquire() + task = cast(RuntimeTask, payload) + self.most_recent_read_submit = task.unique_id + self._add_task(task) + self.read_receipt_mutex.release() + + elif msg == RuntimeMessage.SUBMIT_BATCH: + self.read_receipt_mutex.acquire() + tasks = cast(List[RuntimeTask], payload) + self.most_recent_read_submit = tasks[0].unique_id + self._add_task(tasks.pop()) # Submit one task + self._delayed_tasks.extend(tasks) # Delay rest + # Delayed tasks have no context and are stored (more-or-less) + # as a function pointer together with the arguments. + # When it gets started, it consumes much more memory, + # so we delay the task start until necessary (at no cost) + self.read_receipt_mutex.release() + + elif msg == RuntimeMessage.RESULT: + result = cast(RuntimeResult, payload) + self._handle_result(result) + + elif msg == RuntimeMessage.CANCEL: + addr = cast(RuntimeAddress, payload) + self._handle_cancel(addr) + # TODO: preempt? + + elif msg == RuntimeMessage.IMPORTPATH: + paths = cast(List[str], payload) + for path in paths: + if path not in sys.path: + sys.path.append(path) def _add_task(self, task: RuntimeTask) -> None: """Start a task and add it to the loop.""" From f9edea39396ea7b0a8eaf8c848361092a1ef2dec Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 4 Apr 2024 08:02:22 -0400 Subject: [PATCH 24/44] Logging overhaul start --- bqskit/compiler/compiler.py | 2 +- bqskit/runtime/attached.py | 28 +++++++++++-------- bqskit/runtime/base.py | 55 +++++++++++++++++++----------------- bqskit/runtime/detached.py | 33 ++++++++++++++-------- bqskit/runtime/manager.py | 18 +++++++++--- bqskit/runtime/task.py | 4 ++- bqskit/runtime/worker.py | 56 +++++++++++++++++++++++++++---------- 7 files changed, 128 insertions(+), 68 deletions(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index ba19fbd16..ec03d5019 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -122,7 +122,7 @@ def _start_server( See :obj:`~bqskit.runtime.attached.AttachedServer` for more info. """ - params = f'{num_workers}, {runtime_log_level}, {worker_port=}' + params = f'{num_workers}, log_level={runtime_log_level}, {worker_port=}' import_str = 'from bqskit.runtime.attached import start_attached_server' launch_str = f'{import_str}; start_attached_server({params})' if sys.platform == 'win32': diff --git a/bqskit/runtime/attached.py b/bqskit/runtime/attached.py index 27a6f7983..a1067d0fa 100644 --- a/bqskit/runtime/attached.py +++ b/bqskit/runtime/attached.py @@ -15,6 +15,9 @@ from bqskit.runtime.direction import MessageDirection +_logger = logging.getLogger(__name__) + + class AttachedServer(DetachedServer): """ BQSKit Runtime Server in attached mode. @@ -33,6 +36,7 @@ def __init__( num_workers: int = -1, port: int = default_server_port, worker_port: int = default_worker_port, + log_level: int = logging.WARNING, ) -> None: """ Create a server with `num_workers` workers. @@ -50,6 +54,17 @@ def __init__( on. Default can be found in the :obj:`~bqskit.runtime.default_worker_port` global variable. """ + # Initialize runtime logging + logging.getLogger().setLevel(log_level) + _handler = logging.StreamHandler() + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' %(module)s: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + logging.getLogger().addHandler(_handler) + ServerBase.__init__(self) # See DetachedServer for more info on the following fields: @@ -67,7 +82,7 @@ def __init__( selectors.EVENT_READ, MessageDirection.CLIENT, ) - self.logger.info('Connected to client.') + _logger.info('Connected to client.') # Start workers self.spawn_workers(num_workers, worker_port, log_level) @@ -77,17 +92,8 @@ def handle_disconnect(self, conn: Connection) -> None: self.handle_shutdown() -def start_attached_server( - num_workers: int, - log_level: int, - **kwargs: Any, -) -> None: +def start_attached_server(num_workers: int, **kwargs: Any) -> None: """Start a runtime server in attached mode.""" - # Initialize runtime logging - _logger = logging.getLogger('bqskit-runtime') - _logger.setLevel(log_level) - _logger.addHandler(logging.StreamHandler()) - # Initialize the server server = AttachedServer(num_workers, **kwargs) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 206a66fc2..020a17fa5 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -33,6 +33,9 @@ from bqskit.runtime.worker import start_worker +_logger = logging.getLogger(__name__) + + class RuntimeEmployee: """Data structure for a boss's view of an employee.""" @@ -104,7 +107,7 @@ def sigint_handler(signum: int, _: FrameType | None, node: ServerBase) -> None: node.running = False node.terminate_hotline.send(b'\0') - node.logger.info('Server interrupted.') + _logger.info('Server interrupted.') class ServerBase: @@ -134,9 +137,6 @@ def __init__(self) -> None: self.sel.register(p, selectors.EVENT_READ, MessageDirection.SIGNAL) """Terminate hotline is used to unblock select while running.""" - self.logger = logging.getLogger('bqskit-runtime') - """Logger used to print operational log messages.""" - self.employees: list[RuntimeEmployee] = [] """Tracks this node's employees, which are managers or workers.""" @@ -151,7 +151,7 @@ def __init__(self) -> None: self.outgoing: Queue[tuple[Connection, RuntimeMessage, Any]] = Queue() self.outgoing_thread = Thread(target=self.send_outgoing, daemon=True) self.outgoing_thread.start() - self.logger.info('Started outgoing thread.') + _logger.info('Started outgoing thread.') def connect_to_managers(self, ipports: Sequence[tuple[str, int]]) -> None: """Connect to all managers given by endpoints in `ipports`.""" @@ -167,8 +167,8 @@ def connect_to_managers(self, ipports: Sequence[tuple[str, int]]) -> None: self.upper_id_bound, ) manager_conns.append(self.connect_to_manager(ip, port, lb, ub)) - self.logger.info(f'Connected to manager {i} at {ip}:{port}.') - self.logger.debug(f'Gave bounds {lb=} and {ub=} to manager {i}.') + _logger.info(f'Connected to manager {i} at {ip}:{port}.') + _logger.debug(f'Gave bounds {lb=} and {ub=} to manager {i}.') # Wait for started messages from all managers and register them self.total_workers = 0 @@ -182,11 +182,11 @@ def connect_to_managers(self, ipports: Sequence[tuple[str, int]]) -> None: selectors.EVENT_READ, MessageDirection.BELOW, ) - self.logger.info(f'Registered manager {i} with {num_workers=}.') + _logger.info(f'Registered manager {i} with {num_workers=}.') self.total_workers += num_workers self.num_idle_workers = self.total_workers - self.logger.info(f'Node has {self.total_workers} total workers.') + _logger.info(f'Node has {self.total_workers} total workers.') def connect_to_manager( self, @@ -228,6 +228,7 @@ def spawn_workers( self, num_workers: int = -1, port: int = default_worker_port, + logging_level: int = logging.WARNING, ) -> None: """ Spawn worker processes. @@ -252,10 +253,14 @@ def spawn_workers( procs = {} for i in range(num_workers): w_id = self.lower_id_bound + i - procs[w_id] = Process(target=start_worker, args=(w_id, port)) + procs[w_id] = Process( + target=start_worker, + args=(w_id, port), + kwargs={'logging_level': logging_level} + ) procs[w_id].daemon = True procs[w_id].start() - self.logger.debug(f'Stated worker process {i}.') + _logger.debug(f'Stated worker process {i}.') # Listen for the worker connections family = 'AF_INET' if sys.platform == 'win32' else None @@ -283,12 +288,12 @@ def spawn_workers( selectors.EVENT_READ, MessageDirection.BELOW, ) - self.logger.info(f'Registered worker {i}.') + _logger.debug(f'Registered worker {i}.') self.step_size = 1 self.total_workers = num_workers self.num_idle_workers = num_workers - self.logger.info(f'Node has spawned {num_workers} workers.') + _logger.info(f'Node has spawned {num_workers} workers.') def connect_to_workers( self, @@ -311,7 +316,7 @@ def connect_to_workers( oscount = os.cpu_count() num_workers = oscount if oscount else 1 - self.logger.info(f'Expecting {num_workers} worker connections.') + _logger.info(f'Expecting {num_workers} worker connections.') if self.lower_id_bound + num_workers >= self.upper_id_bound: raise RuntimeError('Insufficient id range for workers.') @@ -338,12 +343,12 @@ def connect_to_workers( selectors.EVENT_READ, MessageDirection.BELOW, ) - self.logger.info(f'Registered worker {i}.') + _logger.info(f'Registered worker {i}.') self.step_size = 1 self.total_workers = num_workers self.num_idle_workers = num_workers - self.logger.info(f'Node has connected to {num_workers} workers.') + _logger.info(f'Node has connected to {num_workers} workers.') def listen_once(self, ip: str, port: int) -> Connection: """Listen on `ip`:`port` for a connection and return on first one.""" @@ -394,7 +399,7 @@ def run(self) -> None: # If interrupted by signal, shutdown and exit if direction == MessageDirection.SIGNAL: - self.logger.debug('Received interrupt signal.') + _logger.debug('Received interrupt signal.') self.handle_shutdown() return @@ -405,11 +410,11 @@ def run(self) -> None: self.handle_disconnect(conn) continue log = f'Received message {msg.name} from {direction.name}.' - self.logger.debug(log) + _logger.debug(log) if msg == RuntimeMessage.SUBMIT_BATCH: - self.logger.log(1, f'{len(payload)}\n') + _logger.log(1, f'[{payload[0]}] * {len(payload)}\n') else: - self.logger.log(1, f'{payload}\n') + _logger.log(1, f'{payload}\n') # Handle message self.handle_message(msg, direction, conn, payload) @@ -417,7 +422,7 @@ def run(self) -> None: except Exception: exc_info = sys.exc_info() error_str = ''.join(traceback.format_exception(*exc_info)) - self.logger.error(error_str) + _logger.error(error_str) self.handle_system_error(error_str) finally: @@ -456,7 +461,7 @@ def handle_system_error(self, error_str: str) -> None: def handle_shutdown(self) -> None: """Shutdown the node and release resources.""" # Stop running - self.logger.info('Shutting down node.') + _logger.info('Shutting down node.') self.running = False # Instruct employees to shutdown @@ -467,17 +472,17 @@ def handle_shutdown(self) -> None: employee.complete_shutdown() self.employees.clear() - self.logger.debug('Shutdown employees.') + _logger.debug('Shutdown employees.') # Close selector self.sel.close() - self.logger.debug('Cleared selector.') + _logger.debug('Cleared selector.') # Close outgoing thread if self.outgoing_thread.is_alive(): self.outgoing.put(b'\0') # type: ignore self.outgoing_thread.join() - self.logger.debug('Joined outgoing thread.') + _logger.debug('Joined outgoing thread.') assert not self.outgoing_thread.is_alive() def handle_disconnect(self, conn: Connection) -> None: diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 86d69f322..1c447eda4 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -33,6 +33,7 @@ from bqskit.runtime.task import RuntimeTask +_logger = logging.getLogger(__name__) @dataclass @@ -102,7 +103,8 @@ def __init__( self.listen_thread = Thread(target=self.listen, args=(port,)) self.listen_thread.daemon = True self.listen_thread.start() - self.logger.info(f'Started client listener on port {self.port}.') + _logger.info(f'Started client listener on port {self.port}.') + def listen(self, port: int) -> None: """Listening thread listens for client connections.""" listener = Listener(('0.0.0.0', port)) @@ -230,7 +232,7 @@ def handle_shutdown(self) -> None: except Exception: pass self.clients.clear() - self.logger.debug('Cleared clients.') + _logger.debug('Cleared clients.') # Close listener (hasattr checked for attachedserver shutdown) if hasattr(self, 'listen_thread') and self.listen_thread.is_alive(): @@ -242,7 +244,7 @@ def handle_shutdown(self) -> None: dummy_socket.connect(('localhost', self.port)) dummy_socket.close() self.listen_thread.join() - self.logger.debug('Joined listening thread.') + _logger.debug('Joined listening thread.') def handle_disconnect(self, conn: Connection) -> None: """Disconnect a client connection from the runtime.""" @@ -250,7 +252,7 @@ def handle_disconnect(self, conn: Connection) -> None: tasks = self.clients.pop(conn) for task_id in tasks: self.handle_cancel_comp_task(task_id) - self.logger.info('Unregistered client.') + _logger.info('Unregistered client.') def handle_new_comp_task( self, @@ -262,7 +264,7 @@ def handle_new_comp_task( self.tasks[task.task_id] = (mailbox_id, conn) self.mailbox_to_task_dict[mailbox_id] = task.task_id self.mailboxes[mailbox_id] = ServerMailbox() - self.logger.info(f'New CompilationTask: {task.task_id}.') + _logger.info(f'New CompilationTask: {task.task_id}.') self.clients[conn].add(task.task_id) @@ -290,7 +292,7 @@ def handle_request(self, conn: Connection, request: uuid.UUID) -> None: if box.ready: # If the result has already arrived, ship it to the client. - self.logger.info(f'Responding to request for task {request}.') + _logger.info(f'Responding to request for task {request}.') self.outgoing.put((conn, RuntimeMessage.RESULT, box.result)) self.mailboxes.pop(mailbox_id) self.clients[conn].remove(request) @@ -320,7 +322,7 @@ def handle_status(self, conn: Connection, request: uuid.UUID) -> None: def handle_cancel_comp_task(self, request: uuid.UUID) -> None: """Cancel a compilation task in the system.""" - self.logger.info(f'Cancelling: {request}.') + _logger.info(f'Cancelling: {request}.') # Remove task from server data mailbox_id, client_conn = self.tasks[request] @@ -351,10 +353,10 @@ def handle_result(self, result: RuntimeResult) -> None: box = self.mailboxes[mailbox_id] box.result = result.result t_id = self.mailbox_to_task_dict[mailbox_id] - self.logger.info(f'Finished: {t_id}.') + _logger.info(f'Finished: {t_id}.') if box.client_waiting: - self.logger.info(f'Responding to request for task {t_id}.') + _logger.info(f'Responding to request for task {t_id}.') m = (self.tasks[t_id][1], RuntimeMessage.RESULT, box.result) self.outgoing.put(m) self.clients[self.tasks[t_id][1]].remove(t_id) @@ -430,9 +432,16 @@ def start_server() -> None: ipports = parse_ipports(args.managers) # Set up logging - _logger = logging.getLogger('bqskit-runtime') - _logger.setLevel([30, 20, 10, 1][min(args.verbose, 3)]) - _logger.addHandler(logging.StreamHandler()) + log_level = [30, 20, 10, 1][min(args.verbose, 3)] + logging.getLogger().setLevel(log_level) + _handler = logging.StreamHandler() + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' %(module)s: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + logging.getLogger().addHandler(_handler) # Import tests package recursively if args.import_tests: diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 72fdabf25..6dcbd4f5c 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -25,6 +25,9 @@ from bqskit.runtime.task import RuntimeTask +_logger = logging.getLogger(__name__) + + class Manager(ServerBase): """ BQSKit Runtime Manager. @@ -117,7 +120,7 @@ def __init__( # Inform upstream we are starting msg = (self.upstream, RuntimeMessage.STARTED, self.total_workers) self.outgoing.put(msg) - self.logger.info('Sent start message upstream.') + _logger.info('Sent start message upstream.') def handle_message( self, @@ -318,9 +321,16 @@ def start_manager() -> None: ipports = None if args.managers is None else parse_ipports(args.managers) # Set up logging - _logger = logging.getLogger('bqskit-runtime') - _logger.setLevel([30, 20, 10, 1][min(args.verbose, 3)]) - _logger.addHandler(logging.StreamHandler()) + log_level = [30, 20, 10, 1][min(args.verbose, 3)] + logging.getLogger().setLevel(log_level) + _handler = logging.StreamHandler() + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' %(module)s: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + logging.getLogger().addHandler(_handler) # Import tests package recursively if args.import_tests: diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index ccffa3b9a..9ac612c6c 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -103,7 +103,9 @@ def step(self, send_val: Any = None) -> Any: self.max_logging_depth < 0 or len(self.breadcrumbs) <= self.max_logging_depth ): - logging.getLogger().setLevel(0) + logging.getLogger().setLevel(self.logging_level) + else: + logging.getLogger().setLevel(100) # Execute a task step to_return = self.coro.send(send_val) diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 8c96f0d6a..c0de57a8e 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -30,6 +30,7 @@ from bqskit.runtime.task import RuntimeTask +_logger = logging.getLogger(__name__) @dataclass @@ -197,12 +198,13 @@ def __init__(self, id: int, conn: Connection) -> None: def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: record = old_factory(*args, **kwargs) - active_task = get_worker()._active_task - if active_task is not None: - lvl = active_task.logging_level - if lvl is None or lvl <= record.levelno: - tid = active_task.comp_task_id - self._conn.send((RuntimeMessage.LOG, (tid, record))) + active_task = self._active_task + if not record.name.startswith('bqskit.runtime'): + if active_task is not None: + lvl = active_task.logging_level + if lvl is None or lvl <= record.levelno: + tid = active_task.comp_task_id + self._conn.send((RuntimeMessage.LOG, (tid, record))) return record logging.setLogRecordFactory(record_factory) @@ -211,6 +213,7 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: self.incomming_thread = Thread(target=self.recv_incoming) self.incomming_thread.daemon = True self.incomming_thread.start() + _logger.debug('Started incoming thread.') # Communicate that this worker is ready self._conn.send((RuntimeMessage.STARTED, self._id)) @@ -224,6 +227,7 @@ def _loop(self) -> None: self._running = False exc_info = sys.exc_info() error_str = ''.join(traceback.format_exception(*exc_info)) + _logger.error(error_str) try: self._conn.send((RuntimeMessage.ERROR, error_str)) except Exception: @@ -626,19 +630,21 @@ async def next(self, future: RuntimeFuture) -> list[tuple[int, Any]]: _worker = None -def start_worker(w_id: int | None, port: int, cpu: int | None = None) -> None: +def start_worker( + w_id: int | None, + port: int, + cpu: int | None = None, + logging_level: int = logging.WARNING, +) -> None: """Start this process's worker.""" if w_id is not None: # Ignore interrupt signals on workers, boss will handle it for us # If w_id is None, then we are being spawned separately. signal.signal(signal.SIGINT, signal.SIG_IGN) - # Purge all standard python logging configurations - for _, logger in logging.Logger.manager.loggerDict.items(): - if isinstance(logger, logging.PlaceHolder): - continue - logger.handlers.clear() - logging.Logger.manager.loggerDict = {} + # Enforce no default logging + # logging.lastResort.setLevel(100) + # logging.getLogger().handlers.clear() # Pin worker to cpu if cpu is not None: @@ -668,6 +674,19 @@ def start_worker(w_id: int | None, port: int, cpu: int | None = None) -> None: msg, w_id = conn.recv() assert msg == RuntimeMessage.STARTED + # Set up runtime logging + _runtime_logger = logging.getLogger('bqskit.runtime') + _runtime_logger.propagate = False + _runtime_logger.setLevel(logging_level) + _handler = logging.StreamHandler() + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' [wid=%(wid)s]: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S', defaults={'wid': w_id}) + _handler.setFormatter(_formatter) + _runtime_logger.addHandler(_handler) + # Build and start worker global _worker _worker = Worker(w_id, conn) @@ -716,6 +735,12 @@ def start_worker_rank() -> None: default=default_worker_port, help='The port the workers will try to connect to a manager on.', ) + parser.add_argument( + '-v', '--verbose', + action='count', + default=0, + help='Enable logging of increasing verbosity, either -v, -vv, or -vvv.', + ) args = parser.parse_args() if args.cpus is not None: @@ -735,10 +760,13 @@ def start_worker_rank() -> None: else: cpus = [None for _ in range(args.num_workers)] + logging_level = [30, 20, 10, 1][min(args.verbose, 3)] + # Spawn worker process procs = [] for cpu in cpus: - procs.append(Process(target=start_worker, args=(None, args.port, cpu))) + args = (None, args.port, cpu, logging_level) + procs.append(Process(target=start_worker, args=args)) procs[-1].start() # Join them From f549a4f9bd58e5036d661ba614f6c7c2a95db2c5 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 4 Apr 2024 08:07:14 -0400 Subject: [PATCH 25/44] pre-commit --- bqskit/runtime/base.py | 2 +- bqskit/runtime/detached.py | 1 - bqskit/runtime/task.py | 2 +- bqskit/runtime/worker.py | 6 ++---- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 020a17fa5..d98e7c672 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -256,7 +256,7 @@ def spawn_workers( procs[w_id] = Process( target=start_worker, args=(w_id, port), - kwargs={'logging_level': logging_level} + kwargs={'logging_level': logging_level}, ) procs[w_id].daemon = True procs[w_id].start() diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 1c447eda4..310072275 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -1,7 +1,6 @@ """This module implements the DetachedServer runtime.""" from __future__ import annotations -import sys import argparse import logging import selectors diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index 9ac612c6c..962ca1ff6 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -53,7 +53,7 @@ def __init__( This doubles as a unique system-wide id for the task. """ - self.logging_level = logging_level + self.logging_level = logging_level or 0 """Logs with levels >= to this get emitted, if None always emit.""" self.comp_task_id = comp_task_id diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index c0de57a8e..fbfd1c5c2 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -8,7 +8,6 @@ import sys import time import traceback -from collections import OrderedDict from dataclasses import dataclass from multiprocessing import Process from multiprocessing.connection import Client @@ -101,7 +100,6 @@ def deposit_result(self, result: RuntimeResult) -> None: self.result[slot_id] = result.result - class Worker: """ BQSKit Runtime's Worker. @@ -765,8 +763,8 @@ def start_worker_rank() -> None: # Spawn worker process procs = [] for cpu in cpus: - args = (None, args.port, cpu, logging_level) - procs.append(Process(target=start_worker, args=args)) + pargs = (None, args.port, cpu, logging_level) + procs.append(Process(target=start_worker, args=pargs)) procs[-1].start() # Join them From c6441465892898d640271ecccac11222a77fba5a Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 4 Apr 2024 09:35:47 -0400 Subject: [PATCH 26/44] BLAS thread control in runtime --- bqskit/compiler/compiler.py | 17 +++++++++++++++-- bqskit/runtime/__init__.py | 26 ++++++++++++++++---------- bqskit/runtime/attached.py | 14 +++++++++++++- bqskit/runtime/base.py | 15 ++++++++++++++- bqskit/runtime/detached.py | 9 ++++----- bqskit/runtime/manager.py | 19 ++++++++++++++++++- bqskit/runtime/worker.py | 17 ++++++++++++++--- 7 files changed, 94 insertions(+), 23 deletions(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index ec03d5019..22fab9feb 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -71,6 +71,7 @@ def __init__( num_workers: int = -1, runtime_log_level: int = logging.WARNING, worker_port: int = default_worker_port, + num_blas_threads: int = 1, ) -> None: """ Construct a Compiler object. @@ -99,6 +100,9 @@ def __init__( worker_port (int): The optional port to pass to an attached runtime. See :obj:`~bqskit.runtime.attached.AttachedServer` for more info. + + num_blas_threads (int): The number of threads to use in the + BLAS libraries on the worker nodes. (Defaults to 1) """ self.p: Popen | None = None # type: ignore self.conn: Connection | None = None @@ -107,7 +111,12 @@ def __init__( if ip is None: ip = 'localhost' - self._start_server(num_workers, runtime_log_level, worker_port) + self._start_server( + num_workers, + runtime_log_level, + worker_port, + num_blas_threads, + ) self._connect_to_server(ip, port, self.p is not None) @@ -116,13 +125,17 @@ def _start_server( num_workers: int, runtime_log_level: int, worker_port: int, + num_blas_threads: int, ) -> None: """ Start an attached serer with `num_workers` workers. See :obj:`~bqskit.runtime.attached.AttachedServer` for more info. """ - params = f'{num_workers}, log_level={runtime_log_level}, {worker_port=}' + params = f'{num_workers}, ' + params += f'log_level={runtime_log_level}, ' + params += f'{worker_port=}, ' + params += f'{num_blas_threads=}, ' import_str = 'from bqskit.runtime.attached import start_attached_server' launch_str = f'{import_str}; start_attached_server({params})' if sys.platform == 'win32': diff --git a/bqskit/runtime/__init__.py b/bqskit/runtime/__init__.py index 0443671b6..2c01c4ddd 100644 --- a/bqskit/runtime/__init__.py +++ b/bqskit/runtime/__init__.py @@ -100,21 +100,27 @@ from typing import Protocol from typing import TYPE_CHECKING +if TYPE_CHECKING: + from bqskit.runtime.future import RuntimeFuture + + # Enable low-level fault handling: system crashes print a minimal trace. faulthandler.enable() +os.environ['RUST_BACKTRACE'] = '1' -# Disable multi-threading in BLAS libraries. -os.environ['OMP_NUM_THREADS'] = '1' -os.environ['OPENBLAS_NUM_THREADS'] = '1' -os.environ['MKL_NUM_THREADS'] = '1' -os.environ['NUMEXPR_NUM_THREADS'] = '1' -os.environ['VECLIB_MAXIMUM_THREADS'] = '1' -os.environ['RUST_BACKTRACE'] = '1' -print('SETTING THREADS TO 1') +# Control multi-threading in BLAS libraries. +def set_blas_thread_counts(i: int = 1) -> None: + """ + Control number of threads used by numpy and others. -if TYPE_CHECKING: - from bqskit.runtime.future import RuntimeFuture + Must be called before any numpy or other BLAS libraries are loaded. + """ + os.environ['OMP_NUM_THREADS'] = str(i) + os.environ['OPENBLAS_NUM_THREADS'] = str(i) + os.environ['MKL_NUM_THREADS'] = str(i) + os.environ['NUMEXPR_NUM_THREADS'] = str(i) + os.environ['VECLIB_MAXIMUM_THREADS'] = str(i) class RuntimeHandle(Protocol): diff --git a/bqskit/runtime/attached.py b/bqskit/runtime/attached.py index a1067d0fa..f68be0ee3 100644 --- a/bqskit/runtime/attached.py +++ b/bqskit/runtime/attached.py @@ -37,6 +37,7 @@ def __init__( port: int = default_server_port, worker_port: int = default_worker_port, log_level: int = logging.WARNING, + num_blas_threads: int = 1, ) -> None: """ Create a server with `num_workers` workers. @@ -53,6 +54,12 @@ def __init__( worker_port (int): The port this server will listen for workers on. Default can be found in the :obj:`~bqskit.runtime.default_worker_port` global variable. + + log_level (int): The logging level for the server and workers. + (Default: logging.WARNING). + + num_blas_threads (int): The number of threads to use in BLAS + libraries. (Default: 1). """ # Initialize runtime logging logging.getLogger().setLevel(log_level) @@ -85,7 +92,12 @@ def __init__( _logger.info('Connected to client.') # Start workers - self.spawn_workers(num_workers, worker_port, log_level) + self.spawn_workers( + num_workers, + worker_port, + log_level, + num_blas_threads, + ) def handle_disconnect(self, conn: Connection) -> None: """A client disconnect in attached mode is equal to a shutdown.""" diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index d98e7c672..1cb53c494 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -25,6 +25,7 @@ from bqskit.runtime import default_manager_port from bqskit.runtime import default_worker_port +from bqskit.runtime import set_blas_thread_counts from bqskit.runtime.address import RuntimeAddress from bqskit.runtime.direction import MessageDirection from bqskit.runtime.message import RuntimeMessage @@ -143,6 +144,9 @@ def __init__(self) -> None: self.conn_to_employee_dict: dict[Connection, RuntimeEmployee] = {} """Used to find the employee associated with a message.""" + # Servers do not need blas threads + set_blas_thread_counts(1) + # Safely and immediately exit on interrupt signals handle = functools.partial(sigint_handler, node=self) signal.signal(signal.SIGINT, handle) @@ -229,6 +233,7 @@ def spawn_workers( num_workers: int = -1, port: int = default_worker_port, logging_level: int = logging.WARNING, + num_blas_threads: int = 1, ) -> None: """ Spawn worker processes. @@ -241,6 +246,11 @@ def spawn_workers( port (int): The port this server will listen for workers on. Default can be found in the :obj:`~bqskit.runtime.default_worker_port` global variable. + + logging_level (int): The logging level for the workers. + + num_blas_threads (int): The number of threads to use in BLAS + libraries. (Default: 1). """ if num_workers == -1: oscount = os.cpu_count() @@ -256,7 +266,10 @@ def spawn_workers( procs[w_id] = Process( target=start_worker, args=(w_id, port), - kwargs={'logging_level': logging_level}, + kwargs={ + 'logging_level': logging_level, + 'num_blas_threads': num_blas_threads, + }, ) procs[w_id].daemon = True procs[w_id].start() diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 310072275..230a5949d 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -19,8 +19,6 @@ from typing import Sequence from typing import Tuple -from bqskit.compiler.status import CompilationStatus -from bqskit.compiler.task import CompilationTask from bqskit.runtime import default_server_port from bqskit.runtime.address import RuntimeAddress from bqskit.runtime.base import import_tests_package @@ -142,8 +140,7 @@ def handle_message( self.handle_disconnect(conn) elif msg == RuntimeMessage.SUBMIT: - ctask = cast(CompilationTask, payload) - self.handle_new_comp_task(conn, ctask) + self.handle_new_comp_task(conn, payload) elif msg == RuntimeMessage.REQUEST: request = cast(uuid.UUID, payload) @@ -256,9 +253,10 @@ def handle_disconnect(self, conn: Connection) -> None: def handle_new_comp_task( self, conn: Connection, - task: CompilationTask, + task: Any, # Explicitly not CompilationTask to avoid early import ) -> None: """Convert a :class:`CompilationTask` into an internal one.""" + from bqskit.compiler.task import CompilationTask mailbox_id = self._get_new_mailbox_id() self.tasks[task.task_id] = (mailbox_id, conn) self.mailbox_to_task_dict[mailbox_id] = task.task_id @@ -306,6 +304,7 @@ def handle_request(self, conn: Connection, request: uuid.UUID) -> None: def handle_status(self, conn: Connection, request: uuid.UUID) -> None: """Inform the client if the task is finished or not.""" + from bqskit.compiler.status import CompilationStatus if request not in self.clients[conn] or request not in self.tasks: # This task is unknown to the system m = (conn, RuntimeMessage.STATUS, CompilationStatus.UNKNOWN) diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 6dcbd4f5c..507cdf9a3 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -47,6 +47,8 @@ def __init__( ipports: list[tuple[str, int]] | None = None, worker_port: int = default_worker_port, only_connect: bool = False, + log_level: int = logging.WARNING, + num_blas_threads: int = 1, ) -> None: """ Create a manager instance in one of two ways: @@ -83,6 +85,16 @@ def __init__( only_connect (bool): If true, do not spawn workers, only connect to already spawned workers. + + log_level (int): The logging level for the manager and workers. + If `only_connect` is True, doesn't set worker's log level. + In that case, set the worker's log level when spawning them. + (Default: logging.WARNING). + + num_blas_threads (int): The number of threads to use in BLAS + libraries. If `only_connect` is True this is ignored. In + that case, set the thread count when spawning workers. + (Default: 1). """ super().__init__() @@ -105,7 +117,12 @@ def __init__( if only_connect: self.connect_to_workers(num_workers, worker_port) else: - self.spawn_workers(num_workers, worker_port) + self.spawn_workers( + num_workers, + worker_port, + log_level, + num_blas_threads, + ) # Case 2: Connect to detached managers at ipports else: diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index fbfd1c5c2..c034c2ce2 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -22,6 +22,7 @@ from typing import List from bqskit.runtime import default_worker_port +from bqskit.runtime import set_blas_thread_counts from bqskit.runtime.address import RuntimeAddress from bqskit.runtime.future import RuntimeFuture from bqskit.runtime.message import RuntimeMessage @@ -633,6 +634,7 @@ def start_worker( port: int, cpu: int | None = None, logging_level: int = logging.WARNING, + num_blas_threads: int = 1, ) -> None: """Start this process's worker.""" if w_id is not None: @@ -640,9 +642,12 @@ def start_worker( # If w_id is None, then we are being spawned separately. signal.signal(signal.SIGINT, signal.SIG_IGN) + # Set number of BLAS threads + set_blas_thread_counts(num_blas_threads) + # Enforce no default logging - # logging.lastResort.setLevel(100) - # logging.getLogger().handlers.clear() + logging.lastResort = logging.NullHandler() # type: ignore # TODO: should I report this as a type bug? # noqa: E501 + logging.getLogger().handlers.clear() # Pin worker to cpu if cpu is not None: @@ -739,6 +744,12 @@ def start_worker_rank() -> None: default=0, help='Enable logging of increasing verbosity, either -v, -vv, or -vvv.', ) + parser.add_argument( + '-t', '--num_blas_threads', + type=int, + default=1, + help='The number of threads to use in BLAS libraries.', + ) args = parser.parse_args() if args.cpus is not None: @@ -763,7 +774,7 @@ def start_worker_rank() -> None: # Spawn worker process procs = [] for cpu in cpus: - pargs = (None, args.port, cpu, logging_level) + pargs = (None, args.port, cpu, logging_level, args.num_blas_threads) procs.append(Process(target=start_worker, args=pargs)) procs[-1].start() From f71e63d8fc837748614de63e01e9b537c219bfc5 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 4 Apr 2024 21:46:24 -0400 Subject: [PATCH 27/44] Scheduling change --- bqskit/runtime/attached.py | 3 + bqskit/runtime/base.py | 115 +++++++++++++++++++++++++++++++++--- bqskit/runtime/manager.py | 3 + bqskit/runtime/task.py | 14 ++++- bqskit/runtime/worker.py | 66 +++++++++++---------- bqskit/utils/cachedclass.py | 3 +- 6 files changed, 162 insertions(+), 42 deletions(-) diff --git a/bqskit/runtime/attached.py b/bqskit/runtime/attached.py index f68be0ee3..1590614c6 100644 --- a/bqskit/runtime/attached.py +++ b/bqskit/runtime/attached.py @@ -99,6 +99,9 @@ def __init__( num_blas_threads, ) + self.schedule_tasks = self.schedule_for_workers # type: ignore + self.handle_waiting = self.handle_direct_worker_waiting # type: ignore + def handle_disconnect(self, conn: Connection) -> None: """A client disconnect in attached mode is equal to a shutdown.""" self.handle_shutdown() diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 1cb53c494..0d0f010eb 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -101,6 +101,49 @@ def get_num_of_tasks_sent_since( raise RuntimeError('Read receipt not found in submit cache.') +class MultiLevelQueue: + """A multi-level queue for delaying submitted tasks.""" + + def __init__(self) -> None: + """Initialize the multi-level queue.""" + self.queue: dict[int, list[RuntimeTask]] = {} + self.levels: int | None = None + + def delay(self, tasks: Sequence[RuntimeTask]) -> None: + """Update the multi-level queue with tasks.""" + for task in tasks: + task_depth = len(task.breadcrumbs) + + if task_depth not in self.queue: + if self.levels is None or task_depth > self.levels: + self.levels = task_depth + self.queue[task_depth] = [] + + self.queue[task_depth].append(task) + + def empty(self) -> bool: + """Return True if the multi-level queue is empty.""" + return self.levels is None + + def pop(self) -> RuntimeTask: + """Pop the next task from the multi-level queue.""" + if self.empty(): + raise RuntimeError('Cannot pop from an empty multi-level queue.') + + task = self.queue[self.levels].pop() # type: ignore # checked above + + while self.levels is not None: + if self.levels in self.queue: + if len(self.queue[self.levels]) != 0: + break + self.queue.pop(self.levels) + self.levels -= 1 + if self.levels < 0: + self.levels = None + + return task + + def sigint_handler(signum: int, _: FrameType | None, node: ServerBase) -> None: """Interrupt the node.""" if not node.running: @@ -144,6 +187,9 @@ def __init__(self) -> None: self.conn_to_employee_dict: dict[Connection, RuntimeEmployee] = {} """Used to find the employee associated with a message.""" + self.multi_level_queue = MultiLevelQueue() + """Used to delay tasks until they can be scheduled.""" + # Servers do not need blas threads set_blas_thread_counts(1) @@ -579,18 +625,47 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: reverse=True, ) for e, assignment in sorted_assignments: - num_tasks = len(assignment) + self.send_tasks_to_employee(e, assignment) - if num_tasks == 0: - continue + self.num_idle_workers = sum(e.num_idle_workers for e in self.employees) - self.outgoing.put((e.conn, RuntimeMessage.SUBMIT_BATCH, assignment)) + def schedule_for_workers(self, tasks: Sequence[RuntimeTask]) -> None: + """Schedule tasks for workers, return the amount assigned.""" + if len(tasks) == 0: + return - e.num_tasks += num_tasks - e.num_idle_workers -= min(num_tasks, e.num_idle_workers) - e.submit_cache.append((assignment[0].unique_id, num_tasks)) + num_assigned_tasks = 0 + for e in self.employees: + # TODO: randomize? prioritize workers idle but with less tasks? + if num_assigned_tasks >= len(tasks): + break + + if e.num_idle_workers > 0: + task = tasks[num_assigned_tasks] + self.send_tasks_to_employee(e, [task]) + num_assigned_tasks += 1 self.num_idle_workers = sum(e.num_idle_workers for e in self.employees) + self.multi_level_queue.delay(tasks[num_assigned_tasks:]) + + def send_tasks_to_employee( + self, + e: RuntimeEmployee, + tasks: Sequence[RuntimeTask], + ) -> None: + """Send the `task` to the employee responsible for `worker_id`.""" + num_tasks = len(tasks) + + if num_tasks == 0: + return + + e.num_tasks += num_tasks + e.num_idle_workers -= min(num_tasks, e.num_idle_workers) + e.submit_cache.append((tasks[0].unique_id, num_tasks)) + if num_tasks == 1: + self.outgoing.put((e.conn, RuntimeMessage.SUBMIT, tasks[0])) + else: + self.outgoing.put((e.conn, RuntimeMessage.SUBMIT_BATCH, tasks)) def send_result_down(self, result: RuntimeResult) -> None: """Send the `result` to the appropriate employee.""" @@ -641,14 +716,36 @@ def handle_waiting( idle count by the number of tasks sent since the read receipt. """ employee = self.conn_to_employee_dict[conn] - unaccounted_task = employee.get_num_of_tasks_sent_since(read_receipt) - adjusted_idle_count = max(new_idle_count - unaccounted_task, 0) + unaccounted_tasks = employee.get_num_of_tasks_sent_since(read_receipt) + adjusted_idle_count = max(new_idle_count - unaccounted_tasks, 0) old_count = employee.num_idle_workers employee.num_idle_workers = adjusted_idle_count self.num_idle_workers += (adjusted_idle_count - old_count) assert 0 <= self.num_idle_workers <= self.total_workers + def handle_direct_worker_waiting( + self, + conn: Connection, + new_idle_count: int, + read_receipt: RuntimeAddress | None, + ) -> None: + """ + Record that a worker is idle with nothing to do. + + Schedule tasks from the multi-level queue to the worker. + """ + ServerBase.handle_waiting(self, conn, new_idle_count, read_receipt) + + if self.multi_level_queue.empty(): + return + + employee = self.conn_to_employee_dict[conn] + if employee.num_idle_workers > 0: + task = self.multi_level_queue.pop() + self.send_tasks_to_employee(employee, [task]) + self.num_idle_workers -= 1 + def parse_ipports(ipports_str: Sequence[str]) -> list[tuple[str, int]]: """Parse command line ip and port inputs.""" diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 507cdf9a3..0283be6c4 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -124,6 +124,9 @@ def __init__( num_blas_threads, ) + self.schedule_tasks = self.schedule_for_workers # type: ignore + self.handle_waiting = self.handle_direct_worker_waiting # type: ignore # noqa: E501 + # Case 2: Connect to detached managers at ipports else: self.connect_to_managers(ipports) diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index 962ca1ff6..ca0feefa7 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -3,6 +3,7 @@ import inspect import logging +import pickle from typing import Any from typing import Coroutine @@ -41,7 +42,13 @@ def __init__( RuntimeTask.task_counter += 1 self.task_id = RuntimeTask.task_counter - self.serialized_fnargs = dill.dumps(fnargs) + try: + self.serialized_fnargs = pickle.dumps(fnargs) + self.serialized_with_pickle = True + except Exception: + self.serialized_fnargs = dill.dumps(fnargs) + self.serialized_with_pickle = False + self._fnargs: tuple[Any, Any, Any] | None = None self._name = fnargs[0].__name__ """Tuple of function pointer, arguments, and keyword arguments.""" @@ -84,7 +91,10 @@ def __init__( def fnargs(self) -> tuple[Any, Any, Any]: """Return the function pointer, arguments, and keyword arguments.""" if self._fnargs is None: - self._fnargs = dill.loads(self.serialized_fnargs) + if self.serialized_with_pickle: + self._fnargs = pickle.loads(self.serialized_fnargs) + else: + self._fnargs = dill.loads(self.serialized_fnargs) assert self._fnargs is not None # for type checker return self._fnargs diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index c034c2ce2..f2590bfaf 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -162,8 +162,8 @@ def __init__(self, id: int, conn: Connection) -> None: self._tasks: dict[RuntimeAddress, RuntimeTask] = {} """Tracks all started, unfinished tasks on this worker.""" - self._delayed_tasks: list[RuntimeTask] = [] - """Store all delayed tasks in LIFO order.""" + # self._delayed_tasks: list[RuntimeTask] = [] + # """Store all delayed tasks in LIFO order.""" self._ready_task_ids: Queue[RuntimeAddress] = Queue() """Tasks queued up for execution.""" @@ -203,7 +203,7 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: lvl = active_task.logging_level if lvl is None or lvl <= record.levelno: tid = active_task.comp_task_id - self._conn.send((RuntimeMessage.LOG, (tid, record))) + self._send(RuntimeMessage.LOG, (tid, record)) return record logging.setLogRecordFactory(record_factory) @@ -215,7 +215,7 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: _logger.debug('Started incoming thread.') # Communicate that this worker is ready - self._conn.send((RuntimeMessage.STARTED, self._id)) + self._send(RuntimeMessage.STARTED, self._id) def _loop(self) -> None: """Main worker event loop.""" @@ -228,7 +228,7 @@ def _loop(self) -> None: error_str = ''.join(traceback.format_exception(*exc_info)) _logger.error(error_str) try: - self._conn.send((RuntimeMessage.ERROR, error_str)) + self._send(RuntimeMessage.ERROR, error_str) except Exception: pass @@ -256,17 +256,17 @@ def recv_incoming(self) -> None: self._add_task(task) self.read_receipt_mutex.release() - elif msg == RuntimeMessage.SUBMIT_BATCH: - self.read_receipt_mutex.acquire() - tasks = cast(List[RuntimeTask], payload) - self.most_recent_read_submit = tasks[0].unique_id - self._add_task(tasks.pop()) # Submit one task - self._delayed_tasks.extend(tasks) # Delay rest - # Delayed tasks have no context and are stored (more-or-less) - # as a function pointer together with the arguments. - # When it gets started, it consumes much more memory, - # so we delay the task start until necessary (at no cost) - self.read_receipt_mutex.release() + # elif msg == RuntimeMessage.SUBMIT_BATCH: + # self.read_receipt_mutex.acquire() + # tasks = cast(List[RuntimeTask], payload) + # self.most_recent_read_submit = tasks[0].unique_id + # self._add_task(tasks.pop()) # Submit one task + # self._delayed_tasks.extend(tasks) # Delay rest + # # Delayed tasks have no context and are stored (more-or-less) + # # as a function pointer together with the arguments. + # # When it gets started, it consumes much more memory, + # # so we delay the task start until necessary (at no cost) + # self.read_receipt_mutex.release() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) @@ -340,17 +340,17 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: } # Remove all tasks that are children of `addr` from delayed tasks - self._delayed_tasks = [ - t for t in self._delayed_tasks - if not t.is_descendant_of(addr) - ] + # self._delayed_tasks = [ + # t for t in self._delayed_tasks + # if not t.is_descendant_of(addr) + # ] def _get_next_ready_task(self) -> RuntimeTask | None: """Return the next ready task if one exists, otherwise block.""" while True: - if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: - self._add_task(self._delayed_tasks.pop()) - continue + # if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: + # self._add_task(self._delayed_tasks.pop()) + # continue self.read_receipt_mutex.acquire() try: @@ -358,7 +358,7 @@ def _get_next_ready_task(self) -> RuntimeTask | None: except Empty: payload = (1, self.most_recent_read_submit) - self._conn.send((RuntimeMessage.WAITING, payload)) + self._send(RuntimeMessage.WAITING, payload) self.read_receipt_mutex.release() addr = self._ready_task_ids.get() @@ -412,7 +412,7 @@ def _try_step_next_ready_task(self) -> None: exc_info = sys.exc_info() error_str = ''.join(traceback.format_exception(*exc_info)) error_payload = (self._active_task.comp_task_id, error_str) - self._conn.send((RuntimeMessage.ERROR, error_payload)) + self._send(RuntimeMessage.ERROR, error_payload) finally: self._active_task = None @@ -456,11 +456,11 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: if task.return_address.worker_id == self._id: self._handle_result(packaged_result) - self._conn.send((RuntimeMessage.UPDATE, -1)) + self._send(RuntimeMessage.UPDATE, -1) # Let manager know this worker has one less task # without sending a result else: - self._conn.send((RuntimeMessage.RESULT, packaged_result)) + self._send(RuntimeMessage.RESULT, packaged_result) # Remove task self._tasks.pop(task.return_address) @@ -498,6 +498,12 @@ def _get_new_mailbox_id(self) -> int: self._mailbox_counter += 1 return new_id + def _send(self, msg: RuntimeMessage, payload: Any) -> None: + """Send a message to the boss.""" + _logger.debug(f'Sending message {msg.name}.') + _logger.log(1, f'Payload: {payload}') + self._conn.send((msg, payload)) + def submit( self, fn: Callable[..., Any], @@ -525,7 +531,7 @@ def submit( ) # Submit the task (on the next cycle) - self._conn.send((RuntimeMessage.SUBMIT, task)) + self._send(RuntimeMessage.SUBMIT, task) # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) @@ -572,7 +578,7 @@ def map( ] # Submit the tasks - self._conn.send((RuntimeMessage.SUBMIT_BATCH, tasks)) + self._send(RuntimeMessage.SUBMIT_BATCH, tasks) # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) @@ -588,7 +594,7 @@ def cancel(self, future: RuntimeFuture) -> None: for slot_id in range(num_slots) ] for addr in addrs: - self._conn.send((RuntimeMessage.CANCEL, addr)) + self._send(RuntimeMessage.CANCEL, addr) def get_cache(self) -> dict[str, Any]: """ diff --git a/bqskit/utils/cachedclass.py b/bqskit/utils/cachedclass.py index 751a361b6..7edac47b9 100644 --- a/bqskit/utils/cachedclass.py +++ b/bqskit/utils/cachedclass.py @@ -63,7 +63,8 @@ def __new__(cls: type[T], *args: Any, **kwargs: Any) -> T: _instances = cls._instances # type: ignore if _instances.get(key, None) is None: - _logger.debug( + _logger.log( + 1, ( 'Creating cached instance for class: %s,' ' with args %s, and kwargs %s' From 401a4ca9cd6490134cb38036f98f865fab52f561 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Thu, 11 Apr 2024 15:44:56 -0400 Subject: [PATCH 28/44] Revert "Scheduling change" This reverts commit f71e63d8fc837748614de63e01e9b537c219bfc5. --- bqskit/runtime/attached.py | 3 - bqskit/runtime/base.py | 115 +++--------------------------------- bqskit/runtime/manager.py | 3 - bqskit/runtime/task.py | 14 +---- bqskit/runtime/worker.py | 66 ++++++++++----------- bqskit/utils/cachedclass.py | 3 +- 6 files changed, 42 insertions(+), 162 deletions(-) diff --git a/bqskit/runtime/attached.py b/bqskit/runtime/attached.py index 1590614c6..f68be0ee3 100644 --- a/bqskit/runtime/attached.py +++ b/bqskit/runtime/attached.py @@ -99,9 +99,6 @@ def __init__( num_blas_threads, ) - self.schedule_tasks = self.schedule_for_workers # type: ignore - self.handle_waiting = self.handle_direct_worker_waiting # type: ignore - def handle_disconnect(self, conn: Connection) -> None: """A client disconnect in attached mode is equal to a shutdown.""" self.handle_shutdown() diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 0d0f010eb..1cb53c494 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -101,49 +101,6 @@ def get_num_of_tasks_sent_since( raise RuntimeError('Read receipt not found in submit cache.') -class MultiLevelQueue: - """A multi-level queue for delaying submitted tasks.""" - - def __init__(self) -> None: - """Initialize the multi-level queue.""" - self.queue: dict[int, list[RuntimeTask]] = {} - self.levels: int | None = None - - def delay(self, tasks: Sequence[RuntimeTask]) -> None: - """Update the multi-level queue with tasks.""" - for task in tasks: - task_depth = len(task.breadcrumbs) - - if task_depth not in self.queue: - if self.levels is None or task_depth > self.levels: - self.levels = task_depth - self.queue[task_depth] = [] - - self.queue[task_depth].append(task) - - def empty(self) -> bool: - """Return True if the multi-level queue is empty.""" - return self.levels is None - - def pop(self) -> RuntimeTask: - """Pop the next task from the multi-level queue.""" - if self.empty(): - raise RuntimeError('Cannot pop from an empty multi-level queue.') - - task = self.queue[self.levels].pop() # type: ignore # checked above - - while self.levels is not None: - if self.levels in self.queue: - if len(self.queue[self.levels]) != 0: - break - self.queue.pop(self.levels) - self.levels -= 1 - if self.levels < 0: - self.levels = None - - return task - - def sigint_handler(signum: int, _: FrameType | None, node: ServerBase) -> None: """Interrupt the node.""" if not node.running: @@ -187,9 +144,6 @@ def __init__(self) -> None: self.conn_to_employee_dict: dict[Connection, RuntimeEmployee] = {} """Used to find the employee associated with a message.""" - self.multi_level_queue = MultiLevelQueue() - """Used to delay tasks until they can be scheduled.""" - # Servers do not need blas threads set_blas_thread_counts(1) @@ -625,47 +579,18 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: reverse=True, ) for e, assignment in sorted_assignments: - self.send_tasks_to_employee(e, assignment) - - self.num_idle_workers = sum(e.num_idle_workers for e in self.employees) + num_tasks = len(assignment) - def schedule_for_workers(self, tasks: Sequence[RuntimeTask]) -> None: - """Schedule tasks for workers, return the amount assigned.""" - if len(tasks) == 0: - return + if num_tasks == 0: + continue - num_assigned_tasks = 0 - for e in self.employees: - # TODO: randomize? prioritize workers idle but with less tasks? - if num_assigned_tasks >= len(tasks): - break + self.outgoing.put((e.conn, RuntimeMessage.SUBMIT_BATCH, assignment)) - if e.num_idle_workers > 0: - task = tasks[num_assigned_tasks] - self.send_tasks_to_employee(e, [task]) - num_assigned_tasks += 1 + e.num_tasks += num_tasks + e.num_idle_workers -= min(num_tasks, e.num_idle_workers) + e.submit_cache.append((assignment[0].unique_id, num_tasks)) self.num_idle_workers = sum(e.num_idle_workers for e in self.employees) - self.multi_level_queue.delay(tasks[num_assigned_tasks:]) - - def send_tasks_to_employee( - self, - e: RuntimeEmployee, - tasks: Sequence[RuntimeTask], - ) -> None: - """Send the `task` to the employee responsible for `worker_id`.""" - num_tasks = len(tasks) - - if num_tasks == 0: - return - - e.num_tasks += num_tasks - e.num_idle_workers -= min(num_tasks, e.num_idle_workers) - e.submit_cache.append((tasks[0].unique_id, num_tasks)) - if num_tasks == 1: - self.outgoing.put((e.conn, RuntimeMessage.SUBMIT, tasks[0])) - else: - self.outgoing.put((e.conn, RuntimeMessage.SUBMIT_BATCH, tasks)) def send_result_down(self, result: RuntimeResult) -> None: """Send the `result` to the appropriate employee.""" @@ -716,36 +641,14 @@ def handle_waiting( idle count by the number of tasks sent since the read receipt. """ employee = self.conn_to_employee_dict[conn] - unaccounted_tasks = employee.get_num_of_tasks_sent_since(read_receipt) - adjusted_idle_count = max(new_idle_count - unaccounted_tasks, 0) + unaccounted_task = employee.get_num_of_tasks_sent_since(read_receipt) + adjusted_idle_count = max(new_idle_count - unaccounted_task, 0) old_count = employee.num_idle_workers employee.num_idle_workers = adjusted_idle_count self.num_idle_workers += (adjusted_idle_count - old_count) assert 0 <= self.num_idle_workers <= self.total_workers - def handle_direct_worker_waiting( - self, - conn: Connection, - new_idle_count: int, - read_receipt: RuntimeAddress | None, - ) -> None: - """ - Record that a worker is idle with nothing to do. - - Schedule tasks from the multi-level queue to the worker. - """ - ServerBase.handle_waiting(self, conn, new_idle_count, read_receipt) - - if self.multi_level_queue.empty(): - return - - employee = self.conn_to_employee_dict[conn] - if employee.num_idle_workers > 0: - task = self.multi_level_queue.pop() - self.send_tasks_to_employee(employee, [task]) - self.num_idle_workers -= 1 - def parse_ipports(ipports_str: Sequence[str]) -> list[tuple[str, int]]: """Parse command line ip and port inputs.""" diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 0283be6c4..507cdf9a3 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -124,9 +124,6 @@ def __init__( num_blas_threads, ) - self.schedule_tasks = self.schedule_for_workers # type: ignore - self.handle_waiting = self.handle_direct_worker_waiting # type: ignore # noqa: E501 - # Case 2: Connect to detached managers at ipports else: self.connect_to_managers(ipports) diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index ca0feefa7..962ca1ff6 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -3,7 +3,6 @@ import inspect import logging -import pickle from typing import Any from typing import Coroutine @@ -42,13 +41,7 @@ def __init__( RuntimeTask.task_counter += 1 self.task_id = RuntimeTask.task_counter - try: - self.serialized_fnargs = pickle.dumps(fnargs) - self.serialized_with_pickle = True - except Exception: - self.serialized_fnargs = dill.dumps(fnargs) - self.serialized_with_pickle = False - + self.serialized_fnargs = dill.dumps(fnargs) self._fnargs: tuple[Any, Any, Any] | None = None self._name = fnargs[0].__name__ """Tuple of function pointer, arguments, and keyword arguments.""" @@ -91,10 +84,7 @@ def __init__( def fnargs(self) -> tuple[Any, Any, Any]: """Return the function pointer, arguments, and keyword arguments.""" if self._fnargs is None: - if self.serialized_with_pickle: - self._fnargs = pickle.loads(self.serialized_fnargs) - else: - self._fnargs = dill.loads(self.serialized_fnargs) + self._fnargs = dill.loads(self.serialized_fnargs) assert self._fnargs is not None # for type checker return self._fnargs diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index f2590bfaf..c034c2ce2 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -162,8 +162,8 @@ def __init__(self, id: int, conn: Connection) -> None: self._tasks: dict[RuntimeAddress, RuntimeTask] = {} """Tracks all started, unfinished tasks on this worker.""" - # self._delayed_tasks: list[RuntimeTask] = [] - # """Store all delayed tasks in LIFO order.""" + self._delayed_tasks: list[RuntimeTask] = [] + """Store all delayed tasks in LIFO order.""" self._ready_task_ids: Queue[RuntimeAddress] = Queue() """Tasks queued up for execution.""" @@ -203,7 +203,7 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: lvl = active_task.logging_level if lvl is None or lvl <= record.levelno: tid = active_task.comp_task_id - self._send(RuntimeMessage.LOG, (tid, record)) + self._conn.send((RuntimeMessage.LOG, (tid, record))) return record logging.setLogRecordFactory(record_factory) @@ -215,7 +215,7 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: _logger.debug('Started incoming thread.') # Communicate that this worker is ready - self._send(RuntimeMessage.STARTED, self._id) + self._conn.send((RuntimeMessage.STARTED, self._id)) def _loop(self) -> None: """Main worker event loop.""" @@ -228,7 +228,7 @@ def _loop(self) -> None: error_str = ''.join(traceback.format_exception(*exc_info)) _logger.error(error_str) try: - self._send(RuntimeMessage.ERROR, error_str) + self._conn.send((RuntimeMessage.ERROR, error_str)) except Exception: pass @@ -256,17 +256,17 @@ def recv_incoming(self) -> None: self._add_task(task) self.read_receipt_mutex.release() - # elif msg == RuntimeMessage.SUBMIT_BATCH: - # self.read_receipt_mutex.acquire() - # tasks = cast(List[RuntimeTask], payload) - # self.most_recent_read_submit = tasks[0].unique_id - # self._add_task(tasks.pop()) # Submit one task - # self._delayed_tasks.extend(tasks) # Delay rest - # # Delayed tasks have no context and are stored (more-or-less) - # # as a function pointer together with the arguments. - # # When it gets started, it consumes much more memory, - # # so we delay the task start until necessary (at no cost) - # self.read_receipt_mutex.release() + elif msg == RuntimeMessage.SUBMIT_BATCH: + self.read_receipt_mutex.acquire() + tasks = cast(List[RuntimeTask], payload) + self.most_recent_read_submit = tasks[0].unique_id + self._add_task(tasks.pop()) # Submit one task + self._delayed_tasks.extend(tasks) # Delay rest + # Delayed tasks have no context and are stored (more-or-less) + # as a function pointer together with the arguments. + # When it gets started, it consumes much more memory, + # so we delay the task start until necessary (at no cost) + self.read_receipt_mutex.release() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) @@ -340,17 +340,17 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: } # Remove all tasks that are children of `addr` from delayed tasks - # self._delayed_tasks = [ - # t for t in self._delayed_tasks - # if not t.is_descendant_of(addr) - # ] + self._delayed_tasks = [ + t for t in self._delayed_tasks + if not t.is_descendant_of(addr) + ] def _get_next_ready_task(self) -> RuntimeTask | None: """Return the next ready task if one exists, otherwise block.""" while True: - # if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: - # self._add_task(self._delayed_tasks.pop()) - # continue + if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: + self._add_task(self._delayed_tasks.pop()) + continue self.read_receipt_mutex.acquire() try: @@ -358,7 +358,7 @@ def _get_next_ready_task(self) -> RuntimeTask | None: except Empty: payload = (1, self.most_recent_read_submit) - self._send(RuntimeMessage.WAITING, payload) + self._conn.send((RuntimeMessage.WAITING, payload)) self.read_receipt_mutex.release() addr = self._ready_task_ids.get() @@ -412,7 +412,7 @@ def _try_step_next_ready_task(self) -> None: exc_info = sys.exc_info() error_str = ''.join(traceback.format_exception(*exc_info)) error_payload = (self._active_task.comp_task_id, error_str) - self._send(RuntimeMessage.ERROR, error_payload) + self._conn.send((RuntimeMessage.ERROR, error_payload)) finally: self._active_task = None @@ -456,11 +456,11 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: if task.return_address.worker_id == self._id: self._handle_result(packaged_result) - self._send(RuntimeMessage.UPDATE, -1) + self._conn.send((RuntimeMessage.UPDATE, -1)) # Let manager know this worker has one less task # without sending a result else: - self._send(RuntimeMessage.RESULT, packaged_result) + self._conn.send((RuntimeMessage.RESULT, packaged_result)) # Remove task self._tasks.pop(task.return_address) @@ -498,12 +498,6 @@ def _get_new_mailbox_id(self) -> int: self._mailbox_counter += 1 return new_id - def _send(self, msg: RuntimeMessage, payload: Any) -> None: - """Send a message to the boss.""" - _logger.debug(f'Sending message {msg.name}.') - _logger.log(1, f'Payload: {payload}') - self._conn.send((msg, payload)) - def submit( self, fn: Callable[..., Any], @@ -531,7 +525,7 @@ def submit( ) # Submit the task (on the next cycle) - self._send(RuntimeMessage.SUBMIT, task) + self._conn.send((RuntimeMessage.SUBMIT, task)) # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) @@ -578,7 +572,7 @@ def map( ] # Submit the tasks - self._send(RuntimeMessage.SUBMIT_BATCH, tasks) + self._conn.send((RuntimeMessage.SUBMIT_BATCH, tasks)) # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) @@ -594,7 +588,7 @@ def cancel(self, future: RuntimeFuture) -> None: for slot_id in range(num_slots) ] for addr in addrs: - self._send(RuntimeMessage.CANCEL, addr) + self._conn.send((RuntimeMessage.CANCEL, addr)) def get_cache(self) -> dict[str, Any]: """ diff --git a/bqskit/utils/cachedclass.py b/bqskit/utils/cachedclass.py index 7edac47b9..751a361b6 100644 --- a/bqskit/utils/cachedclass.py +++ b/bqskit/utils/cachedclass.py @@ -63,8 +63,7 @@ def __new__(cls: type[T], *args: Any, **kwargs: Any) -> T: _instances = cls._instances # type: ignore if _instances.get(key, None) is None: - _logger.log( - 1, + _logger.debug( ( 'Creating cached instance for class: %s,' ' with args %s, and kwargs %s' From a5bb9effcfadf2115057dcc2e85a80196138a5db Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Fri, 12 Apr 2024 09:14:22 -0400 Subject: [PATCH 29/44] Fixed 3.8 and 3.9 worker log issue --- bqskit/runtime/worker.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index c034c2ce2..8aa0b4ce7 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -209,9 +209,9 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: logging.setLogRecordFactory(record_factory) # Start incoming thread - self.incomming_thread = Thread(target=self.recv_incoming) - self.incomming_thread.daemon = True - self.incomming_thread.start() + self.incoming_thread = Thread(target=self.recv_incoming) + self.incoming_thread.daemon = True + self.incoming_thread.start() _logger.debug('Started incoming thread.') # Communicate that this worker is ready @@ -241,6 +241,7 @@ def recv_incoming(self) -> None: except Exception: _logger.debug('Crashed due to lost connection') os.kill(os.getpid(), signal.SIGKILL) + exit() _logger.debug(f'Received message {msg.name}.') _logger.log(1, f'Payload: {payload}') @@ -675,6 +676,7 @@ def start_worker( # If id isn't provided, wait for assignment if w_id is None: msg, w_id = conn.recv() + assert isinstance(w_id, int) assert msg == RuntimeMessage.STARTED # Set up runtime logging @@ -684,9 +686,9 @@ def start_worker( _handler = logging.StreamHandler() _handler.setLevel(0) _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' - _fmt_message = ' [wid=%(wid)s]: %(message)s' + _fmt_message = f' [wid={w_id}]: %(message)s' _fmt = _fmt_header + _fmt_message - _formatter = logging.Formatter(_fmt, '%H:%M:%S', defaults={'wid': w_id}) + _formatter = logging.Formatter(_fmt, '%H:%M:%S') _handler.setFormatter(_formatter) _runtime_logger.addHandler(_handler) From 4189b237a34cbaf3e474753a20be167b39b9bc1c Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Fri, 12 Apr 2024 13:04:46 -0400 Subject: [PATCH 30/44] Reduce CachedClass logging --- bqskit/utils/cachedclass.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bqskit/utils/cachedclass.py b/bqskit/utils/cachedclass.py index 751a361b6..7edac47b9 100644 --- a/bqskit/utils/cachedclass.py +++ b/bqskit/utils/cachedclass.py @@ -63,7 +63,8 @@ def __new__(cls: type[T], *args: Any, **kwargs: Any) -> T: _instances = cls._instances # type: ignore if _instances.get(key, None) is None: - _logger.debug( + _logger.log( + 1, ( 'Creating cached instance for class: %s,' ' with args %s, and kwargs %s' From d6eb2be702bcc8f6a9d27a4c0e8ec5cdac79c982 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Fri, 12 Apr 2024 13:05:05 -0400 Subject: [PATCH 31/44] Add log context capabilities to runtime --- bqskit/runtime/__init__.py | 116 ++++++++++++++++++++++++++++++++++++- bqskit/runtime/task.py | 9 +-- bqskit/runtime/worker.py | 56 ++++++++++++++++++ 3 files changed, 175 insertions(+), 6 deletions(-) diff --git a/bqskit/runtime/__init__.py b/bqskit/runtime/__init__.py index 2c01c4ddd..371427ffd 100644 --- a/bqskit/runtime/__init__.py +++ b/bqskit/runtime/__init__.py @@ -98,6 +98,7 @@ from typing import Any from typing import Callable from typing import Protocol +from typing import Sequence from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -143,18 +144,129 @@ def submit( self, fn: Callable[..., Any], *args: Any, + task_name: str | None = None, + log_context: dict[str, str] = {}, **kwargs: Any, ) -> RuntimeFuture: - """Submit a `fn` to the runtime.""" + """ + Submit a function to the runtime for execution. + + This method schedules the function `fn` to be executed by the + runtime with the provided arguments `args` and keyword arguments + `kwargs`. The execution may happen asynchronously. + + Args: + fn (Callable[..., Any]): The function to be executed. + + *args (Any): Variable length argument list to be passed to + the function `fn`. + + task_name (str | None): An optional name for the task, which + can be used for logging or tracking purposes. Defaults to + None, which will use the function name as the task name. + + log_context (dict[str, str]): A dictionary containing logging + context information. All log messages produced by the fn + and any children tasks will contain this context if the + appropriate level (logging.DEBUG) is set on the logger. + Defaults to an empty dictionary for no added context. + + **kwargs (Any): Arbitrary keyword arguments to be passed to + the function `fn`. + + Returns: + RuntimeFuture: An object representing the future result of + the function execution. This can be used to retrieve the + result by `await`ing it. + + Example: + >>> from bqskit.runtime import get_runtime + >>> + >>> def add(x, y): + ... return x + y + >>> + >>> future = get_runtime().submit(add, 1, 2) + >>> result = await future + >>> print(result) + 3 + + See Also: + - :func:`map` for submitting multiple tasks in parallel. + - :func:`cancel` for cancelling tasks. + - :class:`~bqskit.runtime.future.RuntimeFuture` for more + information on how to interact with the future object. + """ ... def map( self, fn: Callable[..., Any], *args: Any, + task_name: Sequence[str | None] | str | None = None, + log_context: Sequence[dict[str, str]] | dict[str, str] = {}, **kwargs: Any, ) -> RuntimeFuture: - """Map `fn` over the input arguments distributed across the runtime.""" + """ + Map a function over a sequence of arguments and execute in parallel. + + This method schedules the function `fn` to be executed by the runtime + for each set of arguments provided in `args`. Each invocation of `fn` + will be executed potentially in parallel, depending on the runtime's + capabilities and current load. + + Args: + fn (Callable[..., Any]): The function to be executed. + + *args (Any): Variable length argument list to be passed to + the function `fn`. Each argument is expected to be a + sequence of arguments to be passed to a separate + invocation. The sequences should be of equal length. + + task_name (Sequence[str | None] | str | None): An optional + name for the task group, which can be used for logging + or tracking purposes. Defaults to None, which will use + the function name as the task name. If a string is + provided, it will be used as the prefix for all task + names. If a sequence of strings is provided, each task + will be named with the corresponding string in the + sequence. + + log_context (Sequence[dict[str, str]]) | dict[str, str]): A + dictionary containing logging context information. All + log messages produced by the `fn` and any children tasks + will contain this context if the appropriate level + (logging.DEBUG) is set on the logger. Defaults to an + empty dictionary for no added context. Can be a sequence + of contexts, one for each task, or a single context to be + used for all tasks. + + **kwargs (Any): Arbitrary keyword arguments to be passed to + each invocation of the function `fn`. + + Returns: + RuntimeFuture: An object representing the future result of + the function executions. This can be used to retrieve the + results by `await`ing it, which will return a list. + + Example: + >>> from bqskit.runtime import get_runtime + >>> + >>> def add(x, y): + ... return x + y + >>> + >>> args_list = [(1, 2, 3), (4, 5, 6)] + >>> future = get_runtime().map(add, *args_list) + >>> results = await future + >>> print(results) + [5, 7, 9] + + See Also: + - :func:`submit` for submitting a single task. + - :func:`cancel` for cancelling tasks. + - :func:`next` for retrieving results incrementally. + - :class:`~bqskit.runtime.future.RuntimeFuture` for more + information on how to interact with the future object. + """ ... def cancel(self, future: RuntimeFuture) -> None: diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index 962ca1ff6..d74d79ec4 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -36,6 +36,8 @@ def __init__( breadcrumbs: tuple[RuntimeAddress, ...], logging_level: int | None = None, max_logging_depth: int = -1, + task_name: str | None = None, + log_context: dict[str, str] = {}, ) -> None: """Create the task with a new id and return address.""" RuntimeTask.task_counter += 1 @@ -43,7 +45,7 @@ def __init__( self.serialized_fnargs = dill.dumps(fnargs) self._fnargs: tuple[Any, Any, Any] | None = None - self._name = fnargs[0].__name__ + self._name = fnargs[0].__name__ if task_name is None else task_name """Tuple of function pointer, arguments, and keyword arguments.""" self.return_address = return_address @@ -68,9 +70,6 @@ def __init__( self.coro: Coroutine[Any, Any, Any] | None = None """The coroutine containing this tasks code.""" - # self.send: Any = None - # """A register that both the coroutine and task have access to.""" - self.desired_box_id: int | None = None """When waiting on a mailbox, this stores that mailbox's id.""" @@ -80,6 +79,8 @@ def __init__( self.wake_on_next: bool = False """Set to true if this task should wake immediately on a result.""" + self.log_context: dict[str, str] = log_context + @property def fnargs(self) -> tuple[Any, Any, Any]: """Return the function pointer, arguments, and keyword arguments.""" diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 8aa0b4ce7..dee3bdca7 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -20,6 +20,7 @@ from typing import Callable from typing import cast from typing import List +from typing import Sequence from bqskit.runtime import default_worker_port from bqskit.runtime import set_blas_thread_counts @@ -202,6 +203,14 @@ def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: if active_task is not None: lvl = active_task.logging_level if lvl is None or lvl <= record.levelno: + if lvl <= logging.DEBUG: + record.msg += f' [wid={self._id}' + items = active_task.log_context.items() + if len(items) > 0: + record.msg += ', ' + con_str = ', '.join(f'{k}={v}' for k, v in items) + record.msg += con_str + record.msg += ']' tid = active_task.comp_task_id self._conn.send((RuntimeMessage.LOG, (tid, record))) return record @@ -503,10 +512,25 @@ def submit( self, fn: Callable[..., Any], *args: Any, + task_name: str | None = None, + log_context: dict[str, str] = {}, **kwargs: Any, ) -> RuntimeFuture: """Submit `fn` as a task to the runtime.""" assert self._active_task is not None + + if task_name is not None and not isinstance(task_name, str): + raise RuntimeError('task_name must be a string.') + + if not isinstance(log_context, dict): + raise RuntimeError('log_context must be a dictionary.') + + for k, v in log_context.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise RuntimeError( + 'log_context must be a map from strings to strings.', + ) + # Group fnargs together fnarg = (fn, args, kwargs) @@ -523,6 +547,8 @@ def submit( self._active_task.breadcrumbs + (self._active_task.return_address,), self._active_task.logging_level, self._active_task.max_logging_depth, + task_name, + self._active_task.log_context | log_context, ) # Submit the task (on the next cycle) @@ -535,10 +561,38 @@ def map( self, fn: Callable[..., Any], *args: Any, + task_name: Sequence[str | None] | str | None = None, + log_context: Sequence[dict[str, str]] | dict[str, str] = {}, **kwargs: Any, ) -> RuntimeFuture: """Map `fn` over the input arguments distributed across the runtime.""" assert self._active_task is not None + + if task_name is None or isinstance(task_name, str): + task_name = [task_name] * len(args[0]) + + if len(task_name) != len(args[0]): + raise RuntimeError( + 'task_name must be a string or a list of strings equal' + 'in length to the number of tasks.', + ) + + if isinstance(log_context, dict): + log_context = [log_context] * len(args[0]) + + if len(log_context) != len(args[0]): + raise RuntimeError( + 'log_context must be a dictionary or a list of dictionaries' + ' equal in length to the number of tasks.', + ) + + for context in log_context: + for k, v in context.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise RuntimeError( + 'log_context must be a map from strings to strings.', + ) + # Group fnargs together fnargs = [] if len(args) == 1: @@ -568,6 +622,8 @@ def map( breadcrumbs, self._active_task.logging_level, self._active_task.max_logging_depth, + task_name[i], + self._active_task.log_context | log_context[i], ) for i, fnarg in enumerate(fnargs) ] From 264337cd297540d12d8389867bec002d11cfe3c1 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Fri, 12 Apr 2024 14:54:07 -0400 Subject: [PATCH 32/44] Fix python 3.8 dict update --- bqskit/runtime/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index dee3bdca7..be7dd07b9 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -548,7 +548,7 @@ def submit( self._active_task.logging_level, self._active_task.max_logging_depth, task_name, - self._active_task.log_context | log_context, + {**self._active_task.log_context, **log_context}, ) # Submit the task (on the next cycle) @@ -623,7 +623,7 @@ def map( self._active_task.logging_level, self._active_task.max_logging_depth, task_name[i], - self._active_task.log_context | log_context[i], + {**self._active_task.log_context, **log_context[i]}, ) for i, fnarg in enumerate(fnargs) ] From 2065a5f6792ca5e0d7eedf5241b1812b1aa39db4 Mon Sep 17 00:00:00 2001 From: alonkukl Date: Fri, 12 Apr 2024 22:18:15 -0700 Subject: [PATCH 33/44] Fixing the way we kill the worker on windows --- bqskit/runtime/worker.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index dee3bdca7..168a4adcb 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -249,7 +249,10 @@ def recv_incoming(self) -> None: msg, payload = self._conn.recv() except Exception: _logger.debug('Crashed due to lost connection') - os.kill(os.getpid(), signal.SIGKILL) + if sys.platform == 'win32': + os.kill(os.getpid(), 9) + else: + os.kill(os.getpid(), signal.SIGKILL) exit() _logger.debug(f'Received message {msg.name}.') @@ -257,7 +260,10 @@ def recv_incoming(self) -> None: # Process message if msg == RuntimeMessage.SHUTDOWN: - os.kill(os.getpid(), signal.SIGKILL) + if sys.platform == 'win32': + os.kill(os.getpid(), 9) + else: + os.kill(os.getpid(), signal.SIGKILL) elif msg == RuntimeMessage.SUBMIT: self.read_receipt_mutex.acquire() @@ -698,6 +704,7 @@ def start_worker( # Ignore interrupt signals on workers, boss will handle it for us # If w_id is None, then we are being spawned separately. signal.signal(signal.SIGINT, signal.SIG_IGN) + # TODO: check what needs to be done on win # Set number of BLAS threads set_blas_thread_counts(num_blas_threads) From 851ddf35f6cac816ec0ca64f1c8333e800f70698 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Sat, 13 Apr 2024 12:46:56 -0400 Subject: [PATCH 34/44] Fixed passes imports --- bqskit/compiler/__init__.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/bqskit/compiler/__init__.py b/bqskit/compiler/__init__.py index f7048aa56..cd47fec74 100644 --- a/bqskit/compiler/__init__.py +++ b/bqskit/compiler/__init__.py @@ -36,9 +36,9 @@ WorkflowLike """ from __future__ import annotations +from typing import Any from bqskit.compiler.basepass import BasePass -from bqskit.compiler.compile import compile from bqskit.compiler.compiler import Compiler from bqskit.compiler.gateset import GateSet from bqskit.compiler.gateset import GateSetLike @@ -49,6 +49,16 @@ from bqskit.compiler.workflow import Workflow from bqskit.compiler.workflow import WorkflowLike +def __getattr__(name: str) -> Any: + # Lazy imports + if name == 'compile': + from bqskit.compiler.compile import compile + return compile + + # TODO: Move compile to a different subpackage and deprecate import + + raise AttributeError(f'module {__name__} has no attribute {name}') + __all__ = [ 'BasePass', 'compile', From 2a3b94140e9d9aa9d481185aee6dae0cd74015e4 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Sat, 13 Apr 2024 12:48:07 -0400 Subject: [PATCH 35/44] Implements #165 --- bqskit/passes/synthesis/pas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bqskit/passes/synthesis/pas.py b/bqskit/passes/synthesis/pas.py index 2cad278d2..9a109f15a 100644 --- a/bqskit/passes/synthesis/pas.py +++ b/bqskit/passes/synthesis/pas.py @@ -113,6 +113,7 @@ async def synthesize( self.inner_synthesis.synthesize, targets, [data] * len(targets), + log_context=[{'perm': str(perm)} for perm in permsbyperms], ) # Return best circuit From 4682e098c6560fd48f2cc1ecc49f27d0b5dedbe7 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Sat, 13 Apr 2024 12:49:04 -0400 Subject: [PATCH 36/44] pre-commit --- bqskit/compiler/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bqskit/compiler/__init__.py b/bqskit/compiler/__init__.py index cd47fec74..c0ef1b3bf 100644 --- a/bqskit/compiler/__init__.py +++ b/bqskit/compiler/__init__.py @@ -36,6 +36,7 @@ WorkflowLike """ from __future__ import annotations + from typing import Any from bqskit.compiler.basepass import BasePass @@ -49,6 +50,7 @@ from bqskit.compiler.workflow import Workflow from bqskit.compiler.workflow import WorkflowLike + def __getattr__(name: str) -> Any: # Lazy imports if name == 'compile': @@ -59,6 +61,7 @@ def __getattr__(name: str) -> Any: raise AttributeError(f'module {__name__} has no attribute {name}') + __all__ = [ 'BasePass', 'compile', From 75d42f9315258542a29c9d6d2659b04129585330 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Mon, 15 Apr 2024 16:34:28 -0400 Subject: [PATCH 37/44] Small doc fix --- bqskit/compiler/compiler.py | 2 +- bqskit/ext/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 22fab9feb..8f58f3c35 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -293,7 +293,7 @@ def submit( tasks equal opportunity to log. Returns: - (uuid.UUID): The ID of the generated task in the system. This + uuid.UUID: The ID of the generated task in the system. This ID can be used to check the status of, cancel, and request the result of the task. """ diff --git a/bqskit/ext/__init__.py b/bqskit/ext/__init__.py index a6a4ee456..4cd607e74 100644 --- a/bqskit/ext/__init__.py +++ b/bqskit/ext/__init__.py @@ -53,6 +53,7 @@ """ from __future__ import annotations +# TODO: Deprecate imports from __init__, use lazy import to deprecate from bqskit.ext.cirq.models import Sycamore23Model from bqskit.ext.cirq.models import SycamoreModel from bqskit.ext.cirq.translate import bqskit_to_cirq From 95b06d299541f1b1eeaf083a282543a2ac91c650 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 16 Apr 2024 08:34:14 -0400 Subject: [PATCH 38/44] Add toggle for workers to print client log msgs --- bqskit/compiler/__init__.py | 1 + bqskit/ext/__init__.py | 2 +- bqskit/runtime/worker.py | 25 ++++++++++++++++++++++--- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/bqskit/compiler/__init__.py b/bqskit/compiler/__init__.py index c0ef1b3bf..fde8f87ee 100644 --- a/bqskit/compiler/__init__.py +++ b/bqskit/compiler/__init__.py @@ -54,6 +54,7 @@ def __getattr__(name: str) -> Any: # Lazy imports if name == 'compile': + # TODO: fix this (high-priority), overlap between module and function from bqskit.compiler.compile import compile return compile diff --git a/bqskit/ext/__init__.py b/bqskit/ext/__init__.py index 4cd607e74..131eec208 100644 --- a/bqskit/ext/__init__.py +++ b/bqskit/ext/__init__.py @@ -53,7 +53,6 @@ """ from __future__ import annotations -# TODO: Deprecate imports from __init__, use lazy import to deprecate from bqskit.ext.cirq.models import Sycamore23Model from bqskit.ext.cirq.models import SycamoreModel from bqskit.ext.cirq.translate import bqskit_to_cirq @@ -74,6 +73,7 @@ from bqskit.ext.supermarq import supermarq_liveness from bqskit.ext.supermarq import supermarq_parallelism from bqskit.ext.supermarq import supermarq_program_communication +# TODO: Deprecate imports from __init__, use lazy import to deprecate __all__ = [ diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 0738b02dd..ca7c87a4d 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -698,6 +698,7 @@ def start_worker( cpu: int | None = None, logging_level: int = logging.WARNING, num_blas_threads: int = 1, + log_client: bool = False, ) -> None: """Start this process's worker.""" if w_id is not None: @@ -710,7 +711,7 @@ def start_worker( set_blas_thread_counts(num_blas_threads) # Enforce no default logging - logging.lastResort = logging.NullHandler() # type: ignore # TODO: should I report this as a type bug? # noqa: E501 + logging.lastResort = logging.NullHandler() # type: ignore # typeshed#11770 logging.getLogger().handlers.clear() # Pin worker to cpu @@ -743,7 +744,10 @@ def start_worker( assert msg == RuntimeMessage.STARTED # Set up runtime logging - _runtime_logger = logging.getLogger('bqskit.runtime') + if not log_client: + _runtime_logger = logging.getLogger('bqskit.runtime') + else: + _runtime_logger = logging.getLogger() _runtime_logger.propagate = False _runtime_logger.setLevel(logging_level) _handler = logging.StreamHandler() @@ -809,6 +813,11 @@ def start_worker_rank() -> None: default=0, help='Enable logging of increasing verbosity, either -v, -vv, or -vvv.', ) + parser.add_argument( + '-l', '--log-client', + action='store_true', + help='Log messages from the client process.', + ) parser.add_argument( '-t', '--num_blas_threads', type=int, @@ -836,10 +845,20 @@ def start_worker_rank() -> None: logging_level = [30, 20, 10, 1][min(args.verbose, 3)] + if args.log_client and logging_level > 10: + raise RuntimeError('Cannot log client messages without at least -vv.') + # Spawn worker process procs = [] for cpu in cpus: - pargs = (None, args.port, cpu, logging_level, args.num_blas_threads) + pargs = ( + None, + args.port, + cpu, + logging_level, + args.num_blas_threads, + args.log_client, + ) procs.append(Process(target=start_worker, args=pargs)) procs[-1].start() From e5f67f6359c62d8aebf8d0755482b90a815994a2 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 16 Apr 2024 10:39:54 -0400 Subject: [PATCH 39/44] Added direction to runtime sent log --- bqskit/runtime/base.py | 8 +++++++- bqskit/runtime/detached.py | 7 +++++++ bqskit/runtime/manager.py | 7 +++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 1cb53c494..9f9dd2a27 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -387,7 +387,9 @@ def send_outgoing(self) -> None: continue outgoing[0].send((outgoing[1], outgoing[2])) - _logger.debug(f'Sent message {outgoing[1].name}.') + if _logger.isEnabledFor(logging.DEBUG): + to = self.get_to_string(outgoing[0]) + _logger.debug(f'Sent message {outgoing[1].name} to {to}.') if outgoing[1] == RuntimeMessage.SUBMIT_BATCH: _logger.log(1, f'[{outgoing[2][0]}] * {len(outgoing[2])}\n') @@ -471,6 +473,10 @@ def handle_system_error(self, error_str: str) -> None: RuntimeTask's coroutine code. """ + @abc.abstractmethod + def get_to_string(self, conn: Connection) -> str: + """Return a string representation of the connection.""" + def handle_shutdown(self) -> None: """Shutdown the node and release resources.""" # Stop running diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 230a5949d..6c0bcd51b 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -217,6 +217,13 @@ def handle_system_error(self, error_str: str) -> None: # Sleep to ensure clients receive error message before shutdown time.sleep(1) + def get_to_string(self, conn: Connection) -> str: + """Return a string representation of the connection.""" + if conn in self.clients: + return 'Client' + + return 'Employee' + def handle_shutdown(self) -> None: """Shutdown the runtime.""" super().handle_shutdown() diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 507cdf9a3..00bb4b2b7 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -228,6 +228,13 @@ def handle_system_error(self, error_str: str) -> None: # If server has crashed then just exit pass + def get_to_string(self, conn: Connection) -> str: + """Return a string representation of the connection.""" + if conn == self.upstream: + return 'Boss' + + return 'Employee' + def handle_shutdown(self) -> None: """Shutdown the manager and clean up spawned processes.""" super().handle_shutdown() From 104d35306de2afa8db25c0e51de5e73689a70d3d Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Wed, 17 Apr 2024 07:54:48 -0400 Subject: [PATCH 40/44] Address comments and some clean up --- bqskit/runtime/base.py | 42 ++++++++++++++++++++++++++++++-------- bqskit/runtime/detached.py | 4 ++-- bqskit/runtime/manager.py | 6 ++---- bqskit/runtime/worker.py | 29 ++++++++++++++++++++------ 4 files changed, 60 insertions(+), 21 deletions(-) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 9f9dd2a27..d10996e49 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -42,16 +42,28 @@ class RuntimeEmployee: def __init__( self, + id: int, conn: Connection, total_workers: int, process: Process | None = None, + is_manager: bool = False, ) -> None: """Construct an employee with all resources idle.""" + + self.id = id + """ + The ID of the employee. + + If this is a worker, then their unique worker id. If this is a manager, + then their local id. + """ + self.conn: Connection = conn self.total_workers = total_workers self.process = process self.num_tasks = 0 self.num_idle_workers = total_workers + self.is_manager = is_manager self.submit_cache: list[tuple[RuntimeAddress, int]] = [] """ @@ -81,6 +93,11 @@ def shutdown(self) -> None: self.initiate_shutdown() self.complete_shutdown() + @property + def recipient_string(self) -> str: + """Return a string representation of the employee.""" + return f'{"Manager" if self.is_manager else "Worker"} {self.id}' + @property def has_idle_resources(self) -> bool: return self.num_idle_workers > 0 @@ -179,7 +196,14 @@ def connect_to_managers(self, ipports: Sequence[tuple[str, int]]) -> None: for i, conn in enumerate(manager_conns): msg, num_workers = conn.recv() assert msg == RuntimeMessage.STARTED - self.employees.append(RuntimeEmployee(conn, num_workers)) + self.employees.append( + RuntimeEmployee( + i, + conn, + num_workers, + is_manager=True, + ), + ) self.conn_to_employee_dict[conn] = self.employees[-1] self.sel.register( conn, @@ -286,7 +310,7 @@ def spawn_workers( for i, conn in enumerate(conns): msg, w_id = conn.recv() assert msg == RuntimeMessage.STARTED - employee = RuntimeEmployee(conn, 1, procs[w_id]) + employee = RuntimeEmployee(w_id, conn, 1, procs[w_id]) temp_reorder[w_id - self.lower_id_bound] = employee self.conn_to_employee_dict[conn] = employee @@ -295,13 +319,13 @@ def spawn_workers( self.employees.append(temp_reorder[i]) # Register employee communication - for i, employee in enumerate(self.employees): + for employee in self.employees: self.sel.register( employee.conn, selectors.EVENT_READ, MessageDirection.BELOW, ) - _logger.debug(f'Registered worker {i}.') + _logger.debug(f'Registered worker {employee.id}.') self.step_size = 1 self.total_workers = num_workers @@ -343,20 +367,20 @@ def connect_to_workers( for i, conn in enumerate(conns): w_id = self.lower_id_bound + i self.outgoing.put((conn, RuntimeMessage.STARTED, w_id)) - employee = RuntimeEmployee(conn, 1) + employee = RuntimeEmployee(w_id, conn, 1) self.employees.append(employee) self.conn_to_employee_dict[conn] = employee # Register employee communication - for i, employee in enumerate(self.employees): - w_id = self.lower_id_bound + i + for employee in self.employees: + w_id = employee.id assert employee.conn.recv() == (RuntimeMessage.STARTED, w_id) self.sel.register( employee.conn, selectors.EVENT_READ, MessageDirection.BELOW, ) - _logger.info(f'Registered worker {i}.') + _logger.info(f'Registered worker {w_id}.') self.step_size = 1 self.total_workers = num_workers @@ -583,7 +607,7 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: assignments, key=lambda x: x[0].num_idle_workers, reverse=True, - ) + ) # Employees with the most idle workers get assignments first for e, assignment in sorted_assignments: num_tasks = len(assignment) diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 6c0bcd51b..90ad2b964 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -220,9 +220,9 @@ def handle_system_error(self, error_str: str) -> None: def get_to_string(self, conn: Connection) -> str: """Return a string representation of the connection.""" if conn in self.clients: - return 'Client' + return 'CLIENT' - return 'Employee' + return self.conn_to_employee_dict[conn].recipient_string def handle_shutdown(self) -> None: """Shutdown the runtime.""" diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 00bb4b2b7..0edd64f8f 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -183,12 +183,10 @@ def handle_message( if msg == RuntimeMessage.SUBMIT: rtask = cast(RuntimeTask, payload) self.send_up_or_schedule_tasks([rtask]) - # self.update_upstream_idle_workers() elif msg == RuntimeMessage.SUBMIT_BATCH: rtasks = cast(List[RuntimeTask], payload) self.send_up_or_schedule_tasks(rtasks) - # self.update_upstream_idle_workers() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) @@ -231,9 +229,9 @@ def handle_system_error(self, error_str: str) -> None: def get_to_string(self, conn: Connection) -> str: """Return a string representation of the connection.""" if conn == self.upstream: - return 'Boss' + return 'BOSS' - return 'Employee' + return self.conn_to_employee_dict[conn].recipient_string def handle_shutdown(self) -> None: """Shutdown the manager and clean up spawned processes.""" diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index ca7c87a4d..82cf1ab73 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -164,7 +164,14 @@ def __init__(self, id: int, conn: Connection) -> None: """Tracks all started, unfinished tasks on this worker.""" self._delayed_tasks: list[RuntimeTask] = [] - """Store all delayed tasks in LIFO order.""" + """ + Store all delayed tasks in LIFO order. + + Delayed tasks have no context and are stored (more-or-less) as a + function pointer together with the arguments. When it gets started, it + consumes much more memory, so we delay the task start until necessary + (at no cost) + """ self._ready_task_ids: Queue[RuntimeAddress] = Queue() """Tasks queued up for execution.""" @@ -191,7 +198,13 @@ def __init__(self, id: int, conn: Connection) -> None: """Tracks the most recently processed submit message from above.""" self.read_receipt_mutex = Lock() - """A lock to ensure waiting messages's read receipt is correct.""" + """ + A lock to ensure waiting messages's read receipt is correct. + + This lock enforces atomic update of `most_recent_read_submit` and + task addition/enqueueing. This is necessary to ensure that the + idle status is always correct. + """ # Send out every client emitted log message upstream old_factory = logging.getLogRecordFactory() @@ -278,10 +291,6 @@ def recv_incoming(self) -> None: self.most_recent_read_submit = tasks[0].unique_id self._add_task(tasks.pop()) # Submit one task self._delayed_tasks.extend(tasks) # Delay rest - # Delayed tasks have no context and are stored (more-or-less) - # as a function pointer together with the arguments. - # When it gets started, it consumes much more memory, - # so we delay the task start until necessary (at no cost) self.read_receipt_mutex.release() elif msg == RuntimeMessage.RESULT: @@ -368,6 +377,12 @@ def _get_next_ready_task(self) -> RuntimeTask | None: self._add_task(self._delayed_tasks.pop()) continue + # Critical section + # Attempt to get a ready task. If none are available, message + # the manager/server with a waiting message letting them + # know the worker is idle. This needs to be atomic to prevent + # the self.more_recent_read_submit from being updated after + # catching the Empty exception, but before forming the payload. self.read_receipt_mutex.acquire() try: addr = self._ready_task_ids.get_nowait() @@ -376,6 +391,8 @@ def _get_next_ready_task(self) -> RuntimeTask | None: payload = (1, self.most_recent_read_submit) self._conn.send((RuntimeMessage.WAITING, payload)) self.read_receipt_mutex.release() + # Block for new message. Can release lock here since the + # the `self.most_recent_read_submit` has been used. addr = self._ready_task_ids.get() else: From ce64cabd7081888332195b9198c2ccd3306d8f75 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Fri, 19 Apr 2024 09:05:44 -0400 Subject: [PATCH 41/44] Some TODO cleanup --- bqskit/runtime/__init__.py | 6 ++---- bqskit/runtime/task.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/bqskit/runtime/__init__.py b/bqskit/runtime/__init__.py index 371427ffd..75e3691c7 100644 --- a/bqskit/runtime/__init__.py +++ b/bqskit/runtime/__init__.py @@ -70,10 +70,8 @@ :class:`RuntimeHandle`, which you can use to submit, map, wait on, and cancel tasks in the execution environment. -For more information on how to design a custom pass, see this (TODO, sorry, -you can look at the source code of existing -`passes `_ -for a good example for the time being). +For more information on how to design a custom pass, see the following +guide: :doc:`guides/custompass.md`. .. autosummary:: :toctree: autogen diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index d74d79ec4..3d81f1bbd 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -30,7 +30,7 @@ class RuntimeTask: def __init__( self, - fnargs: tuple[Any, Any, Any], # TODO: Look into retyping this + fnargs: tuple[Any, Any, Any], return_address: RuntimeAddress, comp_task_id: int, breadcrumbs: tuple[RuntimeAddress, ...], From e72a6e46de88d8e0441fb514f9e4da3970adaead Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 9 Jul 2024 10:37:06 -0400 Subject: [PATCH 42/44] Implemented Communicate Feature --- bqskit/runtime/detached.py | 3 +++ bqskit/runtime/manager.py | 3 +++ bqskit/runtime/message.py | 1 + bqskit/runtime/task.py | 3 +++ bqskit/runtime/worker.py | 30 ++++++++++++++++++++++++++++++ 5 files changed, 40 insertions(+) diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 90ad2b964..8740c7170 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -193,6 +193,9 @@ def handle_message( task_diff = cast(int, payload) self.conn_to_employee_dict[conn].num_tasks += task_diff + elif msg == RuntimeMessage.COMMUNICATE: + self.broadcast(msg, payload) + else: raise RuntimeError(f'Unexpected message type: {msg.name}') diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 0edd64f8f..14827af9e 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -175,6 +175,9 @@ def handle_message( paths = cast(List[str], payload) self.handle_importpath(paths) + elif msg == RuntimeMessage.COMMUNICATE: + self.broadcast(RuntimeMessage.COMMUNICATE, payload) + else: raise RuntimeError(f'Unexpected message type: {msg.name}') diff --git a/bqskit/runtime/message.py b/bqskit/runtime/message.py index c975099c8..d2585aef2 100644 --- a/bqskit/runtime/message.py +++ b/bqskit/runtime/message.py @@ -22,3 +22,4 @@ class RuntimeMessage(IntEnum): UPDATE = 13 IMPORTPATH = 14 READY = 15 + COMMUNICATE = 16 diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index 3d81f1bbd..329e48bd9 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -80,6 +80,9 @@ def __init__( """Set to true if this task should wake immediately on a result.""" self.log_context: dict[str, str] = log_context + """Additional context to be logged with this task.""" + + self.msg_buffer: list[Any] = [] @property def fnargs(self) -> tuple[Any, Any, Any]: diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 82cf1ab73..41779c99a 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -302,6 +302,10 @@ def recv_incoming(self) -> None: self._handle_cancel(addr) # TODO: preempt? + elif msg == RuntimeMessage.COMMUNICATE: + addrs, msg = cast(tuple[list[RuntimeAddress], Any], payload) + self._handle_communicate(addrs, msg) + elif msg == RuntimeMessage.IMPORTPATH: paths = cast(List[str], payload) for path in paths: @@ -370,6 +374,13 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: if not t.is_descendant_of(addr) ] + def _handle_communicate(self, addrs: list[RuntimeAddress], msg: Any) -> None: + for task_addr in addrs: + if task_addr not in self._tasks: + continue + + self._tasks[task_addr].msg_buffer.append(msg) + def _get_next_ready_task(self) -> RuntimeTask | None: """Return the next ready task if one exists, otherwise block.""" while True: @@ -657,6 +668,25 @@ def map( # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) + def communicate(self, future: RuntimeFuture, msg: Any) -> None: + """Send a message to the task associated with `future`.""" + assert self._active_task is not None + assert future.mailbox_id in self._mailboxes + + num_slots = self._mailboxes[future.mailbox_id].expected_num_results + addrs = [ + RuntimeAddress(self._id, future.mailbox_id, slot_id) + for slot_id in range(num_slots) + ] + self._conn.send((RuntimeMessage.COMMUNICATE, (addrs, msg))) + + def get_messages(self) -> list[Any]: + """Return all messages received by the worker for this task.""" + assert self._active_task is not None + x = self._active_task.msg_buffer + self._active_task.msg_buffer = [] + return x + def cancel(self, future: RuntimeFuture) -> None: """Cancel all tasks associated with `future`.""" assert self._active_task is not None From 7246e371b5bcf169ccf4d814836825f21ed0e1d2 Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Tue, 27 Aug 2024 11:07:18 -0400 Subject: [PATCH 43/44] Update --- bqskit/runtime/worker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 254d686f5..1684f7dbc 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -378,7 +378,11 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: if not t.is_descendant_of(addr) ] - def _handle_communicate(self, addrs: list[RuntimeAddress], msg: Any) -> None: + def _handle_communicate( + self, + addrs: list[RuntimeAddress], + msg: Any, + ) -> None: for task_addr in addrs: if task_addr not in self._tasks: continue @@ -763,7 +767,7 @@ def start_worker( set_blas_thread_counts(num_blas_threads) # Enforce no default logging - logging.lastResort = logging.NullHandler() # type: ignore # typeshed#11770 + logging.lastResort = logging.NullHandler() logging.getLogger().handlers.clear() # Pin worker to cpu From 47490492add1e5f40ad764bdb710cbb8b8d3d29c Mon Sep 17 00:00:00 2001 From: Ed Younis Date: Wed, 28 Aug 2024 11:28:18 -0400 Subject: [PATCH 44/44] Fixed failing test from merge --- bqskit/passes/control/paralleldo.py | 4 ++-- bqskit/runtime/task.py | 6 +++++- tests/passes/control/test_paralleldo.py | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/bqskit/passes/control/paralleldo.py b/bqskit/passes/control/paralleldo.py index 42b9bbeee..a95168058 100644 --- a/bqskit/passes/control/paralleldo.py +++ b/bqskit/passes/control/paralleldo.py @@ -34,7 +34,7 @@ def __init__( self, pass_sequences: Iterable[WorkflowLike], less_than: Callable[[Circuit, Circuit], bool], - pick_fisrt: bool = False, + pick_first: bool = False, ) -> None: """ Construct a ParallelDo. @@ -63,7 +63,7 @@ def __init__( self.workflows = [Workflow(p) for p in pass_sequences] self.less_than = less_than - self.pick_first = pick_fisrt + self.pick_first = pick_first if len(self.workflows) == 0: raise ValueError('Must specify at least one workflow.') diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index 6a07d174a..d8cef7855 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -135,7 +135,11 @@ def cancel(self) -> None: # it is likely a blanket try/accept catching the # error used to stop the coroutine, preventing # it from stopping correctly. - self.coro.close() + try: + self.coro.close() + except ValueError: + # Coroutine is running and cannot be closed. + pass else: raise RuntimeError('Task was cancelled with None coroutine.') diff --git a/tests/passes/control/test_paralleldo.py b/tests/passes/control/test_paralleldo.py index 2d69b5e70..91a5c4f72 100644 --- a/tests/passes/control/test_paralleldo.py +++ b/tests/passes/control/test_paralleldo.py @@ -38,11 +38,11 @@ async def run(self, circuit: Circuit, data: PassData) -> None: data['key'] = '1' -class Sleep3Pass(BasePass): +class Sleep9Pass(BasePass): async def run(self, circuit: Circuit, data: PassData) -> None: circuit.append_gate(ZGate(), 0) - time.sleep(0.3) - data['key'] = '3' + time.sleep(0.9) + data['key'] = '9' def pick_z(c1: Circuit, c2: Circuit) -> bool: @@ -66,7 +66,7 @@ def test_parallel_do_no_passes() -> None: def test_parallel_do_pick_first(compiler: Compiler) -> None: - passes: list[list[BasePass]] = [[Sleep3Pass()], [Sleep1Pass()]] + passes: list[list[BasePass]] = [[Sleep9Pass()], [Sleep1Pass()]] pd_pass = ParallelDo(passes, pick_z, True) _, data = compiler.compile(Circuit(1), pd_pass, True) assert data['key'] == '1'