Skip to content

Commit

Permalink
✨ subscribe to socketio room based on the user_id (#5270)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrei Neagu <[email protected]>
  • Loading branch information
GitHK and Andrei Neagu authored Jan 26, 2024
1 parent 3c18555 commit 3693767
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ..users import GroupID, UserID


class SocketIORoom(str):
__slots__ = ()

@classmethod
def from_socket_id(cls, socket_id: str) -> "SocketIORoom":
return cls(socket_id)

@classmethod
def from_group_id(cls, group_id: GroupID) -> "SocketIORoom":
return cls(f"group:{group_id}")

@classmethod
def from_user_id(cls, user_id: UserID) -> "SocketIORoom":
return cls(f"user:{user_id}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# pylint:disable=redefined-outer-name

import pytest
from faker import Faker
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.users import GroupID, UserID


@pytest.fixture
def user_id(faker: Faker) -> UserID:
return UserID(faker.pyint())


@pytest.fixture
def group_id(faker: Faker) -> GroupID:
return GroupID(faker.pyint())


@pytest.fixture
def socket_id(faker: Faker) -> str:
return faker.pystr()


def test_socketio_room(user_id: UserID, group_id: GroupID, socket_id: str):
assert SocketIORoom.from_user_id(user_id) == f"user:{user_id}"
assert SocketIORoom.from_group_id(group_id) == f"group:{group_id}"
assert SocketIORoom.from_socket_id(socket_id) == socket_id
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SOCKET_IO_PAYMENT_COMPLETED_EVENT,
SOCKET_IO_PAYMENT_METHOD_ACKED_EVENT,
)
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.api_schemas_webserver.wallets import (
PaymentMethodTransaction,
PaymentTransaction,
Expand Down Expand Up @@ -48,7 +49,7 @@ async def notify_payment_completed(
return await self._sio_manager.emit(
SOCKET_IO_PAYMENT_COMPLETED_EVENT,
data=jsonable_encoder(payment, by_alias=True),
room=f"{user_primary_group_id}",
room=SocketIORoom.from_group_id(user_primary_group_id),
)

async def notify_payment_method_acked(
Expand All @@ -61,7 +62,7 @@ async def notify_payment_method_acked(
return await self._sio_manager.emit(
SOCKET_IO_PAYMENT_METHOD_ACKED_EVENT,
data=jsonable_encoder(payment_method, by_alias=True),
room=f"{user_primary_group_id}",
room=SocketIORoom.from_group_id(user_primary_group_id),
)


Expand All @@ -77,7 +78,6 @@ async def _on_startup() -> None:
assert Notifier.get_from_app_state(app) == notifier # nosec

async def _on_shutdown() -> None:

with contextlib.suppress(AttributeError):
Notifier.pop_from_app_state(app)

Expand Down
8 changes: 4 additions & 4 deletions services/payments/tests/unit/test_services_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from models_library.api_schemas_payments.socketio import (
SOCKET_IO_PAYMENT_COMPLETED_EVENT,
)
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.api_schemas_webserver.wallets import PaymentTransaction
from models_library.users import GroupID, UserID
from pydantic import parse_obj_as
Expand Down Expand Up @@ -95,13 +96,12 @@ def socketio_server_events(
mocker: MockerFixture,
user_primary_group_id: GroupID,
) -> dict[str, AsyncMock]:

user_room_name = f"{user_primary_group_id}"
room_name = SocketIORoom.from_group_id(user_primary_group_id)

# handlers
async def connect(sid: str, environ):
print("connecting", sid)
await socketio_server.enter_room(sid, user_room_name)
await socketio_server.enter_room(sid, room_name)

async def on_check(sid, data):
print("check", sid, data)
Expand All @@ -111,7 +111,7 @@ async def on_payment(sid, data):

async def disconnect(sid: str):
print("disconnecting", sid)
await socketio_server.leave_room(sid, user_room_name)
await socketio_server.leave_room(sid, room_name)

# spies
spy_connect = mocker.AsyncMock(wraps=connect)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
WalletCreditsMessage,
)
from models_library.socketio import SocketMessageDict
from models_library.users import GroupID
from pydantic import parse_raw_as
from servicelib.logging_utils import log_catch, log_context
from servicelib.rabbitmq import RabbitMQClient
Expand Down Expand Up @@ -152,13 +153,13 @@ async def _osparc_credits_message_parser(app: web.Application, data: bytes) -> b
wallet_groups = await wallets_api.list_wallet_groups_with_read_access_by_wallet(
app, wallet_id=rabbit_message.wallet_id
)
rooms_to_notify = [f"{item.gid}" for item in wallet_groups]
rooms_to_notify: list[GroupID] = [item.gid for item in wallet_groups]
for room in rooms_to_notify:
await send_group_messages(app, room, socket_messages)
return True


_EXCHANGE_TO_PARSER_CONFIG: Final[tuple[SubcribeArgumentsTuple, ...,]] = (
_EXCHANGE_TO_PARSER_CONFIG: Final[tuple[SubcribeArgumentsTuple, ...]] = (
SubcribeArgumentsTuple(
LoggerRabbitMessage.get_channel_name(),
_log_message_parser,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
ServiceResourcesDictHelpers,
)
from models_library.socketio import SocketMessageDict
from models_library.users import UserID
from models_library.users import GroupID, UserID
from models_library.utils.fastapi_encoders import jsonable_encoder
from models_library.wallets import ZERO_CREDITS, WalletID, WalletInfo
from pydantic import ByteSize, parse_obj_as
Expand Down Expand Up @@ -1458,10 +1458,8 @@ async def notify_project_state_update(
if notify_only_user:
await send_messages(app, user_id=f"{notify_only_user}", messages=messages)
else:
rooms_to_notify = [
f"{gid}"
for gid, rights in project["accessRights"].items()
if rights["read"]
rooms_to_notify: list[GroupID] = [
gid for gid, rights in project["accessRights"].items() if rights["read"]
]
for room in rooms_to_notify:
await send_group_messages(app, room, messages)
Expand All @@ -1476,8 +1474,8 @@ async def notify_project_node_update(
if await is_project_hidden(app, ProjectID(project["uuid"])):
return

rooms_to_notify = [
f"{gid}" for gid, rights in project["accessRights"].items() if rights["read"]
rooms_to_notify: list[GroupID] = [
gid for gid, rights in project["accessRights"].items() if rights["read"]
]

messages: list[SocketMessageDict] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any

from aiohttp import web
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.socketio import SocketMessageDict
from models_library.users import UserID
from servicelib.aiohttp.observer import emit
Expand Down Expand Up @@ -89,7 +90,10 @@ async def _set_user_in_group_rooms(

sio = get_socket_server(app)
for group in groups:
sio.enter_room(socket_id, f"{group['gid']}")
# NOTE socketio need to be upgraded that's why enter_room is not an awaitable
sio.enter_room(socket_id, SocketIORoom.from_group_id(group["gid"]))

sio.enter_room(socket_id, SocketIORoom.from_user_id(user_id))


#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from typing import Final

from aiohttp.web import Application
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.socketio import SocketMessageDict
from models_library.users import UserID
from servicelib.aiohttp.application_keys import APP_FIRE_AND_FORGET_TASKS_KEY
from models_library.users import GroupID, UserID
from servicelib.json_serialization import json_dumps
from servicelib.utils import fire_and_forget_task, logged_gather
from servicelib.utils import logged_gather
from socketio import AsyncServer

from ..resource_manager.user_sessions import managed_resource
Expand Down Expand Up @@ -44,7 +44,11 @@ async def send_messages(

await logged_gather(
*(
sio.emit(message["event_type"], json_dumps(message["data"]), room=sid)
sio.emit(
message["event_type"],
json_dumps(message["data"]),
room=SocketIORoom.from_socket_id(sid),
)
for message in messages
for sid in socket_ids
),
Expand All @@ -54,32 +58,16 @@ async def send_messages(
)


async def post_messages(
app: Application, user_id: UserID, messages: Sequence[SocketMessageDict]
) -> None:
fire_and_forget_task(
send_messages(app, user_id, messages),
task_suffix_name=f"post_message_{user_id=}",
fire_and_forget_tasks_collection=app[APP_FIRE_AND_FORGET_TASKS_KEY],
)


async def post_group_messages(
app: Application, room: str, messages: Sequence[SocketMessageDict]
) -> None:
fire_and_forget_task(
send_group_messages(app, room, messages),
task_suffix_name=f"post_group_messages_{room=}",
fire_and_forget_tasks_collection=app[APP_FIRE_AND_FORGET_TASKS_KEY],
)


async def send_group_messages(
app: Application, room: str, messages: Sequence[SocketMessageDict]
app: Application, group_id: GroupID, messages: Sequence[SocketMessageDict]
) -> None:
sio: AsyncServer = get_socket_server(app)
send_tasks = [
sio.emit(message["event_type"], json_dumps(message["data"]), room=room)
sio.emit(
message["event_type"],
json_dumps(message["data"]),
room=SocketIORoom.from_group_id(group_id),
)
for message in messages
]

Expand Down

0 comments on commit 3693767

Please sign in to comment.