-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add component to index aws opensearch (#740)
This PR aims to add support for indexing to AWS OpenSearch. --------- Co-authored-by: Shubham Krishna <“[email protected]”>
- Loading branch information
Showing
8 changed files
with
324 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
FROM --platform=linux/amd64 python:3.8-slim as base | ||
|
||
# System dependencies | ||
RUN apt-get update && \ | ||
apt-get upgrade -y && \ | ||
apt-get install git -y | ||
|
||
# Install requirements | ||
COPY requirements.txt / | ||
RUN pip3 install --no-cache-dir -r requirements.txt | ||
|
||
# Install Fondant | ||
# This is split from other requirements to leverage caching | ||
ARG FONDANT_VERSION=main | ||
RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} | ||
|
||
# Set the working directory to the component folder | ||
WORKDIR /component | ||
COPY src/ src/ | ||
|
||
FROM base as test | ||
COPY tests/ tests/ | ||
RUN pip3 install --no-cache-dir -r tests/requirements.txt | ||
RUN python -m pytest tests | ||
|
||
FROM base | ||
WORKDIR /component/src | ||
ENTRYPOINT ["fondant", "execute", "main"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Index AWS OpenSearch | ||
|
||
## Description | ||
Component that takes embeddings of text snippets and indexes them into AWS OpenSearch vector database. | ||
|
||
## Inputs / outputs | ||
|
||
### Consumes | ||
**This component consumes:** | ||
|
||
- text: string | ||
- embedding: list<item: float> | ||
|
||
|
||
|
||
|
||
|
||
### Produces | ||
|
||
|
||
**This component does not produce data.** | ||
|
||
## Arguments | ||
|
||
The component takes the following arguments to alter its behavior: | ||
|
||
| argument | type | description | default | | ||
| -------- | ---- | ----------- | ------- | | ||
| host | str | The Cluster endpoint of the AWS OpenSearch cluster where the embeddings will be indexed. E.g. "my-test-domain.us-east-1.aoss.amazonaws.com" | / | | ||
| region | str | The AWS region where the OpenSearch cluster is located. If not specified, the default region will be used. | / | | ||
| index_name | str | The name of the index in the AWS OpenSearch cluster where the embeddings will be stored. | / | | ||
| index_body | dict | Parameters that specify index settings, mappings, and aliases for newly created index. | / | | ||
| port | int | The port number to connect to the AWS OpenSearch cluster. | 443 | | ||
| use_ssl | bool | A boolean flag indicating whether to use SSL/TLS for the connection to the OpenSearch cluster. | True | | ||
| verify_certs | bool | A boolean flag indicating whether to verify SSL certificates when connecting to the OpenSearch cluster. | True | | ||
| pool_maxsize | int | The maximum size of the connection pool to the AWS OpenSearch cluster. | 20 | | ||
|
||
## Usage | ||
|
||
You can add this component to your pipeline using the following code: | ||
|
||
```python | ||
from fondant.pipeline import Pipeline | ||
|
||
|
||
pipeline = Pipeline(...) | ||
|
||
dataset = pipeline.read(...) | ||
|
||
dataset = dataset.apply(...) | ||
|
||
dataset.write( | ||
"index_aws_opensearch", | ||
arguments={ | ||
# Add arguments | ||
# "host": , | ||
# "region": , | ||
# "index_name": , | ||
# "index_body": {}, | ||
# "port": 443, | ||
# "use_ssl": True, | ||
# "verify_certs": True, | ||
# "pool_maxsize": 20, | ||
}, | ||
) | ||
``` | ||
|
||
## Testing | ||
|
||
You can run the tests using docker with BuildKit. From this directory, run: | ||
``` | ||
docker build . --target test | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
name: Index AWS OpenSearch | ||
description: Component that takes embeddings of text snippets and indexes them into AWS OpenSearch vector database. | ||
image: fndnt/index_aws_opensearch:dev | ||
tags: | ||
- Data writing | ||
|
||
consumes: | ||
text: | ||
type: string | ||
embedding: | ||
type: array | ||
items: | ||
type: float32 | ||
|
||
args: | ||
host: | ||
description: The Cluster endpoint of the AWS OpenSearch cluster where the embeddings will be indexed. E.g. "my-test-domain.us-east-1.aoss.amazonaws.com" | ||
type: str | ||
region: | ||
description: The AWS region where the OpenSearch cluster is located. If not specified, the default region will be used. | ||
type: str | ||
index_name: | ||
description: The name of the index in the AWS OpenSearch cluster where the embeddings will be stored. | ||
type: str | ||
index_body: | ||
description: Parameters that specify index settings, mappings, and aliases for newly created index. | ||
type: dict | ||
port: | ||
description: The port number to connect to the AWS OpenSearch cluster. | ||
type: int | ||
default: 443 | ||
use_ssl: | ||
description: A boolean flag indicating whether to use SSL/TLS for the connection to the OpenSearch cluster. | ||
type: bool | ||
default: True | ||
verify_certs: | ||
description: A boolean flag indicating whether to verify SSL certificates when connecting to the OpenSearch cluster. | ||
type: bool | ||
default: True | ||
pool_maxsize: | ||
description: The maximum size of the connection pool to the AWS OpenSearch cluster. | ||
type: int | ||
default: 20 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
boto3==1.34.4 | ||
opensearch-py==2.4.2 | ||
tqdm==4.65.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import logging | ||
from typing import Any, Dict, Optional | ||
|
||
import boto3 | ||
import dask.dataframe as dd | ||
from fondant.component import DaskWriteComponent | ||
from opensearchpy import AWSV4SignerAuth, OpenSearch, RequestsHttpConnection | ||
from tqdm import tqdm | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class IndexAWSOpenSearchComponent(DaskWriteComponent): | ||
def __init__( | ||
self, | ||
*, | ||
host: str, | ||
region: str, | ||
index_name: str, | ||
index_body: Dict[str, Any], | ||
port: Optional[int], | ||
use_ssl: Optional[bool], | ||
verify_certs: Optional[bool], | ||
pool_maxsize: Optional[int], | ||
**kwargs, | ||
): | ||
session = boto3.Session() | ||
credentials = session.get_credentials() | ||
auth = AWSV4SignerAuth(credentials, region) | ||
self.index_name = index_name | ||
self.client = OpenSearch( | ||
hosts=[{"host": host, "port": port}], | ||
http_auth=auth, | ||
use_ssl=use_ssl, | ||
verify_certs=verify_certs, | ||
connection_class=RequestsHttpConnection, | ||
pool_maxsize=pool_maxsize, | ||
**kwargs, | ||
) | ||
self.create_index(index_body) | ||
|
||
def create_index(self, index_body: Dict[str, Any]): | ||
"""Creates an index if not existing in AWS OpenSearch. | ||
Args: | ||
index_body: Parameters that specify index settings, | ||
mappings, and aliases for newly created index. | ||
""" | ||
if self.client.indices.exists(index=self.index_name): | ||
logger.info(f"Index: {self.index_name} already exists.") | ||
else: | ||
logger.info(f"Creating Index: {self.index_name} with body: {index_body}") | ||
self.client.indices.create(index=self.index_name, body=index_body) | ||
|
||
def write(self, dataframe: dd.DataFrame): | ||
""" | ||
Writes the data from the given Dask DataFrame to AWS OpenSearch Index. | ||
Args: | ||
dataframe: The Dask DataFrame containing the data to be written. | ||
""" | ||
for part in tqdm(dataframe.partitions): | ||
df = part.compute() | ||
for row in df.itertuples(): | ||
body = {"embedding": row.embedding, "text": row.text} | ||
self.client.index( | ||
index=self.index_name, | ||
id=str(row.Index), | ||
body=body, | ||
) |
101 changes: 101 additions & 0 deletions
101
components/index_aws_opensearch/tests/aws_opensearch_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from unittest.mock import call | ||
|
||
import dask.dataframe as dd | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from src.main import IndexAWSOpenSearchComponent | ||
|
||
|
||
class TestIndexAWSOpenSearchComponent: | ||
def setup_method(self): | ||
self.index_name = "pytest-index" | ||
self.host = "vectordb-domain-x.eu-west-1.es.amazonaws.com" | ||
self.region = "eu-west-1" | ||
self.port = 443 | ||
self.index_body = {"settings": {"index": {"number_of_shards": 4}}} | ||
self.use_ssl = True | ||
self.verify_certs = True | ||
self.pool_maxsize = 20 | ||
|
||
def test_create_index(self, mocker): | ||
# Mock boto3.session | ||
mocker.patch("src.main.boto3.Session") | ||
|
||
# Mock OpenSearch | ||
mock_opensearch_instance = mocker.patch("src.main.OpenSearch").return_value | ||
mock_opensearch_instance.indices.exists.return_value = False | ||
|
||
# Create IndexAWSOpenSearchComponent instance | ||
IndexAWSOpenSearchComponent( | ||
host=self.host, | ||
region=self.region, | ||
index_name=self.index_name, | ||
index_body=self.index_body, | ||
port=self.port, | ||
use_ssl=self.use_ssl, | ||
verify_certs=self.verify_certs, | ||
pool_maxsize=self.pool_maxsize, | ||
) | ||
|
||
# Assert that indices.create was called | ||
mock_opensearch_instance.indices.create.assert_called_once_with( | ||
index=self.index_name, | ||
body=self.index_body, | ||
) | ||
|
||
def test_write(self, mocker): | ||
# Mock boto3.session | ||
mocker.patch("src.main.boto3.Session") | ||
|
||
# Mock OpenSearch | ||
mock_opensearch_instance = mocker.patch("src.main.OpenSearch").return_value | ||
mock_opensearch_instance.indices.exists.return_value = True | ||
|
||
pandas_df = pd.DataFrame( | ||
[ | ||
("hello abc", np.array([1.0, 2.0])), | ||
("hifasioi", np.array([2.0, 3.0])), | ||
], | ||
columns=["text", "embedding"], | ||
) | ||
dask_df = dd.from_pandas(pandas_df, npartitions=2) | ||
|
||
# Create IndexAWSOpenSearchComponent instance | ||
component = IndexAWSOpenSearchComponent( | ||
host=self.host, | ||
region=self.region, | ||
index_name=self.index_name, | ||
index_body=self.index_body, | ||
port=self.port, | ||
use_ssl=self.use_ssl, | ||
verify_certs=self.verify_certs, | ||
pool_maxsize=self.pool_maxsize, | ||
) | ||
|
||
# Call write method | ||
component.write(dask_df) | ||
|
||
# Assert that index was called with the expected arguments | ||
expected_calls = [ | ||
call( | ||
index=self.index_name, | ||
id="0", | ||
body={"embedding": np.array([1.0, 2.0]), "text": "hello abc"}, | ||
), | ||
call( | ||
index=self.index_name, | ||
id="1", | ||
body={"embedding": np.array([2.0, 3.0]), "text": "hifasioi"}, | ||
), | ||
] | ||
|
||
actual_calls = mock_opensearch_instance.index.call_args_list | ||
for expected, actual in zip(expected_calls, actual_calls): | ||
assert expected[2]["index"] == actual[1]["index"] | ||
assert expected[2]["id"] == actual[1]["id"] | ||
assert np.array_equal( | ||
expected[2]["body"]["embedding"], | ||
actual[1]["body"]["embedding"], | ||
) | ||
assert expected[2]["body"]["text"] == actual[1]["body"]["text"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[pytest] | ||
pythonpath = ../src |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
pytest==7.4.2 | ||
pytest-mock==3.12.0 | ||
pandas==2.0.3 |