diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index d3d1a8a4e7e..c8e89a84c9a 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,4 +1,4 @@ -import json +import orjson as json import logging from typing import Optional, cast, Tuple from typing import Sequence @@ -145,7 +145,7 @@ def heartbeat(self) -> int: """Returns the current server time in nanoseconds to check if the server is alive""" resp = self._session.get(self._api_url) raise_chroma_error(resp) - return int(resp.json()["nanosecond heartbeat"]) + return int(json.loads(resp.text)["nanosecond heartbeat"]) @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION) @override @@ -175,7 +175,7 @@ def get_database( params={"tenant": tenant}, ) raise_chroma_error(resp) - resp_json = resp.json() + resp_json = json.loads(resp.text) return Database( id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"] ) @@ -196,7 +196,7 @@ def get_tenant(self, name: str) -> Tenant: self._api_url + "/tenants/" + name, ) raise_chroma_error(resp) - resp_json = resp.json() + resp_json = json.loads(resp.text) return Tenant(name=resp_json["name"]) @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) @@ -219,7 +219,7 @@ def list_collections( }, ) raise_chroma_error(resp) - json_collections = resp.json() + json_collections = json.loads(resp.text) collections = [] for json_collection in json_collections: collections.append(Collection(self, **json_collection)) @@ -237,7 +237,7 @@ def count_collections( params={"tenant": tenant, "database": database}, ) raise_chroma_error(resp) - return cast(int, resp.json()) + return cast(int, json.loads(resp.text)) @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override @@ -266,7 +266,7 @@ def create_collection( params={"tenant": tenant, "database": database}, ) raise_chroma_error(resp) - resp_json = resp.json() + resp_json = json.loads(resp.text) return Collection( client=self, id=resp_json["id"], @@ -300,7 +300,7 @@ def get_collection( self._api_url + "/collections/" + name if name else str(id), params=_params ) raise_chroma_error(resp) - resp_json = resp.json() + resp_json = json.loads(resp.text) return Collection( client=self, name=resp_json["name"], @@ -379,7 +379,7 @@ def _count( self._api_url + "/collections/" + str(collection_id) + "/count" ) raise_chroma_error(resp) - return cast(int, resp.json()) + return cast(int, json.loads(resp.text)) @trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION) @override @@ -432,7 +432,7 @@ def _get( ) raise_chroma_error(resp) - body = resp.json() + body = json.loads(resp.text) return GetResult( ids=body["ids"], embeddings=body.get("embeddings", None), @@ -460,7 +460,7 @@ def _delete( ) raise_chroma_error(resp) - return cast(IDs, resp.json()) + return cast(IDs, json.loads(resp.text)) @trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL) def _submit_batch( @@ -584,7 +584,7 @@ def _query( ) raise_chroma_error(resp) - body = resp.json() + body = json.loads(resp.text) return QueryResult( ids=body["ids"], @@ -602,7 +602,7 @@ def reset(self) -> bool: """Resets the database""" resp = self._session.post(self._api_url + "/reset") raise_chroma_error(resp) - return cast(bool, resp.json()) + return cast(bool, json.loads(resp.text)) @trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION) @override @@ -610,7 +610,7 @@ def get_version(self) -> str: """Returns the version of the server""" resp = self._session.get(self._api_url + "/version") raise_chroma_error(resp) - return cast(str, resp.json()) + return cast(str, json.loads(resp.text)) @override def get_settings(self) -> Settings: @@ -624,7 +624,7 @@ def max_batch_size(self) -> int: if self._max_batch_size == -1: resp = self._session.get(self._api_url + "/pre-flight-checks") raise_chroma_error(resp) - self._max_batch_size = cast(int, resp.json()["max_batch_size"]) + self._max_batch_size = cast(int, json.loads(resp.text)["max_batch_size"]) return self._max_batch_size @@ -635,7 +635,7 @@ def raise_chroma_error(resp: requests.Response) -> None: chroma_error = None try: - body = resp.json() + body = json.loads(resp.text) if "error" in body: if body["error"] in errors.error_types: chroma_error = errors.error_types[body["error"]](body["message"]) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index b62c002d095..edd0c00d7cf 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ 'typing_extensions >= 4.5.0', 'tenacity>=8.2.3', 'PyYAML>=6.0.0', + 'orjson>=3.9.12', ] [tool.black] diff --git a/clients/python/requirements.txt b/clients/python/requirements.txt index 1242bf7d7e0..b977b03f064 100644 --- a/clients/python/requirements.txt +++ b/clients/python/requirements.txt @@ -9,3 +9,4 @@ PyYAML>=6.0.0 requests >= 2.28 tenacity>=8.2.3 typing_extensions >= 4.5.0 +orjson>=3.9.12