diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index d01528b840c..f3b3c81d175 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -51,6 +51,10 @@ class FastAPI(ServerAPI): _settings: Settings _max_batch_size: int = -1 + @override + def start(self) -> None: + super().start() + @staticmethod def _validate_host(host: str) -> None: parsed = urlparse(host) @@ -141,10 +145,15 @@ def __init__(self, system: System): if self._settings.chroma_server_ssl_verify is not None: self._session.verify = self._settings.chroma_server_ssl_verify + def _raise_for_running(self) -> None: + if not self._running: + raise RuntimeError("Component not running or already closed") + @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @override def heartbeat(self) -> int: """Returns the current server time in nanoseconds to check if the server is alive""" + self._raise_for_running() resp = self._session.get(self._api_url) raise_chroma_error(resp) return int(json.loads(resp.text)["nanosecond heartbeat"]) @@ -157,6 +166,7 @@ def create_database( tenant: str = DEFAULT_TENANT, ) -> None: """Creates a database""" + self._raise_for_running() resp = self._session.post( self._api_url + "/databases", data=json.dumps({"name": name}), @@ -172,6 +182,7 @@ def get_database( tenant: str = DEFAULT_TENANT, ) -> Database: """Returns a database""" + self._raise_for_running() resp = self._session.get( self._api_url + "/databases/" + name, params={"tenant": tenant}, @@ -185,6 +196,7 @@ def get_database( @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) @override def create_tenant(self, name: str) -> None: + self._raise_for_running() resp = self._session.post( self._api_url + "/tenants", data=json.dumps({"name": name}), @@ -194,6 +206,7 @@ def create_tenant(self, name: str) -> None: @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION) @override def get_tenant(self, name: str) -> Tenant: + self._raise_for_running() resp = self._session.get( self._api_url + "/tenants/" + name, ) @@ -211,6 +224,7 @@ def list_collections( database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: """Returns a list of all collections""" + self._raise_for_running() resp = self._session.get( self._api_url + "/collections", params={ @@ -234,6 +248,7 @@ def count_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> int: """Returns a count of collections""" + self._raise_for_running() resp = self._session.get( self._api_url + "/count_collections", params={"tenant": tenant, "database": database}, @@ -256,6 +271,7 @@ def create_collection( database: str = DEFAULT_DATABASE, ) -> Collection: """Creates a collection""" + self._raise_for_running() resp = self._session.post( self._api_url + "/collections", data=json.dumps( @@ -292,6 +308,7 @@ def get_collection( database: str = DEFAULT_DATABASE, ) -> Collection: """Returns a collection""" + self._raise_for_running() if (name is None and id is None) or (name is not None and id is not None): raise ValueError("Name or id must be specified, but not both") @@ -327,6 +344,7 @@ def get_or_create_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: + self._raise_for_running() return cast( Collection, self.create_collection( @@ -349,6 +367,7 @@ def _modify( new_metadata: Optional[CollectionMetadata] = None, ) -> None: """Updates a collection""" + self._raise_for_running() resp = self._session.put( self._api_url + "/collections/" + str(id), data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}), @@ -364,6 +383,7 @@ def delete_collection( database: str = DEFAULT_DATABASE, ) -> None: """Deletes a collection""" + self._raise_for_running() resp = self._session.delete( self._api_url + "/collections/" + name, params={"tenant": tenant, "database": database}, @@ -377,6 +397,7 @@ def _count( collection_id: UUID, ) -> int: """Returns the number of embeddings in the database""" + self._raise_for_running() resp = self._session.get( self._api_url + "/collections/" + str(collection_id) + "/count" ) @@ -390,6 +411,7 @@ def _peek( collection_id: UUID, n: int = 10, ) -> GetResult: + self._raise_for_running() return cast( GetResult, self._get( @@ -414,6 +436,7 @@ def _get( where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents"], ) -> GetResult: + self._raise_for_running() if page and page_size: offset = (page - 1) * page_size limit = page_size @@ -454,6 +477,7 @@ def _delete( where_document: Optional[WhereDocument] = {}, ) -> IDs: """Deletes embeddings from the database""" + self._raise_for_running() resp = self._session.post( self._api_url + "/collections/" + str(collection_id) + "/delete", data=json.dumps( @@ -479,6 +503,7 @@ def _submit_batch( """ Submits a batch of embeddings to the database """ + self._raise_for_running() resp = self._session.post( self._api_url + url, data=json.dumps( @@ -508,6 +533,7 @@ def _add( Adds a batch of embeddings to the database - pass in column oriented data lists """ + self._raise_for_running() batch = (ids, embeddings, metadatas, documents, uris) validate_batch(batch, {"max_batch_size": self.max_batch_size}) resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") @@ -529,6 +555,7 @@ def _update( Updates a batch of embeddings in the database - pass in column oriented data lists """ + self._raise_for_running() batch = (ids, embeddings, metadatas, documents, uris) validate_batch(batch, {"max_batch_size": self.max_batch_size}) resp = self._submit_batch( @@ -552,6 +579,7 @@ def _upsert( Upserts a batch of embeddings in the database - pass in column oriented data lists """ + self._raise_for_running() batch = (ids, embeddings, metadatas, documents, uris) validate_batch(batch, {"max_batch_size": self.max_batch_size}) resp = self._submit_batch( @@ -572,6 +600,7 @@ def _query( include: Include = ["metadatas", "documents", "distances"], ) -> QueryResult: """Gets the nearest neighbors of a single embedding""" + self._raise_for_running() resp = self._session.post( self._api_url + "/collections/" + str(collection_id) + "/query", data=json.dumps( @@ -602,6 +631,7 @@ def _query( @override def reset(self) -> bool: """Resets the database""" + self._raise_for_running() resp = self._session.post(self._api_url + "/reset") raise_chroma_error(resp) return cast(bool, json.loads(resp.text)) @@ -610,6 +640,7 @@ def reset(self) -> bool: @override def get_version(self) -> str: """Returns the version of the server""" + self._raise_for_running() resp = self._session.get(self._api_url + "/version") raise_chroma_error(resp) return cast(str, json.loads(resp.text)) @@ -623,6 +654,7 @@ def get_settings(self) -> Settings: @trace_method("FastAPI.max_batch_size", OpenTelemetryGranularity.OPERATION) @override def max_batch_size(self) -> int: + self._raise_for_running() if self._max_batch_size == -1: resp = self._session.get(self._api_url + "/pre-flight-checks") raise_chroma_error(resp) @@ -632,7 +664,9 @@ def max_batch_size(self) -> int: @trace_method("FastAPI.close", OpenTelemetryGranularity.OPERATION) @override def close(self) -> None: + self._raise_for_running() self._session.close() + self._system.stop() def raise_chroma_error(resp: requests.Response) -> None: diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index f70a9378ed9..23d23c5e17e 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -109,21 +109,22 @@ def __init__(self, system: System): self._producer = self.require(Producer) self._collection_cache = {} + def _raise_for_running(self) -> None: + self._raise_for_running() + @override def start(self) -> None: super().start() @override def heartbeat(self) -> int: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() return int(time.time_ns()) @trace_method("SegmentAPI.create_database", OpenTelemetryGranularity.OPERATION) @override def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() if len(name) < 3: raise ValueError("Database name must be at least 3 characters long") @@ -132,28 +133,28 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: name=name, tenant=tenant, ) + @trace_method("SegmentAPI.get_database", OpenTelemetryGranularity.OPERATION) @override def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() return self._sysdb.get_database(name=name, tenant=tenant) + @trace_method("SegmentAPI.create_tenant", OpenTelemetryGranularity.OPERATION) @override def create_tenant(self, name: str) -> None: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() if len(name) < 3: raise ValueError("Tenant name must be at least 3 characters long") self._sysdb.create_tenant( name=name, ) + @trace_method("SegmentAPI.get_tenant", OpenTelemetryGranularity.OPERATION) @override def get_tenant(self, name: str) -> t.Tenant: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() return self._sysdb.get_tenant(name=name) # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is @@ -173,8 +174,7 @@ def create_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() if metadata is not None: validate_metadata(metadata) @@ -236,8 +236,7 @@ def get_or_create_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() return self.create_collection( # type: ignore name=name, metadata=metadata, @@ -264,8 +263,7 @@ def get_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() if id is None and name is None or (id is not None and name is not None): raise ValueError("Name or id must be specified, but not both") existing = self._sysdb.get_collections( @@ -295,8 +293,7 @@ def list_collections( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() collections = [] db_collections = self._sysdb.get_collections( limit=limit, offset=offset, tenant=tenant, database=database @@ -321,8 +318,7 @@ def count_collections( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> int: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() collection_count = len( self._sysdb.get_collections(tenant=tenant, database=database) ) @@ -337,8 +333,7 @@ def _modify( new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, ) -> None: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() if new_name: # backwards compatibility in naming requirements (for now) check_index_name(new_name) @@ -363,8 +358,7 @@ def delete_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() existing = self._sysdb.get_collections( name=name, tenant=tenant, database=database ) @@ -392,8 +386,7 @@ def _add( documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) @@ -437,8 +430,7 @@ def _update( documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) @@ -484,8 +476,7 @@ def _upsert( documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) @@ -525,8 +516,7 @@ def _get( where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], ) -> GetResult: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() add_attributes_to_current_span( { "collection_id": str(collection_id), @@ -621,8 +611,7 @@ def _delete( where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, ) -> IDs: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() add_attributes_to_current_span( { "collection_id": str(collection_id), @@ -688,8 +677,7 @@ def _delete( @trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION) @override def _count(self, collection_id: UUID) -> int: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() add_attributes_to_current_span({"collection_id": str(collection_id)}) metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() @@ -706,8 +694,7 @@ def _query( where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], ) -> QueryResult: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() add_attributes_to_current_span( { "collection_id": str(collection_id), @@ -824,14 +811,12 @@ def get_version(self) -> str: @override def reset_state(self) -> None: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() self._collection_cache = {} @override def reset(self) -> bool: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() self._system.reset_state() return True @@ -842,8 +827,7 @@ def get_settings(self) -> Settings: @property @override def max_batch_size(self) -> int: - if not self._running: - raise RuntimeError("Component not running or already closed") + self._raise_for_running() return self._producer.max_batch_size # TODO: This could potentially cause race conditions in a distributed version of the @@ -893,6 +877,7 @@ def _get_collection(self, collection_id: UUID) -> t.Collection: @trace_method("SegmentAPI.close", OpenTelemetryGranularity.ALL) @override def close(self) -> None: + self._raise_for_running() self._system.stop() diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py index b0b4e66ad4a..6365111c08e 100644 --- a/chromadb/test/test_client.py +++ b/chromadb/test/test_client.py @@ -296,6 +296,10 @@ def test_http_client_double_close(http_api: ClientAPI) -> None: assert len(_pool_manager.pools._container) > 0 http_api.close() assert len(_pool_manager.pools._container) == 0 + with pytest.raises( + RuntimeError, match="Component not running or already closed" + ): + http_api.close() def test_http_client_use_after_close(http_api: ClientAPI) -> None: @@ -312,31 +316,29 @@ def test_http_client_use_after_close(http_api: ClientAPI) -> None: assert len(_pool_manager.pools._container) > 0 http_api.close() assert len(_pool_manager.pools._container) == 0 - http_api.heartbeat() - assert len(_pool_manager.pools._container) > 0 - # with pytest.raises(RuntimeError,match="Component not running"): - # http_api.heartbeat() - # with pytest.raises(RuntimeError,match="Component not running"): - # col.add(ids=["1"], documents=["test"]) - # with pytest.raises(RuntimeError,match="Component not running"): - # col.delete(ids=["1"]) - # with pytest.raises(RuntimeError,match="Component not running"): - # col.update(ids=["1"], documents=["test1231"]) - # with pytest.raises(RuntimeError,match="Component not running"): - # col.upsert(ids=["1"], documents=["test1231"]) - # with pytest.raises(RuntimeError,match="Component not running"): - # col.count() - # with pytest.raises(RuntimeError,match="Component not running"): - # http_api.create_collection("test1") - # with pytest.raises(RuntimeError,match="Component not running"): - # http_api.get_collection("test") - # with pytest.raises(RuntimeError,match="Component not running"): - # http_api.get_or_create_collection("test") - # with pytest.raises(RuntimeError,match="Component not running"): - # http_api.list_collections() - # with pytest.raises(RuntimeError,match="Component not running"): - # http_api.delete_collection("test") - # with pytest.raises(RuntimeError,match="Component not running"): - # http_api.count_collections() - # with pytest.raises(RuntimeError,match="Component not running"): - # http_api.heartbeat() + with pytest.raises(RuntimeError, match="Component not running"): + http_api.heartbeat() + with pytest.raises(RuntimeError, match="Component not running"): + col.add(ids=["1"], documents=["test"]) + with pytest.raises(RuntimeError, match="Component not running"): + col.delete(ids=["1"]) + with pytest.raises(RuntimeError, match="Component not running"): + col.update(ids=["1"], documents=["test1231"]) + with pytest.raises(RuntimeError, match="Component not running"): + col.upsert(ids=["1"], documents=["test1231"]) + with pytest.raises(RuntimeError, match="Component not running"): + col.count() + with pytest.raises(RuntimeError, match="Component not running"): + http_api.create_collection("test1") + with pytest.raises(RuntimeError, match="Component not running"): + http_api.get_collection("test") + with pytest.raises(RuntimeError, match="Component not running"): + http_api.get_or_create_collection("test") + with pytest.raises(RuntimeError, match="Component not running"): + http_api.list_collections() + with pytest.raises(RuntimeError, match="Component not running"): + http_api.delete_collection("test") + with pytest.raises(RuntimeError, match="Component not running"): + http_api.count_collections() + with pytest.raises(RuntimeError, match="Component not running"): + http_api.heartbeat()