diff --git a/src/_bentoml_sdk/validators.py b/src/_bentoml_sdk/validators.py index a17fb5211c4..dcf2172c593 100644 --- a/src/_bentoml_sdk/validators.py +++ b/src/_bentoml_sdk/validators.py @@ -124,7 +124,7 @@ def encode(self, obj: Path) -> bytes: return obj.read_bytes() def decode(self, obj: bytes | t.BinaryIO | UploadFile | PurePath | str) -> t.Any: - from bentoml._internal.context import request_directory + from bentoml._internal.context import request_temp_dir media_type: str | None = None @@ -156,7 +156,7 @@ def decode(self, obj: bytes | t.BinaryIO | UploadFile | PurePath | str) -> t.Any f"Invalid content type {media_type}, expected {self.content_type}" ) with tempfile.NamedTemporaryFile( - suffix=filename, dir=request_directory.get(), delete=False + suffix=filename, dir=request_temp_dir(), delete=False ) as f: f.write(body) return Path(f.name) diff --git a/src/bentoml/_internal/context.py b/src/bentoml/_internal/context.py index 2347a51e164..7bbd147ba9f 100644 --- a/src/bentoml/_internal/context.py +++ b/src/bentoml/_internal/context.py @@ -3,7 +3,6 @@ import contextlib import contextvars import os -import tempfile import typing as t from abc import ABC from abc import abstractmethod @@ -13,15 +12,28 @@ import starlette.datastructures from .utils.http import Cookie +from .utils.temp import TempfilePool if TYPE_CHECKING: import starlette.requests import starlette.responses -# A request-unique directory for storing temporary files -request_directory: contextvars.ContextVar[str] = contextvars.ContextVar( - "request_directory" +_request_var: contextvars.ContextVar[starlette.requests.Request] = ( + contextvars.ContextVar("request") ) +_response_var: contextvars.ContextVar[ServiceContext.ResponseContext] = ( + contextvars.ContextVar("response") +) + +request_tempdir_pool = TempfilePool(prefix="bentoml-request-") + + +def request_temp_dir() -> str: + """A request-unique directory for storing temporary files""" + request = _request_var.get() + if not hasattr(request.state, "temp_dir"): + request.state.temp_dir = request_tempdir_pool.acquire() + return request.state.temp_dir class Metadata(t.Mapping[str, str], ABC): @@ -81,12 +93,6 @@ def mutablecopy(self) -> Metadata: class ServiceContext: def __init__(self) -> None: - self._request_var: contextvars.ContextVar[starlette.requests.Request] = ( - contextvars.ContextVar("request") - ) - self._response_var: contextvars.ContextVar[ServiceContext.ResponseContext] = ( - contextvars.ContextVar("response") - ) # A dictionary for storing global state shared by the process self.state: dict[str, t.Any] = {} @@ -95,28 +101,27 @@ def in_request( self, request: starlette.requests.Request ) -> t.Generator[ServiceContext, None, None]: request.metadata = request.headers # type: ignore[attr-defined] - request_token = self._request_var.set(request) - response_token = self._response_var.set(ServiceContext.ResponseContext()) - with tempfile.TemporaryDirectory(prefix="bentoml-request-") as temp_dir: - dir_token = request_directory.set(temp_dir) - try: - yield self - finally: - self._request_var.reset(request_token) - self._response_var.reset(response_token) - request_directory.reset(dir_token) + request_token = _request_var.set(request) + response_token = _response_var.set(ServiceContext.ResponseContext()) + try: + yield self + finally: + if hasattr(request.state, "temp_dir"): + request_tempdir_pool.release(request.state.temp_dir) + _request_var.reset(request_token) + _response_var.reset(response_token) @property def request(self) -> starlette.requests.Request: - return self._request_var.get() + return _request_var.get() @property def response(self) -> ResponseContext: - return self._response_var.get() + return _response_var.get() @property def temp_dir(self) -> str: - return request_directory.get() + return request_temp_dir() @attr.define class ResponseContext: diff --git a/src/bentoml/_internal/server/base_app.py b/src/bentoml/_internal/server/base_app.py index ff254d0da5c..5d0ce5480d0 100644 --- a/src/bentoml/_internal/server/base_app.py +++ b/src/bentoml/_internal/server/base_app.py @@ -43,7 +43,9 @@ def on_startup(self) -> list[LifecycleHook]: @property def on_shutdown(self) -> list[LifecycleHook]: - return [] + from ..context import request_tempdir_pool + + return [request_tempdir_pool.cleanup] def mark_as_ready(self) -> None: self._is_ready = True diff --git a/src/bentoml/_internal/utils/temp.py b/src/bentoml/_internal/utils/temp.py new file mode 100644 index 00000000000..17049cb81f3 --- /dev/null +++ b/src/bentoml/_internal/utils/temp.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import shutil +import tempfile +from collections import deque +from functools import partial +from pathlib import Path +from threading import Lock + + +class TempfilePool: + """A simple pool to get temp directories, + so they are reused as much as possible. + """ + + def __init__( + self, + suffix: str | None = None, + prefix: str | None = None, + dir: str | None = None, + ) -> None: + self._pool: deque[str] = deque([]) + self._lock = Lock() + self._new = partial(tempfile.mkdtemp, suffix=suffix, prefix=prefix, dir=dir) + + def cleanup(self) -> None: + while len(self._pool): + dir = self._pool.popleft() + shutil.rmtree(dir, ignore_errors=True) + + def acquire(self) -> str: + with self._lock: + if not len(self._pool): + return self._new() + else: + return self._pool.popleft() + + def release(self, dir: str) -> None: + for child in Path(dir).iterdir(): + if child.is_dir(): + shutil.rmtree(child) + else: + child.unlink() + with self._lock: + self._pool.append(dir)