diff --git a/packages/service-library/src/servicelib/rabbitmq/_client.py b/packages/service-library/src/servicelib/rabbitmq/_client.py index 00fe0bb1b2a..278d2f93724 100644 --- a/packages/service-library/src/servicelib/rabbitmq/_client.py +++ b/packages/service-library/src/servicelib/rabbitmq/_client.py @@ -7,7 +7,7 @@ import aio_pika from pydantic import NonNegativeInt -from ..logging_utils import log_context +from ..logging_utils import log_catch, log_context from ._client_base import RabbitMQClientBase from ._models import MessageHandler, RabbitMessage from ._utils import ( @@ -82,7 +82,9 @@ async def _on_message( if not await message_handler(message.body): await _safe_nack(message_handler, max_retries_upon_error, message) except Exception: # pylint: disable=broad-exception-caught - await _safe_nack(message_handler, max_retries_upon_error, message) + _logger.exception("Exception raised when handling message") + with log_catch(_logger, reraise=False): + await _safe_nack(message_handler, max_retries_upon_error, message) @dataclass diff --git a/services/api-server/src/simcore_service_api_server/services/log_streaming.py b/services/api-server/src/simcore_service_api_server/services/log_streaming.py index 457a196f7db..6080427c922 100644 --- a/services/api-server/src/simcore_service_api_server/services/log_streaming.py +++ b/services/api-server/src/simcore_service_api_server/services/log_streaming.py @@ -5,7 +5,8 @@ from models_library.rabbitmq_messages import LoggerRabbitMessage from models_library.users import UserID -from pydantic import NonNegativeInt, ValidationError +from pydantic import NonNegativeInt +from servicelib.logging_utils import log_catch from servicelib.rabbitmq import RabbitMQClient from ..models.schemas.jobs import JobID, JobLog @@ -31,7 +32,7 @@ class LogStreamerRegistionConflict(LogDistributionBaseException): class LogDistributor: def __init__(self, rabbitmq_client: RabbitMQClient): self._rabbit_client = rabbitmq_client - self._log_streamers: dict[JobID, Queue] = {} + self._log_streamers: dict[JobID, Queue[JobLog]] = {} self._queue_name: str async def setup(self): @@ -53,34 +54,24 @@ async def __aexit__(self, exc_type, exc, tb): await self.teardown() async def _distribute_logs(self, data: bytes): - try: - got = LoggerRabbitMessage.parse_raw( - data - ) # rabbitmq client safe_nacks the message if this deserialization fails - except ValidationError as e: - _logger.debug( - "Could not parse log message from RabbitMQ in LogDistributor._distribute_logs" + with log_catch(_logger, reraise=False): + got = LoggerRabbitMessage.parse_raw(data) + item = JobLog( + job_id=got.project_id, + node_id=got.node_id, + log_level=got.log_level, + messages=got.messages, ) - raise e - _logger.debug( - "LogDistributor._distribute_logs received message message from RabbitMQ: %s", - got.json(), - ) - item = JobLog( - job_id=got.project_id, - node_id=got.node_id, - log_level=got.log_level, - messages=got.messages, - ) - queue = self._log_streamers.get(item.job_id) - if queue is None: - raise LogStreamerNotRegistered( - f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered" - ) - await queue.put(item) - return True + queue = self._log_streamers.get(item.job_id) + if queue is None: + raise LogStreamerNotRegistered( + f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered" + ) + await queue.put(item) + return True + return False - async def register(self, job_id: JobID, queue: Queue): + async def register(self, job_id: JobID, queue: Queue[JobLog]): if job_id in self._log_streamers: raise LogStreamerRegistionConflict( f"A stream was already connected to {job_id=}. Only a single stream can be connected at the time" diff --git a/services/api-server/tests/unit/test_services_rabbitmq.py b/services/api-server/tests/unit/test_services_rabbitmq.py index a58d99b54e2..ccadeb119c7 100644 --- a/services/api-server/tests/unit/test_services_rabbitmq.py +++ b/services/api-server/tests/unit/test_services_rabbitmq.py @@ -11,7 +11,7 @@ from collections.abc import AsyncIterable, Callable from contextlib import asynccontextmanager from datetime import datetime, timedelta -from typing import Final, Iterable +from typing import Final, Iterable, Literal from unittest.mock import AsyncMock import httpx @@ -24,9 +24,9 @@ from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID from models_library.projects_state import RunningState -from models_library.rabbitmq_messages import LoggerRabbitMessage +from models_library.rabbitmq_messages import LoggerRabbitMessage, RabbitMessageBase from models_library.users import UserID -from pydantic import parse_obj_as +from pydantic import ValidationError, parse_obj_as from pytest_mock import MockerFixture, MockFixture from pytest_simcore.helpers.utils_envs import ( EnvVarsDict, @@ -170,15 +170,23 @@ def produce_logs( create_rabbitmq_client: Callable[[str], RabbitMQClient], user_id: UserID, ): - async def _go(name, project_id_=None, node_id_=None, messages_=None, level_=None): + async def _go( + name, + project_id_=None, + node_id_=None, + messages_=None, + level_=None, + log_message: RabbitMessageBase | None = None, + ): rabbitmq_producer = create_rabbitmq_client(f"pytest_producer_{name}") - log_message = LoggerRabbitMessage( - user_id=user_id, - project_id=project_id_ or faker.uuid4(), - node_id=node_id_, - messages=messages_ or [faker.text() for _ in range(10)], - log_level=level_ or logging.INFO, - ) + if log_message is None: + log_message = LoggerRabbitMessage( + user_id=user_id, + project_id=project_id_ or faker.uuid4(), + node_id=node_id_, + messages=messages_ or [faker.text() for _ in range(10)], + log_level=level_ or logging.INFO, + ) await rabbitmq_producer.publish(log_message.channel_name, log_message) return _go @@ -381,7 +389,6 @@ async def test_log_streamer_with_distributor( async def _log_publisher(): while not computation_done(): msg: str = faker.text() - await asyncio.sleep(0.2) await produce_logs("expected", project_id, node_id, [msg], logging.DEBUG) published_logs.append(msg) @@ -399,6 +406,45 @@ async def _log_publisher(): assert published_logs == collected_messages +async def test_log_streamer_not_raise_with_distributor( + client: httpx.AsyncClient, + app: FastAPI, + user_id, + project_id: ProjectID, + node_id: NodeID, + produce_logs: Callable, + log_streamer_with_distributor: LogStreamer, + faker: Faker, + computation_done: Callable[[], bool], +): + class InvalidLoggerRabbitMessage(LoggerRabbitMessage): + channel_name: Literal["simcore.services.logs.v2"] = "simcore.services.logs.v2" + node_id: NodeID | None + messages: int + log_level: int = logging.INFO + + def routing_key(self) -> str: + return f"{self.project_id}.{self.log_level}" + + log_rabbit_message = InvalidLoggerRabbitMessage( + user_id=user_id, + project_id=project_id, + node_id=node_id, + messages=100, + log_level=logging.INFO, + ) + with pytest.raises(ValidationError): + LoggerRabbitMessage.parse_obj(log_rabbit_message.dict()) + + await produce_logs("expected", log_message=log_rabbit_message) + + ii: int = 0 + async for log in log_streamer_with_distributor.log_generator(): + _ = JobLog.parse_raw(log) + ii += 1 + assert ii == 0 + + async def test_log_generator(mocker: MockFixture, faker: Faker): mocker.patch( "simcore_service_api_server.services.log_streaming.LogStreamer._project_done",