From ea8c6f4c0cc4d81418d94b1dbc5cd203718e6925 Mon Sep 17 00:00:00 2001 From: Shubham Krishna Date: Wed, 3 Jan 2024 14:12:09 +0100 Subject: [PATCH] Add component to index aws opensearch (#740) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR aims to add support for indexing to AWS OpenSearch. --------- Co-authored-by: Shubham Krishna <“shubham.krishna@ml6.eu”> --- components/index_aws_opensearch/Dockerfile | 28 +++++ components/index_aws_opensearch/README.md | 73 +++++++++++++ .../fondant_component.yaml | 44 ++++++++ .../index_aws_opensearch/requirements.txt | 3 + components/index_aws_opensearch/src/main.py | 70 ++++++++++++ .../tests/aws_opensearch_test.py | 101 ++++++++++++++++++ .../index_aws_opensearch/tests/pytest.ini | 2 + .../tests/requirements.txt | 3 + 8 files changed, 324 insertions(+) create mode 100644 components/index_aws_opensearch/Dockerfile create mode 100644 components/index_aws_opensearch/README.md create mode 100644 components/index_aws_opensearch/fondant_component.yaml create mode 100644 components/index_aws_opensearch/requirements.txt create mode 100644 components/index_aws_opensearch/src/main.py create mode 100644 components/index_aws_opensearch/tests/aws_opensearch_test.py create mode 100644 components/index_aws_opensearch/tests/pytest.ini create mode 100644 components/index_aws_opensearch/tests/requirements.txt diff --git a/components/index_aws_opensearch/Dockerfile b/components/index_aws_opensearch/Dockerfile new file mode 100644 index 000000000..02b4c9e16 --- /dev/null +++ b/components/index_aws_opensearch/Dockerfile @@ -0,0 +1,28 @@ +FROM --platform=linux/amd64 python:3.8-slim as base + +# System dependencies +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install git -y + +# Install requirements +COPY requirements.txt / +RUN pip3 install --no-cache-dir -r requirements.txt + +# Install Fondant +# This is split from other requirements to leverage caching +ARG FONDANT_VERSION=main +RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} + +# Set the working directory to the component folder +WORKDIR /component +COPY src/ src/ + +FROM base as test +COPY tests/ tests/ +RUN pip3 install --no-cache-dir -r tests/requirements.txt +RUN python -m pytest tests + +FROM base +WORKDIR /component/src +ENTRYPOINT ["fondant", "execute", "main"] \ No newline at end of file diff --git a/components/index_aws_opensearch/README.md b/components/index_aws_opensearch/README.md new file mode 100644 index 000000000..a594be7d6 --- /dev/null +++ b/components/index_aws_opensearch/README.md @@ -0,0 +1,73 @@ +# Index AWS OpenSearch + +## Description +Component that takes embeddings of text snippets and indexes them into AWS OpenSearch vector database. + +## Inputs / outputs + +### Consumes +**This component consumes:** + +- text: string +- embedding: list + + + + + +### Produces + + +**This component does not produce data.** + +## Arguments + +The component takes the following arguments to alter its behavior: + +| argument | type | description | default | +| -------- | ---- | ----------- | ------- | +| host | str | The Cluster endpoint of the AWS OpenSearch cluster where the embeddings will be indexed. E.g. "my-test-domain.us-east-1.aoss.amazonaws.com" | / | +| region | str | The AWS region where the OpenSearch cluster is located. If not specified, the default region will be used. | / | +| index_name | str | The name of the index in the AWS OpenSearch cluster where the embeddings will be stored. | / | +| index_body | dict | Parameters that specify index settings, mappings, and aliases for newly created index. | / | +| port | int | The port number to connect to the AWS OpenSearch cluster. | 443 | +| use_ssl | bool | A boolean flag indicating whether to use SSL/TLS for the connection to the OpenSearch cluster. | True | +| verify_certs | bool | A boolean flag indicating whether to verify SSL certificates when connecting to the OpenSearch cluster. | True | +| pool_maxsize | int | The maximum size of the connection pool to the AWS OpenSearch cluster. | 20 | + +## Usage + +You can add this component to your pipeline using the following code: + +```python +from fondant.pipeline import Pipeline + + +pipeline = Pipeline(...) + +dataset = pipeline.read(...) + +dataset = dataset.apply(...) + +dataset.write( + "index_aws_opensearch", + arguments={ + # Add arguments + # "host": , + # "region": , + # "index_name": , + # "index_body": {}, + # "port": 443, + # "use_ssl": True, + # "verify_certs": True, + # "pool_maxsize": 20, + }, +) +``` + +## Testing + +You can run the tests using docker with BuildKit. From this directory, run: +``` +docker build . --target test +``` diff --git a/components/index_aws_opensearch/fondant_component.yaml b/components/index_aws_opensearch/fondant_component.yaml new file mode 100644 index 000000000..3b6629f7f --- /dev/null +++ b/components/index_aws_opensearch/fondant_component.yaml @@ -0,0 +1,44 @@ +name: Index AWS OpenSearch +description: Component that takes embeddings of text snippets and indexes them into AWS OpenSearch vector database. +image: fndnt/index_aws_opensearch:dev +tags: + - Data writing + +consumes: + text: + type: string + embedding: + type: array + items: + type: float32 + +args: + host: + description: The Cluster endpoint of the AWS OpenSearch cluster where the embeddings will be indexed. E.g. "my-test-domain.us-east-1.aoss.amazonaws.com" + type: str + region: + description: The AWS region where the OpenSearch cluster is located. If not specified, the default region will be used. + type: str + index_name: + description: The name of the index in the AWS OpenSearch cluster where the embeddings will be stored. + type: str + index_body: + description: Parameters that specify index settings, mappings, and aliases for newly created index. + type: dict + port: + description: The port number to connect to the AWS OpenSearch cluster. + type: int + default: 443 + use_ssl: + description: A boolean flag indicating whether to use SSL/TLS for the connection to the OpenSearch cluster. + type: bool + default: True + verify_certs: + description: A boolean flag indicating whether to verify SSL certificates when connecting to the OpenSearch cluster. + type: bool + default: True + pool_maxsize: + description: The maximum size of the connection pool to the AWS OpenSearch cluster. + type: int + default: 20 + diff --git a/components/index_aws_opensearch/requirements.txt b/components/index_aws_opensearch/requirements.txt new file mode 100644 index 000000000..9da036ec3 --- /dev/null +++ b/components/index_aws_opensearch/requirements.txt @@ -0,0 +1,3 @@ +boto3==1.34.4 +opensearch-py==2.4.2 +tqdm==4.65.0 \ No newline at end of file diff --git a/components/index_aws_opensearch/src/main.py b/components/index_aws_opensearch/src/main.py new file mode 100644 index 000000000..b04820e82 --- /dev/null +++ b/components/index_aws_opensearch/src/main.py @@ -0,0 +1,70 @@ +import logging +from typing import Any, Dict, Optional + +import boto3 +import dask.dataframe as dd +from fondant.component import DaskWriteComponent +from opensearchpy import AWSV4SignerAuth, OpenSearch, RequestsHttpConnection +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +class IndexAWSOpenSearchComponent(DaskWriteComponent): + def __init__( + self, + *, + host: str, + region: str, + index_name: str, + index_body: Dict[str, Any], + port: Optional[int], + use_ssl: Optional[bool], + verify_certs: Optional[bool], + pool_maxsize: Optional[int], + **kwargs, + ): + session = boto3.Session() + credentials = session.get_credentials() + auth = AWSV4SignerAuth(credentials, region) + self.index_name = index_name + self.client = OpenSearch( + hosts=[{"host": host, "port": port}], + http_auth=auth, + use_ssl=use_ssl, + verify_certs=verify_certs, + connection_class=RequestsHttpConnection, + pool_maxsize=pool_maxsize, + **kwargs, + ) + self.create_index(index_body) + + def create_index(self, index_body: Dict[str, Any]): + """Creates an index if not existing in AWS OpenSearch. + + Args: + index_body: Parameters that specify index settings, + mappings, and aliases for newly created index. + """ + if self.client.indices.exists(index=self.index_name): + logger.info(f"Index: {self.index_name} already exists.") + else: + logger.info(f"Creating Index: {self.index_name} with body: {index_body}") + self.client.indices.create(index=self.index_name, body=index_body) + + def write(self, dataframe: dd.DataFrame): + """ + Writes the data from the given Dask DataFrame to AWS OpenSearch Index. + + Args: + dataframe: The Dask DataFrame containing the data to be written. + """ + for part in tqdm(dataframe.partitions): + df = part.compute() + for row in df.itertuples(): + body = {"embedding": row.embedding, "text": row.text} + self.client.index( + index=self.index_name, + id=str(row.Index), + body=body, + ) diff --git a/components/index_aws_opensearch/tests/aws_opensearch_test.py b/components/index_aws_opensearch/tests/aws_opensearch_test.py new file mode 100644 index 000000000..31893a104 --- /dev/null +++ b/components/index_aws_opensearch/tests/aws_opensearch_test.py @@ -0,0 +1,101 @@ +from unittest.mock import call + +import dask.dataframe as dd +import numpy as np +import pandas as pd + +from src.main import IndexAWSOpenSearchComponent + + +class TestIndexAWSOpenSearchComponent: + def setup_method(self): + self.index_name = "pytest-index" + self.host = "vectordb-domain-x.eu-west-1.es.amazonaws.com" + self.region = "eu-west-1" + self.port = 443 + self.index_body = {"settings": {"index": {"number_of_shards": 4}}} + self.use_ssl = True + self.verify_certs = True + self.pool_maxsize = 20 + + def test_create_index(self, mocker): + # Mock boto3.session + mocker.patch("src.main.boto3.Session") + + # Mock OpenSearch + mock_opensearch_instance = mocker.patch("src.main.OpenSearch").return_value + mock_opensearch_instance.indices.exists.return_value = False + + # Create IndexAWSOpenSearchComponent instance + IndexAWSOpenSearchComponent( + host=self.host, + region=self.region, + index_name=self.index_name, + index_body=self.index_body, + port=self.port, + use_ssl=self.use_ssl, + verify_certs=self.verify_certs, + pool_maxsize=self.pool_maxsize, + ) + + # Assert that indices.create was called + mock_opensearch_instance.indices.create.assert_called_once_with( + index=self.index_name, + body=self.index_body, + ) + + def test_write(self, mocker): + # Mock boto3.session + mocker.patch("src.main.boto3.Session") + + # Mock OpenSearch + mock_opensearch_instance = mocker.patch("src.main.OpenSearch").return_value + mock_opensearch_instance.indices.exists.return_value = True + + pandas_df = pd.DataFrame( + [ + ("hello abc", np.array([1.0, 2.0])), + ("hifasioi", np.array([2.0, 3.0])), + ], + columns=["text", "embedding"], + ) + dask_df = dd.from_pandas(pandas_df, npartitions=2) + + # Create IndexAWSOpenSearchComponent instance + component = IndexAWSOpenSearchComponent( + host=self.host, + region=self.region, + index_name=self.index_name, + index_body=self.index_body, + port=self.port, + use_ssl=self.use_ssl, + verify_certs=self.verify_certs, + pool_maxsize=self.pool_maxsize, + ) + + # Call write method + component.write(dask_df) + + # Assert that index was called with the expected arguments + expected_calls = [ + call( + index=self.index_name, + id="0", + body={"embedding": np.array([1.0, 2.0]), "text": "hello abc"}, + ), + call( + index=self.index_name, + id="1", + body={"embedding": np.array([2.0, 3.0]), "text": "hifasioi"}, + ), + ] + + actual_calls = mock_opensearch_instance.index.call_args_list + for expected, actual in zip(expected_calls, actual_calls): + assert expected[2]["index"] == actual[1]["index"] + assert expected[2]["id"] == actual[1]["id"] + assert np.array_equal( + expected[2]["body"]["embedding"], + actual[1]["body"]["embedding"], + ) + assert expected[2]["body"]["text"] == actual[1]["body"]["text"] diff --git a/components/index_aws_opensearch/tests/pytest.ini b/components/index_aws_opensearch/tests/pytest.ini new file mode 100644 index 000000000..bf6a8a517 --- /dev/null +++ b/components/index_aws_opensearch/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = ../src \ No newline at end of file diff --git a/components/index_aws_opensearch/tests/requirements.txt b/components/index_aws_opensearch/tests/requirements.txt new file mode 100644 index 000000000..5fabbf257 --- /dev/null +++ b/components/index_aws_opensearch/tests/requirements.txt @@ -0,0 +1,3 @@ +pytest==7.4.2 +pytest-mock==3.12.0 +pandas==2.0.3 \ No newline at end of file