Skip to content

Commit

Permalink
test: Update fastapi for raising error after close + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Mar 21, 2024
1 parent 4e4a50e commit 0e35813
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 72 deletions.
34 changes: 34 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class FastAPI(ServerAPI):
_settings: Settings
_max_batch_size: int = -1

@override
def start(self) -> None:
super().start()

@staticmethod
def _validate_host(host: str) -> None:
parsed = urlparse(host)
Expand Down Expand Up @@ -141,10 +145,15 @@ def __init__(self, system: System):
if self._settings.chroma_server_ssl_verify is not None:
self._session.verify = self._settings.chroma_server_ssl_verify

def _raise_for_running(self) -> None:
if not self._running:
raise RuntimeError("Component not running or already closed")

@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
@override
def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive"""
self._raise_for_running()
resp = self._session.get(self._api_url)
raise_chroma_error(resp)
return int(json.loads(resp.text)["nanosecond heartbeat"])
Expand All @@ -157,6 +166,7 @@ def create_database(
tenant: str = DEFAULT_TENANT,
) -> None:
"""Creates a database"""
self._raise_for_running()
resp = self._session.post(
self._api_url + "/databases",
data=json.dumps({"name": name}),
Expand All @@ -172,6 +182,7 @@ def get_database(
tenant: str = DEFAULT_TENANT,
) -> Database:
"""Returns a database"""
self._raise_for_running()
resp = self._session.get(
self._api_url + "/databases/" + name,
params={"tenant": tenant},
Expand All @@ -185,6 +196,7 @@ def get_database(
@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
self._raise_for_running()
resp = self._session.post(
self._api_url + "/tenants",
data=json.dumps({"name": name}),
Expand All @@ -194,6 +206,7 @@ def create_tenant(self, name: str) -> None:
@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> Tenant:
self._raise_for_running()
resp = self._session.get(
self._api_url + "/tenants/" + name,
)
Expand All @@ -211,6 +224,7 @@ def list_collections(
database: str = DEFAULT_DATABASE,
) -> Sequence[Collection]:
"""Returns a list of all collections"""
self._raise_for_running()
resp = self._session.get(
self._api_url + "/collections",
params={
Expand All @@ -234,6 +248,7 @@ def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int:
"""Returns a count of collections"""
self._raise_for_running()
resp = self._session.get(
self._api_url + "/count_collections",
params={"tenant": tenant, "database": database},
Expand All @@ -256,6 +271,7 @@ def create_collection(
database: str = DEFAULT_DATABASE,
) -> Collection:
"""Creates a collection"""
self._raise_for_running()
resp = self._session.post(
self._api_url + "/collections",
data=json.dumps(
Expand Down Expand Up @@ -292,6 +308,7 @@ def get_collection(
database: str = DEFAULT_DATABASE,
) -> Collection:
"""Returns a collection"""
self._raise_for_running()
if (name is None and id is None) or (name is not None and id is not None):
raise ValueError("Name or id must be specified, but not both")

Expand Down Expand Up @@ -327,6 +344,7 @@ def get_or_create_collection(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
self._raise_for_running()
return cast(
Collection,
self.create_collection(
Expand All @@ -349,6 +367,7 @@ def _modify(
new_metadata: Optional[CollectionMetadata] = None,
) -> None:
"""Updates a collection"""
self._raise_for_running()
resp = self._session.put(
self._api_url + "/collections/" + str(id),
data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
Expand All @@ -364,6 +383,7 @@ def delete_collection(
database: str = DEFAULT_DATABASE,
) -> None:
"""Deletes a collection"""
self._raise_for_running()
resp = self._session.delete(
self._api_url + "/collections/" + name,
params={"tenant": tenant, "database": database},
Expand All @@ -377,6 +397,7 @@ def _count(
collection_id: UUID,
) -> int:
"""Returns the number of embeddings in the database"""
self._raise_for_running()
resp = self._session.get(
self._api_url + "/collections/" + str(collection_id) + "/count"
)
Expand All @@ -390,6 +411,7 @@ def _peek(
collection_id: UUID,
n: int = 10,
) -> GetResult:
self._raise_for_running()
return cast(
GetResult,
self._get(
Expand All @@ -414,6 +436,7 @@ def _get(
where_document: Optional[WhereDocument] = {},
include: Include = ["metadatas", "documents"],
) -> GetResult:
self._raise_for_running()
if page and page_size:
offset = (page - 1) * page_size
limit = page_size
Expand Down Expand Up @@ -454,6 +477,7 @@ def _delete(
where_document: Optional[WhereDocument] = {},
) -> IDs:
"""Deletes embeddings from the database"""
self._raise_for_running()
resp = self._session.post(
self._api_url + "/collections/" + str(collection_id) + "/delete",
data=json.dumps(
Expand All @@ -479,6 +503,7 @@ def _submit_batch(
"""
Submits a batch of embeddings to the database
"""
self._raise_for_running()
resp = self._session.post(
self._api_url + url,
data=json.dumps(
Expand Down Expand Up @@ -508,6 +533,7 @@ def _add(
Adds a batch of embeddings to the database
- pass in column oriented data lists
"""
self._raise_for_running()
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
Expand All @@ -529,6 +555,7 @@ def _update(
Updates a batch of embeddings in the database
- pass in column oriented data lists
"""
self._raise_for_running()
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
Expand All @@ -552,6 +579,7 @@ def _upsert(
Upserts a batch of embeddings in the database
- pass in column oriented data lists
"""
self._raise_for_running()
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
Expand All @@ -572,6 +600,7 @@ def _query(
include: Include = ["metadatas", "documents", "distances"],
) -> QueryResult:
"""Gets the nearest neighbors of a single embedding"""
self._raise_for_running()
resp = self._session.post(
self._api_url + "/collections/" + str(collection_id) + "/query",
data=json.dumps(
Expand Down Expand Up @@ -602,6 +631,7 @@ def _query(
@override
def reset(self) -> bool:
"""Resets the database"""
self._raise_for_running()
resp = self._session.post(self._api_url + "/reset")
raise_chroma_error(resp)
return cast(bool, json.loads(resp.text))
Expand All @@ -610,6 +640,7 @@ def reset(self) -> bool:
@override
def get_version(self) -> str:
"""Returns the version of the server"""
self._raise_for_running()
resp = self._session.get(self._api_url + "/version")
raise_chroma_error(resp)
return cast(str, json.loads(resp.text))
Expand All @@ -623,6 +654,7 @@ def get_settings(self) -> Settings:
@trace_method("FastAPI.max_batch_size", OpenTelemetryGranularity.OPERATION)
@override
def max_batch_size(self) -> int:
self._raise_for_running()
if self._max_batch_size == -1:
resp = self._session.get(self._api_url + "/pre-flight-checks")
raise_chroma_error(resp)
Expand All @@ -632,7 +664,9 @@ def max_batch_size(self) -> int:
@trace_method("FastAPI.close", OpenTelemetryGranularity.OPERATION)
@override
def close(self) -> None:
self._raise_for_running()
self._session.close()
self._system.stop()


def raise_chroma_error(resp: requests.Response) -> None:
Expand Down
Loading

0 comments on commit 0e35813

Please sign in to comment.