diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 2f9075f9b..2aef04413 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -1,4 +1,6 @@ +import logging from collections.abc import Mapping +from functools import lru_cache from typing import Any from blueapi.config import ApplicationConfig @@ -6,11 +8,7 @@ from blueapi.core.event import EventStream from blueapi.messaging.base import MessagingTemplate from blueapi.messaging.stomptemplate import StompMessagingTemplate -from blueapi.service.model import ( - DeviceModel, - PlanModel, - WorkerTask, -) +from blueapi.service.model import DeviceModel, PlanModel, WorkerTask from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.reworker import TaskWorker from blueapi.worker.task import Task @@ -20,61 +18,77 @@ context and worker""" -class InitialisationException(Exception): - pass +_CONFIG: ApplicationConfig = ApplicationConfig() -class _Singleton: - context: BlueskyContext - worker: Worker - messaging_template: MessagingTemplate | None = None - initialized = False +def config() -> ApplicationConfig: + return _CONFIG -def start_worker( - config: ApplicationConfig, - bluesky_context: BlueskyContext | None = None, - worker: TaskWorker | None = None, -) -> None: - """Creates and starts a worker with supplied config""" - if _Singleton.initialized: - raise InitialisationException( - "Worker is already running. To reload call stop first" - ) - if bluesky_context is None: - _Singleton.context = BlueskyContext() - _Singleton.context.with_config(config.env) - else: - _Singleton.context = bluesky_context +def set_config(new_config: ApplicationConfig): + global _CONFIG - if worker is None: - _Singleton.worker = TaskWorker( - _Singleton.context, - broadcast_statuses=config.env.events.broadcast_status_events, - ) - else: - _Singleton.worker = worker - if config.stomp is not None: - _Singleton.messaging_template = StompMessagingTemplate.autoconfigured( - config.stomp - ) + _CONFIG = new_config - # Start worker and setup events - _Singleton.worker.start() - if _Singleton.messaging_template is not None: - event_topic = _Singleton.messaging_template.destinations.topic( - "public.worker.event" - ) + +@lru_cache +def context() -> BlueskyContext: + ctx = BlueskyContext() + ctx.with_config(config().env) + return ctx + + +@lru_cache +def worker() -> Worker: + worker = TaskWorker( + context(), + broadcast_statuses=config().env.events.broadcast_status_events, + ) + worker.start() + return worker + + +@lru_cache +def messaging_template() -> MessagingTemplate | None: + stomp_config = config().stomp + if stomp_config is not None: + template = StompMessagingTemplate.autoconfigured(stomp_config) + + task_worker = worker() + event_topic = template.destinations.topic("public.worker.event") _publish_event_streams( { - _Singleton.worker.worker_events: event_topic, - _Singleton.worker.progress_events: event_topic, - _Singleton.worker.data_events: event_topic, + task_worker.worker_events: event_topic, + task_worker.progress_events: event_topic, + task_worker.data_events: event_topic, } ) - _Singleton.messaging_template.connect() - _Singleton.initialized = True + template.connect() + return template + else: + return None + + +def setup(config: ApplicationConfig) -> None: + """Creates and starts a worker with supplied config""" + + set_config(config) + + # Eagerly initialize worker and messaging connection + + logging.basicConfig(level=config.logging.level) + worker() + messaging_template() + + +def teardown() -> None: + worker().stop() + if (template := messaging_template()) is not None: + template.disconnect() + context.cache_clear() + worker.cache_clear() + messaging_template.cache_clear() def _publish_event_streams(streams_to_destinations: Mapping[EventStream, str]) -> None: @@ -84,130 +98,87 @@ def _publish_event_streams(streams_to_destinations: Mapping[EventStream, str]) - def _publish_event_stream(stream: EventStream, destination: str) -> None: def forward_message(event: Any, correlation_id: str | None) -> None: - if _Singleton.messaging_template is not None: - _Singleton.messaging_template.send(destination, event, None, correlation_id) + if (template := messaging_template()) is not None: + template.send(destination, event, None, correlation_id) stream.subscribe(forward_message) -def stop_worker() -> None: - if not _Singleton.initialized: - raise InitialisationException( - "Cannot stop worker as it hasn't been started yet" - ) - _Singleton.initialized = False - _Singleton.worker.stop() - if ( - _Singleton.messaging_template is not None - and _Singleton.messaging_template.is_connected() - ): - _Singleton.messaging_template.disconnect() - - def get_plans() -> list[PlanModel]: """Get all available plans in the BlueskyContext""" - _ensure_worker_started() - return [PlanModel.from_plan(plan) for plan in _Singleton.context.plans.values()] + return [PlanModel.from_plan(plan) for plan in context().plans.values()] def get_plan(name: str) -> PlanModel: """Get plan by name from the BlueskyContext""" - _ensure_worker_started() - return PlanModel.from_plan(_Singleton.context.plans[name]) + return PlanModel.from_plan(context().plans[name]) def get_devices() -> list[DeviceModel]: """Get all available devices in the BlueskyContext""" - _ensure_worker_started() - return [ - DeviceModel.from_device(device) - for device in _Singleton.context.devices.values() - ] + return [DeviceModel.from_device(device) for device in context().devices.values()] def get_device(name: str) -> DeviceModel: """Retrieve device by name from the BlueskyContext""" - _ensure_worker_started() - return DeviceModel.from_device(_Singleton.context.devices[name]) + return DeviceModel.from_device(context().devices[name]) def submit_task(task: Task) -> str: """Submit a task to be run on begin_task""" - _ensure_worker_started() - return _Singleton.worker.submit_task(task) + return worker().submit_task(task) def clear_task(task_id: str) -> str: """Remove a task from the worker""" - _ensure_worker_started() - return _Singleton.worker.clear_task(task_id) + return worker().clear_task(task_id) def begin_task(task: WorkerTask) -> WorkerTask: """Trigger a task. Will fail if the worker is busy""" - _ensure_worker_started() if task.task_id is not None: - _Singleton.worker.begin_task(task.task_id) + worker().begin_task(task.task_id) return task def get_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]: """Retrieve a list of tasks based on their status.""" - _ensure_worker_started() - return _Singleton.worker.get_tasks_by_status(status) + return worker().get_tasks_by_status(status) def get_active_task() -> TrackableTask | None: """Task the worker is currently running""" - _ensure_worker_started() - return _Singleton.worker.get_active_task() + return worker().get_active_task() def get_worker_state() -> WorkerState: """State of the worker""" - _ensure_worker_started() - return _Singleton.worker.state + return worker().state def pause_worker(defer: bool | None) -> None: """Command the worker to pause""" - _ensure_worker_started() - _Singleton.worker.pause(defer) + worker().pause(defer) def resume_worker() -> None: """Command the worker to resume""" - _ensure_worker_started() - _Singleton.worker.resume() + worker().resume() def cancel_active_task(failure: bool, reason: str | None) -> str: """Remove the currently active task from the worker if there is one Returns the task_id of the active task""" - _ensure_worker_started() - return _Singleton.worker.cancel_active_task(failure, reason) + return worker().cancel_active_task(failure, reason) def get_tasks() -> list[TrackableTask]: """Return a list of all tasks on the worker, any one of which can be triggered with begin_task""" - _ensure_worker_started() - return _Singleton.worker.get_tasks() + return worker().get_tasks() def get_task_by_id(task_id: str) -> TrackableTask | None: """Returns a task matching the task ID supplied, if the worker knows of it""" - _ensure_worker_started() - return _Singleton.worker.get_task_by_id(task_id) - - -def get_state() -> bool: - """Initialization state""" - return _Singleton.initialized - - -def _ensure_worker_started() -> None: - if _Singleton.initialized: - return - raise InitialisationException("Worker must be stared before it is used") + return worker().get_task_by_id(task_id) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index b8670b277..f386483c4 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -100,12 +100,8 @@ async def delete_environment( ) -> EnvironmentResponse: """Delete the current environment, causing internal components to be reloaded.""" - def restart_runner(runner: WorkerDispatcher): - runner.stop() - runner.start() - if runner.state.initialized or runner.state.error_message is not None: - background_tasks.add_task(restart_runner, runner) + background_tasks.add_task(runner.reload) return EnvironmentResponse(initialized=False) diff --git a/src/blueapi/service/runner.py b/src/blueapi/service/runner.py index 055dd3130..8d6fb6b2d 100644 --- a/src/blueapi/service/runner.py +++ b/src/blueapi/service/runner.py @@ -6,7 +6,10 @@ from typing import Any from blueapi.config import ApplicationConfig -from blueapi.service.interface import InitialisationException, start_worker, stop_worker +from blueapi.service.interface import ( + setup, + teardown, +) from blueapi.service.model import EnvironmentResponse # The default multiprocessing start method is fork @@ -32,51 +35,52 @@ class WorkerDispatcher: _state: EnvironmentResponse def __init__( - self, config: ApplicationConfig | None = None, use_subprocess: bool = True + self, + config: ApplicationConfig | None = None, + use_subprocess: bool = True, ) -> None: self._config = config or ApplicationConfig() self._subprocess = None self._use_subprocess = use_subprocess - self._state = EnvironmentResponse(initialized=False) + self._state = EnvironmentResponse( + initialized=False, + ) + + def reload(self): + """Reload the subprocess to account for any changes in python modules""" + self.stop() + self.start() + LOGGER.info("Runner reloaded") def start(self): - if self._subprocess is None and self._use_subprocess: - self._subprocess = Pool(initializer=_init_worker, processes=1) - self._subprocess.apply( - logging.basicConfig, kwds={"level": self._config.logging.level} - ) try: - self.run(start_worker, [self._config]) + if self._use_subprocess: + self._subprocess = Pool(initializer=_init_worker, processes=1) + self.run(setup, [self._config]) + self._state = EnvironmentResponse(initialized=True) except Exception as e: self._state = EnvironmentResponse( - initialized=False, error_message=f"Error configuring blueapi: {e}" + initialized=False, + error_message=str(e), ) - LOGGER.exception(self._state.error_message) - return - self._state = EnvironmentResponse(initialized=True) + LOGGER.exception(e) def stop(self): - if self._subprocess is not None: - self._state = EnvironmentResponse(initialized=False) - try: - self._subprocess.apply(stop_worker) - except InitialisationException: - # There was an error initialising - pass - self._subprocess.close() - self._subprocess.join() - self._subprocess = None - if (not self._use_subprocess) and ( - self._state.initialized or self._state.error_message - ): - self._state = EnvironmentResponse(initialized=False) - stop_worker() - - def reload_context(self): - """Reload the subprocess to account for any changes in python modules""" - self.stop() - self.start() - LOGGER.info("Context reloaded") + try: + self.run(teardown) + if self._subprocess is not None: + self._subprocess.close() + self._subprocess.join() + self._state = EnvironmentResponse( + initialized=False, + error_message=self._state.error_message, + ) + except Exception as e: + self._state = EnvironmentResponse( + initialized=False, + error_message=str(e), + ) + LOGGER.exception(e) def run(self, function: Callable, arguments: Iterable | None = None) -> Any: arguments = arguments or [] @@ -85,9 +89,13 @@ def run(self, function: Callable, arguments: Iterable | None = None) -> Any: else: return function(*arguments) - def _run_in_subprocess(self, function: Callable, arguments: Iterable) -> Any: + def _run_in_subprocess( + self, + function: Callable, + arguments: Iterable, + ) -> Any: if self._subprocess is None: - raise RunnerNotStartedError("Subprocess runner has not been started") + raise InvalidRunnerStateError("Subprocess runner has not been started") return self._subprocess.apply(function, arguments) @property @@ -95,6 +103,6 @@ def state(self) -> EnvironmentResponse: return self._state -class RunnerNotStartedError(Exception): +class InvalidRunnerStateError(Exception): def __init__(self, message): super().__init__(message) diff --git a/tests/service/test_interface.py b/tests/service/test_interface.py index 82d35cf84..bab9a1826 100644 --- a/tests/service/test_interface.py +++ b/tests/service/test_interface.py @@ -22,33 +22,7 @@ def ensure_worker_stopped(): of an assertion error. The start_worker method is not managed by a fixture as some of the tests require it to be customised.""" yield - if interface.get_state(): - interface.stop_worker() - - -def test_start_worker_raises_if_already_started(): - interface.start_worker(ApplicationConfig()) - with pytest.raises(interface.InitialisationException): - interface.start_worker(ApplicationConfig()) - - -def test_stop_worker_raises_if_already_started(): - interface.start_worker(ApplicationConfig()) - interface.stop_worker() - with pytest.raises(interface.InitialisationException): - interface.stop_worker() - - -def test_exception_if_used_before_started(): - with pytest.raises(interface.InitialisationException): - interface.get_active_task() - - -def test_stomp_config(): - stomp_config = StompConfig() - config = ApplicationConfig() - config.stomp = stomp_config - interface.start_worker(config) + interface.teardown() def my_plan() -> MsgGenerator: @@ -61,17 +35,18 @@ def my_second_plan(repeats: int) -> MsgGenerator: yield from {} -def test_get_plans(): +@patch("blueapi.service.interface.context") +def test_get_plans(context_mock: MagicMock): context = BlueskyContext() context.plan(my_plan) context.plan(my_second_plan) - interface.start_worker(ApplicationConfig(), bluesky_context=context) + context_mock.return_value = context assert interface.get_plans() == [ PlanModel( name="my_plan", description="My plan does cool stuff.", - parameter_schema={ + schema={ "title": "my_plan", "type": "object", "properties": {}, @@ -81,7 +56,7 @@ def test_get_plans(): PlanModel( name="my_second_plan", description="Plan B.", - parameter_schema={ + schema={ "title": "my_second_plan", "type": "object", "properties": {"repeats": {"title": "Repeats", "type": "integer"}}, @@ -92,16 +67,17 @@ def test_get_plans(): ] -def test_get_plan(): +@patch("blueapi.service.interface.context") +def test_get_plan(context_mock: MagicMock): context = BlueskyContext() context.plan(my_plan) context.plan(my_second_plan) - interface.start_worker(ApplicationConfig(), bluesky_context=context) + context_mock.return_value = context assert interface.get_plan("my_plan") == PlanModel( name="my_plan", description="My plan does cool stuff.", - parameter_schema={ + schema={ "title": "my_plan", "type": "object", "properties": {}, @@ -118,11 +94,12 @@ class MyDevice: name: str -def test_get_devices(): +@patch("blueapi.service.interface.context") +def test_get_devices(context_mock: MagicMock): context = BlueskyContext() context.device(MyDevice(name="my_device")) context.device(SynAxis(name="my_axis")) - interface.start_worker(ApplicationConfig(), bluesky_context=context) + context_mock.return_value = context assert interface.get_devices() == [ DeviceModel(name="my_device", protocols=["HasName"]), @@ -146,10 +123,12 @@ def test_get_devices(): ] -def test_get_device(): +@patch("blueapi.service.interface.context") +def test_get_device(context_mock: MagicMock): context = BlueskyContext() context.device(MyDevice(name="my_device")) - interface.start_worker(ApplicationConfig(), bluesky_context=context) + context_mock.return_value = context + assert interface.get_device("my_device") == DeviceModel( name="my_device", protocols=["HasName"] ) @@ -158,11 +137,12 @@ def test_get_device(): assert interface.get_device("non_existing_device") -def test_submit_task(): +@patch("blueapi.service.interface.context") +def test_submit_task(context_mock: MagicMock): context = BlueskyContext() context.plan(my_plan) task = Task(name="my_plan") - interface.start_worker(ApplicationConfig(), bluesky_context=context) + context_mock.return_value = context mock_uuid_value = "8dfbb9c2-7a15-47b6-bea8-b6b77c31d3d9" with patch.object(uuid, "uuid4") as uuid_mock: uuid_mock.return_value = uuid.UUID(mock_uuid_value) @@ -170,11 +150,12 @@ def test_submit_task(): assert task_uuid == mock_uuid_value -def test_clear_task(): +@patch("blueapi.service.interface.context") +def test_clear_task(context_mock: MagicMock): context = BlueskyContext() context.plan(my_plan) task = Task(name="my_plan") - interface.start_worker(ApplicationConfig(), bluesky_context=context) + context_mock.return_value = context mock_uuid_value = "3d858a62-b40a-400f-82af-8d2603a4e59a" with patch.object(uuid, "uuid4") as uuid_mock: uuid_mock.return_value = uuid.UUID(mock_uuid_value) @@ -186,7 +167,6 @@ def test_clear_task(): @patch("blueapi.service.interface.TaskWorker.begin_task") def test_begin_task(worker_mock: MagicMock): - interface.start_worker(ApplicationConfig()) uuid_value = "350043fd-597e-41a7-9a92-5d5478232cf7" task = WorkerTask(task_id=uuid_value) returned_task = interface.begin_task(task) @@ -196,7 +176,6 @@ def test_begin_task(worker_mock: MagicMock): @patch("blueapi.service.interface.TaskWorker.begin_task") def test_begin_task_no_task_id(worker_mock: MagicMock): - interface.start_worker(ApplicationConfig()) task = WorkerTask(task_id=None) returned_task = interface.begin_task(task) assert task == returned_task @@ -219,8 +198,6 @@ def mock_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]: get_tasks_by_status_mock.side_effect = mock_tasks_by_status - interface.start_worker(ApplicationConfig()) - assert interface.get_tasks_by_status(TaskStatusEnum.PENDING) == [ pending_task1, pending_task2, @@ -230,18 +207,15 @@ def mock_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]: def test_get_active_task(): - interface.start_worker(ApplicationConfig()) assert interface.get_active_task() is None def test_get_worker_state(): - interface.start_worker(ApplicationConfig()) assert interface.get_worker_state() == WorkerState.IDLE @patch("blueapi.service.interface.TaskWorker.pause") def test_pause_worker(pause_worker_mock: MagicMock): - interface.start_worker(ApplicationConfig()) interface.pause_worker(False) pause_worker_mock.assert_called_once_with(False) @@ -252,14 +226,12 @@ def test_pause_worker(pause_worker_mock: MagicMock): @patch("blueapi.service.interface.TaskWorker.resume") def test_resume_worker(resume_worker_mock: MagicMock): - interface.start_worker(ApplicationConfig()) interface.resume_worker() resume_worker_mock.assert_called_once() @patch("blueapi.service.interface.TaskWorker.cancel_active_task") def test_cancel_active_task(cancel_active_task_mock: MagicMock): - interface.start_worker(ApplicationConfig()) fail = True reason = "End of session" task_id = "789" @@ -277,15 +249,14 @@ def test_get_tasks(get_tasks_mock: MagicMock): ] get_tasks_mock.return_value = tasks - interface.start_worker(ApplicationConfig()) - assert interface.get_tasks() == tasks -def test_get_task_by_id(): +@patch("blueapi.service.interface.context") +def test_get_task_by_id(context_mock: MagicMock): context = BlueskyContext() context.plan(my_plan) - interface.start_worker(ApplicationConfig(), bluesky_context=context) + context_mock.return_value = context task_id = interface.submit_task(Task(name="my_plan")) @@ -296,3 +267,9 @@ def test_get_task_by_id(): is_pending=True, errors=[], ) + + +@pytest.mark.stomp +def test_stomp_config(): + interface.set_config(ApplicationConfig(stomp=StompConfig())) + assert interface.messaging_template() is not None diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index 65cddd9bb..068cf40b9 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -21,8 +21,7 @@ @pytest.fixture def client() -> Iterator[TestClient]: with ( - patch("blueapi.service.runner.start_worker"), - patch("blueapi.service.runner.stop_worker"), + patch("blueapi.service.interface.worker"), ): main.setup_runner(use_subprocess=False) yield TestClient(main.app) diff --git a/tests/service/test_runner.py b/tests/service/test_runner.py index bf5457113..4a0e9f74a 100644 --- a/tests/service/test_runner.py +++ b/tests/service/test_runner.py @@ -5,7 +5,7 @@ from blueapi.service import interface from blueapi.service.model import EnvironmentResponse -from blueapi.service.runner import RunnerNotStartedError, WorkerDispatcher +from blueapi.service.runner import InvalidRunnerStateError, WorkerDispatcher def test_initialize(): @@ -14,7 +14,7 @@ def test_initialize(): runner.start() assert runner.state.initialized # Run a single call to the runner for coverage of dispatch to subprocess - assert runner.run(interface.get_state) + assert runner.run(interface.get_worker_state) runner.stop() assert not runner.state.initialized @@ -23,14 +23,14 @@ def test_reload(): runner = WorkerDispatcher() runner.start() assert runner.state.initialized - runner.reload_context() + runner.reload() assert runner.state.initialized runner.stop() def test_raises_if_used_before_started(): runner = WorkerDispatcher() - with pytest.raises(RunnerNotStartedError): + with pytest.raises(InvalidRunnerStateError): assert runner.run(interface.get_plans) is None @@ -38,17 +38,17 @@ def test_error_on_runner_setup(): runner = WorkerDispatcher(use_subprocess=False) expected_state = EnvironmentResponse( initialized=False, - error_message="Error configuring blueapi: Intentional start_worker exception", + error_message="Intentional start_worker exception", ) with mock.patch( - "blueapi.service.runner.start_worker", + "blueapi.service.runner.setup", side_effect=Exception("Intentional start_worker exception"), ): # Calling reload here instead of start also indirectly # tests that stop() doesn't raise if there is no error message # and the runner is not yet initialised - runner.reload_context() + runner.reload() state = runner.state assert state == expected_state @@ -68,10 +68,8 @@ def test_can_reload_after_an_error(pool_mock: MagicMock): # all calls to subprocess (poll::apply) are mocked subprocess_calls_return_values = [ - None, # logging setup SyntaxError("invalid code"), # start_worker None, # stop_worker - None, # logging_setup None, # start_worker ] @@ -81,9 +79,9 @@ def test_can_reload_after_an_error(pool_mock: MagicMock): runner.start() assert runner.state == EnvironmentResponse( - initialized=False, error_message="Error configuring blueapi: invalid code" + initialized=False, error_message="invalid code" ) - runner.reload_context() + runner.reload() assert runner.state == EnvironmentResponse(initialized=True, error_message=None)