Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
feat: replace with metrics middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
zac-li committed May 19, 2023
1 parent f26bf76 commit 3875be0
Showing 1 changed file with 70 additions and 71 deletions.
141 changes: 70 additions & 71 deletions lcserve/backend/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
from importlib import import_module
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)

from docarray import Document, DocumentArray
from jina import Gateway
Expand All @@ -18,6 +28,7 @@
from jina.serve.runtimes.gateway.composite import CompositeGateway
from jina.serve.runtimes.gateway.http.fastapi import FastAPIBaseGateway
from pydantic import BaseModel, Field, ValidationError, create_model
from starlette.types import ASGIApp, Receive, Scope, Send
from websockets.exceptions import ConnectionClosed

from .playground.utils.helper import (
Expand Down Expand Up @@ -273,11 +284,10 @@ def __init__(
self._fastapi_app_str = fastapi_app_str
self._fix_sys_path()
self._init_fastapi_app()
self._setup_metrics()
self._configure_cors()
self._register_healthz()
self._register_modules()
# self._register_counters()
self._setup_metrics()

@property
def app(self) -> 'FastAPI':
Expand Down Expand Up @@ -342,35 +352,12 @@ def _setup_metrics(self):
unit="s",
)

def _register_counters(self):
# TODO: doesn't work for now
from fastapi.routing import APIRoute, APIWebSocketRoute
from starlette.routing import Route, WebSocketRoute

ignore_paths = {'/healthz', '/dry_run', '/metrics'}
measured_routes = []

for route in self.app.routes:
route: Union[Route, WebSocketRoute, APIWebSocketRoute, APIRoute]

if route.path in ignore_paths or hasattr(route.endpoint, '__decorated__'):
measured_routes.append(route)
continue

if isinstance(route, (WebSocketRoute, APIWebSocketRoute)):
route.endpoint = measure_duration(self.ws_duration_counter)(
route.endpoint
)
elif isinstance(route, (Route, APIRoute)):
route.endpoint = measure_duration(self.http_duration_counter)(
route.endpoint
)
else:
self.logger.warning(f'Unknown route type: {type(route)}')

measured_routes.append(route)

self.app.router.routes = measured_routes
self.app.add_middleware(
MeasureDurationHTTPMiddleware, counter=self.http_duration_counter
)
self.app.add_middleware(
MeasureDurationWebSocketMiddleware, counter=self.ws_duration_counter
)

def _register_healthz(self):
@self.app.get("/healthz")
Expand Down Expand Up @@ -524,7 +511,6 @@ class Config:
'tags': [SERVING],
},
logger=self.logger,
duration_counter=self.http_duration_counter,
)

elif route_type == RouteType.WEBSOCKET:
Expand All @@ -543,7 +529,6 @@ class Config:
},
include_callback_handlers=include_callback_handlers,
logger=self.logger,
duration_counter=self.ws_duration_counter,
)


Expand Down Expand Up @@ -627,7 +612,6 @@ def create_http_route(
output_model: BaseModel,
post_kwargs: Dict,
logger: JinaLogger,
duration_counter: 'Counter',
):
from fastapi import Depends, Form, HTTPException, Security, UploadFile, status
from fastapi.encoders import jsonable_encoder
Expand Down Expand Up @@ -694,7 +678,6 @@ def _the_parser(data: str = Form(...)) -> input_model:
# If file params are present, we need to use a custom parser to make sure that
# the input data included in the Form and parsed correctly.

