Skip to content

Commit

Permalink
Refactor client logic out of CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Jul 8, 2024
1 parent eb9a5d4 commit b4c90bd
Show file tree
Hide file tree
Showing 8 changed files with 483 additions and 221 deletions.
145 changes: 48 additions & 97 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
import json
import logging
from collections import deque
from functools import wraps
from pathlib import Path
from pprint import pprint
from time import sleep

import click
from bluesky.callbacks.best_effort import BestEffortCallback
from pydantic import ValidationError
from requests.exceptions import ConnectionError

from blueapi import __version__
from blueapi.cli.event_bus_client import BlueskyRemoteError, EventBusClient
from blueapi.cli.format import OutputFormat
from blueapi.client.client import BlueapiClient
from blueapi.client.event_bus import AnyEvent, BlueskyRemoteError, EventBusClient
from blueapi.config import ApplicationConfig, ConfigLoader
from blueapi.core import DataEvent
from blueapi.messaging import MessageContext
from blueapi.messaging.stomptemplate import StompMessagingTemplate
from blueapi.service.main import start
from blueapi.service.model import WorkerTask
from blueapi.service.openapi import (
DOCS_SCHEMA_LOCATION,
generate_schema,
print_schema_as_yaml,
write_schema_as_yaml,
)
from blueapi.worker import ProgressEvent, Task, WorkerEvent, WorkerState
from blueapi.worker import ProgressEvent, Task, WorkerEvent

from .rest import BlueapiRestClient
from .updates import CliEventRenderer


@click.group(invoke_without_command=True)
Expand Down Expand Up @@ -106,7 +105,7 @@ def controller(ctx: click.Context, output: str) -> None:
ctx.ensure_object(dict)
config: ApplicationConfig = ctx.obj["config"]
ctx.obj["fmt"] = OutputFormat(output)
ctx.obj["rest_client"] = BlueapiRestClient(config.api)
ctx.obj["client"] = BlueapiClient.from_config(config)


def check_connection(func):
Expand All @@ -125,7 +124,7 @@ def wrapper(*args, **kwargs):
@click.pass_obj
def get_plans(obj: dict) -> None:
"""Get a list of plans available for the worker to use"""
client: BlueapiRestClient = obj["rest_client"]
client: BlueapiClient = obj["client"]
obj["fmt"].display(client.get_plans())


Expand All @@ -134,7 +133,7 @@ def get_plans(obj: dict) -> None:
@click.pass_obj
def get_devices(obj: dict) -> None:
"""Get a list of devices available for the worker to use"""
client: BlueapiRestClient = obj["rest_client"]
client: BlueapiClient = obj["client"]
obj["fmt"].display(client.get_devices())


Expand Down Expand Up @@ -183,30 +182,24 @@ def run_plan(
obj: dict, name: str, parameters: str | None, timeout: float | None
) -> None:
"""Run a plan with parameters"""
config: ApplicationConfig = obj["config"]
client: BlueapiRestClient = obj["rest_client"]

logger = logging.getLogger(__name__)
if config.stomp is not None:
_message_template = StompMessagingTemplate.autoconfigured(config.stomp)
else:
raise RuntimeError(
"Cannot run plans without Stomp configuration to track progress"
)
event_bus_client = EventBusClient(_message_template)
finished_event: deque[WorkerEvent] = deque()

def store_finished_event(event: WorkerEvent) -> None:
if event.is_complete():
finished_event.append(event)
client: BlueapiClient = obj["client"]

parameters = parameters or "{}"
task_id = ""
parsed_params = json.loads(parameters) if isinstance(parameters, str) else {}

progress_bar = CliEventRenderer()
callback = BestEffortCallback()

def on_event(event: AnyEvent) -> None:
if isinstance(event, ProgressEvent):
progress_bar.on_progress_event(event)
elif isinstance(event, DataEvent):
callback(event.name, event.doc)

Check warning on line 198 in src/blueapi/cli/cli.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/cli.py#L195-L198

Added lines #L195 - L198 were not covered by tests

try:
task = Task(name=name, params=parsed_params)
resp = client.create_task(task)
task_id = resp.task_id
resp = client.run_task(task, on_event=on_event)
except ValidationError as e:
pprint(f"failed to validate the task parameters, {task_id}, error: {e}")
return
Expand All @@ -217,18 +210,7 @@ def store_finished_event(event: WorkerEvent) -> None:
pprint("task could not run")
return

with event_bus_client:
event_bus_client.subscribe_to_topics(task_id, on_event=store_finished_event)
updated = client.update_worker_task(WorkerTask(task_id=task_id))

event_bus_client.wait_for_complete(timeout=timeout)

if event_bus_client.timed_out:
logger.error(f"Plan did not complete within {timeout} seconds")
return

process_event_after_finished(finished_event.pop(), logger)
pprint(updated.dict())
pprint(resp.dict())

Check warning on line 213 in src/blueapi/cli/cli.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/cli.py#L213

Added line #L213 was not covered by tests


@controller.command(name="state")
Expand All @@ -237,7 +219,7 @@ def store_finished_event(event: WorkerEvent) -> None:
def get_state(obj: dict) -> None:
"""Print the current state of the worker"""

client: BlueapiRestClient = obj["rest_client"]
client: BlueapiClient = obj["client"]

Check warning on line 222 in src/blueapi/cli/cli.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/cli.py#L222

Added line #L222 was not covered by tests
print(client.get_state().name)


Expand All @@ -248,8 +230,8 @@ def get_state(obj: dict) -> None:
def pause(obj: dict, defer: bool = False) -> None:
"""Pause the execution of the current task"""

