Skip to content

Commit

Permalink
Add transaction mode to worker (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester authored May 23, 2023
1 parent febaf8e commit 3b4ab8c
Show file tree
Hide file tree
Showing 12 changed files with 351 additions and 133 deletions.
27 changes: 27 additions & 0 deletions docs/developer/explanations/decisions/0002-no-queues.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
2. No Queues
============

Date: 2023-05-22

Status
------

Accepted

Context
-------

In asking whether this service should hold and execute a queue of tasks.

Decision
--------

We will not hold any queues. The worker can execute one task at a time and will return
an error if asked to execute one task while another is running. Queueing should be the
responsibility of a different service.

Consequences
------------

The API must be kept queue-free, although transactions are permitted where the server
caches requests.
2 changes: 1 addition & 1 deletion src/blueapi/cli/amq.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def on_progress_event_wrapper(
task_response = self.app.send_and_receive(
"worker.run", {"name": name, "params": params}, reply_type=TaskResponse
).result(5.0)
task_id = task_response.task_name
task_id = task_response.task_id

if timeout is not None:
complete.wait(timeout)
Expand Down
12 changes: 6 additions & 6 deletions src/blueapi/cli/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def _update(self, name: str, view: StatusView) -> None:


class CliEventRenderer:
_task_name: Optional[str]
_task_id: Optional[str]
_pbar_renderer: ProgressBarRenderer

def __init__(
self,
task_name: Optional[str] = None,
task_id: Optional[str] = None,
pbar_renderer: Optional[ProgressBarRenderer] = None,
) -> None:
self._task_name = task_name
self._task_id = task_id
if pbar_renderer is None:
pbar_renderer = ProgressBarRenderer()
self._pbar_renderer = pbar_renderer
Expand All @@ -65,14 +65,14 @@ def on_worker_event(self, event: WorkerEvent) -> None:
print(str(event.state))

def _relates_to_task(self, event: Union[WorkerEvent, ProgressEvent]) -> bool:
if self._task_name is None:
if self._task_id is None:
return True
elif isinstance(event, WorkerEvent):
return (
event.task_status is not None
and event.task_status.task_name == self._task_name
and event.task_status.task_id == self._task_id
)
elif isinstance(event, ProgressEvent):
return event.task_name == self._task_name
return event.task_id == self._task_id
else:
return False
5 changes: 3 additions & 2 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def submit_task(
handler: Handler = Depends(get_handler),
):
"""Submit a task onto the worker queue."""
handler.worker.submit_task(name, RunPlan(name=name, params=task))
return TaskResponse(task_name=name)
task_id = handler.worker.submit_task(RunPlan(name=name, params=task))
handler.worker.begin_task(task_id)
return TaskResponse(task_id=task_id)


@app.get("/worker/state")
Expand Down
2 changes: 1 addition & 1 deletion src/blueapi/service/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ class TaskResponse(BlueapiBaseModel):
Acknowledgement that a task has started, includes its ID
"""

task_name: str = Field(description="Unique identifier for the task")
task_id: str = Field(description="Unique identifier for the task")
5 changes: 4 additions & 1 deletion src/blueapi/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .multithread import run_worker_in_own_thread
from .reworker import RunEngineWorker
from .task import RunPlan, Task
from .worker import Worker
from .worker import TrackableTask, Worker
from .worker_busy_error import WorkerBusyError

__all__ = [
"run_worker_in_own_thread",
Expand All @@ -15,4 +16,6 @@
"StatusView",
"ProgressEvent",
"TaskStatus",
"TrackableTask",
"WorkerBusyError",
]
4 changes: 2 additions & 2 deletions src/blueapi/worker/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ProgressEvent(BlueapiBaseModel):
such as moving motors and exposing detectors.
"""

task_name: str
task_id: str
statuses: Mapping[str, StatusView] = Field(default_factory=dict)


Expand All @@ -97,7 +97,7 @@ class TaskStatus(BlueapiBaseModel):
Status of a task the worker is running.
"""

task_name: str
task_id: str
task_complete: bool
task_failed: bool

Expand Down
60 changes: 43 additions & 17 deletions src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
WorkerState,
)
from .multithread import run_worker_in_own_thread
from .task import ActiveTask, Task
from .worker import Worker
from .task import Task
from .worker import TrackableTask, Worker
from .worker_busy_error import WorkerBusyError

LOGGER = logging.getLogger(__name__)
Expand All @@ -47,11 +47,13 @@ class RunEngineWorker(Worker[Task]):
_ctx: BlueskyContext
_stop_timeout: float

_pending_tasks: Dict[str, TrackableTask]

_state: WorkerState
_errors: List[str]
_warnings: List[str]
_task_queue: Queue # type: ignore
_current: Optional[ActiveTask]
_task_channel: Queue # type: ignore
_current: Optional[TrackableTask]
_status_lock: RLock
_status_snapshot: Dict[str, StatusView]
_completed_statuses: Set[str]
Expand All @@ -70,10 +72,12 @@ def __init__(
self._ctx = ctx
self._stop_timeout = stop_timeout

self._pending_tasks = {}

self._state = WorkerState.from_bluesky_state(ctx.run_engine.state)
self._errors = []
self._warnings = []
self._task_queue = Queue(maxsize=1)
self._task_channel = Queue(maxsize=1)
self._current = None
self._worker_events = EventPublisher()
self._progress_events = EventPublisher()
Expand All @@ -85,11 +89,33 @@ def __init__(
self._stopping = Event()
self._stopped = Event()

def submit_task(self, name: str, task: Task) -> None:
active_task = ActiveTask(name, task)
LOGGER.info(f"Submitting: {active_task}")
def clear_task(self, task_id: str) -> bool:
if task_id in self._pending_tasks:
del self._pending_tasks[task_id]
return True
else:
return False

def get_pending_tasks(self) -> List[TrackableTask[Task]]:
return list(self._pending_tasks.values())

def begin_task(self, task_id: str) -> None:
task = self._pending_tasks.get(task_id)
if task is not None:
self._submit_trackable_task(task)
else:
raise KeyError(f"No pending task with ID {task_id}")

def submit_task(self, task: Task) -> str:
task_id: str = str(uuid.uuid4())
trackable_task = TrackableTask(task_id=task_id, task=task)
self._pending_tasks[task_id] = trackable_task
return task_id

def _submit_trackable_task(self, trackable_task: TrackableTask) -> None:
LOGGER.info(f"Submitting: {trackable_task}")
try:
self._task_queue.put_nowait(active_task)
self._task_channel.put_nowait(trackable_task)
except Full:
LOGGER.error("Cannot submit task while another is running")
raise WorkerBusyError("Cannot submit task while another is running")
Expand All @@ -104,7 +130,7 @@ def stop(self) -> None:

# If the worker has not yet started there is nothing to do.
if self._started.is_set():
self._task_queue.put(KillSignal())
self._task_channel.put(KillSignal())
self._stopped.wait(timeout=self._stop_timeout)
# Event timeouts do not actually raise errors
if not self._stopped.is_set():
Expand Down Expand Up @@ -138,8 +164,8 @@ def _cycle_with_error_handling(self) -> None:
def _cycle(self) -> None:
try:
LOGGER.info("Awaiting task")
next_task: Union[ActiveTask, KillSignal] = self._task_queue.get()
if isinstance(next_task, ActiveTask):
next_task: Union[TrackableTask, KillSignal] = self._task_channel.get()
if isinstance(next_task, TrackableTask):
LOGGER.info(f"Got new task: {next_task}")
self._current = next_task # Informing mypy that the task is not None
self._current.task.do_task(self._ctx)
Expand Down Expand Up @@ -200,11 +226,11 @@ def _report_status(
warnings = self._warnings
if self._current is not None:
task_status = TaskStatus(
task_name=self._current.name,
task_id=self._current.task_id,
task_complete=self._current.is_complete,
task_failed=self._current.is_error or bool(errors),
)
correlation_id = self._current.name
correlation_id = self._current.task_id
else:
task_status = None
correlation_id = None
Expand All @@ -219,7 +245,7 @@ def _report_status(

def _on_document(self, name: str, document: Mapping[str, Any]) -> None:
if self._current is not None:
correlation_id = self._current.name
correlation_id = self._current.task_id
self._data_events.publish(
DataEvent(name=name, doc=document), correlation_id
)
Expand Down Expand Up @@ -293,10 +319,10 @@ def _publish_status_snapshot(self) -> None:
else:
self._progress_events.publish(
ProgressEvent(
task_name=self._current.name,
task_id=self._current.task_id,
statuses=self._status_snapshot,
),
self._current.name,
self._current.task_id,
)


Expand Down
9 changes: 0 additions & 9 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Mapping

from pydantic import BaseModel, Field, parse_obj_as
Expand Down Expand Up @@ -65,11 +64,3 @@ def _lookup_params(

model = plan.model
return parse_obj_as(model, params)


@dataclass
class ActiveTask:
name: str
task: Task
is_complete: bool = False
is_error: bool = False
56 changes: 51 additions & 5 deletions src/blueapi/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,73 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import Generic, List, TypeVar

from blueapi.core import DataEvent, EventStream
from blueapi.utils import BlueapiBaseModel

from .event import ProgressEvent, WorkerEvent, WorkerState

T = TypeVar("T")


class TrackableTask(BlueapiBaseModel, Generic[T]):
"""
A representation of a task that the worker recognizes
"""

task_id: str
task: T
is_complete: bool = False
is_error: bool = False


class Worker(ABC, Generic[T]):
"""
Entity that takes and runs tasks. Intended to be a central,
atomic worker rather than a load distributor
"""

@abstractmethod
def submit_task(self, __name: str, __task: T) -> None:
def get_pending_tasks(self) -> List[TrackableTask[T]]:
"""
Submit a task to be run
Return a list of all tasks pending on the worker,
any one of which can be triggered with begin_task.
Returns:
List[TrackableTask[T]]: List of task objects
"""

@abstractmethod
def clear_task(self, task_id: str) -> bool:
"""
Remove a pending task from the worker
Args:
__name (str): A unique name to identify this task
__task (T): The task to run
task_id: The ID of the task to be removed
Returns:
bool: True if the task existed in the first place
"""

@abstractmethod
def begin_task(self, task_id: str) -> None:
"""
Trigger a pending task. Will fail if the worker is busy.
Args:
task_id: The ID of the task to be triggered
Throws:
WorkerBusyError: If the worker is already running a task.
KeyError: If the task ID does not exist
"""

@abstractmethod
def submit_task(self, task: T) -> str:
"""
Submit a task to be run on begin_task
Args:
task: A description of the task
Returns:
str: A unique ID to refer to this task
"""

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def test_put_plan_submits_task(handler: Handler, client: TestClient) -> None:

client.put(f"/task/{task_name}", json=task_json)

task_queue = handler.worker._task_queue.queue # type: ignore
assert len(task_queue) == 1
assert task_queue[0].task == RunPlan(name=task_name, params=task_json)
assert handler.worker.get_pending_tasks()[0].task == RunPlan(
name=task_name, params=task_json
)


def test_get_state_updates(handler: Handler, client: TestClient) -> None:
Expand Down
Loading

0 comments on commit 3b4ab8c

Please sign in to comment.