Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create very simple RPC so the subprocess loads functions #584

Merged
merged 8 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_plans(runner: WorkerDispatcher = Depends(_runner)):
)
def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about a plan by its (unique) name."""
return runner.run(interface.get_plan, [name])
return runner.run(interface.get_plan, name)


@app.get("/devices", response_model=DeviceResponse)
Expand All @@ -132,7 +132,7 @@ def get_devices(runner: WorkerDispatcher = Depends(_runner)):
)
def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about a devices by its (unique) name."""
return runner.run(interface.get_device, [name])
return runner.run(interface.get_device, name)


example_task = Task(name="count", params={"detectors": ["x"]})
Expand All @@ -151,8 +151,8 @@ def submit_task(
):
"""Submit a task to the worker."""
try:
plan_model = runner.run(interface.get_plan, [task.name])
task_id: str = runner.run(interface.submit_task, [task])
plan_model = runner.run(interface.get_plan, task.name)
task_id: str = runner.run(interface.submit_task, task)
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
except ValidationError as e:
Expand All @@ -176,7 +176,7 @@ def delete_submitted_task(
task_id: str,
runner: WorkerDispatcher = Depends(_runner),
) -> TaskResponse:
return TaskResponse(task_id=runner.run(interface.clear_task, [task_id]))
return TaskResponse(task_id=runner.run(interface.clear_task, task_id))


def validate_task_status(v: str) -> TaskStatusEnum:
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_tasks(
detail="Invalid status query parameter",
) from e

tasks = runner.run(interface.get_tasks_by_status, [desired_status])
tasks = runner.run(interface.get_tasks_by_status, desired_status)
else:
tasks = runner.run(interface.get_tasks)
return TasksListResponse(tasks=tasks)
Expand All @@ -227,7 +227,7 @@ def set_active_task(
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="Worker already active"
)
runner.run(interface.begin_task, [task])
runner.run(interface.begin_task, task)
return task


Expand All @@ -240,7 +240,7 @@ def get_task(
runner: WorkerDispatcher = Depends(_runner),
) -> TrackableTask:
"""Retrieve a task"""
task = runner.run(interface.get_task_by_id, [task_id])
task = runner.run(interface.get_task_by_id, task_id)
if task is None:
raise KeyError
return task
Expand Down Expand Up @@ -313,17 +313,15 @@ def set_state(
and new_state in _ALLOWED_TRANSITIONS[current_state]
):
if new_state == WorkerState.PAUSED:
runner.run(interface.pause_worker, [state_change_request.defer])
runner.run(interface.pause_worker, state_change_request.defer)
elif new_state == WorkerState.RUNNING:
runner.run(interface.resume_worker)
elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}:
try:
runner.run(
interface.cancel_active_task,
[
state_change_request.new_state is WorkerState.ABORTING,
state_change_request.reason,
],
state_change_request.new_state is WorkerState.ABORTING,
state_change_request.reason,
)
except TransitionError:
response.status_code = status.HTTP_400_BAD_REQUEST
Expand Down
82 changes: 67 additions & 15 deletions src/blueapi/service/runner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import inspect
import logging
import signal
from collections.abc import Callable, Iterable
from collections.abc import Callable
from importlib import import_module
from multiprocessing import Pool, set_start_method
from multiprocessing.pool import Pool as PoolClass
from typing import Any
from typing import Any, ParamSpec, TypeVar

from blueapi.config import ApplicationConfig
from blueapi.service.interface import (
setup,
teardown,
)
from blueapi.service.interface import setup, teardown
from blueapi.service.model import EnvironmentResponse

# The default multiprocessing start method is fork
set_start_method("spawn", force=True)

LOGGER = logging.getLogger(__name__)

P = ParamSpec("P")
T = TypeVar("T")


def _init_worker():
# Replace sigint to allow subprocess to be terminated
Expand Down Expand Up @@ -56,7 +58,7 @@
try:
if self._use_subprocess:
self._subprocess = Pool(initializer=_init_worker, processes=1)
self.run(setup, [self._config])
self.run(setup, self._config)
self._state = EnvironmentResponse(initialized=True)
except Exception as e:
self._state = EnvironmentResponse(
Expand All @@ -82,21 +84,39 @@
)
LOGGER.exception(e)

