Skip to content

Commit

Permalink
Fix tests after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Jul 24, 2024
1 parent 5937021 commit e6d8cf8
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 83 deletions.
15 changes: 6 additions & 9 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@
from blueapi import __version__
from blueapi.cli.format import OutputFormat
from blueapi.client.client import BlueapiClient
from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient
from blueapi.client.event_bus import (AnyEvent, BlueskyStreamingError,
EventBusClient)
from blueapi.client.rest import BlueskyRemoteControlError
from blueapi.config import ApplicationConfig, ConfigLoader
from blueapi.core import DataEvent
from blueapi.messaging import MessageContext
from blueapi.messaging.stomptemplate import StompMessagingTemplate
from blueapi.service.main import start
from blueapi.service.openapi import (
DOCS_SCHEMA_LOCATION,
generate_schema,
print_schema_as_yaml,
write_schema_as_yaml,
)
from blueapi.service.openapi import (DOCS_SCHEMA_LOCATION, generate_schema,
print_schema_as_yaml,
write_schema_as_yaml)
from blueapi.worker import ProgressEvent, Task, WorkerEvent

from .scratch import setup_scratch
Expand Down Expand Up @@ -280,7 +278,6 @@ def stop(obj: dict) -> None:
"-r",
"--reload",
is_flag=True,
type=bool,
help="Reload the current environment",
default=False,
)
Expand All @@ -294,7 +291,7 @@ def stop(obj: dict) -> None:
@click.pass_obj
def env(
obj: dict,
reload: bool | None,
reload: bool,
timeout: float | None,
) -> None:
"""
Expand Down
17 changes: 6 additions & 11 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
from blueapi.config import ApplicationConfig
from blueapi.core.bluesky_types import DataEvent
from blueapi.messaging import MessageContext, StompMessagingTemplate
from blueapi.service.model import (
DeviceModel,
DeviceResponse,
EnvironmentResponse,
PlanModel,
PlanResponse,
TaskResponse,
WorkerTask,
)
from blueapi.service.model import (DeviceModel, DeviceResponse,
EnvironmentResponse, PlanModel,
PlanResponse, TaskResponse, WorkerTask)
from blueapi.worker import Task, TrackableTask, WorkerEvent, WorkerState
from blueapi.worker.event import ProgressEvent, TaskStatus

from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent
from .event_bus import (AnyEvent, BlueskyStreamingError, EventBusClient,
OnAnyEvent)
from .rest import BlueapiRestClient, BlueskyRemoteControlError


Expand Down Expand Up @@ -361,7 +356,7 @@ def _wait_for_reload(
# Poll until the environment is restarted or the timeout is reached
status = self._rest.get_environment()
if status.error_message is not None:
raise BlueskyRemoteControlError(status.error_message)
raise BlueskyRemoteControlError(f"Error reloading environment: {status.error_message}")
elif status.initialized:
return status
time.sleep(polling_interval)
Expand Down
26 changes: 5 additions & 21 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
from contextlib import asynccontextmanager

from fastapi import (
BackgroundTasks,
Body,
Depends,
FastAPI,
HTTPException,
Request,
Response,
status,
)
from fastapi import (BackgroundTasks, Body, Depends, FastAPI, HTTPException,
Request, Response, status)
from pydantic import ValidationError
from starlette.responses import JSONResponse
from super_state_machine.errors import TransitionError
Expand All @@ -19,17 +11,9 @@
from blueapi.worker import Task, TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum

from .model import (
DeviceModel,
DeviceResponse,
EnvironmentResponse,
PlanModel,
PlanResponse,
StateChangeRequest,
TaskResponse,
TasksListResponse,
WorkerTask,
)
from .model import (DeviceModel, DeviceResponse, EnvironmentResponse,
PlanModel, PlanResponse, StateChangeRequest, TaskResponse,
TasksListResponse, WorkerTask)
from .runner import WorkerDispatcher

REST_API_VERSION = "0.0.5"
Expand Down
15 changes: 7 additions & 8 deletions src/blueapi/service/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from typing import Any

from blueapi.config import ApplicationConfig
from blueapi.service.interface import InitialisationException, start_worker, stop_worker
from blueapi.service.model import (
EnvironmentResponse,
)
from blueapi.service.interface import (InitialisationException, start_worker,
stop_worker)
from blueapi.service.model import EnvironmentResponse

# The default multiprocessing start method is fork
set_start_method("spawn", force=True)
Expand Down Expand Up @@ -39,7 +38,7 @@ def __init__(
self._config = config or ApplicationConfig()
self._subprocess = None
self._use_subprocess = use_subprocess
self._state = EnvironmentResponse(initialized=False, error_message="")
self._state = EnvironmentResponse(initialized=False)

def start(self):
if self._subprocess is None and self._use_subprocess:
Expand All @@ -55,11 +54,11 @@ def start(self):
)
LOGGER.exception(self._state.error_message)
return
self._state = EnvironmentResponse(initialized=True, error_message="")
self._state = EnvironmentResponse(initialized=True)

def stop(self):
if self._subprocess is not None:
self._state = EnvironmentResponse(initialized=False, error_message="")
self._state = EnvironmentResponse(initialized=False)
try:
self._subprocess.apply(stop_worker)
except InitialisationException:
Expand All @@ -71,7 +70,7 @@ def stop(self):
if (not self._use_subprocess) and (
self._state.initialized or self._state.error_message
):
self._state = EnvironmentResponse(initialized=False, error_message="")
self._state = EnvironmentResponse(initialized=False)
stop_worker()

def reload_context(self):
Expand Down
10 changes: 3 additions & 7 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@

from blueapi.core.bluesky_types import Plan
from blueapi.service import main
from blueapi.service.model import (
DeviceModel,
PlanModel,
StateChangeRequest,
WorkerTask,
)
from blueapi.service.model import (DeviceModel, PlanModel, StateChangeRequest,
WorkerTask)
from blueapi.worker.event import WorkerState
from blueapi.worker.task import Task
from blueapi.worker.worker import TrackableTask
Expand Down Expand Up @@ -546,7 +542,7 @@ def test_set_state_invalid_transition(
def test_get_environment_idle(client: TestClient) -> None:
assert client.get("/environment").json() == {
"initialized": True,
"error_message": "",
"error_message": None,
}


Expand Down
43 changes: 16 additions & 27 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@
from blueapi.client.rest import BlueskyRemoteControlError
from blueapi.config import ScratchConfig, ScratchRepository
from blueapi.core.bluesky_types import Plan
from blueapi.service.model import (
DeviceModel,
DeviceResponse,
EnvironmentResponse,
PlanModel,
PlanResponse,
)
from blueapi.service.model import (DeviceModel, DeviceResponse,
EnvironmentResponse, PlanModel,
PlanResponse)


@pytest.fixture
Expand Down Expand Up @@ -171,7 +167,7 @@ def test_get_env(


@responses.activate(assert_all_requests_are_fired=True)
@patch("blueapi.cli.cli.sleep", return_value=None)
@patch("blueapi.client.client.time.sleep", return_value=None)
def test_reset_env_client_behavior(
mock_sleep: Mock,
runner: CliRunner,
Expand Down Expand Up @@ -210,20 +206,15 @@ def test_reset_env_client_behavior(
# Check if the final environment status is printed correctly
# assert "Environment is initialized." in result.output
assert reload_result.output == dedent("""\
Reloading the environment...
initialized=False error_message=None
Waiting for environment to initialize...
Waiting for environment to initialize...
Environment is initialized.
Reloading environment
Environment is initialized
initialized=True error_message=None
""")


@responses.activate
@patch("blueapi.cli.cli.sleep", return_value=None)
@patch("blueapi.client.client.time.sleep", return_value=None)
def test_env_timeout(mock_sleep: Mock, runner: CliRunner):
max_polling_count = 10 # Assuming this is your max polling count in the command

# Setup mocked responses for the REST endpoints
responses.add(
responses.DELETE,
Expand All @@ -232,13 +223,12 @@ def test_env_timeout(mock_sleep: Mock, runner: CliRunner):
json=EnvironmentResponse(initialized=False).dict(),
)
# Add responses for each polling attempt, all indicating not initialized
for _ in range(max_polling_count):
responses.add(
responses.GET,
"http://localhost:8000/environment",
json=EnvironmentResponse(initialized=False).dict(),
status=200,
)
responses.add(
responses.GET,
"http://localhost:8000/environment",
json=EnvironmentResponse(initialized=False).dict(),
status=200,
)

# Run the command that should interact with these endpoints
result = runner.invoke(main, ["controller", "env", "-r", "-t", "0.1"])
Expand All @@ -264,8 +254,7 @@ def test_env_timeout(mock_sleep: Mock, runner: CliRunner):
# Check the output for the timeout message
assert (
result.output
== "Reloading the environment...\ninitialized=False error_message=None\n"
+ "Waiting for environment to initialize...\n" * 10
== "Reloading environment\n"
)
assert (
result.exit_code == 1
Expand All @@ -283,7 +272,7 @@ def test_env_reload_server_side_error(runner: CliRunner):
assert isinstance(
result.exception, BlueskyRemoteControlError
), "Expected a BlueskyRemoteError from cli runner"
assert result.exception.args[0] == "Failed to reload the environment"
assert result.exception.args[0] == "Failed to tear down the environment"

# Check if the endpoints were hit as expected
assert len(responses.calls) == 1 # +1 for the DELETE call
Expand All @@ -295,7 +284,7 @@ def test_env_reload_server_side_error(runner: CliRunner):
# Check the output for the timeout message
# TODO this seems wrong but this is the current behaviour
# There should be an error message
assert result.output == "Reloading the environment...\n"
assert result.output == "Reloading environment\n"

assert result.exit_code == 1

Expand Down

0 comments on commit e6d8cf8

Please sign in to comment.