Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support adding response background tasks via bentoml.Context #4754

Merged
merged 4 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/_bentoml_impl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def _add_response_headers(
) -> None:
from bentoml._internal.context import trace_context

resp.headers.update({"Server": f"BentoML Service/{self.service.name}"})
if trace_context.request_id is not None:
resp.headers["X-BentoML-Request-ID"] = format(
trace_context.request_id, logging_format["span_id"]
Expand Down Expand Up @@ -449,11 +450,15 @@ async def api_endpoint_wrapper(self, name: str, request: Request) -> Response:
status_code=500,
)
self._add_response_headers(resp)
ctx = self.service.context
if resp.background is not None:
ctx.response.background.tasks.append(resp.background)
# clean the request resources after the response is consumed.
ctx.response.background.add_task(request.close)
resp.background = ctx.response.background
return resp

async def api_endpoint(self, name: str, request: Request) -> Response:
from starlette.background import BackgroundTask

from _bentoml_sdk.io_models import ARGS
from _bentoml_sdk.io_models import KWARGS
from bentoml._internal.utils import get_original_func
Expand Down Expand Up @@ -514,12 +519,9 @@ async def inner() -> t.AsyncGenerator[t.Any, None]:
response = output
else:
response = await method.output_spec.to_http_response(output, serde)
response.headers.update({"Server": f"BentoML Service/{self.service.name}"})

if method.ctx_param is not None:
response.status_code = ctx.response.status_code
response.headers.update(ctx.response.metadata)
set_cookies(response, ctx.response.cookies)
# clean the request resources after the response is consumed.
response.background = BackgroundTask(request.close)
return response
23 changes: 13 additions & 10 deletions src/bentoml/_internal/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import attr
import starlette.datastructures
from starlette.background import BackgroundTasks

from .utils.http import Cookie

Expand Down Expand Up @@ -120,16 +121,18 @@ def temp_dir(self) -> str:

@attr.define
class ResponseContext:
metadata: Metadata
cookies: list[Cookie]
headers: Metadata
status_code: int

def __init__(self):
self.metadata = starlette.datastructures.MutableHeaders() # type: ignore (coercing Starlette headers to Metadata)
self.headers = self.metadata # type: ignore (coercing Starlette headers to Metadata)
self.cookies = []
self.status_code = 200
metadata: Metadata = attr.field(factory=starlette.datastructures.MutableHeaders)
cookies: list[Cookie] = attr.field(factory=list)
status_code: int = 200
background: BackgroundTasks = attr.field(factory=BackgroundTasks)

@property
def headers(self) -> Metadata:
return self.metadata

@headers.setter
def headers(self, headers: Metadata) -> None:
self.metadata = headers

def set_cookie(
self,
Expand Down
33 changes: 27 additions & 6 deletions src/bentoml/_internal/server/http/traffic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
from typing import TYPE_CHECKING
from typing import Any

from starlette.responses import JSONResponse

Expand All @@ -14,19 +15,39 @@ def __init__(self, app: ext.ASGIApp, timeout: float) -> None:
self.app = app
self.timeout = timeout

def _set_timer_out(self, waiter: asyncio.Future[Any]) -> None:
if not waiter.done():
waiter.set_exception(asyncio.TimeoutError)

async def __call__(
self, scope: ext.ASGIScope, receive: ext.ASGIReceive, send: ext.ASGISend
) -> None:
if scope["type"] not in ("http", "websocket"):
return await self.app(scope, receive, send)
loop = asyncio.get_running_loop()
waiter = loop.create_future()
loop.call_later(self.timeout, self._set_timer_out, waiter)

async def _send(message: ext.ASGIMessage) -> None:
if not waiter.done():
waiter.set_result(None)
await send(message)

fut = asyncio.ensure_future(self.app(scope, receive, _send), loop=loop)

try:
await asyncio.wait_for(self.app(scope, receive, send), timeout=self.timeout)
await waiter
except asyncio.TimeoutError:
resp = JSONResponse(
{"error": f"Not able to process the request in {self.timeout} seconds"},
status_code=504,
)
await resp(scope, receive, send)
if fut.cancel():
resp = JSONResponse(
{
"error": f"Not able to process the request in {self.timeout} seconds"
},
status_code=504,
)
await resp(scope, receive, send)
else:
await fut # wait for the future to finish


class MaxConcurrencyMiddleware:
Expand Down
Loading