diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index ce19f0a379c..1440a843fc9 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,3 +1,5 @@ +from functools import cached_property + from chromadb.api import ServerAPI from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.db.system import SysDB @@ -789,7 +791,7 @@ def reset(self) -> bool: def get_settings(self) -> Settings: return self._settings - @property + @cached_property @override def max_batch_size(self) -> int: return self._producer.max_batch_size diff --git a/chromadb/auth/fastapi.py b/chromadb/auth/fastapi.py index 85c5b803135..494d7361dcf 100644 --- a/chromadb/auth/fastapi.py +++ b/chromadb/auth/fastapi.py @@ -1,3 +1,6 @@ +import asyncio + +import chromadb from contextvars import ContextVar from functools import wraps import logging @@ -173,7 +176,7 @@ def authz_context( ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: def decorator(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) - def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any: + async def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any: _dynamic_kwargs = { "api": args[0]._api, "function": f, @@ -213,6 +216,7 @@ def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any: ) if _provider: + # TODO this will block the event loop if it takes too long - refactor for async a_authz_responses.append(_provider.authorize(_context)) if not any(a_authz_responses): raise AuthorizationError("Unauthorized") @@ -239,6 +243,8 @@ def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any: ): kwargs["database"].name = desired_database + if asyncio.iscoroutinefunction(f): + return await f(*args, **kwargs) return f(*args, **kwargs) return wrapped diff --git a/chromadb/config.py b/chromadb/config.py index 7fd8e6d8981..611a7c5087b 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -144,6 +144,8 @@ def empty_str_to_none(cls, v: str) -> Optional[str]: return v chroma_server_nofile: Optional[int] = None + # the number of maximum threads to handle synchronous tasks in the FastAPI server + chroma_server_thread_pool_size: int = 40 chroma_server_auth_provider: Optional[str] = None diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index c7cdb306324..8549bc36207 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -34,6 +34,7 @@ def __init__(self, conn_pool: Pool, stack: local): @override def __enter__(self) -> base.Cursor: if len(self._tx_stack.stack) == 0: + self._conn.execute("PRAGMA case_sensitive_like = ON") self._conn.execute("BEGIN;") self._tx_stack.stack.append(self) return self._conn.cursor() # type: ignore diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 292f5038dea..0094e0061b6 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -1,7 +1,13 @@ -from typing import Any, Callable, Dict, List, Sequence, Optional +from typing import Any, Callable, Dict, List, Sequence, Optional, cast import fastapi -from fastapi import FastAPI as _FastAPI, Response -from fastapi.responses import JSONResponse +import orjson + +from anyio import ( + to_thread, + CapacityLimiter, +) +from fastapi import FastAPI as _FastAPI, Response, Request +from fastapi.responses import JSONResponse, ORJSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.routing import APIRoute @@ -48,7 +54,6 @@ UpdateCollection, UpdateEmbedding, ) -from starlette.requests import Request import logging @@ -128,10 +133,14 @@ class FastAPI(chromadb.server.Server): def __init__(self, settings: Settings): super().__init__(settings) ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI - self._app = fastapi.FastAPI(debug=True) + # https://fastapi.tiangolo.com/advanced/custom-response/#use-orjsonresponse + self._app = fastapi.FastAPI(debug=True, default_response_class=ORJSONResponse) self._system = System(settings) self._api: ServerAPI = self._system.instance(ServerAPI) self._opentelemetry_client = self._api.require(OpenTelemetryClient) + self._capacity_limiter = CapacityLimiter( + settings.chroma_server_thread_pool_size + ) self._system.start() self._app.middleware("http")(check_http_version_middleware) @@ -143,7 +152,9 @@ def __init__(self, settings: Settings): allow_methods=["*"], ) self._app.add_exception_handler(QuotaError, self.quota_exception_handler) - self._app.add_exception_handler(RateLimitError, self.rate_limit_exception_handler) + self._app.add_exception_handler( + RateLimitError, self.rate_limit_exception_handler + ) self._app.on_event("shutdown")(self.shutdown) @@ -295,23 +306,26 @@ def app(self) -> fastapi.FastAPI: async def rate_limit_exception_handler(self, request: Request, exc: RateLimitError): return JSONResponse( status_code=429, - content={"message": f"rate limit. resource: {exc.resource} quota: {exc.quota}"}, + content={ + "message": f"rate limit. resource: {exc.resource} quota: {exc.quota}" + }, ) - def root(self) -> Dict[str, int]: return {"nanosecond heartbeat": self._api.heartbeat()} async def quota_exception_handler(self, request: Request, exc: QuotaError): return JSONResponse( status_code=429, - content={"message": f"quota error. resource: {exc.resource} quota: {exc.quota} actual: {exc.actual}"}, + content={ + "message": f"quota error. resource: {exc.resource} quota: {exc.quota} actual: {exc.actual}" + }, ) - def heartbeat(self) -> Dict[str, int]: + async def heartbeat(self) -> Dict[str, int]: return self.root() - def version(self) -> str: + async def version(self) -> str: return self._api.get_version() @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION) @@ -324,10 +338,18 @@ def version(self) -> str: ), ), ) - def create_database( - self, database: CreateDatabase, tenant: str = DEFAULT_TENANT + async def create_database( + self, request: Request, tenant: str = DEFAULT_TENANT ) -> None: - return self._api.create_database(database.name, tenant) + def process_create_database(raw_body: bytes) -> None: + create = CreateDatabase.model_validate(orjson.loads(raw_body)) + return self._api.create_database(create.name, tenant) + + await to_thread.run_sync( + process_create_database, + await request.body(), + limiter=self._capacity_limiter, + ) @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION) @authz_context( @@ -340,8 +362,18 @@ def create_database( ), ), ) - def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database: - return self._api.get_database(database, tenant) + async def get_database( + self, database: str, tenant: str = DEFAULT_TENANT + ) -> Database: + return cast( + Database, + await to_thread.run_sync( + self._api.get_database, + database, + tenant, + limiter=self._capacity_limiter, + ), + ) @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) @authz_context( @@ -350,8 +382,16 @@ def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database: type=AuthzResourceTypes.TENANT, ), ) - def create_tenant(self, tenant: CreateTenant) -> None: - return self._api.create_tenant(tenant.name) + async def create_tenant(self, request: Request) -> None: + def process_create_tenant(raw_body: bytes) -> None: + create = CreateTenant.model_validate(orjson.loads(raw_body)) + return self._api.create_tenant(create.name) + + await to_thread.run_sync( + process_create_tenant, + await request.body(), + limiter=self._capacity_limiter, + ) @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION) @authz_context( @@ -361,8 +401,15 @@ def create_tenant(self, tenant: CreateTenant) -> None: type=AuthzResourceTypes.TENANT, ), ) - def get_tenant(self, tenant: str) -> Tenant: - return self._api.get_tenant(tenant) + async def get_tenant(self, tenant: str) -> Tenant: + return cast( + Tenant, + await to_thread.run_sync( + self._api.get_tenant, + tenant, + limiter=self._capacity_limiter, + ), + ) @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) @authz_context( @@ -375,15 +422,23 @@ def get_tenant(self, tenant: str) -> Tenant: ), ), ) - def list_collections( + async def list_collections( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: - return self._api.list_collections( - limit=limit, offset=offset, tenant=tenant, database=database + return cast( + Sequence[Collection], + await to_thread.run_sync( + self._api.list_collections, + limit, + offset, + tenant, + database, + limiter=self._capacity_limiter, + ), ) @trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION) @@ -397,12 +452,20 @@ def list_collections( ), ), ) - def count_collections( + async def count_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> int: - return self._api.count_collections(tenant=tenant, database=database) + return cast( + int, + await to_thread.run_sync( + self._api.count_collections, + tenant, + database, + limiter=self._capacity_limiter, + ), + ) @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) @authz_context( @@ -415,18 +478,29 @@ def count_collections( ), ), ) - def create_collection( + async def create_collection( self, - collection: CreateCollection, + request: Request, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - return self._api.create_collection( - name=collection.name, - metadata=collection.metadata, - get_or_create=collection.get_or_create, - tenant=tenant, - database=database, + def process_create_collection(raw_body: bytes) -> Collection: + create = CreateCollection.model_validate(orjson.loads(raw_body)) + return self._api.create_collection( + name=create.name, + metadata=create.metadata, + get_or_create=create.get_or_create, + tenant=tenant, + database=database, + ) + + return cast( + Collection, + await to_thread.run_sync( + process_create_collection, + await request.body(), + limiter=self._capacity_limiter, + ), ) @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) @@ -440,14 +514,24 @@ def create_collection( ), ), ) - def get_collection( + async def get_collection( self, collection_name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - return self._api.get_collection( - collection_name, tenant=tenant, database=database + return cast( + Collection, + await to_thread.run_sync( + self._api.get_collection, + collection_name, + None, + None, + None, + tenant, + database, + limiter=self._capacity_limiter, + ), ) @trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION) @@ -459,13 +543,19 @@ def get_collection( attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) - def update_collection( - self, collection_id: str, collection: UpdateCollection - ) -> None: - return self._api._modify( - id=_uuid(collection_id), - new_name=collection.new_name, - new_metadata=collection.new_metadata, + async def update_collection(self, collection_id: str, request: Request) -> None: + def process_update_collection(raw_body: bytes) -> None: + update = UpdateCollection.model_validate(orjson.loads(raw_body)) + return self._api._modify( + id=_uuid(collection_id), + new_name=update.new_name, + new_metadata=update.new_metadata, + ) + + await to_thread.run_sync( + process_update_collection, + await request.body(), + limiter=self._capacity_limiter, ) @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @@ -479,14 +569,18 @@ def update_collection( ), ), ) - def delete_collection( + async def delete_collection( self, collection_name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: - return self._api.delete_collection( - collection_name, tenant=tenant, database=database + await to_thread.run_sync( + self._api.delete_collection, + collection_name, + tenant, + database, + limiter=self._capacity_limiter, ) @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) @@ -498,19 +592,30 @@ def delete_collection( attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) - def add(self, collection_id: str, add: AddEmbedding) -> None: + async def add(self, request: Request, collection_id: str) -> bool: try: - result = self._api._add( - collection_id=_uuid(collection_id), - embeddings=add.embeddings, # type: ignore - metadatas=add.metadatas, # type: ignore - documents=add.documents, # type: ignore - uris=add.uris, # type: ignore - ids=add.ids, + + def process_add(raw_body: bytes) -> bool: + add = AddEmbedding.model_validate(orjson.loads(raw_body)) + return self._api._add( + collection_id=_uuid(collection_id), + ids=add.ids, + embeddings=add.embeddings, # type: ignore + metadatas=add.metadatas, # type: ignore + documents=add.documents, # type: ignore + uris=add.uris, # type: ignore + ) + + return cast( + bool, + await to_thread.run_sync( + process_add, + await request.body(), + limiter=self._capacity_limiter, + ), ) except InvalidDimensionException as e: raise HTTPException(status_code=500, detail=str(e)) - return result # type: ignore @trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION) @authz_context( @@ -521,14 +626,22 @@ def add(self, collection_id: str, add: AddEmbedding) -> None: attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) - def update(self, collection_id: str, add: UpdateEmbedding) -> None: - self._api._update( - ids=add.ids, - collection_id=_uuid(collection_id), - embeddings=add.embeddings, - documents=add.documents, # type: ignore - uris=add.uris, # type: ignore - metadatas=add.metadatas, # type: ignore + async def update(self, request: Request, collection_id: str) -> None: + def process_update(raw_body: bytes) -> bool: + update = UpdateEmbedding.model_validate(orjson.loads(raw_body)) + return self._api._update( + collection_id=_uuid(collection_id), + ids=update.ids, + embeddings=update.embeddings, + metadatas=update.metadatas, # type: ignore + documents=update.documents, # type: ignore + uris=update.uris, # type: ignore + ) + + await to_thread.run_sync( + process_update, + await request.body(), + limiter=self._capacity_limiter, ) @trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION) @@ -540,14 +653,22 @@ def update(self, collection_id: str, add: UpdateEmbedding) -> None: attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) - def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: - self._api._upsert( - collection_id=_uuid(collection_id), - ids=upsert.ids, - embeddings=upsert.embeddings, # type: ignore - documents=upsert.documents, # type: ignore - uris=upsert.uris, # type: ignore - metadatas=upsert.metadatas, # type: ignore + async def upsert(self, request: Request, collection_id: str) -> None: + def process_upsert(raw_body: bytes) -> bool: + upsert = AddEmbedding.model_validate(orjson.loads(raw_body)) + return self._api._upsert( + collection_id=_uuid(collection_id), + ids=upsert.ids, + embeddings=upsert.embeddings, # type: ignore + metadatas=upsert.metadatas, # type: ignore + documents=upsert.documents, # type: ignore + uris=upsert.uris, # type: ignore + ) + + await to_thread.run_sync( + process_upsert, + await request.body(), + limiter=self._capacity_limiter, ) @trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION) @@ -559,16 +680,27 @@ def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) - def get(self, collection_id: str, get: GetEmbedding) -> GetResult: - return self._api._get( - collection_id=_uuid(collection_id), - ids=get.ids, - where=get.where, - where_document=get.where_document, - sort=get.sort, - limit=get.limit, - offset=get.offset, - include=get.include, + async def get(self, collection_id: str, request: Request) -> GetResult: + def process_get(raw_body: bytes) -> GetResult: + get = GetEmbedding.model_validate(orjson.loads(raw_body)) + return self._api._get( + collection_id=_uuid(collection_id), + ids=get.ids, + where=get.where, + sort=get.sort, + limit=get.limit, + offset=get.offset, + where_document=get.where_document, + include=get.include, + ) + + return cast( + GetResult, + await to_thread.run_sync( + process_get, + await request.body(), + limiter=self._capacity_limiter, + ), ) @trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION) @@ -580,12 +712,23 @@ def get(self, collection_id: str, get: GetEmbedding) -> GetResult: attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) - def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: - return self._api._delete( - where=delete.where, # type: ignore - ids=delete.ids, - collection_id=_uuid(collection_id), - where_document=delete.where_document, + async def delete(self, collection_id: str, request: Request) -> List[UUID]: + def process_delete(raw_body: bytes) -> List[str]: + delete = DeleteEmbedding.model_validate(orjson.loads(raw_body)) + return self._api._delete( + collection_id=_uuid(collection_id), + ids=delete.ids, + where=delete.where, + where_document=delete.where_document, + ) + + return cast( + List[UUID], + await to_thread.run_sync( + process_delete, + await request.body(), + limiter=self._capacity_limiter, + ), ) @trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION) @@ -597,8 +740,15 @@ def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) - def count(self, collection_id: str) -> int: - return self._api._count(_uuid(collection_id)) + async def count(self, collection_id: str) -> int: + return cast( + int, + await to_thread.run_sync( + self._api._count, + _uuid(collection_id), + limiter=self._capacity_limiter, + ), + ) @trace_method("FastAPI.reset", OpenTelemetryGranularity.OPERATION) @authz_context( @@ -608,8 +758,14 @@ def count(self, collection_id: str) -> int: type=AuthzResourceTypes.DB, ), ) - def reset(self) -> bool: - return self._api.reset() + async def reset(self) -> bool: + return cast( + bool, + await to_thread.run_sync( + self._api.reset, + limiter=self._capacity_limiter, + ), + ) @trace_method("FastAPI.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION) @authz_context( @@ -620,20 +776,40 @@ def reset(self) -> bool: attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) - def get_nearest_neighbors( - self, collection_id: str, query: QueryEmbedding + async def get_nearest_neighbors( + self, collection_id: str, request: Request ) -> QueryResult: - nnresult = self._api._query( - collection_id=_uuid(collection_id), - where=query.where, # type: ignore - where_document=query.where_document, # type: ignore - query_embeddings=query.query_embeddings, - n_results=query.n_results, - include=query.include, + def process_query(raw_body: bytes) -> QueryResult: + query = QueryEmbedding.model_validate(orjson.loads(raw_body)) + return self._api._query( + collection_id=_uuid(collection_id), + query_embeddings=query.query_embeddings, + n_results=query.n_results, + where=query.where, # type: ignore + where_document=query.where_document, # type: ignore + include=query.include, + ) + + nnresult = cast( + QueryResult, + await to_thread.run_sync( + process_query, + await request.body(), + limiter=self._capacity_limiter, + ), ) return nnresult - def pre_flight_checks(self) -> Dict[str, Any]: - return { - "max_batch_size": self._api.max_batch_size, - } + async def pre_flight_checks(self) -> Dict[str, Any]: + def process_pre_flight_checks() -> Dict[str, Any]: + return { + "max_batch_size": self._api.max_batch_size, + } + + return cast( + Dict[str, Any], + await to_thread.run_sync( + process_pre_flight_checks, + limiter=self._capacity_limiter, + ), + ) diff --git a/chromadb/telemetry/opentelemetry/__init__.py b/chromadb/telemetry/opentelemetry/__init__.py index 0160c28183d..9f77f2c55f0 100644 --- a/chromadb/telemetry/opentelemetry/__init__.py +++ b/chromadb/telemetry/opentelemetry/__init__.py @@ -1,3 +1,4 @@ +import asyncio from functools import wraps from enum import Enum from typing import Any, Callable, Dict, Optional, Sequence, Union @@ -120,17 +121,32 @@ def trace_method( """A decorator that traces a method.""" def decorator(f: Callable[..., Any]) -> Callable[..., Any]: - @wraps(f) - def wrapper(*args: Any, **kwargs: Dict[Any, Any]) -> Any: - global tracer, granularity - if trace_granularity < granularity: - return f(*args, **kwargs) - if not tracer: - return f(*args, **kwargs) - with tracer.start_as_current_span(trace_name, attributes=attributes): - return f(*args, **kwargs) - - return wrapper + if asyncio.iscoroutinefunction(f): + + @wraps(f) + async def wrapper(*args: Any, **kwargs: Dict[Any, Any]) -> Any: + global tracer, granularity + if trace_granularity < granularity: + return await f(*args, **kwargs) + if not tracer: + return await f(*args, **kwargs) + with tracer.start_as_current_span(trace_name, attributes=attributes): + return await f(*args, **kwargs) + + return wrapper + else: + + @wraps(f) + def wrapper(*args: Any, **kwargs: Dict[Any, Any]) -> Any: + global tracer, granularity + if trace_granularity < granularity: + return f(*args, **kwargs) + if not tracer: + return f(*args, **kwargs) + with tracer.start_as_current_span(trace_name, attributes=attributes): + return f(*args, **kwargs) + + return wrapper return decorator diff --git a/chromadb/test/client/test_multiple_clients_concurrency.py b/chromadb/test/client/test_multiple_clients_concurrency.py index 14054214cbf..ce7817bbf4f 100644 --- a/chromadb/test/client/test_multiple_clients_concurrency.py +++ b/chromadb/test/client/test_multiple_clients_concurrency.py @@ -1,4 +1,5 @@ from concurrent.futures import ThreadPoolExecutor + from chromadb.api.client import AdminClient, Client from chromadb.config import DEFAULT_TENANT @@ -33,7 +34,7 @@ def run_target(n: int) -> None: with ThreadPoolExecutor(max_workers=CLIENT_COUNT) as executor: executor.map(run_target, range(CLIENT_COUNT)) - + executor.shutdown(wait=True) # Create a final client, which will be used to verify the collections were created client = Client(settings=client._system.settings)