Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ subscribe to socketio room based on the user_id #5270

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ..users import GroupID, UserID


class SocketIORoom(str):
GitHK marked this conversation as resolved.
Show resolved Hide resolved
__slots__ = ()

@classmethod
def from_socket_id(cls, socket_id: str) -> "SocketIORoom":
GitHK marked this conversation as resolved.
Show resolved Hide resolved
return cls(socket_id)

@classmethod
def from_group_id(cls, group_id: GroupID) -> "SocketIORoom":
return cls(group_id)
GitHK marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_user_id(cls, user_id: UserID) -> "SocketIORoom":
return cls(f"user:{user_id}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# pylint:disable=redefined-outer-name

import pytest
from faker import Faker
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.users import GroupID, UserID


@pytest.fixture
def user_id(faker: Faker) -> UserID:
return UserID(faker.pyint())


@pytest.fixture
def group_id(faker: Faker) -> GroupID:
return GroupID(faker.pyint())


@pytest.fixture
def socket_id(faker: Faker) -> str:
return faker.pystr()


def test_socketio_room(user_id: UserID, group_id: GroupID, socket_id: str):
assert SocketIORoom.from_user_id(user_id) == f"user:{user_id}"
assert SocketIORoom.from_group_id(group_id) == f"{group_id}"
assert SocketIORoom.from_socket_id(socket_id) == socket_id
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SOCKET_IO_PAYMENT_COMPLETED_EVENT,
SOCKET_IO_PAYMENT_METHOD_ACKED_EVENT,
)
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.api_schemas_webserver.wallets import (
PaymentMethodTransaction,
PaymentTransaction,
Expand Down Expand Up @@ -48,7 +49,7 @@ async def notify_payment_completed(
return await self._sio_manager.emit(
SOCKET_IO_PAYMENT_COMPLETED_EVENT,
data=jsonable_encoder(payment, by_alias=True),
room=f"{user_primary_group_id}",
room=SocketIORoom.from_group_id(user_primary_group_id),
)

async def notify_payment_method_acked(
Expand All @@ -61,7 +62,7 @@ async def notify_payment_method_acked(
return await self._sio_manager.emit(
SOCKET_IO_PAYMENT_METHOD_ACKED_EVENT,
data=jsonable_encoder(payment_method, by_alias=True),
room=f"{user_primary_group_id}",
room=SocketIORoom.from_group_id(user_primary_group_id),
)


Expand All @@ -77,7 +78,6 @@ async def _on_startup() -> None:
assert Notifier.get_from_app_state(app) == notifier # nosec

async def _on_shutdown() -> None:

with contextlib.suppress(AttributeError):
Notifier.pop_from_app_state(app)

Expand Down
8 changes: 4 additions & 4 deletions services/payments/tests/unit/test_services_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from models_library.api_schemas_payments.socketio import (
SOCKET_IO_PAYMENT_COMPLETED_EVENT,
)
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.api_schemas_webserver.wallets import PaymentTransaction
from models_library.users import GroupID, UserID
from pydantic import parse_obj_as
Expand Down Expand Up @@ -95,13 +96,12 @@ def socketio_server_events(
mocker: MockerFixture,
user_primary_group_id: GroupID,
) -> dict[str, AsyncMock]:

user_room_name = f"{user_primary_group_id}"
room_name = SocketIORoom.from_group_id(user_primary_group_id)

# handlers
async def connect(sid: str, environ):
print("connecting", sid)
await socketio_server.enter_room(sid, user_room_name)
await socketio_server.enter_room(sid, room_name)

async def on_check(sid, data):
print("check", sid, data)
Expand All @@ -111,7 +111,7 @@ async def on_payment(sid, data):

async def disconnect(sid: str):
print("disconnecting", sid)
await socketio_server.leave_room(sid, user_room_name)
await socketio_server.leave_room(sid, room_name)

# spies
spy_connect = mocker.AsyncMock(wraps=connect)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
WalletCreditsMessage,
)
from models_library.socketio import SocketMessageDict
from models_library.users import GroupID
from pydantic import parse_raw_as
from servicelib.logging_utils import log_catch, log_context
from servicelib.rabbitmq import RabbitMQClient
Expand Down Expand Up @@ -152,13 +153,13 @@ async def _osparc_credits_message_parser(app: web.Application, data: bytes) -> b
wallet_groups = await wallets_api.list_wallet_groups_with_read_access_by_wallet(
app, wallet_id=rabbit_message.wallet_id
)
rooms_to_notify = [f"{item.gid}" for item in wallet_groups]
rooms_to_notify: list[GroupID] = [item.gid for item in wallet_groups]
for room in rooms_to_notify:
await send_group_messages(app, room, socket_messages)
return True


_EXCHANGE_TO_PARSER_CONFIG: Final[tuple[SubcribeArgumentsTuple, ...,]] = (
_EXCHANGE_TO_PARSER_CONFIG: Final[tuple[SubcribeArgumentsTuple, ...]] = (
SubcribeArgumentsTuple(
LoggerRabbitMessage.get_channel_name(),
_log_message_parser,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
ServiceResourcesDictHelpers,
)
from models_library.socketio import SocketMessageDict
from models_library.users import UserID
from models_library.users import GroupID, UserID
from models_library.utils.fastapi_encoders import jsonable_encoder
from models_library.wallets import ZERO_CREDITS, WalletID, WalletInfo
from pydantic import ByteSize, parse_obj_as
Expand Down Expand Up @@ -1458,10 +1458,8 @@ async def notify_project_state_update(
if notify_only_user:
await send_messages(app, user_id=f"{notify_only_user}", messages=messages)
else:
rooms_to_notify = [
f"{gid}"
for gid, rights in project["accessRights"].items()
if rights["read"]
rooms_to_notify: list[GroupID] = [
gid for gid, rights in project["accessRights"].items() if rights["read"]
]
for room in rooms_to_notify:
await send_group_messages(app, room, messages)
Expand All @@ -1476,8 +1474,8 @@ async def notify_project_node_update(
if await is_project_hidden(app, ProjectID(project["uuid"])):
return

rooms_to_notify = [
f"{gid}" for gid, rights in project["accessRights"].items() if rights["read"]
rooms_to_notify: list[GroupID] = [
gid for gid, rights in project["accessRights"].items() if rights["read"]
]

messages: list[SocketMessageDict] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any

from aiohttp import web
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.socketio import SocketMessageDict
from models_library.users import UserID
from servicelib.aiohttp.observer import emit
Expand Down Expand Up @@ -89,7 +90,10 @@ async def _set_user_in_group_rooms(

sio = get_socket_server(app)
for group in groups:
sio.enter_room(socket_id, f"{group['gid']}")
# NOTE socketio need to be upgraded that's why enter_room is not an awaitable
sio.enter_room(socket_id, SocketIORoom.from_group_id(group["gid"]))

sio.enter_room(socket_id, SocketIORoom.from_user_id(user_id))


#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from typing import Final

from aiohttp.web import Application
from models_library.api_schemas_webserver.socketio import SocketIORoom
from models_library.socketio import SocketMessageDict
from models_library.users import UserID
from servicelib.aiohttp.application_keys import APP_FIRE_AND_FORGET_TASKS_KEY
from models_library.users import GroupID, UserID
from servicelib.json_serialization import json_dumps
from servicelib.utils import fire_and_forget_task, logged_gather
from servicelib.utils import logged_gather
from socketio import AsyncServer

from ..resource_manager.user_sessions import managed_resource
Expand Down Expand Up @@ -44,7 +44,11 @@ async def send_messages(

await logged_gather(
*(
sio.emit(message["event_type"], json_dumps(message["data"]), room=sid)
sio.emit(
message["event_type"],
json_dumps(message["data"]),
room=SocketIORoom.from_socket_id(sid),
)
for message in messages
for sid in socket_ids
),
Expand All @@ -54,32 +58,16 @@ async def send_messages(
)


async def post_messages(
GitHK marked this conversation as resolved.
Show resolved Hide resolved
app: Application, user_id: UserID, messages: Sequence[SocketMessageDict]
) -> None:
fire_and_forget_task(
send_messages(app, user_id, messages),
task_suffix_name=f"post_message_{user_id=}",
fire_and_forget_tasks_collection=app[APP_FIRE_AND_FORGET_TASKS_KEY],
)


async def post_group_messages(
app: Application, room: str, messages: Sequence[SocketMessageDict]
) -> None:
fire_and_forget_task(
send_group_messages(app, room, messages),
task_suffix_name=f"post_group_messages_{room=}",
fire_and_forget_tasks_collection=app[APP_FIRE_AND_FORGET_TASKS_KEY],
)


async def send_group_messages(
app: Application, room: str, messages: Sequence[SocketMessageDict]
app: Application, group_id: GroupID, messages: Sequence[SocketMessageDict]
) -> None:
sio: AsyncServer = get_socket_server(app)
send_tasks = [
sio.emit(message["event_type"], json_dumps(message["data"]), room=room)
sio.emit(
message["event_type"],
json_dumps(message["data"]),
room=SocketIORoom.from_group_id(group_id),
)
for message in messages
]

Expand Down
Loading