Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create very simple RPC so the subprocess loads functions #584

Merged
merged 8 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_plans(runner: WorkerDispatcher = Depends(_runner)):
)
def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about a plan by its (unique) name."""
return runner.run(interface.get_plan, [name])
return runner.run(interface.get_plan, name)


@app.get("/devices", response_model=DeviceResponse)
Expand All @@ -132,7 +132,7 @@ def get_devices(runner: WorkerDispatcher = Depends(_runner)):
)
def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about a devices by its (unique) name."""
return runner.run(interface.get_device, [name])
return runner.run(interface.get_device, name)


example_task = Task(name="count", params={"detectors": ["x"]})
Expand All @@ -151,8 +151,8 @@ def submit_task(
):
"""Submit a task to the worker."""
try:
plan_model = runner.run(interface.get_plan, [task.name])
task_id: str = runner.run(interface.submit_task, [task])
plan_model = runner.run(interface.get_plan, task.name)
task_id: str = runner.run(interface.submit_task, task)
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
except ValidationError as e:
Expand All @@ -176,7 +176,7 @@ def delete_submitted_task(
task_id: str,
runner: WorkerDispatcher = Depends(_runner),
) -> TaskResponse:
return TaskResponse(task_id=runner.run(interface.clear_task, [task_id]))
return TaskResponse(task_id=runner.run(interface.clear_task, task_id))


def validate_task_status(v: str) -> TaskStatusEnum:
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_tasks(
detail="Invalid status query parameter",
) from e

tasks = runner.run(interface.get_tasks_by_status, [desired_status])
tasks = runner.run(interface.get_tasks_by_status, desired_status)
else:
tasks = runner.run(interface.get_tasks)
return TasksListResponse(tasks=tasks)
Expand All @@ -227,7 +227,7 @@ def set_active_task(
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="Worker already active"
)
runner.run(interface.begin_task, [task])
runner.run(interface.begin_task, task)
return task


Expand All @@ -240,7 +240,7 @@ def get_task(
runner: WorkerDispatcher = Depends(_runner),
) -> TrackableTask:
"""Retrieve a task"""
task = runner.run(interface.get_task_by_id, [task_id])
task = runner.run(interface.get_task_by_id, task_id)
if task is None:
raise KeyError
return task
Expand Down Expand Up @@ -313,17 +313,15 @@ def set_state(
and new_state in _ALLOWED_TRANSITIONS[current_state]
):
if new_state == WorkerState.PAUSED:
runner.run(interface.pause_worker, [state_change_request.defer])
runner.run(interface.pause_worker, state_change_request.defer)
elif new_state == WorkerState.RUNNING:
runner.run(interface.resume_worker)
elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}:
try:
runner.run(
interface.cancel_active_task,
[
state_change_request.new_state is WorkerState.ABORTING,
state_change_request.reason,
],
state_change_request.new_state is WorkerState.ABORTING,
state_change_request.reason,
)
except TransitionError:
response.status_code = status.HTTP_400_BAD_REQUEST
Expand Down
56 changes: 41 additions & 15 deletions src/blueapi/service/runner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import logging
import signal
from collections.abc import Callable, Iterable
from collections.abc import Callable
from importlib import import_module
from multiprocessing import Pool, set_start_method
from multiprocessing.pool import Pool as PoolClass
from typing import Any
from typing import Any, ParamSpec, TypeVar

from blueapi.config import ApplicationConfig
from blueapi.service.interface import (
setup,
teardown,
)
from blueapi.service.interface import setup, teardown
from blueapi.service.model import EnvironmentResponse

# The default multiprocessing start method is fork
set_start_method("spawn", force=True)

LOGGER = logging.getLogger(__name__)

P = ParamSpec("P")
T = TypeVar("T")


def _init_worker():
# Replace sigint to allow subprocess to be terminated
Expand Down Expand Up @@ -56,7 +57,7 @@ def start(self):
try:
if self._use_subprocess:
self._subprocess = Pool(initializer=_init_worker, processes=1)
self.run(setup, [self._config])
self.run(setup, self._config)
self._state = EnvironmentResponse(initialized=True)
except Exception as e:
self._state = EnvironmentResponse(
Expand All @@ -82,21 +83,25 @@ def stop(self):
)
LOGGER.exception(e)

