Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 make logstreaming callback safer #5633

6 changes: 4 additions & 2 deletions packages/service-library/src/servicelib/rabbitmq/_client.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import aio_pika
from pydantic import NonNegativeInt

from ..logging_utils import log_context
from ..logging_utils import log_catch, log_context
from ._client_base import RabbitMQClientBase
from ._models import MessageHandler, RabbitMessage
from ._utils import (
@@ -82,7 +82,9 @@ async def _on_message(
if not await message_handler(message.body):
await _safe_nack(message_handler, max_retries_upon_error, message)
except Exception: # pylint: disable=broad-exception-caught
await _safe_nack(message_handler, max_retries_upon_error, message)
_logger.exception("Exception raised when handling message")
with log_catch(_logger, reraise=False):
await _safe_nack(message_handler, max_retries_upon_error, message)


@dataclass
Original file line number Diff line number Diff line change
@@ -5,7 +5,8 @@

from models_library.rabbitmq_messages import LoggerRabbitMessage
from models_library.users import UserID
from pydantic import NonNegativeInt, ValidationError
from pydantic import NonNegativeInt
from servicelib.logging_utils import log_catch
from servicelib.rabbitmq import RabbitMQClient

from ..models.schemas.jobs import JobID, JobLog
@@ -31,7 +32,7 @@ class LogStreamerRegistionConflict(LogDistributionBaseException):
class LogDistributor:
def __init__(self, rabbitmq_client: RabbitMQClient):
self._rabbit_client = rabbitmq_client
self._log_streamers: dict[JobID, Queue] = {}
self._log_streamers: dict[JobID, Queue[JobLog]] = {}
self._queue_name: str

async def setup(self):
@@ -53,34 +54,24 @@ async def __aexit__(self, exc_type, exc, tb):
await self.teardown()

async def _distribute_logs(self, data: bytes):
try:
got = LoggerRabbitMessage.parse_raw(
data
) # rabbitmq client safe_nacks the message if this deserialization fails
except ValidationError as e:
_logger.debug(
"Could not parse log message from RabbitMQ in LogDistributor._distribute_logs"
with log_catch(_logger, reraise=False):
got = LoggerRabbitMessage.parse_raw(data)
item = JobLog(
job_id=got.project_id,
node_id=got.node_id,
log_level=got.log_level,
messages=got.messages,
)
raise e
_logger.debug(
"LogDistributor._distribute_logs received message message from RabbitMQ: %s",
got.json(),
)
item = JobLog(
job_id=got.project_id,
node_id=got.node_id,
log_level=got.log_level,
messages=got.messages,
)
queue = self._log_streamers.get(item.job_id)
if queue is None:
raise LogStreamerNotRegistered(
f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered"
)
await queue.put(item)
return True
queue = self._log_streamers.get(item.job_id)
if queue is None:
raise LogStreamerNotRegistered(
f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered"
)
await queue.put(item)
return True
return False

async def register(self, job_id: JobID, queue: Queue):
async def register(self, job_id: JobID, queue: Queue[JobLog]):
if job_id in self._log_streamers:
raise LogStreamerRegistionConflict(
f"A stream was already connected to {job_id=}. Only a single stream can be connected at the time"
70 changes: 58 additions & 12 deletions services/api-server/tests/unit/test_services_rabbitmq.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from collections.abc import AsyncIterable, Callable
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Final, Iterable
from typing import Final, Iterable, Literal
from unittest.mock import AsyncMock

import httpx
@@ -24,9 +24,9 @@
from models_library.projects import ProjectID
from models_library.projects_nodes_io import NodeID
from models_library.projects_state import RunningState
from models_library.rabbitmq_messages import LoggerRabbitMessage
from models_library.rabbitmq_messages import LoggerRabbitMessage, RabbitMessageBase
from models_library.users import UserID
from pydantic import parse_obj_as
from pydantic import ValidationError, parse_obj_as
from pytest_mock import MockerFixture, MockFixture
from pytest_simcore.helpers.utils_envs import (
EnvVarsDict,
@@ -170,15 +170,23 @@ def produce_logs(
create_rabbitmq_client: Callable[[str], RabbitMQClient],
user_id: UserID,
):
async def _go(name, project_id_=None, node_id_=None, messages_=None, level_=None):
async def _go(
name,
project_id_=None,
node_id_=None,
messages_=None,
level_=None,
log_message: RabbitMessageBase | None = None,
):
rabbitmq_producer = create_rabbitmq_client(f"pytest_producer_{name}")
log_message = LoggerRabbitMessage(
user_id=user_id,
project_id=project_id_ or faker.uuid4(),
node_id=node_id_,
messages=messages_ or [faker.text() for _ in range(10)],
log_level=level_ or logging.INFO,
)
if log_message is None:
log_message = LoggerRabbitMessage(
user_id=user_id,
project_id=project_id_ or faker.uuid4(),
node_id=node_id_,
messages=messages_ or [faker.text() for _ in range(10)],
log_level=level_ or logging.INFO,
)
await rabbitmq_producer.publish(log_message.channel_name, log_message)

return _go
@@ -381,7 +389,6 @@ async def test_log_streamer_with_distributor(
async def _log_publisher():
while not computation_done():
msg: str = faker.text()
await asyncio.sleep(0.2)
await produce_logs("expected", project_id, node_id, [msg], logging.DEBUG)
published_logs.append(msg)

@@ -399,6 +406,45 @@ async def _log_publisher():
assert published_logs == collected_messages


async def test_log_streamer_not_raise_with_distributor(
client: httpx.AsyncClient,
app: FastAPI,
user_id,
project_id: ProjectID,
node_id: NodeID,
produce_logs: Callable,
log_streamer_with_distributor: LogStreamer,
faker: Faker,
computation_done: Callable[[], bool],
):
class InvalidLoggerRabbitMessage(LoggerRabbitMessage):
channel_name: Literal["simcore.services.logs.v2"] = "simcore.services.logs.v2"
node_id: NodeID | None
messages: int
log_level: int = logging.INFO

def routing_key(self) -> str:
return f"{self.project_id}.{self.log_level}"

log_rabbit_message = InvalidLoggerRabbitMessage(
user_id=user_id,
project_id=project_id,
node_id=node_id,
messages=100,
log_level=logging.INFO,
)
with pytest.raises(ValidationError):
LoggerRabbitMessage.parse_obj(log_rabbit_message.dict())

await produce_logs("expected", log_message=log_rabbit_message)

ii: int = 0
async for log in log_streamer_with_distributor.log_generator():
_ = JobLog.parse_raw(log)
ii += 1
assert ii == 0


async def test_log_generator(mocker: MockFixture, faker: Faker):
mocker.patch(
"simcore_service_api_server.services.log_streaming.LogStreamer._project_done",