Skip to content

Commit

Permalink
chore: Making chromadb package optional (#32)
Browse files Browse the repository at this point in the history
Closes #30
  • Loading branch information
tazarov authored Sep 26, 2024
1 parent 4bad725 commit 9678cf3
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ jobs:
- name: Run tests
run: |
set -e
poetry update
poetry update --with dev
poetry run pytest
7 changes: 7 additions & 0 deletions chroma_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
try:
import chromadb # noqa: F401
except ImportError:
raise ValueError(
"The chromadb is not installed. This package (chromadbx) requires that Chroma is installed to work. "
"Please install it with `pip install chromadb`"
)
21 changes: 18 additions & 3 deletions chroma_ops/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def info(
results = cursor.fetchall()
collection["segments"] = []
if len(results) > 0:
collection["dimension"] = c._model.dimension if hasattr(c, "_model") else results[0][8]
collection["dimension"] = (
c._model.dimension if hasattr(c, "_model") else results[0][8]
)
for row in results:
segment = {
"id": row[0],
Expand Down Expand Up @@ -103,7 +105,20 @@ def info(
hnsw_metadata = PersistentData.load_from_file(
segment["segment_metadata_path"]
)
segment["hnsw_metadata_max_seq_id"] = hnsw_metadata.max_seq_id
# support chroma 0.5.7+
if hasattr(hnsw_metadata, "max_seq_id"):
segment[
"hnsw_metadata_max_seq_id"
] = hnsw_metadata.max_seq_id
else:
max_seq_id_query_hnsw_057 = (
"SELECT seq_id FROM max_seq_id WHERE segment_id = ?"
)
cursor.execute(max_seq_id_query_hnsw_057, [row[0]])
results = cursor.fetchall()
segment["hnsw_metadata_max_seq_id"] = (
decode_seq_id(results[0][0]) if len(results) > 0 else 0
)
segment["hnsw_metadata_total_elements"] = len(
hnsw_metadata.id_to_label
)
Expand All @@ -117,7 +132,7 @@ def info(
index.load_index(
os.path.join(segment["path"]),
is_persistent_index=True,
max_elements=hnsw_metadata.max_seq_id,
max_elements=segment["hnsw_metadata_max_seq_id"],
)
hnsw_ids = index.get_ids_list()
segment["hnsw_raw_total_elements"] = len(hnsw_ids)
Expand Down
6 changes: 3 additions & 3 deletions chroma_ops/scripts/drop_fts.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
BEGIN TRANSACTION;
DROP TABLE IF EXISTS embedding_fulltext_search;
DROP TABLE IF EXISTS embedding_fulltext_search;
DROP TABLE IF EXISTS embedding_fulltext_search_config;
DROP TABLE IF EXISTS embedding_fulltext_search_content;
DROP TABLE IF EXISTS embedding_fulltext_search_data;
DROP TABLE IF EXISTS embedding_fulltext_search_docsize;
DROP TABLE IF EXISTS embedding_fulltext_search_idx;
CREATE TABLE embedding_fulltext (id INTEGER PRIMARY KEY);
DELETE FROM migrations WHERE dir='metadb' AND version='3' AND filename='00003-full-text-tokenize.sqlite.sql';
CREATE VIRTUAL TABLE embedding_fulltext_search USING fts5(string_value, tokenize='trigram');
INSERT INTO embedding_fulltext_search (rowid, string_value) SELECT rowid, string_value FROM embedding_metadata;
COMMIT TRANSACTION;
14 changes: 11 additions & 3 deletions chroma_ops/wal_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_hnsw_index_ids,
get_dir_size,
PersistentData,
decode_seq_id,
)


Expand Down Expand Up @@ -56,10 +57,18 @@ def clean_wal(
metadata_pickle = os.path.join(persist_dir, segment_id, "index_metadata.pickle")
if os.path.exists(metadata_pickle):
metadata = PersistentData.load_from_file(metadata_pickle)
if hasattr(metadata, "max_seq_id"):
max_seq_id = metadata.max_seq_id
else:
max_seq_id_query_hnsw_057 = (
"SELECT seq_id FROM max_seq_id WHERE segment_id = ?"
)
cursor.execute(max_seq_id_query_hnsw_057, [row[0]])
results = cursor.fetchall()
max_seq_id = decode_seq_id(results[0][0]) if len(results) > 0 else 0
wal_cleanup_queries.append(
f"DELETE FROM embeddings_queue WHERE seq_id < {metadata.max_seq_id} AND topic='{topic}';"
f"DELETE FROM embeddings_queue WHERE seq_id < {max_seq_id} AND topic='{topic}';"
)
print("topic", topic)
else:
hnsw_space = cursor.execute(
"select str_value from collection_metadata where collection_id=? and key='hnsw:space'",
Expand All @@ -72,7 +81,6 @@ def clean_wal(
f"{os.path.join(persist_dir, segment_id)}", hnsw_space, row[3]
)
batch_size = 100
print(list_of_ids)
for batch in range(0, len(list_of_ids), batch_size):
wal_cleanup_queries.append(
f"DELETE FROM embeddings_queue WHERE seq_id IN ({','.join([str(i) for i in list_of_ids[batch:batch + batch_size]])});"
Expand Down
29 changes: 17 additions & 12 deletions chroma_ops/wal_commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,28 @@ def commit_wal(
client = chromadb.PersistentClient(
path=persist_dir
) # TODO we inadvetently migrate the targe DB to whatever version of Chroma is installed
vector_segments = [
s
for s in client._server._sysdb.get_segments()
if s["scope"] == SegmentScope.VECTOR
]
all_collections = client.list_collections()
vector_segments = []
for col in all_collections:
if (
col.name in skip_collection_names
if skip_collection_names
else [] or col.count() == 0
):
typer.echo(f"Ignoring skipped collection {col.name}", file=sys.stderr)
continue
vector_segments.extend(
[
s
for s in client._server._sysdb.get_segments(collection=col.id)
if s["scope"] == SegmentScope.VECTOR
]
)

for s in vector_segments:
col = client._server._get_collection(
s["collection"]
) # load the collection and apply WAL
if skip_collection_names and col["name"] in skip_collection_names:
typer.echo(f"Ignoring skipped collection {col['name']}", file=sys.stderr)
continue
if client._server._count(col.id) == 0:
typer.echo(f"Skipping empty collection {col['name']}", file=sys.stderr)
continue

client._server._manager.hint_use_collection(
s["collection"], Operation.ADD
) # Add hint to load the index into memory
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ packages = [{ include = "chroma_ops" }]

[tool.poetry.dependencies]
python = ">=3.9"
chromadb = ">=0.4.0, <0.6.0"
chromadb = { version = ">=0.4.0,<0.6.0", optional = true }
typer = {extras = ["all"], version = "^0.9.0"}


Expand All @@ -30,6 +30,7 @@ pytest = "^7.4.3"
black = "23.3.0"
pre-commit = "^3.6.0"
hypothesis = "^6.92.0"
chromadb = { version = ">=0.4.0,<0.6.0" }

[tool.mypy]
python_version = "3.9"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rebuild_fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from chroma_ops.rebuild_fts import rebuild_fts


def test_basic_clean() -> None:
def test_rebuild_fts() -> None:
records_to_add = 1
with tempfile.TemporaryDirectory() as temp_dir:
client = chromadb.PersistentClient(path=temp_dir)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_wal_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def test_basic_clean(records_to_add: int) -> None:
conn = sqlite3.connect(f"file:{sql_file}?mode=ro", uri=True)
cursor = conn.cursor()
count = cursor.execute("SELECT count(*) FROM embeddings_queue")
assert count.fetchone()[0] == records_to_add
if tuple(int(part) for part in chromadb.__version__.split(".")) > (0, 5, 5):
assert count.fetchone()[0] == 1
else:
assert count.fetchone()[0] == records_to_add
clean_wal(temp_dir)
count = cursor.execute("SELECT count(*) FROM embeddings_queue")
assert count.fetchone()[0] < records_to_add
Expand Down
39 changes: 33 additions & 6 deletions tests/test_wal_commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ def test_basic_commit(records_to_add: int) -> None:

vector_segments = [
client._server._manager.get_segment(s["collection"], VectorReader)
for s in client._server._sysdb.get_segments()
for s in client._server._sysdb.get_segments(collection=col.id)
if s["scope"] == SegmentScope.VECTOR
]
for segment in vector_segments:
batch_size = segment._batch_size
_sync_threshold = segment._sync_threshold
if records_to_add % batch_size == 0:
assert (
len(segment._index.get_ids_list()) == records_to_add
Expand All @@ -46,10 +47,23 @@ def test_basic_commit(records_to_add: int) -> None:
cursor = conn.cursor()

count = cursor.execute("SELECT count(*) FROM embeddings_queue")
assert count.fetchone()[0] == records_to_add
if tuple(int(part) for part in chromadb.__version__.split(".")) > (0, 5, 5):
if records_to_add % _sync_threshold == 0:
assert count.fetchone()[0] == 1
else:
assert count.fetchone()[0] <= records_to_add
else:
assert count.fetchone()[0] == records_to_add

commit_wal(temp_dir)
count = cursor.execute("SELECT count(*) FROM embeddings_queue")
assert count.fetchone()[0] == records_to_add
if tuple(int(part) for part in chromadb.__version__.split(".")) > (0, 5, 5):
if records_to_add % _sync_threshold == 0:
assert count.fetchone()[0] == 1
else:
assert count.fetchone()[0] <= records_to_add
else:
assert count.fetchone()[0] == records_to_add
for segment in vector_segments:
assert (
len(segment._index.get_ids_list()) == records_to_add
Expand All @@ -73,11 +87,12 @@ def test_commit_skip_collection(records_to_add: int) -> None:

vector_segments = [
client._server._manager.get_segment(s["collection"], VectorReader)
for s in client._server._sysdb.get_segments()
for s in client._server._sysdb.get_segments(collection=col.id)
if s["scope"] == SegmentScope.VECTOR
]
for segment in vector_segments:
batch_size = segment._batch_size
_sync_threshold = segment._sync_threshold
if records_to_add % batch_size == 0:
assert (
len(segment._index.get_ids_list()) == records_to_add
Expand All @@ -89,10 +104,22 @@ def test_commit_skip_collection(records_to_add: int) -> None:
cursor = conn.cursor()

count = cursor.execute("SELECT count(*) FROM embeddings_queue")
assert count.fetchone()[0] == records_to_add
if tuple(int(part) for part in chromadb.__version__.split(".")) > (0, 5, 5):
if records_to_add % _sync_threshold == 0:
assert count.fetchone()[0] == 1
else:
assert count.fetchone()[0] <= records_to_add
else:
assert count.fetchone()[0] == records_to_add
commit_wal(temp_dir, skip_collection_names=["test"])
count = cursor.execute("SELECT count(*) FROM embeddings_queue")
assert count.fetchone()[0] == records_to_add
if tuple(int(part) for part in chromadb.__version__.split(".")) > (0, 5, 5):
if records_to_add % _sync_threshold == 0:
assert count.fetchone()[0] == 1
else:
assert count.fetchone()[0] <= records_to_add
else:
assert count.fetchone()[0] == records_to_add
for segment in vector_segments:
if records_to_add % batch_size == 0:
assert (
Expand Down
16 changes: 15 additions & 1 deletion tests/test_wal_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import tempfile
import uuid

from chromadb.segment import VectorReader
from chromadb.types import SegmentScope
from hypothesis import given, settings
import hypothesis.strategies as st

Expand Down Expand Up @@ -29,6 +31,18 @@ def test_basic_export(records_to_add: int) -> None:
ids, documents, embeddings = zip(*ids_documents)
col.add(ids=list(ids), documents=list(documents), embeddings=list(embeddings))
with tempfile.NamedTemporaryFile() as temp_file:
vector_segments = [
client._server._manager.get_segment(s["collection"], VectorReader)
for s in client._server._sysdb.get_segments(collection=col.id)
if s["scope"] == SegmentScope.VECTOR
]
_sync_threshold = vector_segments[0]._sync_threshold
export_wal(temp_dir, temp_file.name)
assert os.path.exists(temp_file.name)
assert count_lines(temp_file.name) == records_to_add
if tuple(int(part) for part in chromadb.__version__.split(".")) > (0, 5, 5):
if records_to_add % _sync_threshold == 0:
assert count_lines(temp_file.name) == 1
else:
assert count_lines(temp_file.name) <= records_to_add
else:
assert count_lines(temp_file.name) == records_to_add

0 comments on commit 9678cf3

Please sign in to comment.