Skip to content

Commit

Permalink
upgrade: switched from in-memory queue to rabbitmq
Browse files Browse the repository at this point in the history
  • Loading branch information
crushr3sist committed Jun 5, 2024
1 parent 53854ef commit a7ae322
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 64 deletions.
50 changes: 34 additions & 16 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
version: "3.8"

volumes:
data:
rabbitmq-data:

networks:
my-network:
Expand All @@ -23,25 +26,40 @@ services:
image: redis:latest
ports:
- "6379:6379"
- target: 6379
networks:
- my-network
restart: always
privileged: true
cap_add:
- SYS_ADMIN


realmx:
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
depends_on:
- postgres
- redis
rabbitmq:
image: "rabbitmq:3-management-alpine"
hostname: "rabbitmq"
environment:
DATABASE_URL: "postgresql://postgres:ronny@localhost:5432"
REDIS_URL: "redis://redis:6379"
RABBITMQ_ERLANG_COOKIE: "SWQOKODSQALRPCLNMEQG"
RABBITMQ_DEFAULT_USER: "rabbitmq"
RABBITMQ_DEFAULT_PASS: "rabbitmq"
ports:
- "15672:15672"
- "5672:5672"
volumes:
- rabbitmq-data:/var/lib/rabbitmq
- ~/.docker-conf/rabbitmq/data/:/var/lib/rabbitmq/
- ~/.docker-conf/rabbitmq/log/:/var/log/rabbitmq
networks:
- my-network
restart: always

# realmx:
# build:
# context: .
# dockerfile: Dockerfile
# ports:
# - "8000:8000"
# depends_on:
# - postgres
# - redis
# - rabbitmq
# environment:
# DATABASE_URL: "postgresql://postgres:ronny@postgres:5432"
# REDIS_URL: "redis://redis:6379"
# networks:
# - my-network
116 changes: 69 additions & 47 deletions r3almX_backend/realtime_service/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import sys
import traceback
from queue import Queue
from typing import Dict

import aio_pika
from fastapi import Depends, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt

Expand All @@ -14,41 +14,40 @@
from r3almX_backend.realtime_service.connection_service import NotificationSystem
from r3almX_backend.realtime_service.main import realtime

rabbit_connection = None

def get_user_from_token(token: str, db) -> User:

try:
payload = jwt.decode(
token, UsersConfig.SECRET_KEY, algorithms=[UsersConfig.ALGORITHM]
async def get_rabbit_connection():
global rabbit_connection
if not rabbit_connection or rabbit_connection.is_closed:
rabbit_connection = await aio_pika.connect_robust(
"amqp://rabbitmq:rabbitmq@localhost:5672/"
)
username: str = payload.get("sub")
token_data = TokenData(username=username)
user = get_user_by_username(db, username=token_data.username)
return user
except JWTError as j:
return j
return rabbit_connection


class RoomManager:

def __init__(self):

self.rooms: Dict[str, set] = {}
self.message_queues: Dict[str, Queue] = {}
self.rabbit_queues: Dict[str, aio_pika.Queue] = {}
self.rabbit_channels: Dict[str, aio_pika.Channel] = {}
self.broadcast_tasks: Dict[str, asyncio.Task] = {}

async def broadcast(self, room_id: str):

try:
queue = self.message_queues[room_id]
room = self.rooms[room_id]
while True:
if not queue.empty():
message, user = queue.get_nowait()
for websocket in room:
await websocket.send_text(f"{user}: {message}")
await asyncio.sleep(0.1)
queue = self.rabbit_queues.get(room_id)
if queue is None:
print(f"Queue for room {room_id} is not initialized")
return

room = self.rooms[room_id]
async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process():
user, data = message.body.decode().split(":", 1)
for websocket in room:
await websocket.send_text(f"{user}: {data}")
print(f"Sent message to websocket: {user}: {data}")
except Exception as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
print(f"Error in broadcast task for room {room_id}: {e}")
Expand All @@ -57,55 +56,75 @@ async def broadcast(self, room_id: str):
)

async def start_broadcast_task(self, room_id: str):

if room_id not in self.broadcast_tasks:
print(f"Starting broadcast task for room {room_id}")
self.broadcast_tasks[room_id] = asyncio.create_task(self.broadcast(room_id))

async def stop_broadcast_task(self, room_id: str):

