From 893a8e6ecee3c45ee7b3c5ece2ec8e8f6081c39f Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 26 Dec 2023 00:30:51 +0000 Subject: [PATCH 1/5] refactor api --- src/fastserve/core.py | 8 ++++++-- src/fastserve/models/llama_cpp.py | 4 ++-- tests/test_core.py | 4 ++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/fastserve/core.py b/src/fastserve/core.py index a2d2483..396a225 100644 --- a/src/fastserve/core.py +++ b/src/fastserve/core.py @@ -30,7 +30,7 @@ async def lifespan(app: FastAPI): self._app = FastAPI(lifespan=lifespan, title="FastServe") - def _serve( + def register_router( self, ): INPUT_SCHEMA = self.input_schema @@ -45,11 +45,15 @@ def api(request: INPUT_SCHEMA): def run_server( self, ): - self._serve() + self.register_router() import uvicorn uvicorn.run(self._app) + @property + def app(self): + return self._app + @property def test_client(self): from fastapi.testclient import TestClient diff --git a/src/fastserve/models/llama_cpp.py b/src/fastserve/models/llama_cpp.py index 65c1c58..297a64f 100644 --- a/src/fastserve/models/llama_cpp.py +++ b/src/fastserve/models/llama_cpp.py @@ -5,7 +5,7 @@ from llama_cpp import Llama from pydantic import BaseModel -from fastserve.core import ParallelFastServe +from fastserve.core import FastServe logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ class ResponseModel(BaseModel): finished: bool # Whether the whole request is finished. -class ServeLlamaCpp(ParallelFastServe): +class ServeLlamaCpp(FastServe): def __init__( self, model_path=DEFAULT_MODEL, diff --git a/tests/test_core.py b/tests/test_core.py index 94ee6dc..34eba38 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -20,7 +20,7 @@ def test_handle(): def test_run_server(): serve = FakeServe() - serve._serve() + serve.register_router() test_client = serve.test_client data = BaseRequest(request=1).model_dump_json() response = test_client.post("/endpoint", data=data) @@ -30,7 +30,7 @@ def test_run_server(): def test_unprocessable_content(): serve = FakeServe() - serve._serve() + serve.register_router() test_client = serve.test_client data = {} # wrong data format response = test_client.post("/endpoint", data=data) From bb1703ce24d43d70387f728b157f8a43a9e9d768 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 26 Dec 2023 01:07:09 +0000 Subject: [PATCH 2/5] add middleware --- src/fastserve/core.py | 46 ++++++++++++++++++++----------- src/fastserve/middleware.py | 11 ++++++++ src/fastserve/models/llama_cpp.py | 14 ++++++++-- 3 files changed, 52 insertions(+), 19 deletions(-) create mode 100644 src/fastserve/middleware.py diff --git a/src/fastserve/core.py b/src/fastserve/core.py index 396a225..772f17c 100644 --- a/src/fastserve/core.py +++ b/src/fastserve/core.py @@ -6,6 +6,7 @@ from .batching import BatchProcessor from .handler import BaseHandler, ParallelHandler +from .middleware import register_default_middlewares from .utils import BaseRequest logging.basicConfig( @@ -13,11 +14,15 @@ format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()], ) +logger = logging.getLogger(__name__) class BaseFastServe: - def __init__(self, handle: Callable, batch_size, timeout, input_schema) -> None: + def __init__( + self, handle: Callable, batch_size, timeout, input_schema, response_schema + ) -> None: self.input_schema = input_schema + self.response_schema = response_schema self.handle: Callable = handle self.batch_processing = BatchProcessor( func=self.handle, bs=batch_size, timeout=timeout @@ -28,31 +33,33 @@ async def lifespan(app: FastAPI): yield self.batch_processing.cancel() - self._app = FastAPI(lifespan=lifespan, title="FastServe") + self._app = FastAPI(lifespan=lifespan, title="FastServe", docs_url="/") + register_default_middlewares(self._app) + INPUT_SCHEMA = input_schema - def register_router( - self, - ): - INPUT_SCHEMA = self.input_schema - - @self._app.post(path="/endpoint") def api(request: INPUT_SCHEMA): print("incoming request") wait_obj = self.batch_processing.process(request) result = wait_obj.get() return result + self._app.add_api_route( + path="/endpoint", + endpoint=api, + methods=["post"], + response_model=response_schema, + ) + + @property + def app(self): + return self._app + def run_server( self, ): - self.register_router() import uvicorn - uvicorn.run(self._app) - - @property - def app(self): - return self._app + uvicorn.run(self.app) @property def test_client(self): @@ -62,7 +69,9 @@ def test_client(self): class FastServe(BaseFastServe, BaseHandler): - def __init__(self, batch_size=2, timeout=0.5, input_schema=None): + def __init__( + self, batch_size=2, timeout=0.5, input_schema=None, response_schema=None + ): if input_schema is None: input_schema = BaseRequest super().__init__( @@ -70,11 +79,14 @@ def __init__(self, batch_size=2, timeout=0.5, input_schema=None): batch_size=batch_size, timeout=timeout, input_schema=input_schema, + response_schema=response_schema, ) class ParallelFastServe(BaseFastServe, ParallelHandler): - def __init__(self, batch_size=2, timeout=0.5, input_schema=None): + def __init__( + self, batch_size=2, timeout=0.5, input_schema=None, response_schema=None + ): if input_schema is None: input_schema = BaseRequest super().__init__( @@ -82,4 +94,6 @@ def __init__(self, batch_size=2, timeout=0.5, input_schema=None): batch_size=batch_size, timeout=timeout, input_schema=input_schema, + response_schema=response_schema, ) + logger.info("Launching parallel handler!") diff --git a/src/fastserve/middleware.py b/src/fastserve/middleware.py new file mode 100644 index 0000000..bf3ae91 --- /dev/null +++ b/src/fastserve/middleware.py @@ -0,0 +1,11 @@ +from fastapi.middleware.cors import CORSMiddleware + + +def register_default_middlewares(app): + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) diff --git a/src/fastserve/models/llama_cpp.py b/src/fastserve/models/llama_cpp.py index 297a64f..550f7be 100644 --- a/src/fastserve/models/llama_cpp.py +++ b/src/fastserve/models/llama_cpp.py @@ -5,13 +5,16 @@ from llama_cpp import Llama from pydantic import BaseModel -from fastserve.core import FastServe +from fastserve.core import FastServe, ParallelFastServe logger = logging.getLogger(__name__) # https://huggingface.co/TheBloke/OpenHermes-2-Mistral-7B-GGUF DEFAULT_MODEL = "openhermes-2-mistral-7b.Q6_K.gguf" +FASTSERVE_PARALLEL_HANDLER = int(os.environ.get("FASTSERVE_PARALLEL_HANDLER", "0")) +FastServeMode = ParallelFastServe if FASTSERVE_PARALLEL_HANDLER == 1 else FastServe + class PromptRequest(BaseModel): prompt: str = "Llamas are cute animal" @@ -27,7 +30,7 @@ class ResponseModel(BaseModel): finished: bool # Whether the whole request is finished. -class ServeLlamaCpp(FastServe): +class ServeLlamaCpp(FastServeMode): def __init__( self, model_path=DEFAULT_MODEL, @@ -54,7 +57,12 @@ def __init__( self.main_gpu = main_gpu self.args = args self.kwargs = kwargs - super().__init__(batch_size, timeout, input_schema=PromptRequest) + super().__init__( + batch_size, + timeout, + input_schema=PromptRequest, + response_schema=ResponseModel, + ) def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: result = self.llm(prompt=prompt, *args, **kwargs) From cb9e18cd8b3c7f00f75de55f3729c2e0ffecd259 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 26 Dec 2023 01:08:26 +0000 Subject: [PATCH 3/5] rename --- src/fastserve/__main__.py | 4 ++-- src/fastserve/core.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fastserve/__main__.py b/src/fastserve/__main__.py index 1ea4d09..259144f 100644 --- a/src/fastserve/__main__.py +++ b/src/fastserve/__main__.py @@ -1,9 +1,9 @@ -from fastserve.core import BaseFastServe +from fastserve.core import BaseServe from fastserve.handler import DummyHandler from fastserve.utils import BaseRequest handler = DummyHandler() -serve = BaseFastServe( +serve = BaseServe( handle=handler.handle, batch_size=1, timeout=0, input_schema=BaseRequest ) serve.run_server() diff --git a/src/fastserve/core.py b/src/fastserve/core.py index 772f17c..4a63fe5 100644 --- a/src/fastserve/core.py +++ b/src/fastserve/core.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -class BaseFastServe: +class BaseServe: def __init__( self, handle: Callable, batch_size, timeout, input_schema, response_schema ) -> None: @@ -68,7 +68,7 @@ def test_client(self): return TestClient(self._app) -class FastServe(BaseFastServe, BaseHandler): +class FastServe(BaseServe, BaseHandler): def __init__( self, batch_size=2, timeout=0.5, input_schema=None, response_schema=None ): @@ -83,7 +83,7 @@ def __init__( ) -class ParallelFastServe(BaseFastServe, ParallelHandler): +class ParallelFastServe(BaseServe, ParallelHandler): def __init__( self, batch_size=2, timeout=0.5, input_schema=None, response_schema=None ): From 4a71b894940850edf9747906923b7d2dedfcea4c Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 26 Dec 2023 01:12:48 +0000 Subject: [PATCH 4/5] fix tests --- tests/test_core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 34eba38..5636bd0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -20,7 +20,6 @@ def test_handle(): def test_run_server(): serve = FakeServe() - serve.register_router() test_client = serve.test_client data = BaseRequest(request=1).model_dump_json() response = test_client.post("/endpoint", data=data) @@ -30,7 +29,6 @@ def test_run_server(): def test_unprocessable_content(): serve = FakeServe() - serve.register_router() test_client = serve.test_client data = {} # wrong data format response = test_client.post("/endpoint", data=data) From 4446518d0a8b034b472decf095bda8b71f3e977e Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 26 Dec 2023 01:15:02 +0000 Subject: [PATCH 5/5] add middleware --- src/fastserve/core.py | 10 ++++++++-- src/fastserve/middleware.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/fastserve/core.py b/src/fastserve/core.py index 4a63fe5..35d8786 100644 --- a/src/fastserve/core.py +++ b/src/fastserve/core.py @@ -1,8 +1,9 @@ import logging from contextlib import asynccontextmanager -from typing import Callable +from typing import Callable, Optional from fastapi import FastAPI +from pydantic import BaseModel from .batching import BatchProcessor from .handler import BaseHandler, ParallelHandler @@ -19,7 +20,12 @@ class BaseServe: def __init__( - self, handle: Callable, batch_size, timeout, input_schema, response_schema + self, + handle: Callable, + batch_size: int, + timeout: float, + input_schema: Optional[BaseModel], + response_schema: Optional[BaseModel], ) -> None: self.input_schema = input_schema self.response_schema = response_schema diff --git a/src/fastserve/middleware.py b/src/fastserve/middleware.py index bf3ae91..52a6276 100644 --- a/src/fastserve/middleware.py +++ b/src/fastserve/middleware.py @@ -1,3 +1,6 @@ +import time + +from fastapi import Request from fastapi.middleware.cors import CORSMiddleware @@ -9,3 +12,11 @@ def register_default_middlewares(app): allow_methods=["*"], allow_headers=["*"], ) + + @app.middleware("http") + async def add_process_time_header(request: Request, call_next): + start_time = time.time() + response = await call_next(request) + process_time = time.time() - start_time + response.headers["X-Process-Time"] = str(process_time) + return response