diff --git a/api/routes/ws.py b/api/routes/ws.py index 095433f5..27721cde 100644 --- a/api/routes/ws.py +++ b/api/routes/ws.py @@ -1,10 +1,13 @@ +import logging + from fastapi import APIRouter from fastapi.websockets import WebSocket, WebSocketDisconnect from api import websocket_manager from api.websockets.data import Data -router = APIRouter() +router = APIRouter(tags=["websockets"]) +logger = logging.getLogger(__name__) @router.websocket("/master") @@ -23,9 +26,21 @@ async def master_endpoint(websocket: WebSocket): @router.post("/progress") -async def set_progress(progress: int): +def set_progress(progress: int): "Set the progress of the progress bar" websocket_manager.broadcast_sync( data=Data(data_type="progress", data={"progress": progress}) ) + + +@router.get("/get-active-connetions") +def get_active_connections(): + connections = websocket_manager.get_active_connections() + converted_connections = [ + f"{connection.client.host}:{connection.client.port}-{connection.client_state.name}" + for connection in connections + if connection.client is not None + ] + + return converted_connections diff --git a/api/websockets/manager.py b/api/websockets/manager.py index c4195274..2bff8563 100644 --- a/api/websockets/manager.py +++ b/api/websockets/manager.py @@ -185,3 +185,8 @@ async def close_all(self): torch.cuda.ipc_collect() self.active_connections = [] + + def get_active_connections(self): + "Returns the number of active websocket connections" + + return self.active_connections diff --git a/frontend/dist/assets/index.js b/frontend/dist/assets/index.js index 884f3ff6..b45c09da 100644 --- a/frontend/dist/assets/index.js +++ b/frontend/dist/assets/index.js @@ -41759,7 +41759,8 @@ const useWebsocket = defineStore("websocket", () => { const websocket = useWebSocket(`${webSocketUrl}/api/websockets/master`, { heartbeat: { message: "ping", - interval: 3e4 + interval: 1e3, + pongTimeout: 5e3 }, immediate: false, onMessage: (ws, event2) => { @@ -41779,7 +41780,6 @@ const useWebsocket = defineStore("websocket", () => { onConnectedCallbacks.forEach((callback) => callback()); }, onDisconnected: () => { - messageProvider.error("Disconnected from server"); onDisconnectedCallbacks.forEach((callback) => callback()); } }); diff --git a/frontend/src/store/websockets.ts b/frontend/src/store/websockets.ts index cfa44b2e..45fbd87d 100644 --- a/frontend/src/store/websockets.ts +++ b/frontend/src/store/websockets.ts @@ -21,7 +21,8 @@ export const useWebsocket = defineStore("websocket", () => { const websocket = useWebSocket(`${webSocketUrl}/api/websockets/master`, { heartbeat: { message: "ping", - interval: 30000, + interval: 1000, + pongTimeout: 5000, }, immediate: false, onMessage: (ws: WebSocket, event: MessageEvent) => { @@ -42,7 +43,6 @@ export const useWebsocket = defineStore("websocket", () => { onConnectedCallbacks.forEach((callback) => callback()); }, onDisconnected: () => { - messageProvider.error("Disconnected from server"); onDisconnectedCallbacks.forEach((callback) => callback()); }, });