Skip to content

Commit

Permalink
🐛 make logstreaming callback safer (#5633)
Browse files Browse the repository at this point in the history
  • Loading branch information
bisgaard-itis authored Apr 8, 2024
1 parent c5aa314 commit 5c95fe2
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 42 deletions.
6 changes: 4 additions & 2 deletions packages/service-library/src/servicelib/rabbitmq/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import aio_pika
from pydantic import NonNegativeInt

from ..logging_utils import log_context
from ..logging_utils import log_catch, log_context
from ._client_base import RabbitMQClientBase
from ._models import MessageHandler, RabbitMessage
from ._utils import (
Expand Down Expand Up @@ -82,7 +82,9 @@ async def _on_message(
if not await message_handler(message.body):
await _safe_nack(message_handler, max_retries_upon_error, message)
except Exception: # pylint: disable=broad-exception-caught
await _safe_nack(message_handler, max_retries_upon_error, message)
_logger.exception("Exception raised when handling message")
with log_catch(_logger, reraise=False):
await _safe_nack(message_handler, max_retries_upon_error, message)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from models_library.rabbitmq_messages import LoggerRabbitMessage
from models_library.users import UserID
from pydantic import NonNegativeInt, ValidationError
from pydantic import NonNegativeInt
from servicelib.logging_utils import log_catch
from servicelib.rabbitmq import RabbitMQClient

from ..models.schemas.jobs import JobID, JobLog
Expand All @@ -31,7 +32,7 @@ class LogStreamerRegistionConflict(LogDistributionBaseException):
class LogDistributor:
def __init__(self, rabbitmq_client: RabbitMQClient):
self._rabbit_client = rabbitmq_client
self._log_streamers: dict[JobID, Queue] = {}
self._log_streamers: dict[JobID, Queue[JobLog]] = {}
self._queue_name: str

async def setup(self):
Expand All @@ -53,34 +54,24 @@ async def __aexit__(self, exc_type, exc, tb):
await self.teardown()

async def _distribute_logs(self, data: bytes):
try:
got = LoggerRabbitMessage.parse_raw(
data
) # rabbitmq client safe_nacks the message if this deserialization fails
except ValidationError as e:
_logger.debug(
"Could not parse log message from RabbitMQ in LogDistributor._distribute_logs"
with log_catch(_logger, reraise=False):
got = LoggerRabbitMessage.parse_raw(data)
item = JobLog(
job_id=got.project_id,
node_id=got.node_id,
log_level=got.log_level,
messages=got.messages,
)
raise e
_logger.debug(
"LogDistributor._distribute_logs received message message from RabbitMQ: %s",
got.json(),
)
item = JobLog(
job_id=got.project_id,
node_id=got.node_id,
log_level=got.log_level,
messages=got.messages,
)
queue = self._log_streamers.get(item.job_id)
if queue is None:
raise LogStreamerNotRegistered(
f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered"
)
await queue.put(item)
return True
queue = self._log_streamers.get(item.job_id)
if queue is None:
raise LogStreamerNotRegistered(
f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered"
)
await queue.put(item)
return True
return False

async def register(self, job_id: JobID, queue: Queue):
async def register(self, job_id: JobID, queue: Queue[JobLog]):
if job_id in self._log_streamers:
raise LogStreamerRegistionConflict(
f"A stream was already connected to {job_id=}. Only a single stream can be connected at the time"
Expand Down
70 changes: 58 additions & 12 deletions services/api-server/tests/unit/test_services_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import AsyncIterable, Callable
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Final, Iterable
from typing import Final, Iterable, Literal
from unittest.mock import AsyncMock

import httpx
Expand All @@ -24,9 +24,9 @@
from models_library.projects import ProjectID
from models_library.projects_nodes_io import NodeID
from models_library.projects_state import RunningState
from models_library.rabbitmq_messages import LoggerRabbitMessage
from models_library.rabbitmq_messages import LoggerRabbitMessage, RabbitMessageBase
from models_library.users import UserID
from pydantic import parse_obj_as
from pydantic import ValidationError, parse_obj_as
from pytest_mock import MockerFixture, MockFixture
from pytest_simcore.helpers.utils_envs import (
EnvVarsDict,
Expand Down Expand Up @@ -170,15 +170,23 @@ def produce_logs(
create_rabbitmq_client: Callable[[str], RabbitMQClient],
user_id: UserID,
):
async def _go(name, project_id_=None, node_id_=None, messages_=None, level_=None):
async def _go(
name,
project_id_=None,
node_id_=None,
messages_=None,
level_=None,
log_message: RabbitMessageBase | None = None,
):
rabbitmq_producer = create_rabbitmq_client(f"pytest_producer_{name}")
log_message = LoggerRabbitMessage(
user_id=user_id,
project_id=project_id_ or faker.uuid4(),
node_id=node_id_,
messages=messages_ or [faker.text() for _ in range(10)],
log_level=level_ or logging.INFO,
)
if log_message is None:
log_message = LoggerRabbitMessage(
user_id=user_id,
project_id=project_id_ or faker.uuid4(),
node_id=node_id_,
messages=messages_ or [faker.text() for _ in range(10)],
log_level=level_ or logging.INFO,
)
await rabbitmq_producer.publish(log_message.channel_name, log_message)

return _go
Expand Down Expand Up @@ -381,7 +389,6 @@ async def test_log_streamer_with_distributor(
async def _log_publisher():
while not computation_done():
msg: str = faker.text()
await asyncio.sleep(0.2)
await produce_logs("expected", project_id, node_id, [msg], logging.DEBUG)
published_logs.append(msg)

Expand All @@ -399,6 +406,45 @@ async def _log_publisher():
assert published_logs == collected_messages


async def test_log_streamer_not_raise_with_distributor(
client: httpx.AsyncClient,
app: FastAPI,
user_id,
project_id: ProjectID,
node_id: NodeID,
produce_logs: Callable,
log_streamer_with_distributor: LogStreamer,
faker: Faker,
computation_done: Callable[[], bool],
):
class InvalidLoggerRabbitMessage(LoggerRabbitMessage):
channel_name: Literal["simcore.services.logs.v2"] = "simcore.services.logs.v2"
node_id: NodeID | None
messages: int
log_level: int = logging.INFO

def routing_key(self) -> str:
return f"{self.project_id}.{self.log_level}"

log_rabbit_message = InvalidLoggerRabbitMessage(
user_id=user_id,
project_id=project_id,
node_id=node_id,
messages=100,
log_level=logging.INFO,
)
with pytest.raises(ValidationError):
LoggerRabbitMessage.parse_obj(log_rabbit_message.dict())

await produce_logs("expected", log_message=log_rabbit_message)

ii: int = 0
async for log in log_streamer_with_distributor.log_generator():
_ = JobLog.parse_raw(log)
ii += 1
assert ii == 0


async def test_log_generator(mocker: MockFixture, faker: Faker):
mocker.patch(
"simcore_service_api_server.services.log_streaming.LogStreamer._project_done",
Expand Down

0 comments on commit 5c95fe2

Please sign in to comment.