Skip to content

Commit

Permalink
chore: Simplified tests and removed testcontainers
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Oct 4, 2024
1 parent ff3d64f commit 89a1522
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 50 deletions.
83 changes: 34 additions & 49 deletions chromadb/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import pytest
import tempfile

from testcontainers.chroma import ChromaContainer


@pytest.fixture
def ephemeral_api() -> Generator[ClientAPI, None, None]:
Expand Down Expand Up @@ -202,27 +200,6 @@ def test_persistent_client_use_after_close() -> None:
persistent_api.heartbeat()


@pytest.fixture(params=["sync_client", "async_client"])
def http_client_with_tc(
request: pytest.FixtureRequest,
) -> Generator[ClientAPI, None, None]:
with ChromaContainer() as chroma:
config = chroma.get_config()
if request.param == "sync_client":
http_api = chromadb.HttpClient(host=config["host"], port=config["port"])
yield http_api
http_api.clear_system_cache()
else:

async def init_client() -> AsyncClientAPI:
http_api = await chromadb.AsyncHttpClient(
host=config["host"], port=config["port"]
)
return http_api

yield asyncio.get_event_loop().run_until_complete(init_client())


def get_connection_count(api_client: ClientAPI) -> int:
if isinstance(api_client, AsyncClientAPI):
connections = 0
Expand All @@ -235,7 +212,11 @@ def get_connection_count(api_client: ClientAPI) -> int:
return len(_pool._connections)


def test_http_client_close(http_client_with_tc: ClientAPI) -> None:
def test_http_client_close(client: ClientAPI) -> None:
if client.get_settings().chroma_api_impl == "chromadb.api.segment.SegmentAPI":
pytest.skip(
"Skipping test that closes the persistent client in integration test"
)
if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY") == "1":
pytest.skip(
"Skipping test that closes the persistent client in integration test"
Expand All @@ -244,23 +225,27 @@ def test_http_client_close(http_client_with_tc: ClientAPI) -> None:
async def run_in_async(c: AsyncClientAPI):
col = await c.create_collection("test" + uuid.uuid4().hex)
await col.add(ids=["1"], documents=["test"])
assert get_connection_count(http_client_with_tc) > 0
assert get_connection_count(client) > 0
await c.close()
assert get_connection_count(http_client_with_tc) == 0
assert get_connection_count(client) == 0

if isinstance(http_client_with_tc, AsyncClientAPI):
if isinstance(client, AsyncClientAPI):
asyncio.get_event_loop().run_until_complete(
run_in_async(cast(AsyncClientAPI, http_client_with_tc))
run_in_async(cast(AsyncClientAPI, client))
)
else:
col = http_client_with_tc.create_collection("test" + uuid.uuid4().hex)
col = client.create_collection("test" + uuid.uuid4().hex)
col.add(ids=["1"], documents=["test"])
assert get_connection_count(http_client_with_tc) > 0
http_client_with_tc.close()
assert get_connection_count(http_client_with_tc) == 0
assert get_connection_count(client) > 0
client.close()
assert get_connection_count(client) == 0


def test_http_client_use_after_close(http_client_with_tc: ClientAPI) -> None:
def test_http_client_use_after_close(client: ClientAPI) -> None:
if client.get_settings().chroma_api_impl == "chromadb.api.segment.SegmentAPI":
pytest.skip(
"Skipping test that closes the persistent client in integration test"
)
if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY") == "1":
pytest.skip(
"Skipping test that closes the persistent client in integration test"
Expand All @@ -269,9 +254,9 @@ def test_http_client_use_after_close(http_client_with_tc: ClientAPI) -> None:
async def run_in_async(c: AsyncClientAPI):
col = await c.create_collection("test" + uuid.uuid4().hex)
await col.add(ids=["1"], documents=["test"])
assert get_connection_count(http_client_with_tc) > 0
assert get_connection_count(client) > 0
await c.close()
assert get_connection_count(http_client_with_tc) == 0
assert get_connection_count(client) == 0
with pytest.raises(RuntimeError, match="Component not running"):
await c.heartbeat()
with pytest.raises(RuntimeError, match="Component not running"):
Expand All @@ -293,18 +278,18 @@ async def run_in_async(c: AsyncClientAPI):
with pytest.raises(RuntimeError, match="Component not running"):
await c.list_collections()

if isinstance(http_client_with_tc, AsyncClientAPI):
if isinstance(client, AsyncClientAPI):
asyncio.get_event_loop().run_until_complete(
run_in_async(cast(AsyncClientAPI, http_client_with_tc))
run_in_async(cast(AsyncClientAPI, client))
)
else:
col = http_client_with_tc.create_collection("test" + uuid.uuid4().hex)
col = client.create_collection("test" + uuid.uuid4().hex)
col.add(ids=["1"], documents=["test"])
assert get_connection_count(http_client_with_tc) > 0
http_client_with_tc.close()
assert get_connection_count(http_client_with_tc) == 0
assert get_connection_count(client) > 0
client.close()
assert get_connection_count(client) == 0
with pytest.raises(RuntimeError, match="Component not running"):
http_client_with_tc.heartbeat()
client.heartbeat()
with pytest.raises(RuntimeError, match="Component not running"):
col.add(ids=["1"], documents=["test"])
with pytest.raises(RuntimeError, match="Component not running"):
Expand All @@ -316,16 +301,16 @@ async def run_in_async(c: AsyncClientAPI):
with pytest.raises(RuntimeError, match="Component not running"):
col.count()
with pytest.raises(RuntimeError, match="Component not running"):
http_client_with_tc.create_collection("test1")
client.create_collection("test1")
with pytest.raises(RuntimeError, match="Component not running"):
http_client_with_tc.get_collection("test")
client.get_collection("test")
with pytest.raises(RuntimeError, match="Component not running"):
http_client_with_tc.get_or_create_collection("test")
client.get_or_create_collection("test")
with pytest.raises(RuntimeError, match="Component not running"):
http_client_with_tc.list_collections()
client.list_collections()
with pytest.raises(RuntimeError, match="Component not running"):
http_client_with_tc.delete_collection("test")
client.delete_collection("test")
with pytest.raises(RuntimeError, match="Component not running"):
http_client_with_tc.count_collections()
client.count_collections()
with pytest.raises(RuntimeError, match="Component not running"):
http_client_with_tc.heartbeat()
client.heartbeat()
1 change: 0 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ pytest
pytest-asyncio
pytest-xdist
setuptools_scm
testcontainers[chroma]>=4.7.0
types-protobuf

0 comments on commit 89a1522

Please sign in to comment.