Skip to content

Commit

Permalink
fix: bug: memory leak when using bentoml>=1.2
Browse files Browse the repository at this point in the history
Fixes bentoml#4760

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming committed Jun 4, 2024
1 parent 58c0882 commit 15dc29b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/_bentoml_sdk/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def encode(self, obj: Path) -> bytes:
return obj.read_bytes()

def decode(self, obj: bytes | t.BinaryIO | UploadFile | PurePath | str) -> t.Any:
from bentoml._internal.context import request_directory
from bentoml._internal.context import request_temp_dir

media_type: str | None = None

Expand Down Expand Up @@ -156,7 +156,7 @@ def decode(self, obj: bytes | t.BinaryIO | UploadFile | PurePath | str) -> t.Any
f"Invalid content type {media_type}, expected {self.content_type}"
)
with tempfile.NamedTemporaryFile(
suffix=filename, dir=request_directory.get(), delete=False
suffix=filename, dir=request_temp_dir(), delete=False
) as f:
f.write(body)
return Path(f.name)
Expand Down
49 changes: 27 additions & 22 deletions src/bentoml/_internal/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import contextvars
import os
import shutil
import tempfile
import typing as t
from abc import ABC
Expand All @@ -18,10 +19,19 @@
import starlette.requests
import starlette.responses

# A request-unique directory for storing temporary files
request_directory: contextvars.ContextVar[str] = contextvars.ContextVar(
"request_directory"
_request_var: contextvars.ContextVar[starlette.requests.Request] = (
contextvars.ContextVar("request")
)
_response_var: contextvars.ContextVar[ServiceContext.ResponseContext] = (
contextvars.ContextVar("response")
)


def request_temp_dir() -> str:
request = _request_var.get()
if "temp_dir" not in request.state:
request.state["temp_dir"] = tempfile.mkdtemp(prefix="bentoml-request-")
return t.cast(str, request.state["temp_dir"])


class Metadata(t.Mapping[str, str], ABC):
Expand Down Expand Up @@ -81,12 +91,6 @@ def mutablecopy(self) -> Metadata:

class ServiceContext:
def __init__(self) -> None:
self._request_var: contextvars.ContextVar[starlette.requests.Request] = (
contextvars.ContextVar("request")
)
self._response_var: contextvars.ContextVar[ServiceContext.ResponseContext] = (
contextvars.ContextVar("response")
)
# A dictionary for storing global state shared by the process
self.state: dict[str, t.Any] = {}

Expand All @@ -95,28 +99,29 @@ def in_request(
self, request: starlette.requests.Request
) -> t.Generator[ServiceContext, None, None]:
request.metadata = request.headers # type: ignore[attr-defined]
request_token = self._request_var.set(request)
response_token = self._response_var.set(ServiceContext.ResponseContext())
with tempfile.TemporaryDirectory(prefix="bentoml-request-") as temp_dir:
dir_token = request_directory.set(temp_dir)
try:
yield self
finally:
self._request_var.reset(request_token)
self._response_var.reset(response_token)
request_directory.reset(dir_token)
request_token = _request_var.set(request)
response_token = _response_var.set(ServiceContext.ResponseContext())
try:
yield self
finally:
if "temp_dir" in request.state:
shutil.rmtree(
t.cast(str, request.state["temp_dir"]), ignore_errors=True
)
_request_var.reset(request_token)
_response_var.reset(response_token)

@property
def request(self) -> starlette.requests.Request:
return self._request_var.get()
return _request_var.get()

@property
def response(self) -> ResponseContext:
return self._response_var.get()
return _response_var.get()

@property
def temp_dir(self) -> str:
return request_directory.get()
return request_temp_dir()

@attr.define
class ResponseContext:
Expand Down

0 comments on commit 15dc29b

Please sign in to comment.