From d27e3056f98e34951653d74d3163c44174e38535 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Wed, 19 Apr 2023 19:50:01 +0530 Subject: [PATCH] fix: handle connection closed --- lcserve/backend/gateway.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/lcserve/backend/gateway.py b/lcserve/backend/gateway.py index 03297d04..d2841da6 100644 --- a/lcserve/backend/gateway.py +++ b/lcserve/backend/gateway.py @@ -8,7 +8,7 @@ from importlib import import_module from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Type, Any +from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Type, Any, Union from docarray import Document, DocumentArray from jina import Gateway @@ -16,6 +16,7 @@ from jina.serve.runtimes.gateway.composite import CompositeGateway from jina.serve.runtimes.gateway.http.fastapi import FastAPIBaseGateway from pydantic import Field, ValidationError, create_model +from websockets.exceptions import ConnectionClosed from .playground.utils.helper import ( AGENT_OUTPUT, @@ -462,6 +463,15 @@ async def _create_ws_route(websocket: WebSocket): with BuiltinsWrapper( websocket=websocket, output_model=output_model, wrap_print=False ): + + def _get_error_msg( + e: Union[WebSocketDisconnect, ConnectionClosed] + ) -> str: + return ( + f'Client {websocket.client} disconnected from `{func.__name__}` with code {e.code}' + + (f' and reason {e.reason}' if e.reason else '') + ) + await websocket.accept() _ws_recv_lock = asyncio.Lock() try: @@ -546,10 +556,8 @@ async def _create_ws_route(websocket: WebSocket): await websocket.close() break - except WebSocketDisconnect as e: - self.logger.info( - f'Client {websocket.client} disconnected from `{func.__name__}` with code {e.code} and reason {e.reason}' - ) + except (WebSocketDisconnect, ConnectionClosed) as e: + self.logger.info(_get_error_msg(e)) break except Exception as e: @@ -565,8 +573,6 @@ async def _create_ws_route(websocket: WebSocket): if _ws_serving_error != '': print(f'Error: {_ws_serving_error}') - except WebSocketDisconnect as e: - self.logger.info( - f'Client {websocket.client} disconnected from `{func.__name__}` with code {e.code} and reason {e.reason}' - ) + except (WebSocketDisconnect, ConnectionClosed) as e: + self.logger.info(_get_error_msg(e)) return