def run(self, function: Callable, arguments: Iterable | None = None) -> Any:
arguments = arguments or []
def run(self, function: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
if self._use_subprocess:
return self._run_in_subprocess(function, arguments)
return self._run_in_subprocess(function, *args, **kwargs)
else:
return function(*arguments)
return function(*args, **kwargs)

def _run_in_subprocess(
self,
function: Callable,
arguments: Iterable,
) -> Any:
function: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
if self._subprocess is None:
raise InvalidRunnerStateError("Subprocess runner has not been started")
return self._subprocess.apply(function, arguments)
if not (hasattr(function, "__name__") and hasattr(function, "__module__")):
raise RpcError(f"Target {function} invalid for running in subprocess")
return self._subprocess.apply(
_rpc, (function.__module__, function.__name__, *args), kwargs
)

@property
def state(self) -> EnvironmentResponse:
Expand All @@ -106,3 +111,24 @@ def state(self) -> EnvironmentResponse:
class InvalidRunnerStateError(Exception):
def __init__(self, message):
super().__init__(message)


class RpcError(Exception): ...


def _rpc(
module_name: str, function_name: str, *args: P.args, **kwargs: P.kwargs
) -> Any:
mod = import_module(module_name)
func: Callable[P, T] = _validate_function(
mod.__dict__.get(function_name, None), function_name
)
return func(*args, **kwargs)
DiamondJoseph marked this conversation as resolved.
Show resolved Hide resolved


def _validate_function(func: Any, function_name: str) -> Callable:
if func is None:
raise RpcError(f"{function_name}: No such function in subprocess API")
elif not callable(func):
raise RpcError(f"{function_name}: Object in subprocess is not a function")
return func
8 changes: 7 additions & 1 deletion tests/core/fake_device_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, NonCallableMock

from ophyd import EpicsMotor

Expand Down Expand Up @@ -29,3 +29,9 @@ def _mock_with_name(name: str) -> MagicMock:
mock = MagicMock()
mock.name = name
return mock


FOO = NonCallableMock()
BAR = NonCallableMock()
BAR.__name__ = "BAR"
BAR.__module__ = fake_motor_bundle_a.__module__
2 changes: 1 addition & 1 deletion tests/messaging/test_stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore
reply = template.send_and_receive(test_queue, message, message_type).result(
timeout=_TIMEOUT
)
if type(message) == np.ndarray:
if type(message) is np.ndarray:
message = message.tolist()
assert reply == message

Expand Down
73 changes: 61 additions & 12 deletions tests/service/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,31 @@

from blueapi.service import interface
from blueapi.service.model import EnvironmentResponse
from blueapi.service.runner import InvalidRunnerStateError, WorkerDispatcher
from blueapi.service.runner import (
InvalidRunnerStateError,
RpcError,
WorkerDispatcher,
)


def test_initialize():
runner = WorkerDispatcher()
@pytest.fixture
def local_runner():
return WorkerDispatcher(use_subprocess=False)


@pytest.fixture
def runner():
return WorkerDispatcher()


@pytest.fixture
def started_runner(runner: WorkerDispatcher):
runner.start()
yield runner
runner.stop()


def test_initialize(runner: WorkerDispatcher):
assert not runner.state.initialized
runner.start()
assert runner.state.initialized
Expand All @@ -19,23 +39,20 @@ def test_initialize():
assert not runner.state.initialized


def test_reload():
runner = WorkerDispatcher()
def test_reload(runner: WorkerDispatcher):
runner.start()
assert runner.state.initialized
runner.reload()
assert runner.state.initialized
runner.stop()


def test_raises_if_used_before_started():
runner = WorkerDispatcher()
def test_raises_if_used_before_started(runner: WorkerDispatcher):
with pytest.raises(InvalidRunnerStateError):
assert runner.run(interface.get_plans) is None
runner.run(interface.get_plans)


def test_error_on_runner_setup():
runner = WorkerDispatcher(use_subprocess=False)
def test_error_on_runner_setup(local_runner: WorkerDispatcher):
expected_state = EnvironmentResponse(
initialized=False,
error_message="Intentional start_worker exception",
Expand All @@ -48,8 +65,8 @@ def test_error_on_runner_setup():
# Calling reload here instead of start also indirectly
# tests that stop() doesn't raise if there is no error message
# and the runner is not yet initialised
runner.reload()
state = runner.state
local_runner.reload()
state = local_runner.state
assert state == expected_state


Expand Down Expand Up @@ -85,3 +102,35 @@ def test_can_reload_after_an_error(pool_mock: MagicMock):
runner.reload()

assert runner.state == EnvironmentResponse(initialized=True, error_message=None)


def test_clear_message_for_not_found(started_runner: WorkerDispatcher):
from tests.core.fake_device_module import fake_motor_y

# Change in this process not reflected in subprocess
fake_motor_y.__name__ = "not_exported"

with pytest.raises(
RpcError, match="not_exported: No such function in subprocess API"
):
started_runner.run(fake_motor_y)


def test_clear_message_for_non_function(started_runner: WorkerDispatcher):
from tests.core.fake_device_module import FOO

with pytest.raises(
RpcError,
match="Target <NonCallableMock id='[0-9]+'> invalid for running in subprocess",
):
started_runner.run(FOO)


def test_clear_message_for_invalid_function(started_runner: WorkerDispatcher):
from tests.core.fake_device_module import BAR

with pytest.raises(
RpcError,
match="BAR: Object in subprocess is not a function",
):
started_runner.run(BAR)
Loading