From 3b4ab8cfda6e915754171bab3360e3c661d49ebb Mon Sep 17 00:00:00 2001 From: Callum Forrester <29771545+callumforrester@users.noreply.github.com> Date: Tue, 23 May 2023 15:17:27 +0100 Subject: [PATCH] Add transaction mode to worker (#202) --- .../explanations/decisions/0002-no-queues.rst | 27 ++ src/blueapi/cli/amq.py | 2 +- src/blueapi/cli/updates.py | 12 +- src/blueapi/service/main.py | 5 +- src/blueapi/service/model.py | 2 +- src/blueapi/worker/__init__.py | 5 +- src/blueapi/worker/event.py | 4 +- src/blueapi/worker/reworker.py | 60 +++- src/blueapi/worker/task.py | 9 - src/blueapi/worker/worker.py | 56 +++- tests/service/test_rest_api.py | 6 +- tests/worker/test_reworker.py | 296 +++++++++++++----- 12 files changed, 351 insertions(+), 133 deletions(-) create mode 100644 docs/developer/explanations/decisions/0002-no-queues.rst diff --git a/docs/developer/explanations/decisions/0002-no-queues.rst b/docs/developer/explanations/decisions/0002-no-queues.rst new file mode 100644 index 000000000..ebc5ef9c8 --- /dev/null +++ b/docs/developer/explanations/decisions/0002-no-queues.rst @@ -0,0 +1,27 @@ +2. No Queues +============ + +Date: 2023-05-22 + +Status +------ + +Accepted + +Context +------- + +In asking whether this service should hold and execute a queue of tasks. + +Decision +-------- + +We will not hold any queues. The worker can execute one task at a time and will return +an error if asked to execute one task while another is running. Queueing should be the +responsibility of a different service. + +Consequences +------------ + +The API must be kept queue-free, although transactions are permitted where the server +caches requests. diff --git a/src/blueapi/cli/amq.py b/src/blueapi/cli/amq.py index 8006bee4e..9cce7b2b5 100644 --- a/src/blueapi/cli/amq.py +++ b/src/blueapi/cli/amq.py @@ -61,7 +61,7 @@ def on_progress_event_wrapper( task_response = self.app.send_and_receive( "worker.run", {"name": name, "params": params}, reply_type=TaskResponse ).result(5.0) - task_id = task_response.task_name + task_id = task_response.task_id if timeout is not None: complete.wait(timeout) diff --git a/src/blueapi/cli/updates.py b/src/blueapi/cli/updates.py index 51a7b4f9f..d9279b5b6 100644 --- a/src/blueapi/cli/updates.py +++ b/src/blueapi/cli/updates.py @@ -43,15 +43,15 @@ def _update(self, name: str, view: StatusView) -> None: class CliEventRenderer: - _task_name: Optional[str] + _task_id: Optional[str] _pbar_renderer: ProgressBarRenderer def __init__( self, - task_name: Optional[str] = None, + task_id: Optional[str] = None, pbar_renderer: Optional[ProgressBarRenderer] = None, ) -> None: - self._task_name = task_name + self._task_id = task_id if pbar_renderer is None: pbar_renderer = ProgressBarRenderer() self._pbar_renderer = pbar_renderer @@ -65,14 +65,14 @@ def on_worker_event(self, event: WorkerEvent) -> None: print(str(event.state)) def _relates_to_task(self, event: Union[WorkerEvent, ProgressEvent]) -> bool: - if self._task_name is None: + if self._task_id is None: return True elif isinstance(event, WorkerEvent): return ( event.task_status is not None - and event.task_status.task_name == self._task_name + and event.task_status.task_id == self._task_id ) elif isinstance(event, ProgressEvent): - return event.task_name == self._task_name + return event.task_id == self._task_id else: return False diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index ae50857bc..dc24be009 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -70,8 +70,9 @@ def submit_task( handler: Handler = Depends(get_handler), ): """Submit a task onto the worker queue.""" - handler.worker.submit_task(name, RunPlan(name=name, params=task)) - return TaskResponse(task_name=name) + task_id = handler.worker.submit_task(RunPlan(name=name, params=task)) + handler.worker.begin_task(task_id) + return TaskResponse(task_id=task_id) @app.get("/worker/state") diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index 17cd57c06..b5599fcd0 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -80,4 +80,4 @@ class TaskResponse(BlueapiBaseModel): Acknowledgement that a task has started, includes its ID """ - task_name: str = Field(description="Unique identifier for the task") + task_id: str = Field(description="Unique identifier for the task") diff --git a/src/blueapi/worker/__init__.py b/src/blueapi/worker/__init__.py index bde984831..78309230e 100644 --- a/src/blueapi/worker/__init__.py +++ b/src/blueapi/worker/__init__.py @@ -2,7 +2,8 @@ from .multithread import run_worker_in_own_thread from .reworker import RunEngineWorker from .task import RunPlan, Task -from .worker import Worker +from .worker import TrackableTask, Worker +from .worker_busy_error import WorkerBusyError __all__ = [ "run_worker_in_own_thread", @@ -15,4 +16,6 @@ "StatusView", "ProgressEvent", "TaskStatus", + "TrackableTask", + "WorkerBusyError", ] diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 6193f4068..9e9b7e8e3 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -88,7 +88,7 @@ class ProgressEvent(BlueapiBaseModel): such as moving motors and exposing detectors. """ - task_name: str + task_id: str statuses: Mapping[str, StatusView] = Field(default_factory=dict) @@ -97,7 +97,7 @@ class TaskStatus(BlueapiBaseModel): Status of a task the worker is running. """ - task_name: str + task_id: str task_complete: bool task_failed: bool diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index b1330e7e7..bec23f768 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -25,8 +25,8 @@ WorkerState, ) from .multithread import run_worker_in_own_thread -from .task import ActiveTask, Task -from .worker import Worker +from .task import Task +from .worker import TrackableTask, Worker from .worker_busy_error import WorkerBusyError LOGGER = logging.getLogger(__name__) @@ -47,11 +47,13 @@ class RunEngineWorker(Worker[Task]): _ctx: BlueskyContext _stop_timeout: float + _pending_tasks: Dict[str, TrackableTask] + _state: WorkerState _errors: List[str] _warnings: List[str] - _task_queue: Queue # type: ignore - _current: Optional[ActiveTask] + _task_channel: Queue # type: ignore + _current: Optional[TrackableTask] _status_lock: RLock _status_snapshot: Dict[str, StatusView] _completed_statuses: Set[str] @@ -70,10 +72,12 @@ def __init__( self._ctx = ctx self._stop_timeout = stop_timeout + self._pending_tasks = {} + self._state = WorkerState.from_bluesky_state(ctx.run_engine.state) self._errors = [] self._warnings = [] - self._task_queue = Queue(maxsize=1) + self._task_channel = Queue(maxsize=1) self._current = None self._worker_events = EventPublisher() self._progress_events = EventPublisher() @@ -85,11 +89,33 @@ def __init__( self._stopping = Event() self._stopped = Event() - def submit_task(self, name: str, task: Task) -> None: - active_task = ActiveTask(name, task) - LOGGER.info(f"Submitting: {active_task}") + def clear_task(self, task_id: str) -> bool: + if task_id in self._pending_tasks: + del self._pending_tasks[task_id] + return True + else: + return False + + def get_pending_tasks(self) -> List[TrackableTask[Task]]: + return list(self._pending_tasks.values()) + + def begin_task(self, task_id: str) -> None: + task = self._pending_tasks.get(task_id) + if task is not None: + self._submit_trackable_task(task) + else: + raise KeyError(f"No pending task with ID {task_id}") + + def submit_task(self, task: Task) -> str: + task_id: str = str(uuid.uuid4()) + trackable_task = TrackableTask(task_id=task_id, task=task) + self._pending_tasks[task_id] = trackable_task + return task_id + + def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: + LOGGER.info(f"Submitting: {trackable_task}") try: - self._task_queue.put_nowait(active_task) + self._task_channel.put_nowait(trackable_task) except Full: LOGGER.error("Cannot submit task while another is running") raise WorkerBusyError("Cannot submit task while another is running") @@ -104,7 +130,7 @@ def stop(self) -> None: # If the worker has not yet started there is nothing to do. if self._started.is_set(): - self._task_queue.put(KillSignal()) + self._task_channel.put(KillSignal()) self._stopped.wait(timeout=self._stop_timeout) # Event timeouts do not actually raise errors if not self._stopped.is_set(): @@ -138,8 +164,8 @@ def _cycle_with_error_handling(self) -> None: def _cycle(self) -> None: try: LOGGER.info("Awaiting task") - next_task: Union[ActiveTask, KillSignal] = self._task_queue.get() - if isinstance(next_task, ActiveTask): + next_task: Union[TrackableTask, KillSignal] = self._task_channel.get() + if isinstance(next_task, TrackableTask): LOGGER.info(f"Got new task: {next_task}") self._current = next_task # Informing mypy that the task is not None self._current.task.do_task(self._ctx) @@ -200,11 +226,11 @@ def _report_status( warnings = self._warnings if self._current is not None: task_status = TaskStatus( - task_name=self._current.name, + task_id=self._current.task_id, task_complete=self._current.is_complete, task_failed=self._current.is_error or bool(errors), ) - correlation_id = self._current.name + correlation_id = self._current.task_id else: task_status = None correlation_id = None @@ -219,7 +245,7 @@ def _report_status( def _on_document(self, name: str, document: Mapping[str, Any]) -> None: if self._current is not None: - correlation_id = self._current.name + correlation_id = self._current.task_id self._data_events.publish( DataEvent(name=name, doc=document), correlation_id ) @@ -293,10 +319,10 @@ def _publish_status_snapshot(self) -> None: else: self._progress_events.publish( ProgressEvent( - task_name=self._current.name, + task_id=self._current.task_id, statuses=self._status_snapshot, ), - self._current.name, + self._current.task_id, ) diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index 2a77a18ba..fdec51202 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -1,6 +1,5 @@ import logging from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any, Mapping from pydantic import BaseModel, Field, parse_obj_as @@ -65,11 +64,3 @@ def _lookup_params( model = plan.model return parse_obj_as(model, params) - - -@dataclass -class ActiveTask: - name: str - task: Task - is_complete: bool = False - is_error: bool = False diff --git a/src/blueapi/worker/worker.py b/src/blueapi/worker/worker.py index cee67df7a..d935ee6a7 100644 --- a/src/blueapi/worker/worker.py +++ b/src/blueapi/worker/worker.py @@ -1,13 +1,25 @@ from abc import ABC, abstractmethod -from typing import Generic, TypeVar +from typing import Generic, List, TypeVar from blueapi.core import DataEvent, EventStream +from blueapi.utils import BlueapiBaseModel from .event import ProgressEvent, WorkerEvent, WorkerState T = TypeVar("T") +class TrackableTask(BlueapiBaseModel, Generic[T]): + """ + A representation of a task that the worker recognizes + """ + + task_id: str + task: T + is_complete: bool = False + is_error: bool = False + + class Worker(ABC, Generic[T]): """ Entity that takes and runs tasks. Intended to be a central, @@ -15,13 +27,47 @@ class Worker(ABC, Generic[T]): """ @abstractmethod - def submit_task(self, __name: str, __task: T) -> None: + def get_pending_tasks(self) -> List[TrackableTask[T]]: """ - Submit a task to be run + Return a list of all tasks pending on the worker, + any one of which can be triggered with begin_task. + + Returns: + List[TrackableTask[T]]: List of task objects + """ + + @abstractmethod + def clear_task(self, task_id: str) -> bool: + """ + Remove a pending task from the worker Args: - __name (str): A unique name to identify this task - __task (T): The task to run + task_id: The ID of the task to be removed + Returns: + bool: True if the task existed in the first place + """ + + @abstractmethod + def begin_task(self, task_id: str) -> None: + """ + Trigger a pending task. Will fail if the worker is busy. + + Args: + task_id: The ID of the task to be triggered + Throws: + WorkerBusyError: If the worker is already running a task. + KeyError: If the task ID does not exist + """ + + @abstractmethod + def submit_task(self, task: T) -> str: + """ + Submit a task to be run on begin_task + + Args: + task: A description of the task + Returns: + str: A unique ID to refer to this task """ @abstractmethod diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index 2bc3c3c46..748b9ff0c 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -80,9 +80,9 @@ def test_put_plan_submits_task(handler: Handler, client: TestClient) -> None: client.put(f"/task/{task_name}", json=task_json) - task_queue = handler.worker._task_queue.queue # type: ignore - assert len(task_queue) == 1 - assert task_queue[0].task == RunPlan(name=task_name, params=task_json) + assert handler.worker.get_pending_tasks()[0].task == RunPlan( + name=task_name, params=task_json + ) def test_get_state_updates(handler: Handler, client: TestClient) -> None: diff --git a/tests/worker/test_reworker.py b/tests/worker/test_reworker.py index d04fd42ff..bdda025f3 100644 --- a/tests/worker/test_reworker.py +++ b/tests/worker/test_reworker.py @@ -1,4 +1,5 @@ import itertools +import threading from concurrent.futures import Future from typing import Callable, Iterable, List, Optional, TypeVar @@ -7,106 +8,260 @@ from blueapi.config import EnvironmentConfig, Source, SourceKind from blueapi.core import BlueskyContext, EventStream from blueapi.worker import ( + ProgressEvent, RunEngineWorker, RunPlan, Task, TaskStatus, + TrackableTask, Worker, + WorkerBusyError, WorkerEvent, WorkerState, ) -from blueapi.worker.event import ProgressEvent -from blueapi.worker.worker_busy_error import WorkerBusyError + +_SIMPLE_TASK = RunPlan(name="sleep", params={"time": 0.0}) +_LONG_TASK = RunPlan(name="sleep", params={"time": 1.0}) +_INDEFINITE_TASK = RunPlan( + name="set_absolute", + params={"movable": "fake_device", "value": 4.0}, +) + + +class FakeDevice: + event: threading.Event + + @property + def name(self) -> str: + return "fake_device" + + def __init__(self) -> None: + self.event = threading.Event() + + def set(self, pos: float) -> None: + self.event.wait() + self.event.clear() + + +@pytest.fixture +def fake_device() -> FakeDevice: + return FakeDevice() @pytest.fixture -def context() -> BlueskyContext: +def context(fake_device: FakeDevice) -> BlueskyContext: ctx = BlueskyContext() ctx_config = EnvironmentConfig() ctx_config.sources.append( Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices") ) + ctx.device(fake_device) ctx.with_config(ctx_config) return ctx @pytest.fixture -def worker(context: BlueskyContext) -> Iterable[Worker[Task]]: - worker = RunEngineWorker(context) - yield worker - worker.stop() +def inert_worker(context: BlueskyContext) -> Worker[Task]: + return RunEngineWorker(context, stop_timeout=2.0) -def test_stop_doesnt_hang(worker: Worker) -> None: - worker.start() +@pytest.fixture +def worker(inert_worker: Worker[Task]) -> Iterable[Worker[Task]]: + inert_worker.start() + yield inert_worker + inert_worker.stop() -def test_stop_is_idempontent_if_worker_not_started(worker: Worker) -> None: - ... +def test_stop_doesnt_hang(inert_worker: Worker) -> None: + inert_worker.start() + inert_worker.stop() -def test_multi_stop(worker: Worker) -> None: - worker.start() - worker.stop() +def test_stop_is_idempontent_if_worker_not_started(inert_worker: Worker) -> None: + inert_worker.stop() -def test_multi_start(worker: Worker) -> None: - worker.start() +def test_multi_stop(inert_worker: Worker) -> None: + inert_worker.start() + inert_worker.stop() + inert_worker.stop() + + +def test_multi_start(inert_worker: Worker) -> None: + inert_worker.start() with pytest.raises(Exception): - worker.start() - - -def test_runs_plan(worker: Worker) -> None: - assert_run_produces_worker_events( - [ - WorkerEvent( - state=WorkerState.RUNNING, - task_status=TaskStatus( - task_name="test", task_complete=False, task_failed=False - ), - errors=[], - warnings=[], + inert_worker.start() + inert_worker.stop() + + +def test_submit_task(worker: Worker) -> None: + assert worker.get_pending_tasks() == [] + task_id = worker.submit_task(_SIMPLE_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id, task=_SIMPLE_TASK) + ] + + +def test_submit_multiple_tasks(worker: Worker) -> None: + assert worker.get_pending_tasks() == [] + task_id_1 = worker.submit_task(_SIMPLE_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id_1, task=_SIMPLE_TASK) + ] + task_id_2 = worker.submit_task(_LONG_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id_1, task=_SIMPLE_TASK), + TrackableTask(task_id=task_id_2, task=_LONG_TASK), + ] + + +def test_stop_with_task_pending(inert_worker: Worker) -> None: + inert_worker.start() + inert_worker.submit_task(_SIMPLE_TASK) + inert_worker.stop() + + +def test_restart_leaves_task_pending(worker: Worker) -> None: + task_id = worker.submit_task(_SIMPLE_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id, task=_SIMPLE_TASK) + ] + worker.stop() + worker.start() + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id, task=_SIMPLE_TASK) + ] + + +def test_submit_before_start_pending(inert_worker: Worker) -> None: + task_id = inert_worker.submit_task(_SIMPLE_TASK) + inert_worker.start() + assert inert_worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id, task=_SIMPLE_TASK) + ] + inert_worker.stop() + assert inert_worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id, task=_SIMPLE_TASK) + ] + + +def test_clear_task(worker: Worker) -> None: + task_id = worker.submit_task(_SIMPLE_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id, task=_SIMPLE_TASK) + ] + assert worker.clear_task(task_id) + assert worker.get_pending_tasks() == [] + + +def test_clear_nonexistant_task(worker: Worker) -> None: + assert not worker.clear_task("foo") + + +def test_does_not_allow_simultaneous_running_tasks( + worker: Worker, + fake_device: FakeDevice, +) -> None: + task_ids = [ + worker.submit_task(_INDEFINITE_TASK), + worker.submit_task(_INDEFINITE_TASK), + ] + with pytest.raises(WorkerBusyError): + for task_id in task_ids: + worker.begin_task(task_id) + fake_device.event.set() + + +@pytest.mark.parametrize("num_runs", [0, 1, 2]) +def test_produces_worker_events(worker: Worker, num_runs: int) -> None: + task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)] + event_sequences = [_sleep_events(task_id) for task_id in task_ids] + + for task_id, events in zip(task_ids, event_sequences): + assert_run_produces_worker_events(events, worker, task_id) + + +def _sleep_events(task_id: str) -> List[WorkerEvent]: + return [ + WorkerEvent( + state=WorkerState.RUNNING, + task_status=TaskStatus( + task_id=task_id, task_complete=False, task_failed=False ), - WorkerEvent( - state=WorkerState.IDLE, - task_status=TaskStatus( - task_name="test", task_complete=False, task_failed=False - ), - errors=[], - warnings=[], + errors=[], + warnings=[], + ), + WorkerEvent( + state=WorkerState.IDLE, + task_status=TaskStatus( + task_id=task_id, task_complete=False, task_failed=False ), - WorkerEvent( - state=WorkerState.IDLE, - task_status=TaskStatus( - task_name="test", task_complete=True, task_failed=False - ), - errors=[], - warnings=[], + errors=[], + warnings=[], + ), + WorkerEvent( + state=WorkerState.IDLE, + task_status=TaskStatus( + task_id=task_id, task_complete=True, task_failed=False ), - ], - worker, + errors=[], + warnings=[], + ), + ] + + +def test_no_additional_progress_events_after_complete(worker: Worker): + """ + See https://github.com/bluesky/ophyd/issues/1115 + """ + + progress_events: List[ProgressEvent] = [] + worker.progress_events.subscribe(lambda event, id: progress_events.append(event)) + + task: Task = RunPlan( + name="move", params={"moves": {"additional_status_device": 5.0}} ) + task_id = worker.submit_task(task) + begin_task_and_wait_until_complete(worker, task_id) + + # Extract all the display_name fields from the events + list_of_dict_keys = [pe.statuses.values() for pe in progress_events] + status_views = [item for sublist in list_of_dict_keys for item in sublist] + display_names = [view.display_name for view in status_views] + + assert "STATUS_AFTER_FINISH" not in display_names + + +# +# Worker helpers +# -def submit_task_and_wait_until_complete( - worker: Worker, task: Task, timeout: float = 5.0 +def assert_run_produces_worker_events( + expected_events: List[WorkerEvent], + worker: Worker, + task_id: str, +) -> None: + assert begin_task_and_wait_until_complete(worker, task_id) == expected_events + + +def begin_task_and_wait_until_complete( + worker: Worker, + task_id: str, + timeout: float = 5.0, ) -> List[WorkerEvent]: events: "Future[List[WorkerEvent]]" = take_events( worker.worker_events, lambda event: event.is_complete(), ) - worker.submit_task("test", task) + worker.begin_task(task_id) return events.result(timeout=timeout) -def assert_run_produces_worker_events( - expected_events: List[WorkerEvent], - worker: Worker, - task: Task = RunPlan(name="sleep", params={"time": 0.0}), -) -> None: - worker.start() - assert submit_task_and_wait_until_complete(worker, task) == expected_events +# +# Event stream helpers +# E = TypeVar("E") @@ -136,34 +291,3 @@ def on_event(event: E, event_id: Optional[str]) -> None: sub = stream.subscribe(on_event) future.add_done_callback(lambda _: stream.unsubscribe(sub)) return future - - -def test_worker_only_accepts_one_task_on_queue(worker: Worker): - worker.start() - task: Task = RunPlan(name="sleep", params={"time": 1.0}) - - worker.submit_task("first_task", task) - with pytest.raises(WorkerBusyError): - worker.submit_task("second_task", task) - - -def test_no_additional_progress_events_after_complete(worker: Worker): - """ - See https://github.com/bluesky/ophyd/issues/1115 - """ - worker.start() - - progress_events: List[ProgressEvent] = [] - worker.progress_events.subscribe(lambda event, id: progress_events.append(event)) - - task: Task = RunPlan( - name="move", params={"moves": {"additional_status_device": 5.0}} - ) - submit_task_and_wait_until_complete(worker, task) - - # Exctract all the display_name fields from the events - list_of_dict_keys = [pe.statuses.values() for pe in progress_events] - status_views = [item for sublist in list_of_dict_keys for item in sublist] - display_names = [view.display_name for view in status_views] - - assert "STATUS_AFTER_FINISH" not in display_names