-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR intends to add [Qdrant](https://qdrant.tech/) as a supported loading destination. You can find Qdrant setup instructions here. https://qdrant.tech/documentation/quick-start/
- Loading branch information
Showing
7 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
FROM --platform=linux/amd64 python:3.8-slim as base | ||
|
||
# System dependencies | ||
RUN apt-get update && \ | ||
apt-get upgrade -y && \ | ||
apt-get install git -y | ||
|
||
# Install requirements | ||
COPY requirements.txt / | ||
RUN pip3 install --no-cache-dir -r requirements.txt | ||
|
||
# Install Fondant | ||
# This is split from other requirements to leverage caching | ||
ARG FONDANT_VERSION=main | ||
RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} | ||
|
||
# Set the working directory to the component folder | ||
WORKDIR /component | ||
COPY src/ src/ | ||
ENV PYTHONPATH "${PYTHONPATH}:./src" | ||
|
||
FROM base as test | ||
COPY test_requirements.txt . | ||
RUN pip3 install --no-cache-dir -r test_requirements.txt | ||
COPY tests/ tests/ | ||
RUN python -m pytest tests | ||
|
||
FROM base | ||
WORKDIR /component/src | ||
ENTRYPOINT ["fondant", "execute", "main"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Index Qdrant | ||
|
||
### Description | ||
A Fondant component to load textual data and embeddings into a Qdrant database. NOTE: A Qdrant collection has to be created in advance with the appropriate configurations. https://qdrant.tech/documentation/concepts/collections/ | ||
|
||
### Inputs / outputs | ||
|
||
**This component consumes:** | ||
|
||
- text | ||
- data: string | ||
- embedding: list<item: float> | ||
|
||
**This component produces no data.** | ||
|
||
### Arguments | ||
|
||
The component takes the following arguments to alter its behavior: | ||
|
||
| argument | type | description | default | | ||
| -------- | ---- | ----------- | ------- | | ||
| collection_name | str | The name of the Qdrant collection to upsert data into. | / | | ||
| location | str | The location of the Qdrant instance. | / | | ||
| batch_size | int | The batch size to use when uploading points to Qdrant. | 64 | | ||
| parallelism | int | The number of parallel workers to use when uploading points to Qdrant. | 1 | | ||
| url | str | Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'. | / | | ||
| port | int | Port of the REST API interface. | 6333 | | ||
| grpc_port | int | Port of the gRPC interface. | 6334 | | ||
| prefer_grpc | bool | If `true` - use gRPC interface whenever possible in custom methods. | / | | ||
| https | bool | If `true` - use HTTPS(SSL) protocol. | / | | ||
| api_key | str | API key for authentication in Qdrant Cloud. | / | | ||
| prefix | str | If set, add `prefix` to the REST URL path. | / | | ||
| timeout | int | Timeout for API requests. | / | | ||
| host | str | Host name of Qdrant service. If url and host are not set, defaults to 'localhost'. | / | | ||
| path | str | Persistence path for QdrantLocal. Eg. `local_data/qdrant` | / | | ||
| force_disable_check_same_thread | bool | Force disable check_same_thread for QdrantLocal sqlite connection. | / | | ||
|
||
### Usage | ||
|
||
You can add this component to your pipeline using the following code: | ||
|
||
```python | ||
from fondant.pipeline import ComponentOp | ||
|
||
|
||
index_qdrant_op = ComponentOp.from_registry( | ||
name="index_qdrant", | ||
arguments={ | ||
# Add arguments | ||
# "collection_name": , | ||
# "location": , | ||
# "batch_size": 64, | ||
# "parallelism": 1, | ||
# "url": , | ||
# "port": 6333, | ||
# "grpc_port": 6334, | ||
# "prefer_grpc": False, | ||
# "https": False, | ||
# "api_key": , | ||
# "prefix": , | ||
# "timeout": 0, | ||
# "host": , | ||
# "path": , | ||
# "force_disable_check_same_thread": False, | ||
} | ||
) | ||
pipeline.add_op(index_qdrant_op, dependencies=[...]) #Add previous component as dependency | ||
``` | ||
|
||
### Testing | ||
|
||
You can run the tests using docker with BuildKit. From this directory, run: | ||
``` | ||
docker build . --target test | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
name: Index Qdrant | ||
description: >- | ||
A Fondant component to load textual data and embeddings into a Qdrant database. | ||
NOTE: A Qdrant collection has to be created in advance with the appropriate configurations. https://qdrant.tech/documentation/concepts/collections/ | ||
image: 'fndnt/index_qdrant:dev' | ||
tags: | ||
- Data writing | ||
consumes: | ||
text: | ||
fields: | ||
data: | ||
type: string | ||
embedding: | ||
type: array | ||
items: | ||
type: float32 | ||
args: | ||
collection_name: | ||
description: The name of the Qdrant collection to upsert data into. | ||
type: str | ||
location: | ||
description: The location of the Qdrant instance. | ||
type: str | ||
default: None | ||
batch_size: | ||
description: The batch size to use when uploading points to Qdrant. | ||
type: int | ||
default: 64 | ||
parallelism: | ||
description: The number of parallel workers to use when uploading points to Qdrant. | ||
type: int | ||
default: 1 | ||
url: | ||
description: >- | ||
Either host or str of 'Optional[scheme], host, Optional[port], | ||
Optional[prefix]'. | ||
type: str | ||
default: None | ||
port: | ||
description: Port of the REST API interface. | ||
type: int | ||
default: 6333 | ||
grpc_port: | ||
description: Port of the gRPC interface. | ||
type: int | ||
default: 6334 | ||
prefer_grpc: | ||
description: If `true` - use gRPC interface whenever possible in custom methods. | ||
type: bool | ||
default: False | ||
https: | ||
description: If `true` - use HTTPS(SSL) protocol. | ||
type: bool | ||
default: False | ||
api_key: | ||
description: API key for authentication in Qdrant Cloud. | ||
type: str | ||
default: None | ||
prefix: | ||
description: 'If set, add `prefix` to the REST URL path.' | ||
type: str | ||
default: None | ||
timeout: | ||
description: Timeout for API requests. | ||
type: int | ||
default: None | ||
host: | ||
description: >- | ||
Host name of Qdrant service. If url and host are not set, defaults to | ||
'localhost'. | ||
type: str | ||
default: None | ||
path: | ||
description: Persistence path for QdrantLocal. Eg. `local_data/qdrant` | ||
type: str | ||
default: None | ||
force_disable_check_same_thread: | ||
description: Force disable check_same_thread for QdrantLocal sqlite connection. | ||
type: bool | ||
default: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
qdrant_client==1.6.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import ast | ||
from typing import List, Optional | ||
|
||
import dask.dataframe as dd | ||
from fondant.component import DaskWriteComponent | ||
from qdrant_client import QdrantClient, models | ||
from qdrant_client.qdrant_fastembed import uuid | ||
|
||
|
||
class IndexQdrantComponent(DaskWriteComponent): | ||
def __init__( | ||
self, | ||
*_, | ||
collection_name: str, | ||
location: Optional[str] = None, | ||
batch_size: int = 64, | ||
parallelism: int = 1, | ||
url: Optional[str] = None, | ||
port: Optional[int] = 6333, | ||
grpc_port: int = 6334, | ||
prefer_grpc: bool = False, | ||
https: Optional[bool] = None, | ||
api_key: Optional[str] = None, | ||
prefix: Optional[str] = None, | ||
timeout: Optional[float] = None, | ||
host: Optional[str] = None, | ||
path: Optional[str] = None, | ||
force_disable_check_same_thread: bool = False, | ||
): | ||
"""Initialize the IndexQdrantComponent with the component parameters.""" | ||
self.client = QdrantClient( | ||
location=location, | ||
url=url, | ||
port=port, | ||
grpc_port=grpc_port, | ||
prefer_grpc=prefer_grpc, | ||
https=https, | ||
api_key=api_key, | ||
prefix=prefix, | ||
timeout=timeout, | ||
host=host, | ||
path=path, | ||
force_disable_check_same_thread=force_disable_check_same_thread, | ||
) | ||
self.collection_name = collection_name | ||
self.batch_size = batch_size | ||
self.parallelism = parallelism | ||
|
||
def write(self, dataframe: dd.DataFrame) -> None: | ||
""" | ||
Writes the data from the given Dask DataFrame to the Qdrant collection. | ||
Args: | ||
dataframe (dd.DataFrame): The Dask DataFrame containing the data to be written. | ||
""" | ||
records: List[models.Record] = [] | ||
for part in dataframe.partitions: | ||
df = part.compute() | ||
for row in df.itertuples(): | ||
payload = { | ||
"id_": str(row.Index), | ||
"passage": row.text_data, | ||
} | ||
id = str(uuid.uuid4()) | ||
# Check if 'text_embedding' attribute is a string. | ||
# If it is, safely evaluate and convert it into a list of floats. | ||
# else (i.e., it is already a list), it is directly assigned. | ||
embedding = ( | ||
ast.literal_eval(row.text_embedding) | ||
if isinstance(row.text_embedding, str) | ||
else row.text_embedding | ||
) | ||
records.append(models.Record(id=id, payload=payload, vector=embedding)) | ||
|
||
self.client.upload_records( | ||
collection_name=self.collection_name, | ||
records=records, | ||
batch_size=self.batch_size, | ||
parallel=self.parallelism, | ||
wait=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pytest==7.4.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import tempfile | ||
import uuid | ||
|
||
import dask.dataframe as dd | ||
|
||
from src.main import IndexQdrantComponent, QdrantClient, models | ||
|
||
|
||
def test_qdrant_write(): | ||
""" | ||
Test case for the write method of the IndexQdrantComponent class. | ||
This test creates a temporary collection using a QdrantClient. | ||
Writes data to it using the write method of the IndexQdrantComponent. | ||
Asserts that the count of entries in the collection is equal to the expected number of entries. | ||
""" | ||
collection_name = uuid.uuid4().hex | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
client = QdrantClient(path=str(tmpdir)) | ||
entries = 100 | ||
|
||
client.create_collection( | ||
collection_name=collection_name, | ||
vectors_config=models.VectorParams(distance=models.Distance.COSINE, size=5), | ||
) | ||
# There cannot be multiple clients accessing the same local persistent storage | ||
# Qdrant server supports multiple concurrent access | ||
del client | ||
|
||
component = IndexQdrantComponent( | ||
collection_name=collection_name, | ||
path=str(tmpdir), | ||
) | ||
|
||
dask_dataframe = dd.DataFrame.from_dict( | ||
{ | ||
"text_data": [ | ||
"Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo", | ||
] | ||
* entries, | ||
"text_embedding": [[0.1, 0.2, 0.3, 0.4, 0.5]] * entries, | ||
}, | ||
npartitions=1, | ||
) | ||
|
||
component.write(dask_dataframe) | ||
del component | ||
|
||
client = QdrantClient(path=str(tmpdir)) | ||
assert client.count(collection_name).count == entries |