Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: Adding close() method to clients #1792

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5348b7f
feat: Adding close() method to clients - WIP
tazarov Feb 23, 2024
6a4cd70
feat: Added tests
tazarov Mar 15, 2024
0a79168
fix: Added psutil to dev deps
tazarov Mar 15, 2024
1602fdb
fix: Fixing for integration tests
tazarov Mar 16, 2024
20ca135
fix: Adding debug to tests
tazarov Mar 16, 2024
883082a
fix: Checking for persistent path in test.
tazarov Mar 20, 2024
4e4a50e
test: Added test for double close and use after close
tazarov Mar 21, 2024
0e35813
test: Update fastapi for raising error after close + tests
tazarov Mar 21, 2024
4837dc3
fix: Fixed a recursion and a failing test for double close
tazarov Mar 21, 2024
4dd37aa
fix: Added api component start in tests
tazarov Mar 21, 2024
8ef16a0
fix: Added open file filtering to prevent failing test due to multipl…
tazarov Mar 21, 2024
c208914
fix: Simplified assertions
tazarov Mar 21, 2024
53c0b96
fix: Added reset
tazarov Mar 21, 2024
a26fb03
fix: Added a fix for Windows paths escape
tazarov Mar 21, 2024
7dd6683
fix: Added debug for windows failures
tazarov Mar 21, 2024
1114eaa
fix: Added debug for windows failures
tazarov Mar 21, 2024
fcc06ab
fix: Create persistent client for each test (no fixture)
tazarov Mar 22, 2024
2bc23ef
fix: Simplifying regex in test
tazarov Mar 22, 2024
f87b92c
fix: Changing the regexp match
tazarov Mar 26, 2024
cf1b0b8
chore: Adding random collection name
tazarov Mar 26, 2024
cb31780
chore: Adding random collection name
tazarov Mar 26, 2024
d353f91
test: Removing backslash escapes
tazarov Mar 27, 2024
f9264ca
test: Regex on windows are the bane of my existence!!!
tazarov Mar 27, 2024
b0b8b6d
test: Regex, why!
tazarov Mar 27, 2024
628587a
test: Regex, why!
tazarov Mar 27, 2024
78dd080
test: Weird persistent dir under windows
tazarov Mar 27, 2024
b57fb44
test: Me - 1, Windows Paths - 0
tazarov Mar 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ def max_batch_size(self) -> int:
to submit_embeddings."""
pass

@abstractmethod
def close(self) -> None:
"""Close the client and release any resources."""
pass


class ClientAPI(BaseAPI, ABC):
tenant: str
Expand Down Expand Up @@ -473,6 +478,11 @@ def clear_system_cache() -> None:
This should only be used for testing purposes."""
pass

@abstractmethod
def close(self) -> None:
"""Close the client and release any resources."""
pass


class AdminAPI(ABC):
@abstractmethod
Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ def _validate_tenant_database(self, tenant: str, database: str) -> None:
f"Could not connect to database {database} for tenant {tenant}. Are you sure it exists?"
)

@override
def close(self) -> None:
self._server.close()

# endregion


Expand Down
39 changes: 39 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class FastAPI(ServerAPI):
_settings: Settings
_max_batch_size: int = -1

@override
def start(self) -> None:
super().start()

@staticmethod
def _validate_host(host: str) -> None:
parsed = urlparse(host)
Expand Down Expand Up @@ -141,10 +145,15 @@ def __init__(self, system: System):
if self._settings.chroma_server_ssl_verify is not None:
self._session.verify = self._settings.chroma_server_ssl_verify

def _raise_for_running(self) -> None:
if not self._running:
raise RuntimeError("Component not running or already closed")

