Skip to content

Commit

Permalink
feat: Improved server-side serialization with orjson + async
Browse files Browse the repository at this point in the history
- Added orjson serialization improving serialization performance especially on large batches 100+ docs 2x faster (tested with locust)
- Added async body serialization further improving performance (tested with locust)
- Added async handling with AnyIO of the more impactful server queries further reducing concurrent request response times

Future work:  Add orjson serialization at client-side
  • Loading branch information
tazarov committed Feb 3, 2024
1 parent e5751fd commit 316db9b
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 71 deletions.
4 changes: 3 additions & 1 deletion chromadb/api/segment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import cached_property

from chromadb.api import ServerAPI
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.db.system import SysDB
Expand Down Expand Up @@ -780,7 +782,7 @@ def reset(self) -> bool:
def get_settings(self) -> Settings:
return self._settings

@property
@cached_property
@override
def max_batch_size(self) -> int:
return self._producer.max_batch_size
Expand Down
6 changes: 5 additions & 1 deletion chromadb/auth/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import chromadb
from contextvars import ContextVar
from functools import wraps
Expand Down Expand Up @@ -173,7 +175,7 @@ def authz_context(
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
async def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
_dynamic_kwargs = {
"api": args[0]._api,
"function": f,
Expand Down Expand Up @@ -239,6 +241,8 @@ def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
):
kwargs["database"].name = desired_database

if asyncio.iscoroutinefunction(f):
return await f(*args, **kwargs)
return f(*args, **kwargs)

return wrapped
Expand Down
148 changes: 91 additions & 57 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Callable, Dict, List, Sequence, Optional
from typing import Any, Callable, Dict, List, Sequence, Optional, cast
import fastapi
from fastapi import FastAPI as _FastAPI, Response
import orjson
from anyio import to_thread
from fastapi import FastAPI as _FastAPI, Response, Request
from fastapi.responses import JSONResponse

from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -46,7 +48,8 @@
UpdateCollection,
UpdateEmbedding,
)
from starlette.requests import Request

# from starlette.requests import Request

import logging

Expand All @@ -63,6 +66,11 @@
logger = logging.getLogger(__name__)


class ORJSONResponse(JSONResponse): # type: ignore
def render(self, content: Any) -> bytes:
return orjson.dumps(content) # type: ignore


