Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Qdrant support #646

Merged
merged 9 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions components/index_qdrant/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
FROM --platform=linux/amd64 python:3.8-slim as base

# System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# Install requirements
COPY requirements.txt /
RUN pip3 install --no-cache-dir -r requirements.txt

# Install Fondant
# This is split from other requirements to leverage caching
ARG FONDANT_VERSION=main
RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION}

# Set the working directory to the component folder
WORKDIR /component
COPY src/ src/
ENV PYTHONPATH "${PYTHONPATH}:./src"

FROM base as test
COPY test_requirements.txt .
RUN pip3 install --no-cache-dir -r test_requirements.txt
COPY tests/ tests/
RUN python -m pytest tests

FROM base
WORKDIR /component/src
ENTRYPOINT ["fondant", "execute", "main"]
75 changes: 75 additions & 0 deletions components/index_qdrant/README.md
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is autogenerated by our precommit hooks. You can add any custom information you want to add (like the "important" note) to the description field in the fondant_component.yaml file. It supports markdown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be resolved now.
38acad6

Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Index Qdrant

### Description
A Fondant component to load textual data and embeddings into a Qdrant database. NOTE: A Qdrant collection has to be created in advance with the appropriate configurations. https://qdrant.tech/documentation/concepts/collections/

### Inputs / outputs

**This component consumes:**

- text
- data: string
- embedding: list<item: float>

**This component produces no data.**

### Arguments

The component takes the following arguments to alter its behavior:

| argument | type | description | default |
| -------- | ---- | ----------- | ------- |
| collection_name | str | The name of the Qdrant collection to upsert data into. | / |
| location | str | The location of the Qdrant instance. | / |
| batch_size | int | The batch size to use when uploading points to Qdrant. | 64 |
| parallelism | int | The number of parallel workers to use when uploading points to Qdrant. | 1 |
| url | str | Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'. | / |
| port | int | Port of the REST API interface. | 6333 |
| grpc_port | int | Port of the gRPC interface. | 6334 |
| prefer_grpc | bool | If `true` - use gRPC interface whenever possible in custom methods. | / |
| https | bool | If `true` - use HTTPS(SSL) protocol. | / |
| api_key | str | API key for authentication in Qdrant Cloud. | / |
| prefix | str | If set, add `prefix` to the REST URL path. | / |
| timeout | int | Timeout for API requests. | / |
| host | str | Host name of Qdrant service. If url and host are not set, defaults to 'localhost'. | / |
| path | str | Persistence path for QdrantLocal. Eg. `local_data/qdrant` | / |
| force_disable_check_same_thread | bool | Force disable check_same_thread for QdrantLocal sqlite connection. | / |

### Usage

You can add this component to your pipeline using the following code:

```python
from fondant.pipeline import ComponentOp


index_qdrant_op = ComponentOp.from_registry(
name="index_qdrant",
arguments={
# Add arguments
# "collection_name": ,
# "location": ,
# "batch_size": 64,
# "parallelism": 1,
# "url": ,
# "port": 6333,
# "grpc_port": 6334,
# "prefer_grpc": False,
# "https": False,
# "api_key": ,
# "prefix": ,
# "timeout": 0,
# "host": ,
# "path": ,
# "force_disable_check_same_thread": False,
}
)
pipeline.add_op(index_qdrant_op, dependencies=[...]) #Add previous component as dependency
```

### Testing

You can run the tests using docker with BuildKit. From this directory, run:
```
docker build . --target test
```
81 changes: 81 additions & 0 deletions components/index_qdrant/fondant_component.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
name: Index Qdrant
description: >-
A Fondant component to load textual data and embeddings into a Qdrant database.
NOTE: A Qdrant collection has to be created in advance with the appropriate configurations. https://qdrant.tech/documentation/concepts/collections/