@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
@override
def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive"""
self._raise_for_running()
resp = self._session.get(self._api_url)
raise_chroma_error(resp)
return int(json.loads(resp.text)["nanosecond heartbeat"])
Expand All @@ -157,6 +166,7 @@ def create_database(
tenant: str = DEFAULT_TENANT,
) -> None:
"""Creates a database"""
self._raise_for_running()
resp = self._session.post(
self._api_url + "/databases",
data=json.dumps({"name": name}),
Expand All @@ -172,6 +182,7 @@ def get_database(
tenant: str = DEFAULT_TENANT,
) -> Database:
"""Returns a database"""
self._raise_for_running()
resp = self._session.get(
self._api_url + "/databases/" + name,
params={"tenant": tenant},
Expand All @@ -185,6 +196,7 @@ def get_database(
@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
self._raise_for_running()
resp = self._session.post(
self._api_url + "/tenants",
data=json.dumps({"name": name}),
Expand All @@ -194,6 +206,7 @@ def create_tenant(self, name: str) -> None:
@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> Tenant:
self._raise_for_running()
resp = self._session.get(
self._api_url + "/tenants/" + name,
)
Expand All @@ -211,6 +224,7 @@ def list_collections(
database: str = DEFAULT_DATABASE,
) -> Sequence[Collection]:
"""Returns a list of all collections"""
self._raise_for_running()
resp = self._session.get(
self._api_url + "/collections",
params={
Expand All @@ -234,6 +248,7 @@ def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int:
"""Returns a count of collections"""
self._raise_for_running()
resp = self._session.get(
self._api_url + "/count_collections",
params={"tenant": tenant, "database": database},
Expand All @@ -256,6 +271,7 @@ def create_collection(
database: str = DEFAULT_DATABASE,
) -> Collection:
"""Creates a collection"""
self._raise_for_running()
resp = self._session.post(
self._api_url + "/collections",
data=json.dumps(
Expand Down Expand Up @@ -292,6 +308,7 @@ def get_collection(
database: str = DEFAULT_DATABASE,
) -> Collection:
"""Returns a collection"""
self._raise_for_running()
if (name is None and id is None) or (name is not None and id is not None):
raise ValueError("Name or id must be specified, but not both")

Expand Down Expand Up @@ -327,6 +344,7 @@ def get_or_create_collection(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
self._raise_for_running()
return cast(
Collection,
self.create_collection(
Expand All @@ -349,6 +367,7 @@ def _modify(
new_metadata: Optional[CollectionMetadata] = None,
) -> None:
"""Updates a collection"""
self._raise_for_running()
resp = self._session.put(
self._api_url + "/collections/" + str(id),
data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
Expand All @@ -364,6 +383,7 @@ def delete_collection(
database: str = DEFAULT_DATABASE,
) -> None:
"""Deletes a collection"""
self._raise_for_running()
resp = self._session.delete(
self._api_url + "/collections/" + name,
params={"tenant": tenant, "database": database},
Expand All @@ -377,6 +397,7 @@ def _count(
collection_id: UUID,
) -> int:
"""Returns the number of embeddings in the database"""
self._raise_for_running()
resp = self._session.get(
self._api_url + "/collections/" + str(collection_id) + "/count"
)
Expand All @@ -390,6 +411,7 @@ def _peek(
collection_id: UUID,
n: int = 10,
) -> GetResult:
self._raise_for_running()
return cast(
GetResult,
self._get(
Expand All @@ -414,6 +436,7 @@ def _get(
where_document: Optional[WhereDocument] = {},
include: Include = ["metadatas", "documents"],
) -> GetResult:
self._raise_for_running()
if page and page_size:
offset = (page - 1) * page_size
limit = page_size
Expand Down Expand Up @@ -454,6 +477,7 @@ def _delete(
where_document: Optional[WhereDocument] = {},
) -> IDs:
"""Deletes embeddings from the database"""
self._raise_for_running()
resp = self._session.post(
self._api_url + "/collections/" + str(collection_id) + "/delete",
data=json.dumps(
Expand All @@ -479,6 +503,7 @@ def _submit_batch(
"""
Submits a batch of embeddings to the database
"""
self._raise_for_running()
resp = self._session.post(
self._api_url + url,
data=json.dumps(
Expand Down Expand Up @@ -508,6 +533,7 @@ def _add(
Adds a batch of embeddings to the database
- pass in column oriented data lists
"""
self._raise_for_running()
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
Expand All @@ -529,6 +555,7 @@ def _update(
Updates a batch of embeddings in the database
- pass in column oriented data lists
"""
self._raise_for_running()
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
Expand All @@ -552,6 +579,7 @@ def _upsert(
Upserts a batch of embeddings in the database
- pass in column oriented data lists
"""
self._raise_for_running()
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
Expand All @@ -572,6 +600,7 @@ def _query(
include: Include = ["metadatas", "documents", "distances"],
) -> QueryResult:
"""Gets the nearest neighbors of a single embedding"""
self._raise_for_running()
resp = self._session.post(
self._api_url + "/collections/" + str(collection_id) + "/query",
data=json.dumps(
Expand Down Expand Up @@ -602,6 +631,7 @@ def _query(
@override
def reset(self) -> bool:
"""Resets the database"""
self._raise_for_running()
resp = self._session.post(self._api_url + "/reset")
raise_chroma_error(resp)
return cast(bool, json.loads(resp.text))
Expand All @@ -610,6 +640,7 @@ def reset(self) -> bool:
@override
def get_version(self) -> str:
"""Returns the version of the server"""
self._raise_for_running()
resp = self._session.get(self._api_url + "/version")
raise_chroma_error(resp)
return cast(str, json.loads(resp.text))
Expand All @@ -623,12 +654,20 @@ def get_settings(self) -> Settings:
@trace_method("FastAPI.max_batch_size", OpenTelemetryGranularity.OPERATION)
@override
def max_batch_size(self) -> int:
self._raise_for_running()
if self._max_batch_size == -1:
resp = self._session.get(self._api_url + "/pre-flight-checks")
raise_chroma_error(resp)
self._max_batch_size = cast(int, json.loads(resp.text)["max_batch_size"])
return self._max_batch_size

@trace_method("FastAPI.close", OpenTelemetryGranularity.OPERATION)
@override
def close(self) -> None:
self._raise_for_running()
self._session.close()
self._system.stop()


def raise_chroma_error(resp: requests.Response) -> None:
"""Raises an error if the response is not ok, using a ChromaError if possible"""
Expand Down
Loading
Loading