Skip to content

Commit

Permalink
Merge branch 'main' into remove-api-wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Jan 10, 2025
2 parents 5aa90b0 + 37d0998 commit 258151d
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Repository = "https://github.com/Quansight/ragna"
[project.optional-dependencies]
# to update the array below, run scripts/update_optional_dependencies.py
all = [
"chromadb<=0.5.11,>=0.4.13",
"chromadb>=0.6.0",
"httpx_sse",
"ijson",
"lancedb>=0.2",
Expand Down
28 changes: 20 additions & 8 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,20 @@ class Chroma(VectorDatabaseSourceStorage):
!!! info "Required packages"
- `chromadb>=0.4.13`
- `chromadb>=0.6.0`
!!! warning
The `NE` and `NOT_IN` metadata filter operators behave differently in Chroma
than the other builtin source storages. With most other source storages,
given a key-value pair `(key, value)`, the operators `NE` and `NOT_IN` return
only the sources with a metadata key `key` and a value not equal to or
not in, respectively, `value`. To contrast, the `NE` and `NOT_IN` metadata filter
operators in `ChromaDB` return everything described in the preceding sentence,
together with all sources that do not have the metadata key `key`.
For more information, see the notes for `v0.5.12` in the
[`ChromaDB` migration guide](https://docs.trychroma.com/production/administration/migration).
"""

# Note that this class has no extra requirements, since the chromadb package is
Expand All @@ -39,7 +52,7 @@ def __init__(self) -> None:
)

def list_corpuses(self) -> list[str]:
return [collection.name for collection in self._client.list_collections()]
return [str(c) for c in self._client.list_collections()]

def _get_collection(
self, corpus_name: str, *, create: bool = False
Expand All @@ -49,15 +62,14 @@ def _get_collection(
corpus_name, embedding_function=self._embedding_function
)

collections = list(self._client.list_collections())
if not collections:
corpuses = self.list_corpuses()
if not corpuses:
raise_no_corpuses_available(self)

try:
return next(
collection
for collection in collections
if collection.name == corpus_name
return self._client.get_collection(
name=next(name for name in corpuses if name == corpus_name),
embedding_function=self._embedding_function,
)
except StopIteration:
raise_non_existing_corpus(self, corpus_name)
Expand Down
2 changes: 1 addition & 1 deletion ragna/source_storages/_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LanceDB(VectorDatabaseSourceStorage):
!!! info "Required packages"
- `chromadb>=0.4.13`
- `chromadb>=0.6.0`
- `lancedb>=0.2`
- `pyarrow`
"""
Expand Down
43 changes: 24 additions & 19 deletions ragna/source_storages/_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,16 @@ def __init__(self) -> None:

from qdrant_client import QdrantClient

url = os.getenv("QDRANT_URL")
api_key = os.getenv("QDRANT_API_KEY")
path = ragna.local_root() / "qdrant"

# Cannot pass both url and path
self._client = (
QdrantClient(url=url, api_key=api_key) if url else QdrantClient(path=path)
)
if (url := os.environ.get("QDRANT_URL")) is not None:
kwargs = dict(url=url, api_key=os.environ.get("QDRANT_API_KEY"))
else:
kwargs = dict(path=str(ragna.local_root() / "qdrant"))
self._client = QdrantClient(**kwargs) # type: ignore[arg-type]

def list_corpuses(self) -> list[str]:
return [c.name for c in self._client.get_collections().collections]

def _ensure_table(self, corpus_name: str, *, create: bool = False):
def _ensure_table(self, corpus_name: str, *, create: bool = False) -> None:
table_names = self.list_corpuses()
no_corpuses = not table_names
non_existing_corpus = corpus_name not in table_names
Expand All @@ -91,6 +88,7 @@ def list_metadata(
if corpus_name is None:
corpus_names = self.list_corpuses()
else:
self._ensure_table(corpus_name)
corpus_names = [corpus_name]

metadata = {}
Expand All @@ -101,7 +99,7 @@ def list_metadata(

corpus_metadata = defaultdict(set)
for point in points:
for key, value in point.payload.items():
for key, value in cast(dict[str, Any], point.payload).items():
if any(
[
(key.startswith("__") and key.endswith("__")),
Expand Down Expand Up @@ -142,7 +140,10 @@ def store(
points.append(
models.PointStruct(
id=str(uuid.uuid4()),
vector=self._embedding_function([chunk.text])[0],
vector=cast(
list[float],
self._embedding_function([chunk.text])[0].tolist(),
),
payload={
"document_id": str(document.id),
"document_name": document.name,
Expand All @@ -158,7 +159,9 @@ def store(

self._client.upsert(collection_name=corpus_name, points=points)

def _build_condition(self, operator, key, value):
def _build_condition(
self, operator: MetadataOperator, key: str, value: Any
) -> models.FieldCondition:
from qdrant_client import models

# See https://qdrant.tech/documentation/concepts/filtering/#range
Expand All @@ -184,7 +187,7 @@ def _build_condition(self, operator, key, value):

def _translate_metadata_filter(
self, metadata_filter: MetadataFilter
) -> models.Filter:
) -> models.Filter | models.FieldCondition:
from qdrant_client import models

if metadata_filter.operator is MetadataOperator.RAW:
Expand Down Expand Up @@ -247,12 +250,14 @@ def retrieve(
return self._take_sources_up_to_max_tokens(
(
Source(
id=point.id,
document_id=point.payload["document_id"],
document_name=point.payload["document_name"],
location=point.payload["__page_numbers__"],
content=point.payload[self.DOC_CONTENT_KEY],
num_tokens=point.payload["__num_tokens__"],
id=cast(str, point.id),
document_id=(payload := cast(dict[str, Any], point.payload))[
"document_id"
],
document_name=payload["document_name"],
location=payload["__page_numbers__"],
content=payload[self.DOC_CONTENT_KEY],
num_tokens=payload["__num_tokens__"],
)
for point in points
),
Expand Down
2 changes: 1 addition & 1 deletion ragna/source_storages/_vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def requirements(cls) -> list[Requirement]:
# to manage and mostly not even used by the vector DB. Chroma provides a
# wrapper around a compiled embedding function that has only minimal
# requirements. We use this as base for all of our Vector DBs.
PackageRequirement("chromadb<=0.5.11,>=0.4.13"),
PackageRequirement("chromadb>=0.6.0"),
PackageRequirement("tiktoken"),
]

Expand Down
2 changes: 1 addition & 1 deletion requirements-docker.lock
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ charset-normalizer==3.4.0
# via requests
chroma-hnswlib==0.7.6
# via chromadb
chromadb==0.5.11
chromadb==0.6.2
# via Ragna (pyproject.toml)
click==8.1.7
# via
Expand Down
47 changes: 43 additions & 4 deletions tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ragna.core import (
LocalDocument,
MetadataFilter,
MetadataOperator,
PlainTextDocumentHandler,
RagnaException,
)
Expand Down Expand Up @@ -69,7 +70,9 @@
MetadataFilter.and_(
[
MetadataFilter.eq("key", "other_value"),
MetadataFilter.ne("other_key", "other_value"),
MetadataFilter.in_(
"other_key", ["some_value", "other_value"]
),
]
),
]
Expand Down Expand Up @@ -104,7 +107,13 @@
@pytest.mark.parametrize(
"source_storage_cls", set(SOURCE_STORAGES) - {RagnaDemoSourceStorage}
)
def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idcs):
def test_smoke(
tmp_local_root,
source_storage_cls,
metadata_filter,
expected_idcs,
chroma_override=False,
):
document_root = tmp_local_root / "documents"
document_root.mkdir()
documents = []
Expand Down Expand Up @@ -135,13 +144,43 @@ def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idc
num_tokens=num_tokens,
)

actual_idcs = sorted(map(int, (source.document_name for source in sources)))
assert actual_idcs == expected_idcs
if (
not (
source_storage_cls is Chroma
and isinstance(metadata_filter, MetadataFilter)
and metadata_filter.operator
in {
MetadataOperator.NE,
MetadataOperator.NOT_IN,
}
)
or chroma_override
):
actual_idcs = sorted(map(int, (source.document_name for source in sources)))
assert actual_idcs == expected_idcs

# Should be able to call .store() multiple times
source_storage.store(corpus_name, documents)


@pytest.mark.parametrize(
("metadata_filter", "expected_idcs"),
[
pytest.param(MetadataFilter.ne("key", "value"), [2, 3, 4, 5, 6], id="ne"),
pytest.param(
MetadataFilter.not_in("key", ["foo", "bar"]), [0, 1, 2, 3, 4], id="not_in"
),
],
)
def test_chroma_ne_nin_non_existing_keys(
tmp_local_root, metadata_filter, expected_idcs
):
# See https://github.com/Quansight/ragna/issues/523 for details
test_smoke(
tmp_local_root, Chroma, metadata_filter, expected_idcs, chroma_override=True
)


@pytest.mark.parametrize("source_storage_cls", [Chroma, LanceDB])
def test_corpus_names(tmp_local_root, source_storage_cls):
document_root = tmp_local_root / "documents"
Expand Down

0 comments on commit 258151d

Please sign in to comment.