image: 'fndnt/index_qdrant:dev'
tags:
- Data writing
consumes:
text:
fields:
data:
type: string
embedding:
type: array
items:
type: float32
args:
collection_name:
description: The name of the Qdrant collection to upsert data into.
type: str
location:
description: The location of the Qdrant instance.
type: str
default: None
batch_size:
description: The batch size to use when uploading points to Qdrant.
type: int
default: 64
parallelism:
description: The number of parallel workers to use when uploading points to Qdrant.
type: int
default: 1
url:
description: >-
Either host or str of 'Optional[scheme], host, Optional[port],
Optional[prefix]'.
type: str
default: None
port:
description: Port of the REST API interface.
type: int
default: 6333
grpc_port:
description: Port of the gRPC interface.
type: int
default: 6334
prefer_grpc:
description: If `true` - use gRPC interface whenever possible in custom methods.
type: bool
default: False
https:
description: If `true` - use HTTPS(SSL) protocol.
type: bool
default: False
api_key:
description: API key for authentication in Qdrant Cloud.
type: str
default: None
prefix:
description: 'If set, add `prefix` to the REST URL path.'
type: str
default: None
timeout:
description: Timeout for API requests.
type: int
default: None
host:
description: >-
Host name of Qdrant service. If url and host are not set, defaults to
'localhost'.
type: str
default: None
path:
description: Persistence path for QdrantLocal. Eg. `local_data/qdrant`
type: str
default: None
force_disable_check_same_thread:
description: Force disable check_same_thread for QdrantLocal sqlite connection.
type: bool
default: False
1 change: 1 addition & 0 deletions components/index_qdrant/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
qdrant_client==1.6.9
81 changes: 81 additions & 0 deletions components/index_qdrant/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import ast
from typing import List, Optional

import dask.dataframe as dd
from fondant.component import DaskWriteComponent
from qdrant_client import QdrantClient, models
from qdrant_client.qdrant_fastembed import uuid


class IndexQdrantComponent(DaskWriteComponent):
def __init__(
self,
*_,
collection_name: str,
location: Optional[str] = None,
batch_size: int = 64,
parallelism: int = 1,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
force_disable_check_same_thread: bool = False,
):
"""Initialize the IndexQdrantComponent with the component parameters."""
self.client = QdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
force_disable_check_same_thread=force_disable_check_same_thread,
)
self.collection_name = collection_name
self.batch_size = batch_size
self.parallelism = parallelism

def write(self, dataframe: dd.DataFrame) -> None:
"""
Writes the data from the given Dask DataFrame to the Qdrant collection.

Args:
dataframe (dd.DataFrame): The Dask DataFrame containing the data to be written.
"""
records: List[models.Record] = []
for part in dataframe.partitions:
df = part.compute()
for row in df.itertuples():
payload = {
"id_": str(row.Index),
"passage": row.text_data,
}
id = str(uuid.uuid4())
# Check if 'text_embedding' attribute is a string.
# If it is, safely evaluate and convert it into a list of floats.
# else (i.e., it is already a list), it is directly assigned.
embedding = (
ast.literal_eval(row.text_embedding)
if isinstance(row.text_embedding, str)
else row.text_embedding
)
records.append(models.Record(id=id, payload=payload, vector=embedding))

self.client.upload_records(
collection_name=self.collection_name,
records=records,
batch_size=self.batch_size,
parallel=self.parallelism,
wait=True,
)
1 change: 1 addition & 0 deletions components/index_qdrant/test_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.4.2
50 changes: 50 additions & 0 deletions components/index_qdrant/tests/qdrant_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import tempfile
import uuid

import dask.dataframe as dd

from src.main import IndexQdrantComponent, QdrantClient, models


def test_qdrant_write():
"""
Test case for the write method of the IndexQdrantComponent class.

This test creates a temporary collection using a QdrantClient.
Writes data to it using the write method of the IndexQdrantComponent.
Asserts that the count of entries in the collection is equal to the expected number of entries.
"""
collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir:
client = QdrantClient(path=str(tmpdir))
entries = 100

client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(distance=models.Distance.COSINE, size=5),
)
# There cannot be multiple clients accessing the same local persistent storage
# Qdrant server supports multiple concurrent access
del client

component = IndexQdrantComponent(
collection_name=collection_name,
path=str(tmpdir),
)

dask_dataframe = dd.DataFrame.from_dict(
{
"text_data": [
"Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo",
]
* entries,
"text_embedding": [[0.1, 0.2, 0.3, 0.4, 0.5]] * entries,
},
npartitions=1,
)

component.write(dask_dataframe)
del component

client = QdrantClient(path=str(tmpdir))
assert client.count(collection_name).count == entries
Loading