Skip to content

Commit

Permalink
refactor: Improve WebSocketWorker class and user connection management
Browse files Browse the repository at this point in the history
  • Loading branch information
crushr3sist committed Jun 5, 2024
1 parent 1e0e594 commit d4a730c
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions r3almX_backend/realtime_service/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import sys
import traceback
from typing import Dict
from typing import Dict, Set

import aio_pika
from fastapi import Depends, WebSocket, WebSocketDisconnect
Expand All @@ -18,6 +18,10 @@


async def get_rabbit_connection():
"""
Establish a connection to the RabbitMQ server.
:return: RabbitMQ connection object.
"""
global rabbit_connection
if not rabbit_connection or rabbit_connection.is_closed:
rabbit_connection = await aio_pika.connect_robust(
Expand All @@ -28,27 +32,39 @@ async def get_rabbit_connection():

class RoomManager:
def __init__(self):
self.rooms: Dict[str, set] = {}
self.rooms: Dict[str, Set[WebSocket]] = {}
self.rabbit_queues: Dict[str, aio_pika.Queue] = {}
self.rabbit_channels: Dict[str, aio_pika.Channel] = {}
self.broadcast_tasks: Dict[str, asyncio.Task] = {}

async def broadcast(self, room_id: str):
"""
Broadcasts messages to all users in a room.
"""
try:
# Retrieve the RabbitMQ queue for the specified room
queue = self.rabbit_queues.get(room_id)
if queue is None:
print(f"Queue for room {room_id} is not initialized")
return

# Retrieve the set of WebSocket connections for the specified room
room = self.rooms[room_id]

# Iterate over messages in the RabbitMQ queue
async with queue.iterator() as queue_iter:
async for message in queue_iter:
# Process each message asynchronously
async with message.process():
# Decode the message body into user and data
user, data = message.body.decode().split(":", 1)

# Send the message to each WebSocket connection in the room
for websocket in room:
await websocket.send_text(f"{user}: {data}")
await websocket.send_text({user: data})
print(f"Sent message to websocket: {user}: {data}")
except Exception as e:
# Handle exceptions gracefully
exc_type, exc_value, exc_traceback = sys.exc_info()
print(f"Error in broadcast task for room {room_id}: {e}")
traceback.print_exception(
Expand Down Expand Up @@ -112,7 +128,6 @@ async def disconnect_user(self, room_id: str, websocket: WebSocket):


def get_user_from_token(token: str, db) -> User:

try:
payload = jwt.decode(
token, UsersConfig.SECRET_KEY, algorithms=[UsersConfig.ALGORITHM]
Expand All @@ -122,7 +137,7 @@ def get_user_from_token(token: str, db) -> User:
user = get_user_by_username(db, username=token_data.username)
return user
except JWTError as j:
return j
return None


@realtime.websocket("/message/{room_id}")
Expand Down

0 comments on commit d4a730c

Please sign in to comment.