Skip to content

Commit

Permalink
✨Comp backend: disconnect progress update in webserver (#4273)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg authored May 26, 2023
1 parent 0e8309f commit d8698f4
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 247 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class Node(BaseModel):
default=None,
ge=0,
le=100,
description="the node progress value",
description="the node progress value (deprecated in DB, still used for API only)",
deprecated=True,
)
thumbnail: HttpUrlWithCustomMinLength | None = Field(
Expand Down
10 changes: 7 additions & 3 deletions packages/models-library/src/models_library/rabbitmq_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class ProgressType(StrAutoEnum):


class ProgressMessageMixin(RabbitMessageBase):
channel_name: Literal["simcore.services.progress"] = "simcore.services.progress"
channel_name: Literal[
"simcore.services.progress.v2"
] = "simcore.services.progress.v2"
progress_type: ProgressType = (
ProgressType.COMPUTATION_RUNNING
) # NOTE: backwards compatible
Expand All @@ -93,11 +95,13 @@ def routing_key(self) -> str | None:


class ProgressRabbitMessageNode(ProgressMessageMixin, NodeMessageBase):
...
def routing_key(self) -> str | None:
return f"{self.project_id}.{self.node_id}"


class ProgressRabbitMessageProject(ProgressMessageMixin, ProjectMessageBase):
...
def routing_key(self) -> str | None:
return f"{self.project_id}.all_nodes"


class InstrumentationRabbitMessage(RabbitMessageBase, NodeMessageBase):
Expand Down
18 changes: 13 additions & 5 deletions packages/service-library/src/servicelib/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,19 @@ async def _on_message(
message: aio_pika.abc.AbstractIncomingMessage,
) -> None:
async with message.process(requeue=True):
with log_context(
_logger, logging.DEBUG, msg=f"Message received {message}"
):
if not await message_handler(message.body):
await message.nack()
try:
with log_context(
_logger, logging.DEBUG, msg=f"Message received {message}"
):
if not await message_handler(message.body):
await message.nack()
except Exception: # pylint: disable=broad-exception-caught
_logger.exception(
"unhandled exception when consuming RabbitMQ message, "
"this is catched but should not happen. "
"Please check, message will be queued back!"
)
await message.nack()

await queue.consume(_on_message)
return queue.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@
TaskNotCompletedError,
TaskNotFoundError,
)
from servicelib.long_running_tasks._models import TaskProgress, TaskResult, TaskStatus
from servicelib.long_running_tasks._models import (
ProgressPercent,
TaskProgress,
TaskResult,
TaskStatus,
)
from servicelib.long_running_tasks._task import TasksManager, start_task
from tenacity._asyncio import AsyncRetrying
from tenacity.retry import retry_if_exception_type
from tenacity.stop import stop_after_delay
from tenacity.wait import wait_fixed

# UTILS
_RETRY_PARAMS = dict(
reraise=True,
wait=wait_fixed(0.1),
stop=stop_after_delay(60),
retry=retry_if_exception_type(AssertionError),
)


async def a_background_task(
Expand All @@ -35,7 +45,7 @@ async def a_background_task(
"""sleeps and raises an error or returns 42"""
for i in range(total_sleep):
await asyncio.sleep(1)
task_progress.update(percent=float((i + 1) / total_sleep))
task_progress.update(percent=ProgressPercent((i + 1) / total_sleep))
if raise_when_finished:
raise RuntimeError("raised this error as instructed")

Expand Down Expand Up @@ -73,12 +83,14 @@ async def test_unchecked_task_is_auto_removed(tasks_manager: TasksManager):
total_sleep=10 * TEST_CHECK_STALE_INTERVAL_S,
)
await asyncio.sleep(2 * TEST_CHECK_STALE_INTERVAL_S + 1)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_status(task_id, with_task_context=None)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_result(task_id, with_task_context=None)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_result_old(task_id)
async for attempt in AsyncRetrying(**_RETRY_PARAMS):
with attempt:
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_status(task_id, with_task_context=None)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_result(task_id, with_task_context=None)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_result_old(task_id)


async def test_checked_once_task_is_auto_removed(tasks_manager: TasksManager):
Expand All @@ -91,12 +103,14 @@ async def test_checked_once_task_is_auto_removed(tasks_manager: TasksManager):
# check once (different branch in code)
tasks_manager.get_task_status(task_id, with_task_context=None)
await asyncio.sleep(2 * TEST_CHECK_STALE_INTERVAL_S + 1)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_status(task_id, with_task_context=None)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_result(task_id, with_task_context=None)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_result_old(task_id)
async for attempt in AsyncRetrying(**_RETRY_PARAMS):
with attempt:
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_status(task_id, with_task_context=None)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_result(task_id, with_task_context=None)
with pytest.raises(TaskNotFoundError):
tasks_manager.get_task_result_old(task_id)


async def test_checked_task_is_not_auto_removed(tasks_manager: TasksManager):
Expand All @@ -106,12 +120,7 @@ async def test_checked_task_is_not_auto_removed(tasks_manager: TasksManager):
raise_when_finished=False,
total_sleep=5 * TEST_CHECK_STALE_INTERVAL_S,
)
async for attempt in AsyncRetrying(
reraise=True,
wait=wait_fixed(TEST_CHECK_STALE_INTERVAL_S / 10.0),
stop=stop_after_delay(60),
retry=retry_if_exception_type(AssertionError),
):
async for attempt in AsyncRetrying(**_RETRY_PARAMS):
with attempt:
status = tasks_manager.get_task_status(task_id, with_task_context=None)
assert status.done, f"task {task_id} not complete"
Expand Down Expand Up @@ -220,12 +229,7 @@ async def test_get_result_missing(tasks_manager: TasksManager):
async def test_get_result_finished_with_error(tasks_manager: TasksManager):
task_id = start_task(tasks_manager=tasks_manager, task=failing_background_task)
# wait for result
async for attempt in AsyncRetrying(
reraise=True,
wait=wait_fixed(0.1),
stop=stop_after_delay(60),
retry=retry_if_exception_type(AssertionError),
):
async for attempt in AsyncRetrying(**_RETRY_PARAMS):
with attempt:
assert tasks_manager.get_task_status(task_id, with_task_context=None).done

