Skip to content

Commit

Permalink
Rabbitmq/only listens to logs if needed (#4180)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg authored May 8, 2023
1 parent ad4d076 commit 2e8d8e4
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 36 deletions.
3 changes: 0 additions & 3 deletions packages/pytest-simcore/src/pytest_simcore/rabbit_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,3 @@ def _creator(client_name: str) -> RabbitMQClient:
yield _creator
# cleanup, properly close the clients
await asyncio.gather(*(client.close() for client in created_clients))
for client in created_clients:
assert client._channel_pool # pylint: disable=protected-access
assert client._channel_pool.is_closed # pylint: disable=protected-access
8 changes: 5 additions & 3 deletions packages/service-library/src/servicelib/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ async def rpc_initialize(self) -> None:
await self._rpc.initialize()

async def close(self) -> None:
with log_context(_logger, logging.INFO, msg="Closing connection to RabbitMQ"):
assert self._channel_pool # nosec
await self._channel_pool.close()
with log_context(
_logger,
logging.INFO,
msg=f"{self.client_name} closing connection to RabbitMQ",
):
assert self._connection_pool # nosec
await self._connection_pool.close()

Expand Down
2 changes: 0 additions & 2 deletions packages/service-library/tests/test_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ async def test_rabbit_client(rabbit_client_name: str, rabbit_service: RabbitSett
await client.close()
assert client._connection_pool
assert client._connection_pool.is_closed
assert client._channel_pool
assert client._channel_pool.is_closed


@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions services/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ services:
- traefik.http.services.${SWARM_STACK_NAME}_webserver.loadbalancer.healthcheck.path=/v0/
- traefik.http.services.${SWARM_STACK_NAME}_webserver.loadbalancer.healthcheck.interval=2000ms
- traefik.http.services.${SWARM_STACK_NAME}_webserver.loadbalancer.healthcheck.timeout=1000ms
# NOTE: stickyness must remain until the long running tasks in the webserver are removed
# and also https://github.com/ITISFoundation/osparc-simcore/pull/4180 is resolved.
- traefik.http.services.${SWARM_STACK_NAME}_webserver.loadbalancer.sticky.cookie=true
- traefik.http.services.${SWARM_STACK_NAME}_webserver.loadbalancer.sticky.cookie.samesite=lax
- traefik.http.services.${SWARM_STACK_NAME}_webserver.loadbalancer.sticky.cookie.httponly=true
Expand Down
15 changes: 7 additions & 8 deletions services/web/server/src/simcore_service_webserver/groups_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import Any, Optional
from typing import Any

import sqlalchemy as sa
from aiohttp import web
Expand Down Expand Up @@ -139,7 +139,7 @@ async def get_product_group_for_user(

async def create_user_group(
app: web.Application, user_id: int, new_group: dict
) -> dict[str, str]:
) -> dict[str, Any]:
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
result = await conn.execute(
Expand Down Expand Up @@ -281,9 +281,9 @@ async def add_user_in_group(
user_id: int,
gid: int,
*,
new_user_id: Optional[int] = None,
new_user_email: Optional[str] = None,
access_rights: Optional[AccessRightsDict] = None,
new_user_id: int | None = None,
new_user_email: str | None = None,
access_rights: AccessRightsDict | None = None,
) -> None:
"""
adds new_user (either by id or email) in group (with gid) owned by user_id
Expand Down Expand Up @@ -323,14 +323,13 @@ async def add_user_in_group(
async def _get_user_in_group_permissions(
conn: SAConnection, gid: int, the_user_id_in_group: int
) -> RowProxy:

# now get the user
result = await conn.execute(
sa.select([users, user_to_groups.c.access_rights])
.select_from(users.join(user_to_groups, users.c.id == user_to_groups.c.uid))
.where(and_(user_to_groups.c.gid == gid, users.c.id == the_user_id_in_group))
)
the_user: RowProxy = await result.fetchone()
the_user: RowProxy | None = await result.fetchone()
if not the_user:
raise UserInGroupNotFoundError(the_user_id_in_group, gid)
return the_user
Expand Down Expand Up @@ -410,7 +409,7 @@ async def delete_user_in_group(
)


async def get_group_from_gid(app: web.Application, gid: int) -> Optional[RowProxy]:
async def get_group_from_gid(app: web.Application, gid: int) -> RowProxy | None:
engine: Engine = app[APP_DB_ENGINE_KEY]

async with engine.acquire() as conn:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Final

APP_RABBITMQ_CONSUMERS_KEY: Final[str] = f"{__name__}.rabbit_consumers"
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from servicelib.json_serialization import json_dumps
from servicelib.logging_utils import log_context
from servicelib.rabbitmq import BIND_TO_ALL_TOPICS, RabbitMQClient
from servicelib.rabbitmq import RabbitMQClient
from servicelib.utils import logged_gather

from ..projects import projects_api
Expand All @@ -35,6 +35,7 @@
SocketMessageDict,
send_messages,
)
from ._constants import APP_RABBITMQ_CONSUMERS_KEY

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -170,7 +171,7 @@ async def _events_message_parser(app: web.Application, data: bytes) -> bool:
(
LoggerRabbitMessage.get_channel_name(),
_log_message_parser,
dict(topics=[BIND_TO_ALL_TOPICS]),
dict(topics=[]),
),
(
ProgressRabbitMessageNode.get_channel_name(),
Expand Down Expand Up @@ -201,6 +202,12 @@ async def setup_rabbitmq_consumers(app: web.Application) -> AsyncIterator[None]:
for exchange_name, parser_fct, queue_kwargs in EXCHANGE_TO_PARSER_CONFIG
)
)
app[APP_RABBITMQ_CONSUMERS_KEY] = {
exchange_name: queue_name
for (exchange_name, *_), queue_name in zip(
EXCHANGE_TO_PARSER_CONFIG, subscribed_queues
)
}

yield

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from aiohttp import web
from models_library.projects import ProjectID
from models_library.rabbitmq_messages import LoggerRabbitMessage
from servicelib.rabbitmq import RabbitMQClient

from ..rabbitmq import get_rabbitmq_client
from ._constants import APP_RABBITMQ_CONSUMERS_KEY


def _get_queue_name_from_exchange_name(app: web.Application, exchange_name: str) -> str:
exchange_to_queues = app[APP_RABBITMQ_CONSUMERS_KEY]
queue_name = exchange_to_queues[exchange_name]
return queue_name


async def subscribe(app: web.Application, project_id: ProjectID) -> None:
rabbit_client: RabbitMQClient = get_rabbitmq_client(app)
exchange_name = LoggerRabbitMessage.get_channel_name()
queue_name = _get_queue_name_from_exchange_name(app, exchange_name)
await rabbit_client.add_topics(
exchange_name, queue_name, topics=[f"{project_id}.*"]
)


async def unsubscribe(app: web.Application, project_id: ProjectID) -> None:
rabbit_client: RabbitMQClient = get_rabbitmq_client(app)
exchange_name = LoggerRabbitMessage.get_channel_name()
queue_name = _get_queue_name_from_exchange_name(app, exchange_name)
await rabbit_client.remove_topics(
exchange_name, queue_name, topics=[f"{project_id}.*"]
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import logging

from aiohttp import web
from models_library.projects import ProjectID
from servicelib.observer import event_registry as _event_registry
from servicelib.observer import observe
from servicelib.utils import logged_gather

from ..notifications import project_logs
from ..resource_manager.websocket_manager import PROJECT_ID_KEY, managed_resource
from .projects_api import retrieve_and_notify_project_locked_state

Expand All @@ -23,6 +25,10 @@ async def _on_user_disconnected(
with managed_resource(user_id, client_session_id, app) as rt:
list_projects: list[str] = await rt.find(PROJECT_ID_KEY)

await logged_gather(
*[project_logs.unsubscribe(app, ProjectID(prj)) for prj in list_projects]
)

await logged_gather(
*[
retrieve_and_notify_project_locked_state(
Expand All @@ -33,7 +39,7 @@ async def _on_user_disconnected(
)


def setup_project_events(_app: web.Application):
def setup_project_events(_app: web.Application) -> None:
# For the moment, this is only used as a placeholder to import this file
# This way the functions above are registered as handlers of a give event
# using the @observe decorator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .._meta import api_version_prefix as VTAG
from ..director_v2.exceptions import DirectorServiceError
from ..login.decorators import login_required
from ..notifications import project_logs
from ..products.plugin import Product, get_current_product
from ..security.decorators import permission_required
from . import projects_api
Expand Down Expand Up @@ -101,6 +102,9 @@ async def open_project(request: web.Request) -> web.Response:
request.app, path_params.project_id, req_ctx.product_name
)

# we now need to receive logs for that project
await project_logs.subscribe(request.app, path_params.project_id)

# user id opened project uuid
if not query_params.disable_service_auto_start:
with contextlib.suppress(ProjectStartsTooManyDynamicNodes):
Expand Down Expand Up @@ -189,6 +193,7 @@ async def close_project(request: web.Request) -> web.Response:
X_SIMCORE_USER_AGENT, UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
),
)
await project_logs.unsubscribe(request.app, path_params.project_id)
raise web.HTTPNoContent(content_type=MIMETYPE_APPLICATION_JSON)
except ProjectNotFoundError as exc:
raise web.HTTPNotFound(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=redefined-outer-name
# pylint: disable=unused-argument
# pylint: disable=unused-variable
# pylint: disable=too-many-arguments

import asyncio
import json
Expand Down Expand Up @@ -45,6 +46,7 @@
from simcore_service_webserver.diagnostics.plugin import setup_diagnostics
from simcore_service_webserver.director_v2.plugin import setup_director_v2
from simcore_service_webserver.login.plugin import setup_login
from simcore_service_webserver.notifications import project_logs
from simcore_service_webserver.notifications.plugin import setup_notifications
from simcore_service_webserver.projects.plugin import setup_projects
from simcore_service_webserver.projects.project_models import ProjectDict
Expand Down Expand Up @@ -73,20 +75,23 @@

pytest_simcore_ops_services_selection = []

_STABLE_DELAY_S = 3


async def _assert_no_handler_not_called(handler: mock.Mock) -> None:
with pytest.raises(RetryError):
async for attempt in AsyncRetrying(
retry=retry_always,
stop=stop_after_delay(5),
stop=stop_after_delay(_STABLE_DELAY_S),
reraise=True,
wait=wait_fixed(1),
):
with attempt:
print(
f"--> checking no mesage reached webclient... {attempt.retry_state.attempt_number} attempt"
f"--> checking no message reached webclient for {attempt.retry_state.attempt_number}/{_STABLE_DELAY_S}s..."
)
handler.assert_not_called()
print(f"no calls received for {_STABLE_DELAY_S}s. very good.")


async def _assert_handler_called(handler: mock.Mock, expected_call: mock._Call) -> None:
Expand Down Expand Up @@ -177,6 +182,9 @@ async def rabbitmq_publisher(
@pytest.mark.parametrize(
"sender_same_user_id", [True, False], ids=lambda id: f"same_sender_id={id}"
)
@pytest.mark.parametrize(
"subscribe_to_logs", [True, False], ids=lambda id: f"subscribed={id}"
)
async def test_log_workflow(
client: TestClient,
rabbitmq_publisher: RabbitMQClient,
Expand All @@ -190,9 +198,10 @@ async def test_log_workflow(
project_hidden: bool,
aiopg_engine: aiopg.sa.Engine,
sender_same_user_id: bool,
subscribe_to_logs: bool,
):
"""
RabbitMQ --> Webserver --> Redis --> webclient (socketio)
RabbitMQ (TOPIC) --> Webserver --> Redis --> webclient (socketio)
"""
socket_io_conn = await socketio_client_factory(None, client)
Expand All @@ -208,19 +217,25 @@ async def test_log_workflow(
mock_log_handler = mocker.MagicMock()
socket_io_conn.on(SOCKET_IO_LOG_EVENT, handler=mock_log_handler)

project_id = ProjectID(user_project["uuid"])
random_node_id_in_project = NodeID(choice(list(user_project["workbench"])))
sender_user_id = UserID(logged_user["id"])
if sender_same_user_id is False:
sender_user_id = UserID(faker.pyint(min_value=logged_user["id"] + 1))

if subscribe_to_logs:
assert client.app
await project_logs.subscribe(client.app, project_id)

log_message = LoggerRabbitMessage(
user_id=sender_user_id,
project_id=ProjectID(user_project["uuid"]),
project_id=project_id,
node_id=random_node_id_in_project,
messages=[faker.text() for _ in range(10)],
)
await rabbitmq_publisher.publish(log_message.channel_name, log_message)

call_expected = not project_hidden and sender_same_user_id
call_expected = not project_hidden and sender_same_user_id and subscribe_to_logs
if call_expected:
expected_call = jsonable_encoder(
log_message, exclude={"user_id", "channel_name"}
Expand All @@ -230,6 +245,58 @@ async def test_log_workflow(
await _assert_no_handler_not_called(mock_log_handler)


@pytest.mark.parametrize("user_role", [UserRole.GUEST], ids=str)
async def test_log_workflow_only_receives_messages_if_subscribed(
client: TestClient,
rabbitmq_publisher: RabbitMQClient,
logged_user: UserInfoDict,
user_project: ProjectDict,
faker: Faker,
mocker: MockerFixture,
):
"""
RabbitMQ (TOPIC) --> Webserver
"""
mocked_send_messages = mocker.patch(
"simcore_service_webserver.notifications._rabbitmq_consumers.send_messages",
autospec=True,
)

project_id = ProjectID(user_project["uuid"])
random_node_id_in_project = NodeID(choice(list(user_project["workbench"])))
sender_user_id = UserID(logged_user["id"])

assert client.app
await project_logs.subscribe(client.app, project_id)

log_message = LoggerRabbitMessage(
user_id=sender_user_id,
project_id=project_id,
node_id=random_node_id_in_project,
messages=[faker.text() for _ in range(10)],
)
await rabbitmq_publisher.publish(log_message.channel_name, log_message)
await _assert_handler_called(
mocked_send_messages,
mock.call(
client.app,
f"{log_message.user_id}",
[
{
"event_type": SOCKET_IO_LOG_EVENT,
"data": log_message.dict(exclude={"user_id", "channel_name"}),
}
],
),
)
mocked_send_messages.reset_mock()

# when unsubscribed, we do not receive the messages anymore
await project_logs.unsubscribe(client.app, project_id)
await _assert_no_handler_not_called(mocked_send_messages)


@pytest.mark.parametrize("user_role", [UserRole.GUEST], ids=str)
@pytest.mark.parametrize(
"progress_type",
Expand Down
Loading

0 comments on commit 2e8d8e4

Please sign in to comment.