diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index fdb910b29..9b08c4752 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -1,33 +1,32 @@ import json import logging -from collections import deque from functools import wraps from pathlib import Path from pprint import pprint -from time import sleep import click +from bluesky.callbacks.best_effort import BestEffortCallback from pydantic import ValidationError from requests.exceptions import ConnectionError from blueapi import __version__ -from blueapi.cli.event_bus_client import BlueskyRemoteError, EventBusClient from blueapi.cli.format import OutputFormat +from blueapi.client.client import BlueapiClient +from blueapi.client.event_bus import AnyEvent, BlueskyRemoteError, EventBusClient from blueapi.config import ApplicationConfig, ConfigLoader from blueapi.core import DataEvent from blueapi.messaging import MessageContext from blueapi.messaging.stomptemplate import StompMessagingTemplate from blueapi.service.main import start -from blueapi.service.model import WorkerTask from blueapi.service.openapi import ( DOCS_SCHEMA_LOCATION, generate_schema, print_schema_as_yaml, write_schema_as_yaml, ) -from blueapi.worker import ProgressEvent, Task, WorkerEvent, WorkerState +from blueapi.worker import ProgressEvent, Task, WorkerEvent -from .rest import BlueapiRestClient +from .updates import CliEventRenderer @click.group(invoke_without_command=True) @@ -106,7 +105,7 @@ def controller(ctx: click.Context, output: str) -> None: ctx.ensure_object(dict) config: ApplicationConfig = ctx.obj["config"] ctx.obj["fmt"] = OutputFormat(output) - ctx.obj["rest_client"] = BlueapiRestClient(config.api) + ctx.obj["client"] = BlueapiClient.from_config(config) def check_connection(func): @@ -125,7 +124,7 @@ def wrapper(*args, **kwargs): @click.pass_obj def get_plans(obj: dict) -> None: """Get a list of plans available for the worker to use""" - client: BlueapiRestClient = obj["rest_client"] + client: BlueapiClient = obj["client"] obj["fmt"].display(client.get_plans()) @@ -134,7 +133,7 @@ def get_plans(obj: dict) -> None: @click.pass_obj def get_devices(obj: dict) -> None: """Get a list of devices available for the worker to use""" - client: BlueapiRestClient = obj["rest_client"] + client: BlueapiClient = obj["client"] obj["fmt"].display(client.get_devices()) @@ -183,30 +182,24 @@ def run_plan( obj: dict, name: str, parameters: str | None, timeout: float | None ) -> None: """Run a plan with parameters""" - config: ApplicationConfig = obj["config"] - client: BlueapiRestClient = obj["rest_client"] - - logger = logging.getLogger(__name__) - if config.stomp is not None: - _message_template = StompMessagingTemplate.autoconfigured(config.stomp) - else: - raise RuntimeError( - "Cannot run plans without Stomp configuration to track progress" - ) - event_bus_client = EventBusClient(_message_template) - finished_event: deque[WorkerEvent] = deque() - - def store_finished_event(event: WorkerEvent) -> None: - if event.is_complete(): - finished_event.append(event) + client: BlueapiClient = obj["client"] parameters = parameters or "{}" task_id = "" parsed_params = json.loads(parameters) if isinstance(parameters, str) else {} + + progress_bar = CliEventRenderer() + callback = BestEffortCallback() + + def on_event(event: AnyEvent) -> None: + if isinstance(event, ProgressEvent): + progress_bar.on_progress_event(event) + elif isinstance(event, DataEvent): + callback(event.name, event.doc) + try: task = Task(name=name, params=parsed_params) - resp = client.create_task(task) - task_id = resp.task_id + resp = client.run_task(task, on_event=on_event) except ValidationError as e: pprint(f"failed to validate the task parameters, {task_id}, error: {e}") return @@ -217,18 +210,7 @@ def store_finished_event(event: WorkerEvent) -> None: pprint("task could not run") return - with event_bus_client: - event_bus_client.subscribe_to_topics(task_id, on_event=store_finished_event) - updated = client.update_worker_task(WorkerTask(task_id=task_id)) - - event_bus_client.wait_for_complete(timeout=timeout) - - if event_bus_client.timed_out: - logger.error(f"Plan did not complete within {timeout} seconds") - return - - process_event_after_finished(finished_event.pop(), logger) - pprint(updated.dict()) + pprint(resp.dict()) @controller.command(name="state") @@ -237,7 +219,7 @@ def store_finished_event(event: WorkerEvent) -> None: def get_state(obj: dict) -> None: """Print the current state of the worker""" - client: BlueapiRestClient = obj["rest_client"] + client: BlueapiClient = obj["client"] print(client.get_state().name) @@ -248,8 +230,8 @@ def get_state(obj: dict) -> None: def pause(obj: dict, defer: bool = False) -> None: """Pause the execution of the current task""" - client: BlueapiRestClient = obj["rest_client"] - pprint(client.set_state(WorkerState.PAUSED, defer=defer)) + client: BlueapiClient = obj["client"] + pprint(client.pause(defer=defer)) @controller.command(name="resume") @@ -258,8 +240,8 @@ def pause(obj: dict, defer: bool = False) -> None: def resume(obj: dict) -> None: """Resume the execution of the current task""" - client: BlueapiRestClient = obj["rest_client"] - pprint(client.set_state(WorkerState.RUNNING)) + client: BlueapiClient = obj["client"] + pprint(client.resume()) @controller.command(name="abort") @@ -272,8 +254,8 @@ def abort(obj: dict, reason: str | None = None) -> None: with optional reason """ - client: BlueapiRestClient = obj["rest_client"] - pprint(client.cancel_current_task(state=WorkerState.ABORTING, reason=reason)) + client: BlueapiClient = obj["client"] + pprint(client.abort(reason=reason)) @controller.command(name="stop") @@ -284,8 +266,8 @@ def stop(obj: dict) -> None: Stop the execution of the current task, marking as ongoing runs as success """ - client: BlueapiRestClient = obj["rest_client"] - pprint(client.cancel_current_task(state=WorkerState.STOPPING)) + client: BlueapiClient = obj["client"] + pprint(client.stop()) @controller.command(name="env") @@ -298,60 +280,29 @@ def stop(obj: dict) -> None: help="Reload the current environment", default=False, ) +@click.option( + "-t", + "--timeout", + type=float, + help="Timeout to wait for reload in seconds, defaults to 10", + default=10.0, +) @click.pass_obj -def env(obj: dict, reload: bool | None) -> None: +def env( + obj: dict, + reload: bool | None, + timeout: float | None, +) -> None: """ Inspect or restart the environment """ - assert isinstance(client := obj["rest_client"], BlueapiRestClient) + assert isinstance(client := obj["client"], BlueapiClient) if reload: # Reload the environment if needed - print("Reloading the environment...") - try: - deserialized = client.reload_environment() - print(deserialized) - - except BlueskyRemoteError as e: - raise BlueskyRemoteError("Failed to reload the environment") from e - - # Initialize a variable to keep track of the environment status - environment_initialized = False - polling_count = 0 - max_polling_count = 10 - # Use a while loop to keep checking until the environment is initialized - while not environment_initialized and polling_count < max_polling_count: - # Fetch the current environment status - environment_status = client.get_environment() - - # Check if the environment is initialized - if environment_status.initialized: - print("Environment is initialized.") - environment_initialized = True - else: - print("Waiting for environment to initialize...") - polling_count += 1 - sleep(1) # Wait for 1 seconds before checking again - if polling_count == max_polling_count: - raise TimeoutError("Environment initialization timed out.") - - # Once out of the loop, print the initialized environment status - print(environment_status) + print("Reloading environment") + status = client.reload_environment(timeout=timeout) + print("Environment is initialized.") else: - print(client.get_environment()) - - -# helper function -def process_event_after_finished(event: WorkerEvent, logger: logging.Logger): - if event.is_error(): - logger.info("Failed with errors: \n") - for error in event.errors: - logger.error(error) - return - if len(event.warnings) != 0: - logger.info("Passed with warnings: \n") - for warning in event.warnings: - logger.warn(warning) - return - - logger.info("Plan passed") + status = client.get_environment() + print(status) diff --git a/src/blueapi/cli/event_bus_client.py b/src/blueapi/cli/event_bus_client.py deleted file mode 100644 index afa2e4416..000000000 --- a/src/blueapi/cli/event_bus_client.py +++ /dev/null @@ -1,76 +0,0 @@ -import threading -from collections.abc import Callable - -from bluesky.callbacks.best_effort import BestEffortCallback - -from blueapi.core import DataEvent -from blueapi.messaging import MessageContext, MessagingTemplate -from blueapi.worker import ProgressEvent, WorkerEvent - -from .updates import CliEventRenderer - - -class BlueskyRemoteError(Exception): - def __init__(self, message: str) -> None: - super().__init__(message) - - -_Event = WorkerEvent | ProgressEvent | DataEvent - - -class EventBusClient: - app: MessagingTemplate - complete: threading.Event - timed_out: bool | None - - def __init__(self, app: MessagingTemplate) -> None: - self.app = app - self.complete = threading.Event() - self.timed_out = None - - def __enter__(self) -> None: - self.app.connect() - - def __exit__(self, exc_type, exc_value, exc_traceback) -> None: - self.app.disconnect() - - def subscribe_to_topics( - self, - correlation_id: str, - on_event: Callable[[WorkerEvent], None] | None = None, - ) -> None: - """Run callbacks on events/progress events with a given correlation id.""" - - progress_bar = CliEventRenderer(correlation_id) - callback = BestEffortCallback() - - def on_event_wrapper( - ctx: MessageContext, - event: _Event, - ) -> None: - if isinstance(event, WorkerEvent): - if (on_event is not None) and (ctx.correlation_id == correlation_id): - on_event(event) - - if (event.is_complete()) and (ctx.correlation_id == correlation_id): - self.complete.set() - elif isinstance(event, ProgressEvent): - progress_bar.on_progress_event(event) - elif isinstance(event, DataEvent): - callback(event.name, event.doc) - - self.subscribe_to_all_events(on_event_wrapper) - - def subscribe_to_all_events( - self, - on_event: Callable[[MessageContext, _Event], None], - ) -> None: - self.app.subscribe( - self.app.destinations.topic("public.worker.event"), - on_event, - ) - - def wait_for_complete(self, timeout: float | None = None) -> None: - self.timed_out = not self.complete.wait(timeout=timeout) - - self.complete.clear() diff --git a/src/blueapi/cli/updates.py b/src/blueapi/cli/updates.py index 7b7d73dd0..c3d922c6b 100644 --- a/src/blueapi/cli/updates.py +++ b/src/blueapi/cli/updates.py @@ -43,36 +43,18 @@ def _update(self, name: str, view: StatusView) -> None: class CliEventRenderer: - _task_id: str | None _pbar_renderer: ProgressBarRenderer def __init__( self, - task_id: str | None = None, pbar_renderer: ProgressBarRenderer | None = None, ) -> None: - self._task_id = task_id if pbar_renderer is None: pbar_renderer = ProgressBarRenderer() self._pbar_renderer = pbar_renderer def on_progress_event(self, event: ProgressEvent) -> None: - if self._relates_to_task(event): - self._pbar_renderer.update(event.statuses) + self._pbar_renderer.update(event.statuses) def on_worker_event(self, event: WorkerEvent) -> None: - if self._relates_to_task(event): - print(str(event.state)) - - def _relates_to_task(self, event: WorkerEvent | ProgressEvent) -> bool: - if self._task_id is None: - return True - elif isinstance(event, WorkerEvent): - return ( - event.task_status is not None - and event.task_status.task_id == self._task_id - ) - elif isinstance(event, ProgressEvent): - return event.task_id == self._task_id - else: - return False + print(str(event.state)) diff --git a/src/blueapi/client/__init__.py b/src/blueapi/client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py new file mode 100644 index 000000000..eeabccc1b --- /dev/null +++ b/src/blueapi/client/client.py @@ -0,0 +1,366 @@ +import time +from concurrent.futures import Future + +from blueapi.config import ApplicationConfig +from blueapi.core.bluesky_types import DataEvent +from blueapi.messaging import MessageContext, StompMessagingTemplate +from blueapi.service.model import ( + DeviceModel, + DeviceResponse, + EnvironmentResponse, + PlanModel, + PlanResponse, + TaskResponse, + WorkerTask, +) +from blueapi.worker import Task, TrackableTask, WorkerEvent, WorkerState +from blueapi.worker.event import ProgressEvent, TaskStatus + +from .event_bus import AnyEvent, BlueskyRemoteError, EventBusClient, OnAnyEvent +from .rest import BlueapiRestClient + + +class BlueapiClient: + """Unified client for controlling blueapi""" + + _rest: BlueapiRestClient + _events: EventBusClient | None + + def __init__( + self, + rest: BlueapiRestClient, + events: EventBusClient | None = None, + ): + self._rest = rest + self._events = events + + @classmethod + def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": + rest = BlueapiRestClient(config.api) + if config.stomp is not None: + template = StompMessagingTemplate.autoconfigured(config.stomp) + events = EventBusClient(template) + else: + events = None + return cls(rest, events) + + def get_plans(self) -> PlanResponse: + """ + List plans available + + Returns: + PlanResponse: Plans that can be run + """ + return self._rest.get_plans() + + def get_plan(self, name: str) -> PlanModel: + """ + Get details of a single plan + + Args: + name: Plan name + + Returns: + PlanModel: Details of the plan if found + """ + return self._rest.get_plan(name) + + def get_devices(self) -> DeviceResponse: + """ + List devices available + + Returns: + DeviceResponse: Devices that can be used in plans + """ + + return self._rest.get_devices() + + def get_device(self, name: str) -> DeviceModel: + """ + Get details of a single device + + Args: + name: Device name + + Returns: + DeviceModel: Details of the device if found + """ + + return self._rest.get_device(name) + + def get_state(self) -> WorkerState: + """ + Get current state of the blueapi worker + + Returns: + WorkerState: Current state + """ + + return self._rest.get_state() + + def pause(self, defer: bool = False) -> WorkerState: + """ + Pause execution of the current task, if any + + Args: + defer: Wait until the next checkpoint to pause. + Defaults to False. + + Returns: + WorkerState: Final state of the worker following + pause operation + """ + + return self._rest.set_state(WorkerState.PAUSED, defer=defer) + + def resume(self) -> WorkerState: + """ + Resume plan execution if previously paused + + Returns: + WorkerState: Final state of the worker following + resume operation + """ + + return self._rest.set_state(WorkerState.RUNNING, defer=False) + + def get_task(self, task_id: str) -> TrackableTask[Task]: + """ + Get a task stored by the worker + + Args: + task_id: Unique ID for the task + + Returns: + TrackableTask[Task]: Task details + """ + + return self._rest.get_task(task_id) + + def get_active_task(self) -> WorkerTask: + """ + Get the currently active task, if any + + Returns: + WorkerTask: The currently active task, the task the worker + is executing right now. + """ + + return self._rest.get_active_task() + + def run_task( + self, + task: Task, + on_event: OnAnyEvent | None = None, + timeout: float | None = None, + ) -> WorkerEvent: + """ + Synchronously run a task, requires a message bus connection + + Args: + task: Task to run + on_event: Callback for each event. Defaults to None. + timeout: Time to wait until the task is finished. + Defaults to None, so waits forever. + + Returns: + WorkerEvent: The final event, which includes final details + of task execution. + """ + + if self._events is None: + raise RuntimeError( + "Cannot run plans without Stomp configuration to track progress" + ) + + task_response = self.create_task(task) + task_id = task_response.task_id + + complete: Future[WorkerEvent] = Future() + + def inner_on_event(ctx: MessageContext, event: AnyEvent) -> None: + match event: + case WorkerEvent(task_status=TaskStatus(task_id=test_id)): + relates_to_task = test_id == task_id + case ProgressEvent(task_id=test_id): + relates_to_task = test_id == task_id + case DataEvent(): + relates_to_task = True + case _: + relates_to_task = False + if relates_to_task: + if on_event is not None: + on_event(event) + if isinstance(event, WorkerEvent) and ( + (event.is_complete()) and (ctx.correlation_id == task_id) + ): + if len(event.errors) > 0: + complete.set_exception( + BlueskyRemoteError("\n".join(event.errors)) + ) + else: + complete.set_result(event) + + with self._events: + self._events.subscribe_to_all_events(inner_on_event) + self.start_task(WorkerTask(task_id=task_id)) + return complete.result(timeout=timeout) + + def create_and_start_task(self, task: Task) -> TaskResponse: + """ + Create a new task and instruct the worker to start it + immediately. + + Args: + task: The task to create on the worker + + Returns: + TaskResponse: Acknowledgement of request + """ + + response = self.create_task(task) + worker_response = self.start_task(WorkerTask(task_id=response.task_id)) + if worker_response.task_id == response.task_id: + return response + else: + raise BlueskyRemoteError( + f"Tried to create and start task {response.task_id} " + f"but {worker_response.task_id} was started instead" + ) + + def create_task(self, task: Task) -> TaskResponse: + """ + Create a new task, does not start execution + + Args: + task: The task to create on the worker + + Returns: + TaskResponse: Acknowledgement of request + """ + + return self._rest.create_task(task) + + def clear_task(self, task_id: str) -> TaskResponse: + """ + Delete a stored task on the worker + + Args: + task_id: ID for the task + + Returns: + TaskResponse: Acknowledgement of request + """ + + return self._rest.clear_task(task_id) + + def start_task(self, task: WorkerTask) -> WorkerTask: + """ + Instruct the worker to start a stored task immediately + + Args: + task_id: ID for the task + + Returns: + WorkerTask: Acknowledgement of request + """ + + return self._rest.update_worker_task(task) + + def abort(self, reason: str | None = None) -> WorkerState: + """ + Abort the plan currently being executed, if any. + Stop execution, perform cleanup steps, mark the plan + as failed. + + Args: + reason: Reason for abort to include in the documents. + Defaults to None. + + Returns: + WorkerState: Final state of the worker following the + abort operation. + """ + + return self._rest.cancel_current_task( + WorkerState.ABORTING, + reason=reason, + ) + + def stop(self) -> WorkerState: + """ + Stop execution of the current plan early. + Stop execution, perform cleanup steps, but still mark the plan + as successful. + + Returns: + WorkerState: Final state of the worker following the + stop operation. + """ + + return self._rest.cancel_current_task(WorkerState.STOPPING) + + def get_environment(self) -> EnvironmentResponse: + """ + Get details of the worker environment + + Returns: + EnvironmentResponse: Details of the worker + environment. + """ + + return self._rest.get_environment() + + def reload_environment( + self, + timeout: float | None = None, + polling_interval: float = 0.5, + ) -> EnvironmentResponse: + """ + Teardown the worker environment and create a new one + + Args: + timeout: Time to wait for teardown. Defaults to None, + so waits forever. + polling_interval: If there is a timeout, the number of + seconds to wait between checking whether the environment + has been successfully reloaded. Defaults to 0.5. + + Returns: + EnvironmentResponse: Details of the new worker + environment. + """ + + try: + status = self._rest.delete_environment() + except Exception as e: + raise BlueskyRemoteError("Failed to tear down the environment") from e + return self._wait_for_reload( + status, + timeout, + polling_interval, + ) + + def _wait_for_reload( + self, + status: EnvironmentResponse, + timeout: float | None, + polling_interval: float = 0.5, + ) -> EnvironmentResponse: + teardown_complete_time = time.time() + too_late = teardown_complete_time + timeout if timeout is not None else None + + # Wait forever if there was no timeout + while too_late is None or time.time() < too_late: + # Poll until the environment is restarted or the timeout is reached + status = self._rest.get_environment() + if status.error_message is not None: + raise BlueskyRemoteError(status.error_message) + elif status.initialized: + return status + time.sleep(polling_interval) + # If the function did not raise or return early, it timed out + raise TimeoutError( + f"Failed to reload the environment within {timeout} " + "seconds, a server restart is recommended" + ) diff --git a/src/blueapi/client/event_bus.py b/src/blueapi/client/event_bus.py new file mode 100644 index 000000000..a6d48f687 --- /dev/null +++ b/src/blueapi/client/event_bus.py @@ -0,0 +1,46 @@ +import threading +from collections.abc import Callable + +from blueapi.core import DataEvent +from blueapi.messaging import MessageContext, MessagingTemplate +from blueapi.worker import ProgressEvent, WorkerEvent + + +class BlueskyRemoteError(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) + + +AnyEvent = WorkerEvent | ProgressEvent | DataEvent +OnAnyEvent = Callable[[AnyEvent], None] + + +class EventBusClient: + app: MessagingTemplate + complete: threading.Event + timed_out: bool | None + + def __init__(self, app: MessagingTemplate) -> None: + self.app = app + self.complete = threading.Event() + self.timed_out = None + + def __enter__(self) -> None: + self.app.connect() + + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + self.app.disconnect() + + def subscribe_to_all_events( + self, + on_event: Callable[[MessageContext, AnyEvent], None], + ) -> None: + self.app.subscribe( + self.app.destinations.topic("public.worker.event"), + on_event, + ) + + def wait_for_complete(self, timeout: float | None = None) -> None: + self.timed_out = not self.complete.wait(timeout=timeout) + + self.complete.clear() diff --git a/src/blueapi/cli/rest.py b/src/blueapi/client/rest.py similarity index 97% rename from src/blueapi/cli/rest.py rename to src/blueapi/client/rest.py index 9be4fd4c2..20e93dae4 100644 --- a/src/blueapi/cli/rest.py +++ b/src/blueapi/client/rest.py @@ -17,7 +17,7 @@ ) from blueapi.worker import Task, TrackableTask, WorkerState -from .event_bus_client import BlueskyRemoteError +from .event_bus import BlueskyRemoteError T = TypeVar("T") @@ -132,7 +132,7 @@ def _url(self, suffix: str) -> str: def get_environment(self) -> EnvironmentResponse: return self._request_and_deserialize("/environment", EnvironmentResponse) - def reload_environment(self) -> EnvironmentResponse: + def delete_environment(self) -> EnvironmentResponse: return self._request_and_deserialize( "/environment", EnvironmentResponse, method="DELETE" ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 074efb732..3b9a4fb91 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,8 +14,8 @@ from blueapi import __version__ from blueapi.cli.cli import main -from blueapi.cli.event_bus_client import BlueskyRemoteError from blueapi.cli.format import OutputFormat +from blueapi.client.event_bus import BlueskyRemoteError from blueapi.core.bluesky_types import Plan from blueapi.service.handler import Handler, teardown_handler from blueapi.service.model import ( @@ -239,9 +239,9 @@ def test_get_env( @pytest.mark.handler @patch("blueapi.service.handler.Handler") -@patch("blueapi.cli.rest.BlueapiRestClient.get_environment") -@patch("blueapi.cli.rest.BlueapiRestClient.reload_environment") -@patch("blueapi.cli.cli.sleep", return_value=None) +@patch("blueapi.client.rest.BlueapiRestClient.get_environment") +@patch("blueapi.client.rest.BlueapiRestClient.delete_environment") +@patch("blueapi.client.client.time.sleep", return_value=None) def test_reset_env_client_behavior( mock_sleep: MagicMock, mock_reload_environment: Mock, @@ -268,23 +268,18 @@ def test_reset_env_client_behavior( # Invoke the CLI command that would trigger the environment initialization check reload_result = runner.invoke(main, ["controller", "env", "-r"]) - assert mock_get_environment.call_count == 3 - - # Verify if sleep was called between polling iterations - assert mock_sleep.call_count == 2 # Since the last check doesn't require a sleep - # Check if the final environment status is printed correctly # assert "Environment is initialized." in result.output assert ( reload_result.output - == "Reloading the environment...\nEnvironment reload initiated.\nWaiting for environment to initialize...\nWaiting for environment to initialize...\nEnvironment is initialized.\ninitialized=True error_message=None\n" # noqa: E501 + == "Reloading environment\nEnvironment is initialized.\ninitialized=True error_message=None\n" # noqa: E501 ) @responses.activate @pytest.mark.handler @patch("blueapi.service.handler.Handler") -@patch("blueapi.cli.cli.sleep", return_value=None) +@patch("blueapi.client.client.time.sleep", return_value=None) def test_env_endpoint_interaction( mock_sleep: MagicMock, mock_handler: Mock, handler: Handler, runner: CliRunner ): @@ -317,9 +312,6 @@ def test_env_endpoint_interaction( # Run the command that should interact with these endpoints result = runner.invoke(main, ["controller", "env", "-r"]) - # Check if the endpoints were hit as expected - assert len(responses.calls) == 4 # Ensures that all expected calls were made - for index, call in enumerate(responses.calls): if index == 0: assert call.request.method == "DELETE" @@ -336,12 +328,10 @@ def test_env_endpoint_interaction( @pytest.mark.handler @responses.activate @patch("blueapi.service.handler.Handler") -@patch("blueapi.cli.cli.sleep", return_value=None) +@patch("blueapi.client.client.time.sleep", return_value=None) def test_env_timeout( mock_sleep: MagicMock, mock_handler: Mock, handler: Handler, runner: CliRunner ): - max_polling_count = 10 # Assuming this is your max polling count in the command - # Setup mocked responses for the REST endpoints responses.add( responses.DELETE, @@ -350,7 +340,7 @@ def test_env_timeout( json=EnvironmentResponse(initialized=False).dict(), ) # Add responses for each polling attempt, all indicating not initialized - for _ in range(max_polling_count): + for _ in range(10): responses.add( responses.GET, "http://localhost:8000/environment", @@ -359,16 +349,17 @@ def test_env_timeout( ) # Run the command that should interact with these endpoints - result = runner.invoke(main, ["controller", "env", "-r"]) + result = runner.invoke(main, ["controller", "env", "-r", "-t", "0.1"]) if result.exception is not None: assert isinstance(result.exception, TimeoutError), "Expected a TimeoutError" - assert result.exception.args[0] == "Environment initialization timed out." + assert ( + result.exception.args[0] + == "Failed to reload the environment within 0.1 seconds, " + "a server restart is recommended" + ) else: raise AssertionError("Expected an exception but got None") - # Check if the endpoints were hit as expected - assert len(responses.calls) == max_polling_count + 1 # +1 for the DELETE call - # First call should be DELETE assert responses.calls[0].request.method == "DELETE" assert responses.calls[0].request.url == "http://localhost:8000/environment" @@ -387,7 +378,7 @@ def test_env_timeout( @pytest.mark.handler @responses.activate @patch("blueapi.service.handler.Handler") -@patch("blueapi.cli.cli.sleep", return_value=None) +@patch("blueapi.client.client.time.sleep", return_value=None) def test_env_reload_server_side_error( mock_sleep: MagicMock, mock_handler: Mock, handler: Handler, runner: CliRunner ): @@ -404,7 +395,7 @@ def test_env_reload_server_side_error( assert isinstance( result.exception, BlueskyRemoteError ), "Expected a BlueskyRemoteError" - assert result.exception.args[0] == "Failed to reload the environment" + assert result.exception.args[0] == "Failed to tear down the environment" else: raise AssertionError("Expected an exception but got None") @@ -439,7 +430,9 @@ def mock_config(): ) def test_error_handling(mock_config, exception, expected_exit_code, runner: CliRunner): # Patching the create_task method to raise different exceptions - with patch("blueapi.cli.rest.BlueapiRestClient.create_task", side_effect=exception): + with patch( + "blueapi.client.rest.BlueapiRestClient.create_task", side_effect=exception + ): result = runner.invoke( main, [