def run(self, function: Callable, arguments: Iterable | None = None) -> Any:
arguments = arguments or []
def run(self, function: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
if self._use_subprocess:
return self._run_in_subprocess(function, arguments)
return self._run_in_subprocess(function, *args, **kwargs)
else:
return function(*arguments)
return function(*args, **kwargs)

def _run_in_subprocess(
self,
function: Callable,
arguments: Iterable,
) -> Any:
function: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
if self._subprocess is None:
raise InvalidRunnerStateError("Subprocess runner has not been started")
return self._subprocess.apply(function, arguments)
if not (hasattr(function, "__name__") and hasattr(function, "__module__")):
raise RpcError(f"{function} is anonymous, cannot be run in subprocess")
if not callable(function):
raise RpcError(f"{function} is not Callable, cannot be run in subprocess")
try:
return_type = inspect.signature(function).return_annotation
except TypeError:
return_type = None

Check warning on line 108 in src/blueapi/service/runner.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/runner.py#L107-L108

Added lines #L107 - L108 were not covered by tests

return self._subprocess.apply(
_rpc,
(
function.__module__,
function.__name__,
return_type,
*args,
),
kwargs,
)

@property
def state(self) -> EnvironmentResponse:
Expand All @@ -106,3 +126,35 @@
class InvalidRunnerStateError(Exception):
def __init__(self, message):
super().__init__(message)


class RpcError(Exception): ...


def _rpc(
module_name: str,
function_name: str,
expected_type: type[T] | None,
*args: Any,
**kwargs: Any,
) -> T:
mod = import_module(module_name)
func: Callable[P, T] = _validate_function(
mod.__dict__.get(function_name, None), function_name
)
value = func(*args, **kwargs)
if expected_type is None or isinstance(value, expected_type):
return value
else:
raise TypeError(
f"{function_name} returned value of type {type(value)}"
+ f" which is incompatible with expected {expected_type}"
)


def _validate_function(func: Any, function_name: str) -> Callable:
if func is None:
raise RpcError(f"{function_name}: No such function in subprocess API")
elif not callable(func):
raise RpcError(f"{function_name}: Object in subprocess is not a function")
return func
16 changes: 15 additions & 1 deletion tests/core/fake_device_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, NonCallableMock

from ophyd import EpicsMotor

Expand Down Expand Up @@ -26,6 +26,20 @@ def fake_motor_bundle_a(


def _mock_with_name(name: str) -> MagicMock:
# mock.name must return str, cannot MagicMock(name=name)
mock = MagicMock()
mock.name = name
return mock


def wrong_return_type() -> int:
return "0" # type: ignore


fetchable_non_callable = NonCallableMock()
fetchable_callable = MagicMock(return_value="string")

fetchable_non_callable.__name__ = "fetchable_non_callable"
fetchable_non_callable.__module__ = fake_motor_bundle_a.__module__
fetchable_callable.__name__ = "fetchable_callable"
fetchable_callable.__module__ = fake_motor_bundle_a.__module__
2 changes: 1 addition & 1 deletion tests/messaging/test_stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore
reply = template.send_and_receive(test_queue, message, message_type).result(
timeout=_TIMEOUT
)
if type(message) == np.ndarray:
if type(message) is np.ndarray:
message = message.tolist()
assert reply == message

Expand Down
100 changes: 88 additions & 12 deletions tests/service/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,31 @@

from blueapi.service import interface
from blueapi.service.model import EnvironmentResponse
from blueapi.service.runner import InvalidRunnerStateError, WorkerDispatcher
from blueapi.service.runner import (
InvalidRunnerStateError,
RpcError,
WorkerDispatcher,
)


def test_initialize():
runner = WorkerDispatcher()
@pytest.fixture
def local_runner():
return WorkerDispatcher(use_subprocess=False)


@pytest.fixture
def runner():
return WorkerDispatcher()


@pytest.fixture
def started_runner(runner: WorkerDispatcher):
runner.start()
yield runner
runner.stop()


def test_initialize(runner: WorkerDispatcher):
assert not runner.state.initialized
runner.start()
assert runner.state.initialized
Expand All @@ -19,23 +39,20 @@ def test_initialize():
assert not runner.state.initialized


def test_reload():
runner = WorkerDispatcher()
def test_reload(runner: WorkerDispatcher):
runner.start()
assert runner.state.initialized
runner.reload()
assert runner.state.initialized
runner.stop()


def test_raises_if_used_before_started():
runner = WorkerDispatcher()
def test_raises_if_used_before_started(runner: WorkerDispatcher):
with pytest.raises(InvalidRunnerStateError):
assert runner.run(interface.get_plans) is None
runner.run(interface.get_plans)


def test_error_on_runner_setup():
runner = WorkerDispatcher(use_subprocess=False)
def test_error_on_runner_setup(local_runner: WorkerDispatcher):
expected_state = EnvironmentResponse(
initialized=False,
error_message="Intentional start_worker exception",
Expand All @@ -48,8 +65,8 @@ def test_error_on_runner_setup():
# Calling reload here instead of start also indirectly
# tests that stop() doesn't raise if there is no error message
# and the runner is not yet initialised
runner.reload()
state = runner.state
local_runner.reload()
state = local_runner.state
assert state == expected_state


Expand Down Expand Up @@ -85,3 +102,62 @@ def test_can_reload_after_an_error(pool_mock: MagicMock):
runner.reload()

assert runner.state == EnvironmentResponse(initialized=True, error_message=None)


def test_function_not_findable_on_subprocess(started_runner: WorkerDispatcher):
from tests.core.fake_device_module import fake_motor_y

# Valid target on main but not sub process
# Change in this process not reflected in subprocess
fake_motor_y.__name__ = "not_exported"

with pytest.raises(
RpcError, match="not_exported: No such function in subprocess API"
):
started_runner.run(fake_motor_y)


def test_non_callable_excepts_in_main_process(started_runner: WorkerDispatcher):
# Not a valid target on main or sub process
from tests.core.fake_device_module import fetchable_non_callable

with pytest.raises(
RpcError,
match="<NonCallableMock id='[0-9]+'> is not Callable, "
+ "cannot be run in subprocess",
):
started_runner.run(fetchable_non_callable)


def test_non_callable_excepts_in_sub_process(started_runner: WorkerDispatcher):
# Valid target on main but finds non-callable in sub process
from tests.core.fake_device_module import fetchable_callable, fetchable_non_callable

fetchable_callable.__name__ = fetchable_non_callable.__name__

with pytest.raises(
RpcError,
match="fetchable_non_callable: Object in subprocess is not a function",
):
started_runner.run(fetchable_callable)


def test_clear_message_for_anonymous_function(started_runner: WorkerDispatcher):
non_fetchable_callable = MagicMock()

with pytest.raises(
RpcError,
match="<MagicMock id='[0-9]+'> is anonymous, cannot be run in subprocess",
):
started_runner.run(non_fetchable_callable)


def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher):
from tests.core.fake_device_module import wrong_return_type

with pytest.raises(
TypeError,
match="wrong_return_type returned value of type <class 'str'>"
+ " which is incompatible with expected <class 'int'>",
):
started_runner.run(wrong_return_type)
Loading