client: BlueapiRestClient = obj["rest_client"]
pprint(client.set_state(WorkerState.PAUSED, defer=defer))
client: BlueapiClient = obj["client"]
pprint(client.pause(defer=defer))

Check warning on line 234 in src/blueapi/cli/cli.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/cli.py#L233-L234

Added lines #L233 - L234 were not covered by tests


@controller.command(name="resume")
Expand All @@ -258,8 +240,8 @@ def pause(obj: dict, defer: bool = False) -> None:
def resume(obj: dict) -> None:
"""Resume the execution of the current task"""

client: BlueapiRestClient = obj["rest_client"]
pprint(client.set_state(WorkerState.RUNNING))
client: BlueapiClient = obj["client"]
pprint(client.resume())

Check warning on line 244 in src/blueapi/cli/cli.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/cli.py#L243-L244

Added lines #L243 - L244 were not covered by tests


@controller.command(name="abort")
Expand All @@ -272,8 +254,8 @@ def abort(obj: dict, reason: str | None = None) -> None:
with optional reason
"""

client: BlueapiRestClient = obj["rest_client"]
pprint(client.cancel_current_task(state=WorkerState.ABORTING, reason=reason))
client: BlueapiClient = obj["client"]
pprint(client.abort(reason=reason))

Check warning on line 258 in src/blueapi/cli/cli.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/cli.py#L257-L258

Added lines #L257 - L258 were not covered by tests


@controller.command(name="stop")
Expand All @@ -284,8 +266,8 @@ def stop(obj: dict) -> None:
Stop the execution of the current task, marking as ongoing runs as success
"""

client: BlueapiRestClient = obj["rest_client"]
pprint(client.cancel_current_task(state=WorkerState.STOPPING))
client: BlueapiClient = obj["client"]
pprint(client.stop())

Check warning on line 270 in src/blueapi/cli/cli.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/cli.py#L269-L270

Added lines #L269 - L270 were not covered by tests


@controller.command(name="env")
Expand All @@ -298,60 +280,29 @@ def stop(obj: dict) -> None:
help="Reload the current environment",
default=False,
)
@click.option(
"-t",
"--timeout",
type=float,
help="Timeout to wait for reload in seconds, defaults to 10",
default=10.0,
)
@click.pass_obj
def env(obj: dict, reload: bool | None) -> None:
def env(
obj: dict,
reload: bool | None,
timeout: float | None,
) -> None:
"""
Inspect or restart the environment
"""

assert isinstance(client := obj["rest_client"], BlueapiRestClient)
assert isinstance(client := obj["client"], BlueapiClient)
if reload:
# Reload the environment if needed
print("Reloading the environment...")
try:
deserialized = client.reload_environment()
print(deserialized)

except BlueskyRemoteError as e:
raise BlueskyRemoteError("Failed to reload the environment") from e

# Initialize a variable to keep track of the environment status
environment_initialized = False
polling_count = 0
max_polling_count = 10
# Use a while loop to keep checking until the environment is initialized
while not environment_initialized and polling_count < max_polling_count:
# Fetch the current environment status
environment_status = client.get_environment()

# Check if the environment is initialized
if environment_status.initialized:
print("Environment is initialized.")
environment_initialized = True
else:
print("Waiting for environment to initialize...")
polling_count += 1
sleep(1) # Wait for 1 seconds before checking again
if polling_count == max_polling_count:
raise TimeoutError("Environment initialization timed out.")

# Once out of the loop, print the initialized environment status
print(environment_status)
print("Reloading environment")
status = client.reload_environment(timeout=timeout)
print("Environment is initialized.")
else:
print(client.get_environment())


# helper function
def process_event_after_finished(event: WorkerEvent, logger: logging.Logger):
if event.is_error():
logger.info("Failed with errors: \n")
for error in event.errors:
logger.error(error)
return
if len(event.warnings) != 0:
logger.info("Passed with warnings: \n")
for warning in event.warnings:
logger.warn(warning)
return

logger.info("Plan passed")
status = client.get_environment()
print(status)
76 changes: 0 additions & 76 deletions src/blueapi/cli/event_bus_client.py

This file was deleted.

22 changes: 2 additions & 20 deletions src/blueapi/cli/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,36 +43,18 @@ def _update(self, name: str, view: StatusView) -> None:


class CliEventRenderer:
_task_id: str | None
_pbar_renderer: ProgressBarRenderer

def __init__(
self,
task_id: str | None = None,
pbar_renderer: ProgressBarRenderer | None = None,
) -> None:
self._task_id = task_id
if pbar_renderer is None:
pbar_renderer = ProgressBarRenderer()
self._pbar_renderer = pbar_renderer

def on_progress_event(self, event: ProgressEvent) -> None:
if self._relates_to_task(event):
self._pbar_renderer.update(event.statuses)
self._pbar_renderer.update(event.statuses)

Check warning on line 57 in src/blueapi/cli/updates.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/updates.py#L57

Added line #L57 was not covered by tests

def on_worker_event(self, event: WorkerEvent) -> None:
if self._relates_to_task(event):
print(str(event.state))

def _relates_to_task(self, event: WorkerEvent | ProgressEvent) -> bool:
if self._task_id is None:
return True
elif isinstance(event, WorkerEvent):
return (
event.task_status is not None
and event.task_status.task_id == self._task_id
)
elif isinstance(event, ProgressEvent):
return event.task_id == self._task_id
else:
return False
print(str(event.state))

Check warning on line 60 in src/blueapi/cli/updates.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/cli/updates.py#L60

Added line #L60 was not covered by tests
Empty file added src/blueapi/client/__init__.py
Empty file.
Loading

0 comments on commit b4c90bd

Please sign in to comment.