From 3875be038c520a658ab7fd2ee4c49ec36731e718 Mon Sep 17 00:00:00 2001 From: Zac Li Date: Fri, 19 May 2023 11:33:07 +0800 Subject: [PATCH] feat: replace with metrics middleware --- lcserve/backend/gateway.py | 141 ++++++++++++++++++------------------- 1 file changed, 70 insertions(+), 71 deletions(-) diff --git a/lcserve/backend/gateway.py b/lcserve/backend/gateway.py index 8bcc890e..9b21807e 100644 --- a/lcserve/backend/gateway.py +++ b/lcserve/backend/gateway.py @@ -9,7 +9,17 @@ from importlib import import_module from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) from docarray import Document, DocumentArray from jina import Gateway @@ -18,6 +28,7 @@ from jina.serve.runtimes.gateway.composite import CompositeGateway from jina.serve.runtimes.gateway.http.fastapi import FastAPIBaseGateway from pydantic import BaseModel, Field, ValidationError, create_model +from starlette.types import ASGIApp, Receive, Scope, Send from websockets.exceptions import ConnectionClosed from .playground.utils.helper import ( @@ -273,11 +284,10 @@ def __init__( self._fastapi_app_str = fastapi_app_str self._fix_sys_path() self._init_fastapi_app() - self._setup_metrics() self._configure_cors() self._register_healthz() self._register_modules() - # self._register_counters() + self._setup_metrics() @property def app(self) -> 'FastAPI': @@ -342,35 +352,12 @@ def _setup_metrics(self): unit="s", ) - def _register_counters(self): - # TODO: doesn't work for now - from fastapi.routing import APIRoute, APIWebSocketRoute - from starlette.routing import Route, WebSocketRoute - - ignore_paths = {'/healthz', '/dry_run', '/metrics'} - measured_routes = [] - - for route in self.app.routes: - route: Union[Route, WebSocketRoute, APIWebSocketRoute, APIRoute] - - if route.path in ignore_paths or hasattr(route.endpoint, '__decorated__'): - measured_routes.append(route) - continue - - if isinstance(route, (WebSocketRoute, APIWebSocketRoute)): - route.endpoint = measure_duration(self.ws_duration_counter)( - route.endpoint - ) - elif isinstance(route, (Route, APIRoute)): - route.endpoint = measure_duration(self.http_duration_counter)( - route.endpoint - ) - else: - self.logger.warning(f'Unknown route type: {type(route)}') - - measured_routes.append(route) - - self.app.router.routes = measured_routes + self.app.add_middleware( + MeasureDurationHTTPMiddleware, counter=self.http_duration_counter + ) + self.app.add_middleware( + MeasureDurationWebSocketMiddleware, counter=self.ws_duration_counter + ) def _register_healthz(self): @self.app.get("/healthz") @@ -524,7 +511,6 @@ class Config: 'tags': [SERVING], }, logger=self.logger, - duration_counter=self.http_duration_counter, ) elif route_type == RouteType.WEBSOCKET: @@ -543,7 +529,6 @@ class Config: }, include_callback_handlers=include_callback_handlers, logger=self.logger, - duration_counter=self.ws_duration_counter, ) @@ -627,7 +612,6 @@ def create_http_route( output_model: BaseModel, post_kwargs: Dict, logger: JinaLogger, - duration_counter: 'Counter', ): from fastapi import Depends, Form, HTTPException, Security, UploadFile, status from fastapi.encoders import jsonable_encoder @@ -694,7 +678,6 @@ def _the_parser(data: str = Form(...)) -> input_model: # If file params are present, we need to use a custom parser to make sure that # the input data included in the Form and parsed correctly. - @measure_duration(duration_counter) async def _the_http_route( input_data: input_model = Depends(_the_parser), auth_response: Any = Depends(_the_authorizer), @@ -713,7 +696,6 @@ async def _the_http_route( else: # If no file params are present, we include the input args in the Body. - @measure_duration(duration_counter) async def _the_http_route( input_data: input_model, auth_response: Any = Depends(_the_authorizer) ) -> output_model: @@ -730,7 +712,6 @@ async def _the_http_route( # If file params are present, we need to use a custom parser to make sure that # the input data included in the Form and parsed correctly. - @measure_duration(duration_counter) async def _the_http_route( input_data: input_model = Depends(_the_parser), **kwargs ) -> output_model: @@ -747,7 +728,6 @@ async def _the_http_route( else: # If no file params are present, we include the input args in the Body. - @measure_duration(duration_counter) async def _the_http_route(input_data: input_model) -> output_model: return await _the_route( input_data=input_data, @@ -769,7 +749,6 @@ def create_websocket_route( include_callback_handlers: bool, ws_kwargs: Dict, logger: JinaLogger, - duration_counter: 'Counter', ): from fastapi import ( Depends, @@ -927,7 +906,6 @@ def _get_error_msg(e: Union[WebSocketDisconnect, ConnectionClosed]) -> str: logger.info(f'Auth enabled for `{func.__name__}`') @app.websocket(**ws_kwargs) - @measure_duration(duration_counter) async def _create_ws_route( websocket: WebSocket, auth_response: Any = Depends(_the_authorizer) ) -> output_model: @@ -936,7 +914,6 @@ async def _create_ws_route( else: @app.websocket(**ws_kwargs) - @measure_duration(duration_counter) async def _create_ws_route(websocket: WebSocket) -> output_model: return await _the_route(websocket=websocket, auth_response=None) @@ -1011,46 +988,68 @@ def _get_result_type(): return _output_model_fields -def measure_duration(duration_counter): +class Timer: class SharedData: def __init__(self, last_reported_time): self.last_reported_time = last_reported_time - async def send_metrics_periodically( - duration_counter, interval, route_name, shared_data - ): + def __init__(self, interval: int): + self.interval = interval + + async def send_duration_periodically(self, shared_data, counter): while True: - await asyncio.sleep(interval) + await asyncio.sleep(self.interval) current_time = time.perf_counter() - if duration_counter: - duration_counter.add( - current_time - shared_data.last_reported_time, {"route": route_name} - ) + duration = current_time - shared_data.last_reported_time + print(f"Duration: {duration} seconds") + if counter: + counter.add(current_time - shared_data.last_reported_time) + shared_data.last_reported_time = current_time - def decorator(func): - @wraps(func) - async def wrapped(*args, **kwargs): - shared_data = SharedData(last_reported_time=time.perf_counter()) - # Start the async task which reports the metrics every 5s - send_metrics_task = asyncio.create_task( - send_metrics_periodically( - duration_counter, 5, func.__name__, shared_data - ) + +class BaseMeasureDurationMiddleware: + def __init__( + self, app: ASGIApp, scope_type: str, counter: Optional['Counter'] = None + ): + self.app = app + self.scope_type = scope_type + self.counter = counter + # TODO: figure out solution for static assets + self.skip_routes = [ + '/docs', + '/redoc', + '/openapi.json', + '/healthz', + '/dry_run', + '/metrics', + '/favicon.ico', + ] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == self.scope_type and scope['path'] not in self.skip_routes: + timer = Timer(5) + shared_data = timer.SharedData(last_reported_time=time.perf_counter()) + send_duration_task = asyncio.create_task( + timer.send_duration_periodically(shared_data, self.counter) ) try: - result = await func(*args, **kwargs) - return result + await self.app(scope, receive, send) finally: - send_metrics_task.cancel() - # Final metrics update to wrap up the untracked duration in the end - if duration_counter: - duration_counter.add( - time.perf_counter() - shared_data.last_reported_time, - {"route": func.__name__}, + send_duration_task.cancel() + if self.counter: + self.counter.add( + time.perf_counter() - shared_data.last_reported_time ) + else: + await self.app(scope, receive, send) + + +class MeasureDurationHTTPMiddleware(BaseMeasureDurationMiddleware): + def __init__(self, app: ASGIApp, counter: Optional['Counter'] = None): + super().__init__(app, "http", counter) - wrapped.__decorated__ = True - return wrapped - return decorator +class MeasureDurationWebSocketMiddleware(BaseMeasureDurationMiddleware): + def __init__(self, app: ASGIApp, counter: Optional['Counter'] = None): + super().__init__(app, "websocket", counter)