if room_id in self.broadcast_tasks:
print(f"Stopping broadcast task for room {room_id}")
task = self.broadcast_tasks.pop(room_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

def add_message_to_queue(self, room_id: str, message: str, user: str):

queue = self.message_queues.get(room_id)
if queue:
queue.put_nowait((message, user))
async def add_message_to_queue(self, room_id: str, message: str, user: str):
channel = self.rabbit_channels.get(room_id)
if channel:
await channel.default_exchange.publish(
aio_pika.Message(body=f"{user}:{message}".encode()),
routing_key=self.rabbit_queues[room_id].name,
)
print(
f"Added message to queue {self.rabbit_queues[room_id].name}: {user}:{message}"
)

async def connect_user(self, room_id: str, websocket: WebSocket):

room = self.rooms.get(room_id)
if room is None:
self.rooms[room_id] = set()
self.message_queues[room_id] = Queue()
self.broadcast_tasks[room_id] = asyncio.create_task(self.broadcast(room_id))
connection = await get_rabbit_connection()
channel = await connection.channel()
queue = await channel.declare_queue(room_id, auto_delete=True)
self.rabbit_queues[room_id] = queue
self.rabbit_channels[room_id] = channel
print(f"Declared queue for room {room_id}")
await self.start_broadcast_task(room_id)
self.rooms[room_id].add(websocket)
print(f"User connected to room {room_id}")

async def disconnect_user(self, room_id: str, websocket: WebSocket):

room = self.rooms.get(room_id)
if room:
room.remove(websocket)
if not room:
del self.rooms[room_id]
del self.message_queues[room_id]
task = self.broadcast_tasks.pop(room_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await self.rabbit_queues[room_id].delete()
del self.rabbit_queues[room_id]
await self.stop_broadcast_task(room_id)
print(f"Deleted queue and stopped task for room {room_id}")


room_manager = RoomManager()
notification_system = NotificationSystem()


def get_user_from_token(token: str, db) -> User:

try:
payload = jwt.decode(
token, UsersConfig.SECRET_KEY, algorithms=[UsersConfig.ALGORITHM]
)
username: str = payload.get("sub")
token_data = TokenData(username=username)
user = get_user_by_username(db, username=token_data.username)
return user
except JWTError as j:
return j


@realtime.websocket("/message/{room_id}")
async def websocket_endpoint(
websocket: WebSocket, room_id: str, token: str, db=Depends(get_db)
Expand All @@ -117,10 +136,9 @@ async def websocket_endpoint(
try:
while True:
data = await websocket.receive_text()
room_manager.add_message_to_queue(room_id, data, user.id)
await room_manager.start_broadcast_task(room_id)
await room_manager.add_message_to_queue(room_id, data, user.id)
await notification_system.send_notification_to_user(
user.id, f"New message in room {room_id}: {data}"
user.id, {"room_id": room_id, "data": data}
)
except WebSocketDisconnect:
await room_manager.disconnect_user(room_id, websocket)
Expand All @@ -129,11 +147,15 @@ async def websocket_endpoint(


@realtime.get("/message/rooms/")
def get_all_connections():

async def get_all_connections():
data = {}
for room_id, room in room_manager.rooms.items():
queue_size = room_manager.message_queues[room_id].qsize()
queue_size = 0
if room_id in room_manager.rabbit_queues:
queue = room_manager.rabbit_queues[room_id]
channel = room_manager.rabbit_channels[room_id]
queue_state = await channel.declare_queue(room_id, passive=True)
queue_size = queue_state.message_count
users = [str(websocket) for websocket in room]
data[room_id] = {
"queue_size": queue_size,
Expand Down
7 changes: 6 additions & 1 deletion r3almX_backend/realtime_service/connection_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def set_status_cache(self, user_id, status):
async def send_notification(self, user_id, message):
websocket = self.connection_sockets.get(user_id)
if websocket:
await websocket.send_text(message)
await websocket.send_json({"sender": str(user_id), "message": message})


connection_manager = Connection()
Expand Down Expand Up @@ -114,11 +114,16 @@ async def connect(websocket: WebSocket, token: str, db=Depends(get_db)):
await websocket.accept()
connection_manager.connect(user.id)
connection_manager.connection_sockets[user.id] = websocket

last_activity = datetime.datetime.now()
heartbeat_interval = 30
expiry_timeout = 100

try:
while True:
await websocket.send_json(
{"status": "200", "connection": "established"}
)
try:
if connection_manager.is_connected(user.id) is False:
connection_manager.connection_sockets[str(user.id)] = websocket
Expand Down
Loading

0 comments on commit a7ae322

Please sign in to comment.