Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 make logstreaming callback safer #5633

6 changes: 4 additions & 2 deletions packages/service-library/src/servicelib/rabbitmq/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import aio_pika
from pydantic import NonNegativeInt

from ..logging_utils import log_context
from ..logging_utils import log_catch, log_context
from ._client_base import RabbitMQClientBase
from ._models import MessageHandler, RabbitMessage
from ._utils import (
Expand Down Expand Up @@ -82,7 +82,9 @@ async def _on_message(
if not await message_handler(message.body):
await _safe_nack(message_handler, max_retries_upon_error, message)
except Exception: # pylint: disable=broad-exception-caught
await _safe_nack(message_handler, max_retries_upon_error, message)
_logger.exception("Exception raised when handling message")
with log_catch(_logger):
bisgaard-itis marked this conversation as resolved.
Show resolved Hide resolved
await _safe_nack(message_handler, max_retries_upon_error, message)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from models_library.rabbitmq_messages import LoggerRabbitMessage
from models_library.users import UserID
from pydantic import NonNegativeInt, ValidationError
from pydantic import NonNegativeInt
from servicelib.rabbitmq import RabbitMQClient

from ..models.schemas.jobs import JobID, JobLog
Expand All @@ -31,7 +31,7 @@ class LogStreamerRegistionConflict(LogDistributionBaseException):
class LogDistributor:
def __init__(self, rabbitmq_client: RabbitMQClient):
self._rabbit_client = rabbitmq_client
self._log_streamers: dict[JobID, Queue] = {}
self._log_streamers: dict[JobID, Queue[JobLog]] = {}
self._queue_name: str

async def setup(self):
Expand All @@ -53,34 +53,27 @@ async def __aexit__(self, exc_type, exc, tb):
await self.teardown()

async def _distribute_logs(self, data: bytes):
queue: Queue | None = None
try:
got = LoggerRabbitMessage.parse_raw(
data
) # rabbitmq client safe_nacks the message if this deserialization fails
except ValidationError as e:
_logger.debug(
"Could not parse log message from RabbitMQ in LogDistributor._distribute_logs"
got = LoggerRabbitMessage.parse_raw(data)
item = JobLog(
job_id=got.project_id,
node_id=got.node_id,
log_level=got.log_level,
messages=got.messages,
)
raise e
_logger.debug(
"LogDistributor._distribute_logs received message message from RabbitMQ: %s",
got.json(),
)
item = JobLog(
job_id=got.project_id,
node_id=got.node_id,
log_level=got.log_level,
messages=got.messages,
)
queue = self._log_streamers.get(item.job_id)
if queue is None:
raise LogStreamerNotRegistered(
f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered"
)
await queue.put(item)
queue = self._log_streamers.get(item.job_id)
if queue is None:
raise LogStreamerNotRegistered(
f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered"
)
await queue.put(item)
except Exception as exc: # pylint: disable=broad-except
bisgaard-itis marked this conversation as resolved.
Show resolved Hide resolved
bisgaard-itis marked this conversation as resolved.
Show resolved Hide resolved
_logger.exception("Exception raised in log distributor callback")
return False
return True

async def register(self, job_id: JobID, queue: Queue):
async def register(self, job_id: JobID, queue: Queue[JobLog]):
if job_id in self._log_streamers:
raise LogStreamerRegistionConflict(
f"A stream was already connected to {job_id=}. Only a single stream can be connected at the time"
Expand Down
70 changes: 58 additions & 12 deletions services/api-server/tests/unit/test_services_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import AsyncIterable, Callable
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Final, Iterable
from typing import Final, Iterable, Literal
from unittest.mock import AsyncMock

import httpx
Expand All @@ -24,9 +24,9 @@
from models_library.projects import ProjectID
from models_library.projects_nodes_io import NodeID
from models_library.projects_state import RunningState
from models_library.rabbitmq_messages import LoggerRabbitMessage
from models_library.rabbitmq_messages import LoggerRabbitMessage, RabbitMessageBase
from models_library.users import UserID
from pydantic import parse_obj_as
from pydantic import ValidationError, parse_obj_as
from pytest_mock import MockerFixture, MockFixture
from pytest_simcore.helpers.utils_envs import (
EnvVarsDict,
Expand Down Expand Up @@ -170,15 +170,23 @@ def produce_logs(
create_rabbitmq_client: Callable[[str], RabbitMQClient],
user_id: UserID,
):
async def _go(name, project_id_=None, node_id_=None, messages_=None, level_=None):
async def _go(
name,
project_id_=None,
node_id_=None,
messages_=None,
level_=None,
log_message: RabbitMessageBase | None = None,
):
rabbitmq_producer = create_rabbitmq_client(f"pytest_producer_{name}")
log_message = LoggerRabbitMessage(
user_id=user_id,
project_id=project_id_ or faker.uuid4(),
node_id=node_id_,
messages=messages_ or [faker.text() for _ in range(10)],
log_level=level_ or logging.INFO,
)
if log_message is None:
log_message = LoggerRabbitMessage(
user_id=user_id,
project_id=project_id_ or faker.uuid4(),
node_id=node_id_,
messages=messages_ or [faker.text() for _ in range(10)],
log_level=level_ or logging.INFO,
)
await rabbitmq_producer.publish(log_message.channel_name, log_message)

return _go
Expand Down Expand Up @@ -381,7 +389,6 @@ async def test_log_streamer_with_distributor(
async def _log_publisher():
while not computation_done():
msg: str = faker.text()
await asyncio.sleep(0.2)
await produce_logs("expected", project_id, node_id, [msg], logging.DEBUG)
published_logs.append(msg)

Expand All @@ -399,6 +406,45 @@ async def _log_publisher():
assert published_logs == collected_messages


async def test_log_streamer_not_raise_with_distributor(
client: httpx.AsyncClient,
app: FastAPI,
user_id,
project_id: ProjectID,
node_id: NodeID,
produce_logs: Callable,
log_streamer_with_distributor: LogStreamer,
faker: Faker,
computation_done: Callable[[], bool],
):
class InvalidLoggerRabbitMessage(LoggerRabbitMessage):
channel_name: Literal["simcore.services.logs.v2"] = "simcore.services.logs.v2"
node_id: NodeID | None
messages: int
log_level: int = logging.INFO

def routing_key(self) -> str:
return f"{self.project_id}.{self.log_level}"

log_rabbit_message = InvalidLoggerRabbitMessage(
user_id=user_id,
project_id=project_id,
node_id=node_id,
messages=100,
log_level=logging.INFO,
)
with pytest.raises(ValidationError):
LoggerRabbitMessage.parse_obj(log_rabbit_message.dict())

await produce_logs("expected", log_message=log_rabbit_message)

ii: int = 0
async for log in log_streamer_with_distributor.log_generator():
_ = JobLog.parse_raw(log)
ii += 1
assert ii == 0


async def test_log_generator(mocker: MockFixture, faker: Faker):
mocker.patch(
"simcore_service_api_server.services.log_streaming.LogStreamer._project_done",
Expand Down
Loading