def use_route_names_as_operation_ids(app: _FastAPI) -> None:
"""
Simplify operation IDs so that generated API clients have simpler function
Expand Down Expand Up @@ -126,7 +134,7 @@ class FastAPI(chromadb.server.Server):
def __init__(self, settings: Settings):
super().__init__(settings)
ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI
self._app = fastapi.FastAPI(debug=True)
self._app = fastapi.FastAPI(debug=True, default_response_class=ORJSONResponse)
self._system = System(settings)
self._api: ServerAPI = self._system.instance(ServerAPI)
self._opentelemetry_client = self._api.require(OpenTelemetryClient)
Expand Down Expand Up @@ -462,14 +470,14 @@ def update_collection(
),
),
)
def delete_collection(
async def delete_collection(
self,
collection_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
return self._api.delete_collection(
collection_name, tenant=tenant, database=database
return await to_thread.run_sync( # type: ignore
self._api.delete_collection, collection_name, tenant, database
)

@trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
Expand All @@ -481,15 +489,18 @@ def delete_collection(
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def add(self, collection_id: str, add: AddEmbedding) -> None:
async def add(self, request: Request, collection_id: str) -> None:
try:
result = self._api._add(
collection_id=_uuid(collection_id),
embeddings=add.embeddings, # type: ignore
metadatas=add.metadatas, # type: ignore
documents=add.documents, # type: ignore
uris=add.uris, # type: ignore
ids=add.ids,
raw_body = await request.body()
add = AddEmbedding.model_validate(orjson.loads(raw_body))
result = await to_thread.run_sync(
self._api._add,
add.ids,
_uuid(collection_id),
add.embeddings,
add.metadatas,
add.documents,
add.uris,
)
except InvalidDimensionException as e:
raise HTTPException(status_code=500, detail=str(e))
Expand All @@ -504,14 +515,17 @@ def add(self, collection_id: str, add: AddEmbedding) -> None:
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def update(self, collection_id: str, add: UpdateEmbedding) -> None:
self._api._update(
ids=add.ids,
collection_id=_uuid(collection_id),
embeddings=add.embeddings,
documents=add.documents, # type: ignore
uris=add.uris, # type: ignore
metadatas=add.metadatas, # type: ignore
async def update(self, request: Request, collection_id:str) -> None:
raw_body = await request.body()
add = UpdateEmbedding.model_validate(orjson.loads(raw_body))
await to_thread.run_sync(
self._api._update,
_uuid(collection_id),
add.ids,
add.embeddings,
add.metadatas,
add.documents,
add.uris,
)

@trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION)
Expand All @@ -523,14 +537,18 @@ def update(self, collection_id: str, add: UpdateEmbedding) -> None:
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def upsert(self, collection_id: str, upsert: AddEmbedding) -> None:
self._api._upsert(
collection_id=_uuid(collection_id),
ids=upsert.ids,
embeddings=upsert.embeddings, # type: ignore
documents=upsert.documents, # type: ignore
uris=upsert.uris, # type: ignore
metadatas=upsert.metadatas, # type: ignore
async def upsert(self, request: Request, collection_id:str) -> None:
raw_body = await request.body()
collection_id = request.path_params.get("collection_id")
upsert = AddEmbedding.model_validate(orjson.loads(raw_body))
await to_thread.run_sync(
self._api._upsert,
_uuid(collection_id),
upsert.ids,
upsert.embeddings,
upsert.metadatas,
upsert.documents,
upsert.uris,
)

@trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION)
Expand All @@ -542,16 +560,22 @@ def upsert(self, collection_id: str, upsert: AddEmbedding) -> None:
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
return self._api._get(
collection_id=_uuid(collection_id),
ids=get.ids,
where=get.where,
where_document=get.where_document,
sort=get.sort,
limit=get.limit,
offset=get.offset,
include=get.include,
async def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
return cast(
GetResult,
await to_thread.run_sync(
self._api._get,
_uuid(collection_id),
get.ids,
get.where,
get.sort,
get.limit,
get.offset,
None,
None,
get.where_document,
get.include,
),
)

@trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION)
Expand All @@ -563,12 +587,16 @@ def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
return self._api._delete(
where=delete.where, # type: ignore
ids=delete.ids,
collection_id=_uuid(collection_id),
where_document=delete.where_document,
async def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
return cast(
List[UUID],
await to_thread.run_sync(
self._api._delete,
_uuid(collection_id),
delete.ids,
delete.where,
delete.where_document,
),
)

@trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION)
Expand All @@ -580,8 +608,10 @@ def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def count(self, collection_id: str) -> int:
return self._api._count(_uuid(collection_id))
async def count(self, collection_id: str) -> int:
return cast(
int, await to_thread.run_sync(self._api._count, _uuid(collection_id))
)

@trace_method("FastAPI.reset", OpenTelemetryGranularity.OPERATION)
@authz_context(
Expand All @@ -603,16 +633,20 @@ def reset(self) -> bool:
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def get_nearest_neighbors(
async def get_nearest_neighbors(
self, collection_id: str, query: QueryEmbedding
) -> QueryResult:
nnresult = self._api._query(
collection_id=_uuid(collection_id),
where=query.where, # type: ignore
where_document=query.where_document, # type: ignore
query_embeddings=query.query_embeddings,
n_results=query.n_results,
include=query.include,
nnresult = cast(
QueryResult,
await to_thread.run_sync(
self._api._query,
_uuid(collection_id),
query.query_embeddings,
query.n_results,
query.where,
query.where_document,
query.include,
),
)
return nnresult

Expand Down
38 changes: 27 additions & 11 deletions chromadb/telemetry/opentelemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from functools import wraps
from enum import Enum
from typing import Any, Callable, Dict, Optional, Sequence, Union
Expand Down Expand Up @@ -120,17 +121,32 @@ def trace_method(
"""A decorator that traces a method."""

def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
global tracer, granularity
if trace_granularity < granularity:
return f(*args, **kwargs)
if not tracer:
return f(*args, **kwargs)
with tracer.start_as_current_span(trace_name, attributes=attributes):
return f(*args, **kwargs)

return wrapper
if asyncio.iscoroutinefunction(f):

@wraps(f)
async def wrapper(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
global tracer, granularity
if trace_granularity < granularity:
return await f(*args, **kwargs)
if not tracer:
return await f(*args, **kwargs)
with tracer.start_as_current_span(trace_name, attributes=attributes):
return await f(*args, **kwargs)

return wrapper
else:

@wraps(f)
def wrapper(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
global tracer, granularity
if trace_granularity < granularity:
return f(*args, **kwargs)
if not tracer:
return f(*args, **kwargs)
with tracer.start_as_current_span(trace_name, attributes=attributes):
return f(*args, **kwargs)

return wrapper

return decorator

Expand Down
4 changes: 3 additions & 1 deletion chromadb/test/client/test_multiple_clients_concurrency.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
from time import sleep

from chromadb.api.client import AdminClient, Client
from chromadb.config import DEFAULT_TENANT

Expand Down Expand Up @@ -33,7 +35,7 @@ def run_target(n: int) -> None:

with ThreadPoolExecutor(max_workers=CLIENT_COUNT) as executor:
executor.map(run_target, range(CLIENT_COUNT))

executor.shutdown(wait=True)
# Create a final client, which will be used to verify the collections were created
client = Client(settings=client._system.settings)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
'tenacity>=8.2.3',
'PyYAML>=6.0.0',
'mmh3>=4.0.1',
'orjson>=3.9.12'
]

[tool.black]
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ tqdm>=4.65.0
typer>=0.9.0
typing_extensions>=4.5.0
uvicorn[standard]==0.18.3
orjson>=3.9.12

0 comments on commit 316db9b

Please sign in to comment.