Expand All @@ -236,12 +240,7 @@ async def test_get_result_finished_with_error(tasks_manager: TasksManager):
async def test_get_result_old_finished_with_error(tasks_manager: TasksManager):
task_id = start_task(tasks_manager=tasks_manager, task=failing_background_task)
# wait for result
async for attempt in AsyncRetrying(
reraise=True,
wait=wait_fixed(0.1),
stop=stop_after_delay(60),
retry=retry_if_exception_type(AssertionError),
):
async for attempt in AsyncRetrying(**_RETRY_PARAMS):
with attempt:
assert tasks_manager.get_task_status(task_id, with_task_context=None).done

Expand Down
2 changes: 1 addition & 1 deletion services/autoscaling/tests/unit/test_utils_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def test_post_task_progress_message(
await client.subscribe(
ProgressRabbitMessageNode.get_channel_name(),
mocked_message_handler,
topics=None,
topics=[BIND_TO_ALL_TOPICS],
)

service_with_labels = await create_service(
Expand Down
3 changes: 0 additions & 3 deletions services/dask-sidecar/docker/boot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ fi

if [ ${DASK_START_AS_SCHEDULER+x} ]; then
scheduler_version=$(dask scheduler --version)
mkdir --parents /home/scu/.config/dask
dask_logging=$(printf "logging:\n distributed: %s\n distributed.scheduler: %s" "${LOG_LEVEL:-warning}" "${LOG_LEVEL:-warning}")
echo "$dask_logging" >> /home/scu/.config/dask/distributed.yaml

echo "$INFO" "Starting as dask scheduler:${scheduler_version}..."
if [ "${SC_BOOT_MODE}" = "debug-ptvsd" ]; then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
service_started,
service_stopped,
)
from servicelib.json_serialization import json_dumps
from servicelib.logging_utils import log_context
from servicelib.rabbitmq import RabbitMQClient
from servicelib.utils import logged_gather

from ..projects import projects_api
from ..projects.exceptions import NodeNotFoundError, ProjectNotFoundError
from ..projects.exceptions import ProjectNotFoundError
from ..rabbitmq import get_rabbitmq_client
from ..socketio.messages import (
SOCKET_IO_EVENT,
Expand All @@ -40,92 +39,87 @@
_logger = logging.getLogger(__name__)


async def _handle_computation_running_progress(
def _convert_to_project_progress_event(
message: ProgressRabbitMessageProject,
) -> SocketMessageDict:
return SocketMessageDict(
event_type=SOCKET_IO_PROJECT_PROGRESS_EVENT,
data={
"project_id": message.project_id,
"user_id": message.user_id,
"progress_type": message.progress_type,
"progress": message.progress,
},
)


async def _convert_to_node_update_event(
app: web.Application, message: ProgressRabbitMessageNode
) -> bool:
) -> SocketMessageDict | None:
try:
project = await projects_api.update_project_node_progress(
app,
message.user_id,
f"{message.project_id}",
f"{message.node_id}",
progress=message.progress,
project = await projects_api.get_project_for_user(
app, f"{message.project_id}", message.user_id
)
if project and not await projects_api.is_project_hidden(
app, message.project_id
):
messages: list[SocketMessageDict] = [
{
"event_type": SOCKET_IO_NODE_UPDATED_EVENT,
"data": {
"project_id": message.project_id,
"node_id": message.node_id,
"data": project["workbench"][f"{message.node_id}"],
},
}
]
await send_messages(app, f"{message.user_id}", messages)
return True
if f"{message.node_id}" in project["workbench"]:
# update the project node progress with the latest value
project["workbench"][f"{message.node_id}"].update(
{"progress": round(message.progress * 100.0)}
)
return SocketMessageDict(
event_type=SOCKET_IO_NODE_UPDATED_EVENT,
data={
"project_id": message.project_id,
"node_id": message.node_id,
"data": project["workbench"][f"{message.node_id}"],
},
)
_logger.warning("node not found: '%s'", message.dict())
except ProjectNotFoundError:
_logger.warning(
"project related to received rabbitMQ progress message not found: '%s'",
json_dumps(message, indent=2),
)
return True
except NodeNotFoundError:
_logger.warning(
"node related to received rabbitMQ progress message not found: '%s'",
json_dumps(message, indent=2),
)
return True
return False
_logger.warning("project not found: '%s'", message.dict())
return None


def _convert_to_node_progress_event(
message: ProgressRabbitMessageNode,
) -> SocketMessageDict:
return SocketMessageDict(
event_type=SOCKET_IO_NODE_PROGRESS_EVENT,
data={
"project_id": message.project_id,
"node_id": message.node_id,
"user_id": message.user_id,
"progress_type": message.progress_type,
"progress": message.progress,
},
)


async def _progress_message_parser(app: web.Application, data: bytes) -> bool:
# update corresponding project, node, progress value
rabbit_message: (
ProgressRabbitMessageNode | ProgressRabbitMessageProject
) = parse_raw_as(
rabbit_message = parse_raw_as(
Union[ProgressRabbitMessageNode, ProgressRabbitMessageProject], data
)
socket_message: SocketMessageDict | None = None
if isinstance(rabbit_message, ProgressRabbitMessageProject):
socket_message = _convert_to_project_progress_event(rabbit_message)
elif rabbit_message.progress_type is ProgressType.COMPUTATION_RUNNING:
socket_message = await _convert_to_node_update_event(app, rabbit_message)
else:
socket_message = _convert_to_node_progress_event(rabbit_message)
if socket_message:
await send_messages(app, rabbit_message.user_id, [socket_message])

if rabbit_message.progress_type is ProgressType.COMPUTATION_RUNNING:
# NOTE: backward compatibility, this progress is kept in the project
assert isinstance(rabbit_message, ProgressRabbitMessageNode) # nosec
return await _handle_computation_running_progress(app, rabbit_message)

# NOTE: other types of progress are transient
is_type_message_node = type(rabbit_message) == ProgressRabbitMessageNode
socket_message: SocketMessageDict = {
"event_type": (
SOCKET_IO_NODE_PROGRESS_EVENT
if is_type_message_node
else SOCKET_IO_PROJECT_PROGRESS_EVENT
),
"data": {
"project_id": rabbit_message.project_id,
"user_id": rabbit_message.user_id,
"progress_type": rabbit_message.progress_type,
"progress": rabbit_message.progress,
},
}
if is_type_message_node:
socket_message["data"]["node_id"] = rabbit_message.node_id
await send_messages(app, f"{rabbit_message.user_id}", [socket_message])
return True


async def _log_message_parser(app: web.Application, data: bytes) -> bool:
rabbit_message = LoggerRabbitMessage.parse_raw(data)

if not await projects_api.is_project_hidden(app, rabbit_message.project_id):
socket_messages: list[SocketMessageDict] = [
{
"event_type": SOCKET_IO_LOG_EVENT,
"data": rabbit_message.dict(exclude={"user_id", "channel_name"}),
}
]
await send_messages(app, f"{rabbit_message.user_id}", socket_messages)
socket_messages: list[SocketMessageDict] = [
{
"event_type": SOCKET_IO_LOG_EVENT,
"data": rabbit_message.dict(exclude={"user_id", "channel_name"}),
}
]
await send_messages(app, rabbit_message.user_id, socket_messages)
return True


Expand Down Expand Up @@ -154,7 +148,7 @@ async def _events_message_parser(app: web.Application, data: bytes) -> bool:
},
}
]
await send_messages(app, f"{rabbit_message.user_id}", socket_messages)
await send_messages(app, rabbit_message.user_id, socket_messages)
return True


Expand All @@ -176,7 +170,7 @@ async def _events_message_parser(app: web.Application, data: bytes) -> bool:
(
ProgressRabbitMessageNode.get_channel_name(),
_progress_message_parser,
{},
dict(topics=[]),
),
(
InstrumentationRabbitMessage.get_channel_name(),
Expand Down
Loading

0 comments on commit d8698f4

Please sign in to comment.