Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/refactor api #8

Merged
merged 5 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/fastserve/__main__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fastserve.core import BaseFastServe
from fastserve.core import BaseServe
from fastserve.handler import DummyHandler
from fastserve.utils import BaseRequest

handler = DummyHandler()
serve = BaseFastServe(
serve = BaseServe(
handle=handler.handle, batch_size=1, timeout=0, input_schema=BaseRequest
)
serve.run_server()
56 changes: 40 additions & 16 deletions src/fastserve/core.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
import logging
from contextlib import asynccontextmanager
from typing import Callable
from typing import Callable, Optional

from fastapi import FastAPI
from pydantic import BaseModel

from .batching import BatchProcessor
from .handler import BaseHandler, ParallelHandler
from .middleware import register_default_middlewares
from .utils import BaseRequest

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)


class BaseFastServe:
def __init__(self, handle: Callable, batch_size, timeout, input_schema) -> None:
class BaseServe:
def __init__(
self,
handle: Callable,
batch_size: int,
timeout: float,
input_schema: Optional[BaseModel],
response_schema: Optional[BaseModel],
) -> None:
self.input_schema = input_schema
self.response_schema = response_schema
self.handle: Callable = handle
self.batch_processing = BatchProcessor(
func=self.handle, bs=batch_size, timeout=timeout
Expand All @@ -28,27 +39,33 @@ async def lifespan(app: FastAPI):
yield
self.batch_processing.cancel()

self._app = FastAPI(lifespan=lifespan, title="FastServe")
self._app = FastAPI(lifespan=lifespan, title="FastServe", docs_url="/")
register_default_middlewares(self._app)
INPUT_SCHEMA = input_schema

def _serve(
self,
):
INPUT_SCHEMA = self.input_schema

@self._app.post(path="/endpoint")
def api(request: INPUT_SCHEMA):
print("incoming request")
wait_obj = self.batch_processing.process(request)
result = wait_obj.get()
return result

self._app.add_api_route(
path="/endpoint",
endpoint=api,
methods=["post"],
response_model=response_schema,
)

@property
def app(self):
return self._app

def run_server(
self,
):
self._serve()
import uvicorn

uvicorn.run(self._app)
uvicorn.run(self.app)

@property
def test_client(self):
Expand All @@ -57,25 +74,32 @@ def test_client(self):
return TestClient(self._app)


class FastServe(BaseFastServe, BaseHandler):
def __init__(self, batch_size=2, timeout=0.5, input_schema=None):
class FastServe(BaseServe, BaseHandler):
def __init__(
self, batch_size=2, timeout=0.5, input_schema=None, response_schema=None
):
if input_schema is None:
input_schema = BaseRequest
super().__init__(
handle=self.handle,
batch_size=batch_size,
timeout=timeout,
input_schema=input_schema,
response_schema=response_schema,
)


class ParallelFastServe(BaseFastServe, ParallelHandler):
def __init__(self, batch_size=2, timeout=0.5, input_schema=None):
class ParallelFastServe(BaseServe, ParallelHandler):
def __init__(
self, batch_size=2, timeout=0.5, input_schema=None, response_schema=None
):
if input_schema is None:
input_schema = BaseRequest
super().__init__(
handle=self.handle,
batch_size=batch_size,
timeout=timeout,
input_schema=input_schema,
response_schema=response_schema,
)
logger.info("Launching parallel handler!")
22 changes: 22 additions & 0 deletions src/fastserve/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import time

from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware


def register_default_middlewares(app):
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
14 changes: 11 additions & 3 deletions src/fastserve/models/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from llama_cpp import Llama
from pydantic import BaseModel

from fastserve.core import ParallelFastServe
from fastserve.core import FastServe, ParallelFastServe

logger = logging.getLogger(__name__)

# https://huggingface.co/TheBloke/OpenHermes-2-Mistral-7B-GGUF
DEFAULT_MODEL = "openhermes-2-mistral-7b.Q6_K.gguf"

FASTSERVE_PARALLEL_HANDLER = int(os.environ.get("FASTSERVE_PARALLEL_HANDLER", "0"))
FastServeMode = ParallelFastServe if FASTSERVE_PARALLEL_HANDLER == 1 else FastServe


class PromptRequest(BaseModel):
prompt: str = "Llamas are cute animal"
Expand All @@ -27,7 +30,7 @@ class ResponseModel(BaseModel):
finished: bool # Whether the whole request is finished.


class ServeLlamaCpp(ParallelFastServe):
class ServeLlamaCpp(FastServeMode):
def __init__(
self,
model_path=DEFAULT_MODEL,
Expand All @@ -54,7 +57,12 @@ def __init__(
self.main_gpu = main_gpu
self.args = args
self.kwargs = kwargs
super().__init__(batch_size, timeout, input_schema=PromptRequest)
super().__init__(
batch_size,
timeout,
input_schema=PromptRequest,
response_schema=ResponseModel,
)

def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
result = self.llm(prompt=prompt, *args, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_handle():

def test_run_server():
serve = FakeServe()
serve._serve()
test_client = serve.test_client
data = BaseRequest(request=1).model_dump_json()
response = test_client.post("/endpoint", data=data)
Expand All @@ -30,7 +29,6 @@ def test_run_server():

def test_unprocessable_content():
serve = FakeServe()
serve._serve()
test_client = serve.test_client
data = {} # wrong data format
response = test_client.post("/endpoint", data=data)
Expand Down
Loading