Skip to content

Commit

Permalink
feat: Added thread pool limiter
Browse files Browse the repository at this point in the history
- Configurable via `chroma_server_thread_pool_size` setting
  • Loading branch information
tazarov committed Feb 22, 2024
1 parent f62a5c8 commit ae4733c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
12 changes: 9 additions & 3 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
DEFAULT_TENANT = "default_tenant"
DEFAULT_DATABASE = "default_database"


class Settings(BaseSettings): # type: ignore
environment: str = ""

Expand Down Expand Up @@ -139,6 +140,8 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
return v

chroma_server_nofile: Optional[int] = None
# the number of maximum threads to handle synchronous tasks in the FastAPI server
chroma_server_thread_pool_size: Optional[int] = 40

pulsar_broker_url: Optional[str] = None
pulsar_admin_port: Optional[int] = 8080
Expand Down Expand Up @@ -320,13 +323,16 @@ def __init__(self, settings: Settings):
if settings[key] is not None:
raise ValueError(LEGACY_ERROR)

if settings["chroma_segment_cache_policy"] is not None and settings["chroma_segment_cache_policy"] != "LRU":
if (
settings["chroma_segment_cache_policy"] is not None
and settings["chroma_segment_cache_policy"] != "LRU"
):
logger.error(
f"Failed to set chroma_segment_cache_policy: Only LRU is available."
"Failed to set chroma_segment_cache_policy: Only LRU is available."
)
if settings["chroma_memory_limit_bytes"] == 0:
logger.error(
f"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require."
"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require."
)

# Apply the nofile limit if set
Expand Down
20 changes: 18 additions & 2 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# https://anyio.readthedocs.io/en/stable/threads.html#adjusting-the-default-maximum-worker-thread-count
from anyio import (
to_thread,
CapacityLimiter,
) # this is used to transform sync code to async. By default, AnyIO uses 40 threads pool
from fastapi import FastAPI as _FastAPI, Response, Request
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -141,6 +142,8 @@ def __init__(self, settings: Settings):
self._system = System(settings)
self._api: ServerAPI = self._system.instance(ServerAPI)
self._opentelemetry_client = self._api.require(OpenTelemetryClient)
self._thread_pool_size = settings.chroma_server_thread_pool_size
self._capacity_limiter = CapacityLimiter(self._thread_pool_size)
self._system.start()

self._app.middleware("http")(check_http_version_middleware)
Expand Down Expand Up @@ -487,7 +490,11 @@ async def delete_collection(
database: str = DEFAULT_DATABASE,
) -> None:
return await to_thread.run_sync( # type: ignore
self._api.delete_collection, collection_name, tenant, database
self._api.delete_collection,
collection_name,
tenant,
database,
limiter=self._capacity_limiter,
)

@trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
Expand All @@ -511,6 +518,7 @@ async def add(self, request: Request, collection_id: str) -> None:
add.metadatas,
add.documents,
add.uris,
limiter=self._capacity_limiter,
)
except InvalidDimensionException as e:
raise HTTPException(status_code=500, detail=str(e))
Expand All @@ -536,6 +544,7 @@ async def update(self, request: Request, collection_id: str) -> None:
add.metadatas,
add.documents,
add.uris,
limiter=self._capacity_limiter,
)

@trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION)
Expand All @@ -559,6 +568,7 @@ async def upsert(self, request: Request, collection_id: str) -> None:
upsert.metadatas,
upsert.documents,
upsert.uris,
limiter=self._capacity_limiter,
)

@trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION)
Expand All @@ -585,6 +595,7 @@ async def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
None,
get.where_document,
get.include,
limiter=self._capacity_limiter,
),
)

Expand All @@ -606,6 +617,7 @@ async def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID
delete.ids,
delete.where,
delete.where_document,
limiter=self._capacity_limiter,
),
)

Expand All @@ -620,7 +632,10 @@ async def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID
)
async def count(self, collection_id: str) -> int:
return cast(
int, await to_thread.run_sync(self._api._count, _uuid(collection_id))
int,
await to_thread.run_sync(
self._api._count, _uuid(collection_id), limiter=self._capacity_limiter
),
)

@trace_method("FastAPI.reset", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -656,6 +671,7 @@ async def get_nearest_neighbors(
query.where,
query.where_document,
query.include,
limiter=self._capacity_limiter,
),
)
return nnresult
Expand Down

0 comments on commit ae4733c

Please sign in to comment.