diff --git a/components/index_qdrant/Dockerfile b/components/index_qdrant/Dockerfile new file mode 100644 index 000000000..35e7cc91f --- /dev/null +++ b/components/index_qdrant/Dockerfile @@ -0,0 +1,30 @@ +FROM --platform=linux/amd64 python:3.8-slim as base + +# System dependencies +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install git -y + +# Install requirements +COPY requirements.txt / +RUN pip3 install --no-cache-dir -r requirements.txt + +# Install Fondant +# This is split from other requirements to leverage caching +ARG FONDANT_VERSION=main +RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} + +# Set the working directory to the component folder +WORKDIR /component +COPY src/ src/ +ENV PYTHONPATH "${PYTHONPATH}:./src" + +FROM base as test +COPY test_requirements.txt . +RUN pip3 install --no-cache-dir -r test_requirements.txt +COPY tests/ tests/ +RUN python -m pytest tests + +FROM base +WORKDIR /component/src +ENTRYPOINT ["fondant", "execute", "main"] \ No newline at end of file diff --git a/components/index_qdrant/README.md b/components/index_qdrant/README.md new file mode 100644 index 000000000..0335f6835 --- /dev/null +++ b/components/index_qdrant/README.md @@ -0,0 +1,75 @@ +# Index Qdrant + +### Description +A Fondant component to load textual data and embeddings into a Qdrant database. NOTE: A Qdrant collection has to be created in advance with the appropriate configurations. https://qdrant.tech/documentation/concepts/collections/ + +### Inputs / outputs + +**This component consumes:** + +- text + - data: string + - embedding: list + +**This component produces no data.** + +### Arguments + +The component takes the following arguments to alter its behavior: + +| argument | type | description | default | +| -------- | ---- | ----------- | ------- | +| collection_name | str | The name of the Qdrant collection to upsert data into. | / | +| location | str | The location of the Qdrant instance. | / | +| batch_size | int | The batch size to use when uploading points to Qdrant. | 64 | +| parallelism | int | The number of parallel workers to use when uploading points to Qdrant. | 1 | +| url | str | Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'. | / | +| port | int | Port of the REST API interface. | 6333 | +| grpc_port | int | Port of the gRPC interface. | 6334 | +| prefer_grpc | bool | If `true` - use gRPC interface whenever possible in custom methods. | / | +| https | bool | If `true` - use HTTPS(SSL) protocol. | / | +| api_key | str | API key for authentication in Qdrant Cloud. | / | +| prefix | str | If set, add `prefix` to the REST URL path. | / | +| timeout | int | Timeout for API requests. | / | +| host | str | Host name of Qdrant service. If url and host are not set, defaults to 'localhost'. | / | +| path | str | Persistence path for QdrantLocal. Eg. `local_data/qdrant` | / | +| force_disable_check_same_thread | bool | Force disable check_same_thread for QdrantLocal sqlite connection. | / | + +### Usage + +You can add this component to your pipeline using the following code: + +```python +from fondant.pipeline import ComponentOp + + +index_qdrant_op = ComponentOp.from_registry( + name="index_qdrant", + arguments={ + # Add arguments + # "collection_name": , + # "location": , + # "batch_size": 64, + # "parallelism": 1, + # "url": , + # "port": 6333, + # "grpc_port": 6334, + # "prefer_grpc": False, + # "https": False, + # "api_key": , + # "prefix": , + # "timeout": 0, + # "host": , + # "path": , + # "force_disable_check_same_thread": False, + } +) +pipeline.add_op(index_qdrant_op, dependencies=[...]) #Add previous component as dependency +``` + +### Testing + +You can run the tests using docker with BuildKit. From this directory, run: +``` +docker build . --target test +``` diff --git a/components/index_qdrant/fondant_component.yaml b/components/index_qdrant/fondant_component.yaml new file mode 100644 index 000000000..6feb3b257 --- /dev/null +++ b/components/index_qdrant/fondant_component.yaml @@ -0,0 +1,81 @@ +name: Index Qdrant +description: >- + A Fondant component to load textual data and embeddings into a Qdrant database. + NOTE: A Qdrant collection has to be created in advance with the appropriate configurations. https://qdrant.tech/documentation/concepts/collections/ + +image: 'fndnt/index_qdrant:dev' +tags: + - Data writing +consumes: + text: + fields: + data: + type: string + embedding: + type: array + items: + type: float32 +args: + collection_name: + description: The name of the Qdrant collection to upsert data into. + type: str + location: + description: The location of the Qdrant instance. + type: str + default: None + batch_size: + description: The batch size to use when uploading points to Qdrant. + type: int + default: 64 + parallelism: + description: The number of parallel workers to use when uploading points to Qdrant. + type: int + default: 1 + url: + description: >- + Either host or str of 'Optional[scheme], host, Optional[port], + Optional[prefix]'. + type: str + default: None + port: + description: Port of the REST API interface. + type: int + default: 6333 + grpc_port: + description: Port of the gRPC interface. + type: int + default: 6334 + prefer_grpc: + description: If `true` - use gRPC interface whenever possible in custom methods. + type: bool + default: False + https: + description: If `true` - use HTTPS(SSL) protocol. + type: bool + default: False + api_key: + description: API key for authentication in Qdrant Cloud. + type: str + default: None + prefix: + description: 'If set, add `prefix` to the REST URL path.' + type: str + default: None + timeout: + description: Timeout for API requests. + type: int + default: None + host: + description: >- + Host name of Qdrant service. If url and host are not set, defaults to + 'localhost'. + type: str + default: None + path: + description: Persistence path for QdrantLocal. Eg. `local_data/qdrant` + type: str + default: None + force_disable_check_same_thread: + description: Force disable check_same_thread for QdrantLocal sqlite connection. + type: bool + default: False \ No newline at end of file diff --git a/components/index_qdrant/requirements.txt b/components/index_qdrant/requirements.txt new file mode 100644 index 000000000..05fcad6b2 --- /dev/null +++ b/components/index_qdrant/requirements.txt @@ -0,0 +1 @@ +qdrant_client==1.6.9 \ No newline at end of file diff --git a/components/index_qdrant/src/main.py b/components/index_qdrant/src/main.py new file mode 100644 index 000000000..06b68426d --- /dev/null +++ b/components/index_qdrant/src/main.py @@ -0,0 +1,81 @@ +import ast +from typing import List, Optional + +import dask.dataframe as dd +from fondant.component import DaskWriteComponent +from qdrant_client import QdrantClient, models +from qdrant_client.qdrant_fastembed import uuid + + +class IndexQdrantComponent(DaskWriteComponent): + def __init__( + self, + *_, + collection_name: str, + location: Optional[str] = None, + batch_size: int = 64, + parallelism: int = 1, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + path: Optional[str] = None, + force_disable_check_same_thread: bool = False, + ): + """Initialize the IndexQdrantComponent with the component parameters.""" + self.client = QdrantClient( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + force_disable_check_same_thread=force_disable_check_same_thread, + ) + self.collection_name = collection_name + self.batch_size = batch_size + self.parallelism = parallelism + + def write(self, dataframe: dd.DataFrame) -> None: + """ + Writes the data from the given Dask DataFrame to the Qdrant collection. + + Args: + dataframe (dd.DataFrame): The Dask DataFrame containing the data to be written. + """ + records: List[models.Record] = [] + for part in dataframe.partitions: + df = part.compute() + for row in df.itertuples(): + payload = { + "id_": str(row.Index), + "passage": row.text_data, + } + id = str(uuid.uuid4()) + # Check if 'text_embedding' attribute is a string. + # If it is, safely evaluate and convert it into a list of floats. + # else (i.e., it is already a list), it is directly assigned. + embedding = ( + ast.literal_eval(row.text_embedding) + if isinstance(row.text_embedding, str) + else row.text_embedding + ) + records.append(models.Record(id=id, payload=payload, vector=embedding)) + + self.client.upload_records( + collection_name=self.collection_name, + records=records, + batch_size=self.batch_size, + parallel=self.parallelism, + wait=True, + ) diff --git a/components/index_qdrant/test_requirements.txt b/components/index_qdrant/test_requirements.txt new file mode 100644 index 000000000..2a929edcc --- /dev/null +++ b/components/index_qdrant/test_requirements.txt @@ -0,0 +1 @@ +pytest==7.4.2 diff --git a/components/index_qdrant/tests/qdrant_test.py b/components/index_qdrant/tests/qdrant_test.py new file mode 100644 index 000000000..835c7de01 --- /dev/null +++ b/components/index_qdrant/tests/qdrant_test.py @@ -0,0 +1,50 @@ +import tempfile +import uuid + +import dask.dataframe as dd + +from src.main import IndexQdrantComponent, QdrantClient, models + + +def test_qdrant_write(): + """ + Test case for the write method of the IndexQdrantComponent class. + + This test creates a temporary collection using a QdrantClient. + Writes data to it using the write method of the IndexQdrantComponent. + Asserts that the count of entries in the collection is equal to the expected number of entries. + """ + collection_name = uuid.uuid4().hex + with tempfile.TemporaryDirectory() as tmpdir: + client = QdrantClient(path=str(tmpdir)) + entries = 100 + + client.create_collection( + collection_name=collection_name, + vectors_config=models.VectorParams(distance=models.Distance.COSINE, size=5), + ) + # There cannot be multiple clients accessing the same local persistent storage + # Qdrant server supports multiple concurrent access + del client + + component = IndexQdrantComponent( + collection_name=collection_name, + path=str(tmpdir), + ) + + dask_dataframe = dd.DataFrame.from_dict( + { + "text_data": [ + "Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo", + ] + * entries, + "text_embedding": [[0.1, 0.2, 0.3, 0.4, 0.5]] * entries, + }, + npartitions=1, + ) + + component.write(dask_dataframe) + del component + + client = QdrantClient(path=str(tmpdir)) + assert client.count(collection_name).count == entries