@measure_duration(duration_counter)
async def _the_http_route(
input_data: input_model = Depends(_the_parser),
auth_response: Any = Depends(_the_authorizer),
Expand All @@ -713,7 +696,6 @@ async def _the_http_route(
else:
# If no file params are present, we include the input args in the Body.

@measure_duration(duration_counter)
async def _the_http_route(
input_data: input_model, auth_response: Any = Depends(_the_authorizer)
) -> output_model:
Expand All @@ -730,7 +712,6 @@ async def _the_http_route(
# If file params are present, we need to use a custom parser to make sure that
# the input data included in the Form and parsed correctly.

@measure_duration(duration_counter)
async def _the_http_route(
input_data: input_model = Depends(_the_parser), **kwargs
) -> output_model:
Expand All @@ -747,7 +728,6 @@ async def _the_http_route(
else:
# If no file params are present, we include the input args in the Body.

@measure_duration(duration_counter)
async def _the_http_route(input_data: input_model) -> output_model:
return await _the_route(
input_data=input_data,
Expand All @@ -769,7 +749,6 @@ def create_websocket_route(
include_callback_handlers: bool,
ws_kwargs: Dict,
logger: JinaLogger,
duration_counter: 'Counter',
):
from fastapi import (
Depends,
Expand Down Expand Up @@ -927,7 +906,6 @@ def _get_error_msg(e: Union[WebSocketDisconnect, ConnectionClosed]) -> str:
logger.info(f'Auth enabled for `{func.__name__}`')

@app.websocket(**ws_kwargs)
@measure_duration(duration_counter)
async def _create_ws_route(
websocket: WebSocket, auth_response: Any = Depends(_the_authorizer)
) -> output_model:
Expand All @@ -936,7 +914,6 @@ async def _create_ws_route(
else:

@app.websocket(**ws_kwargs)
@measure_duration(duration_counter)
async def _create_ws_route(websocket: WebSocket) -> output_model:
return await _the_route(websocket=websocket, auth_response=None)

Expand Down Expand Up @@ -1011,46 +988,68 @@ def _get_result_type():
return _output_model_fields


def measure_duration(duration_counter):
class Timer:
class SharedData:
def __init__(self, last_reported_time):
self.last_reported_time = last_reported_time

async def send_metrics_periodically(
duration_counter, interval, route_name, shared_data
):
def __init__(self, interval: int):
self.interval = interval

async def send_duration_periodically(self, shared_data, counter):
while True:
await asyncio.sleep(interval)
await asyncio.sleep(self.interval)
current_time = time.perf_counter()
if duration_counter:
duration_counter.add(
current_time - shared_data.last_reported_time, {"route": route_name}
)
duration = current_time - shared_data.last_reported_time
print(f"Duration: {duration} seconds")
if counter:
counter.add(current_time - shared_data.last_reported_time)

shared_data.last_reported_time = current_time

def decorator(func):
@wraps(func)
async def wrapped(*args, **kwargs):
shared_data = SharedData(last_reported_time=time.perf_counter())
# Start the async task which reports the metrics every 5s
send_metrics_task = asyncio.create_task(
send_metrics_periodically(
duration_counter, 5, func.__name__, shared_data
)

class BaseMeasureDurationMiddleware:
def __init__(
self, app: ASGIApp, scope_type: str, counter: Optional['Counter'] = None
):
self.app = app
self.scope_type = scope_type
self.counter = counter
# TODO: figure out solution for static assets
self.skip_routes = [
'/docs',
'/redoc',
'/openapi.json',
'/healthz',
'/dry_run',
'/metrics',
'/favicon.ico',
]

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == self.scope_type and scope['path'] not in self.skip_routes:
timer = Timer(5)
shared_data = timer.SharedData(last_reported_time=time.perf_counter())
send_duration_task = asyncio.create_task(
timer.send_duration_periodically(shared_data, self.counter)
)
try:
result = await func(*args, **kwargs)
return result
await self.app(scope, receive, send)
finally:
send_metrics_task.cancel()
# Final metrics update to wrap up the untracked duration in the end
if duration_counter:
duration_counter.add(
time.perf_counter() - shared_data.last_reported_time,
{"route": func.__name__},
send_duration_task.cancel()
if self.counter:
self.counter.add(
time.perf_counter() - shared_data.last_reported_time
)
else:
await self.app(scope, receive, send)


class MeasureDurationHTTPMiddleware(BaseMeasureDurationMiddleware):
def __init__(self, app: ASGIApp, counter: Optional['Counter'] = None):
super().__init__(app, "http", counter)

wrapped.__decorated__ = True
return wrapped

return decorator
class MeasureDurationWebSocketMiddleware(BaseMeasureDurationMiddleware):
def __init__(self, app: ASGIApp, counter: Optional['Counter'] = None):
super().__init__(app, "websocket", counter)

0 comments on commit 3875be0

Please sign in to comment.