Skip to content

Commit

Permalink
feat: Chroma python client orjson serialization
Browse files Browse the repository at this point in the history
Initial implementation.
  • Loading branch information
tazarov committed Feb 6, 2024
1 parent 64b63bb commit 92df427
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
32 changes: 16 additions & 16 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
import orjson as json
import logging
from typing import Optional, cast, Tuple
from typing import Sequence
Expand Down Expand Up @@ -145,7 +145,7 @@ def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive"""
resp = self._session.get(self._api_url)
raise_chroma_error(resp)
return int(resp.json()["nanosecond heartbeat"])
return int(json.loads(resp.text)["nanosecond heartbeat"])

@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -175,7 +175,7 @@ def get_database(
params={"tenant": tenant},
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Database(
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
)
Expand All @@ -196,7 +196,7 @@ def get_tenant(self, name: str) -> Tenant:
self._api_url + "/tenants/" + name,
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Tenant(name=resp_json["name"])

@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
Expand All @@ -219,7 +219,7 @@ def list_collections(
},
)
raise_chroma_error(resp)
json_collections = resp.json()
json_collections = json.loads(resp.text)
collections = []
for json_collection in json_collections:
collections.append(Collection(self, **json_collection))
Expand All @@ -237,7 +237,7 @@ def count_collections(
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
return cast(int, resp.json())
return cast(int, json.loads(resp.text))

@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -266,7 +266,7 @@ def create_collection(
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Collection(
client=self,
id=resp_json["id"],
Expand Down Expand Up @@ -300,7 +300,7 @@ def get_collection(
self._api_url + "/collections/" + name if name else str(id), params=_params
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Collection(
client=self,
name=resp_json["name"],
Expand Down Expand Up @@ -379,7 +379,7 @@ def _count(
self._api_url + "/collections/" + str(collection_id) + "/count"
)
raise_chroma_error(resp)
return cast(int, resp.json())
return cast(int, json.loads(resp.text))

@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -432,7 +432,7 @@ def _get(
)

raise_chroma_error(resp)
body = resp.json()
body = json.loads(resp.text)
return GetResult(
ids=body["ids"],
embeddings=body.get("embeddings", None),
Expand Down Expand Up @@ -460,7 +460,7 @@ def _delete(
)

raise_chroma_error(resp)
return cast(IDs, resp.json())
return cast(IDs, json.loads(resp.text))

@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
def _submit_batch(
Expand Down Expand Up @@ -584,7 +584,7 @@ def _query(
)

raise_chroma_error(resp)
body = resp.json()
body = json.loads(resp.text)

return QueryResult(
ids=body["ids"],
Expand All @@ -602,15 +602,15 @@ def reset(self) -> bool:
"""Resets the database"""
resp = self._session.post(self._api_url + "/reset")
raise_chroma_error(resp)
return cast(bool, resp.json())
return cast(bool, json.loads(resp.text))

@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
@override
def get_version(self) -> str:
"""Returns the version of the server"""
resp = self._session.get(self._api_url + "/version")
raise_chroma_error(resp)
return cast(str, resp.json())
return cast(str, json.loads(resp.text))

@override
def get_settings(self) -> Settings:
Expand All @@ -624,7 +624,7 @@ def max_batch_size(self) -> int:
if self._max_batch_size == -1:
resp = self._session.get(self._api_url + "/pre-flight-checks")
raise_chroma_error(resp)
self._max_batch_size = cast(int, resp.json()["max_batch_size"])
self._max_batch_size = cast(int, json.loads(resp.text)["max_batch_size"])
return self._max_batch_size


Expand All @@ -635,7 +635,7 @@ def raise_chroma_error(resp: requests.Response) -> None:

chroma_error = None
try:
body = resp.json()
body = json.loads(resp.text)
if "error" in body:
if body["error"] in errors.error_types:
chroma_error = errors.error_types[body["error"]](body["message"])
Expand Down
1 change: 1 addition & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
'typing_extensions >= 4.5.0',
'tenacity>=8.2.3',
'PyYAML>=6.0.0',
'orjson>=3.9.12',
]

[tool.black]
Expand Down
1 change: 1 addition & 0 deletions clients/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ PyYAML>=6.0.0
requests >= 2.28
tenacity>=8.2.3
typing_extensions >= 4.5.0
orjson>=3.9.12

0 comments on commit 92df427

Please sign in to comment.