From 37923c52279057e8d14464eeb87d87695010c69a Mon Sep 17 00:00:00 2001 From: William Black <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 9 Jan 2025 23:38:32 -0800 Subject: [PATCH 1/3] Upgrade ChromaDB to >=0.6.0 and fix broken tests (#530) Co-authored-by: Philip Meier --- pyproject.toml | 2 +- ragna/source_storages/_chroma.py | 28 +++++++---- ragna/source_storages/_lancedb.py | 2 +- ragna/source_storages/_vector_database.py | 2 +- tests/source_storages/test_source_storages.py | 47 +++++++++++++++++-- 5 files changed, 66 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 46dc9b80..f849db91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ Repository = "https://github.com/Quansight/ragna" [project.optional-dependencies] # to update the array below, run scripts/update_optional_dependencies.py all = [ - "chromadb<=0.5.11,>=0.4.13", + "chromadb>=0.6.0", "httpx_sse", "ijson", "lancedb>=0.2", diff --git a/ragna/source_storages/_chroma.py b/ragna/source_storages/_chroma.py index 70f8a90a..ba9b3e35 100644 --- a/ragna/source_storages/_chroma.py +++ b/ragna/source_storages/_chroma.py @@ -19,7 +19,20 @@ class Chroma(VectorDatabaseSourceStorage): !!! info "Required packages" - - `chromadb>=0.4.13` + - `chromadb>=0.6.0` + + !!! warning + + The `NE` and `NOT_IN` metadata filter operators behave differently in Chroma + than the other builtin source storages. With most other source storages, + given a key-value pair `(key, value)`, the operators `NE` and `NOT_IN` return + only the sources with a metadata key `key` and a value not equal to or + not in, respectively, `value`. To contrast, the `NE` and `NOT_IN` metadata filter + operators in `ChromaDB` return everything described in the preceding sentence, + together with all sources that do not have the metadata key `key`. + + For more information, see the notes for `v0.5.12` in the + [`ChromaDB` migration guide](https://docs.trychroma.com/production/administration/migration). """ # Note that this class has no extra requirements, since the chromadb package is @@ -39,7 +52,7 @@ def __init__(self) -> None: ) def list_corpuses(self) -> list[str]: - return [collection.name for collection in self._client.list_collections()] + return [str(c) for c in self._client.list_collections()] def _get_collection( self, corpus_name: str, *, create: bool = False @@ -49,15 +62,14 @@ def _get_collection( corpus_name, embedding_function=self._embedding_function ) - collections = list(self._client.list_collections()) - if not collections: + corpuses = self.list_corpuses() + if not corpuses: raise_no_corpuses_available(self) try: - return next( - collection - for collection in collections - if collection.name == corpus_name + return self._client.get_collection( + name=next(name for name in corpuses if name == corpus_name), + embedding_function=self._embedding_function, ) except StopIteration: raise_non_existing_corpus(self, corpus_name) diff --git a/ragna/source_storages/_lancedb.py b/ragna/source_storages/_lancedb.py index e59c1df6..4ada5187 100644 --- a/ragna/source_storages/_lancedb.py +++ b/ragna/source_storages/_lancedb.py @@ -27,7 +27,7 @@ class LanceDB(VectorDatabaseSourceStorage): !!! info "Required packages" - - `chromadb>=0.4.13` + - `chromadb>=0.6.0` - `lancedb>=0.2` - `pyarrow` """ diff --git a/ragna/source_storages/_vector_database.py b/ragna/source_storages/_vector_database.py index d6e6a8b9..2cb66398 100644 --- a/ragna/source_storages/_vector_database.py +++ b/ragna/source_storages/_vector_database.py @@ -46,7 +46,7 @@ def requirements(cls) -> list[Requirement]: # to manage and mostly not even used by the vector DB. Chroma provides a # wrapper around a compiled embedding function that has only minimal # requirements. We use this as base for all of our Vector DBs. - PackageRequirement("chromadb<=0.5.11,>=0.4.13"), + PackageRequirement("chromadb>=0.6.0"), PackageRequirement("tiktoken"), ] diff --git a/tests/source_storages/test_source_storages.py b/tests/source_storages/test_source_storages.py index ea183b6b..fe382eb0 100644 --- a/tests/source_storages/test_source_storages.py +++ b/tests/source_storages/test_source_storages.py @@ -7,6 +7,7 @@ from ragna.core import ( LocalDocument, MetadataFilter, + MetadataOperator, PlainTextDocumentHandler, RagnaException, ) @@ -69,7 +70,9 @@ MetadataFilter.and_( [ MetadataFilter.eq("key", "other_value"), - MetadataFilter.ne("other_key", "other_value"), + MetadataFilter.in_( + "other_key", ["some_value", "other_value"] + ), ] ), ] @@ -104,7 +107,13 @@ @pytest.mark.parametrize( "source_storage_cls", set(SOURCE_STORAGES) - {RagnaDemoSourceStorage} ) -def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idcs): +def test_smoke( + tmp_local_root, + source_storage_cls, + metadata_filter, + expected_idcs, + chroma_override=False, +): document_root = tmp_local_root / "documents" document_root.mkdir() documents = [] @@ -135,13 +144,43 @@ def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idc num_tokens=num_tokens, ) - actual_idcs = sorted(map(int, (source.document_name for source in sources))) - assert actual_idcs == expected_idcs + if ( + not ( + source_storage_cls is Chroma + and isinstance(metadata_filter, MetadataFilter) + and metadata_filter.operator + in { + MetadataOperator.NE, + MetadataOperator.NOT_IN, + } + ) + or chroma_override + ): + actual_idcs = sorted(map(int, (source.document_name for source in sources))) + assert actual_idcs == expected_idcs # Should be able to call .store() multiple times source_storage.store(corpus_name, documents) +@pytest.mark.parametrize( + ("metadata_filter", "expected_idcs"), + [ + pytest.param(MetadataFilter.ne("key", "value"), [2, 3, 4, 5, 6], id="ne"), + pytest.param( + MetadataFilter.not_in("key", ["foo", "bar"]), [0, 1, 2, 3, 4], id="not_in" + ), + ], +) +def test_chroma_ne_nin_non_existing_keys( + tmp_local_root, metadata_filter, expected_idcs +): + # See https://github.com/Quansight/ragna/issues/523 for details + test_smoke( + tmp_local_root, Chroma, metadata_filter, expected_idcs, chroma_override=True + ) + + @pytest.mark.parametrize("source_storage_cls", [Chroma, LanceDB]) def test_corpus_names(tmp_local_root, source_storage_cls): document_root = tmp_local_root / "documents" From 4385620c864a02d64db4c13a3e34e85719ba54a2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 09:14:24 +0100 Subject: [PATCH 2/3] Update requirements-docker.lock (#535) Co-authored-by: smokestacklightnin --- requirements-docker.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-docker.lock b/requirements-docker.lock index e277b8c1..2a4074cc 100644 --- a/requirements-docker.lock +++ b/requirements-docker.lock @@ -41,7 +41,7 @@ charset-normalizer==3.4.0 # via requests chroma-hnswlib==0.7.6 # via chromadb -chromadb==0.5.11 +chromadb==0.6.2 # via Ragna (pyproject.toml) click==8.1.7 # via From 37d0998d0a80d86e2f6bdb44d8a65d06462d07e1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 10 Jan 2025 09:14:37 +0100 Subject: [PATCH 3/3] Fix Qdrant source storage (#534) --- ragna/source_storages/_qdrant.py | 43 ++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/ragna/source_storages/_qdrant.py b/ragna/source_storages/_qdrant.py index fe494fbb..68a179f2 100644 --- a/ragna/source_storages/_qdrant.py +++ b/ragna/source_storages/_qdrant.py @@ -54,19 +54,16 @@ def __init__(self) -> None: from qdrant_client import QdrantClient - url = os.getenv("QDRANT_URL") - api_key = os.getenv("QDRANT_API_KEY") - path = ragna.local_root() / "qdrant" - - # Cannot pass both url and path - self._client = ( - QdrantClient(url=url, api_key=api_key) if url else QdrantClient(path=path) - ) + if (url := os.environ.get("QDRANT_URL")) is not None: + kwargs = dict(url=url, api_key=os.environ.get("QDRANT_API_KEY")) + else: + kwargs = dict(path=str(ragna.local_root() / "qdrant")) + self._client = QdrantClient(**kwargs) # type: ignore[arg-type] def list_corpuses(self) -> list[str]: return [c.name for c in self._client.get_collections().collections] - def _ensure_table(self, corpus_name: str, *, create: bool = False): + def _ensure_table(self, corpus_name: str, *, create: bool = False) -> None: table_names = self.list_corpuses() no_corpuses = not table_names non_existing_corpus = corpus_name not in table_names @@ -91,6 +88,7 @@ def list_metadata( if corpus_name is None: corpus_names = self.list_corpuses() else: + self._ensure_table(corpus_name) corpus_names = [corpus_name] metadata = {} @@ -101,7 +99,7 @@ def list_metadata( corpus_metadata = defaultdict(set) for point in points: - for key, value in point.payload.items(): + for key, value in cast(dict[str, Any], point.payload).items(): if any( [ (key.startswith("__") and key.endswith("__")), @@ -142,7 +140,10 @@ def store( points.append( models.PointStruct( id=str(uuid.uuid4()), - vector=self._embedding_function([chunk.text])[0], + vector=cast( + list[float], + self._embedding_function([chunk.text])[0].tolist(), + ), payload={ "document_id": str(document.id), "document_name": document.name, @@ -158,7 +159,9 @@ def store( self._client.upsert(collection_name=corpus_name, points=points) - def _build_condition(self, operator, key, value): + def _build_condition( + self, operator: MetadataOperator, key: str, value: Any + ) -> models.FieldCondition: from qdrant_client import models # See https://qdrant.tech/documentation/concepts/filtering/#range @@ -184,7 +187,7 @@ def _build_condition(self, operator, key, value): def _translate_metadata_filter( self, metadata_filter: MetadataFilter - ) -> models.Filter: + ) -> models.Filter | models.FieldCondition: from qdrant_client import models if metadata_filter.operator is MetadataOperator.RAW: @@ -247,12 +250,14 @@ def retrieve( return self._take_sources_up_to_max_tokens( ( Source( - id=point.id, - document_id=point.payload["document_id"], - document_name=point.payload["document_name"], - location=point.payload["__page_numbers__"], - content=point.payload[self.DOC_CONTENT_KEY], - num_tokens=point.payload["__num_tokens__"], + id=cast(str, point.id), + document_id=(payload := cast(dict[str, Any], point.payload))[ + "document_id" + ], + document_name=payload["document_name"], + location=payload["__page_numbers__"], + content=payload[self.DOC_CONTENT_KEY], + num_tokens=payload["__num_tokens__"], ) for point in points ),