-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add image retrieval from FAISS index (#876)
The component expects a dataset in the form of: - A Dask dataset containing id mapping (image id to image url) - The faiss index itself The component will load the dataset from the remote storage into the component, initialise a CLIP model, and retrieve similar images based on given prompts.
- Loading branch information
Showing
9 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
src/fondant/components/retrieve_from_faiss_by_prompt/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime as base | ||
|
||
# System dependencies | ||
RUN apt-get update && \ | ||
apt-get upgrade -y && \ | ||
apt-get install git -y | ||
|
||
# Install requirements | ||
COPY requirements.txt ./ | ||
RUN pip3 install --no-cache-dir -r requirements.txt | ||
|
||
# Install Fondant | ||
# This is split from other requirements to leverage caching | ||
ARG FONDANT_VERSION=main | ||
RUN pip3 install fondant[component,aws,azure,gcp,gpu]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} | ||
|
||
# Set the working directory to the component folder | ||
WORKDIR /component | ||
COPY src/ src/ | ||
|
||
FROM base as test | ||
COPY tests/ tests/ | ||
RUN pip3 install --no-cache-dir -r tests/requirements.txt | ||
RUN python -m pytest tests | ||
|
||
FROM base | ||
WORKDIR /component/src | ||
ENTRYPOINT ["fondant", "execute", "main"] |
75 changes: 75 additions & 0 deletions
75
src/fondant/components/retrieve_from_faiss_by_prompt/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Retrieve images from FAISS index | ||
|
||
<a id="retrieve_from_faiss_by_prompt#description"></a> | ||
## Description | ||
Retrieve images from a Faiss index. The component should reference a Faiss image dataset, | ||
which includes both the Faiss index and a dataset of image URLs. The input dataset consists | ||
of a list of prompts. These prompts will be embedded using a CLIP model, and similar | ||
images will be retrieved from the index. | ||
|
||
|
||
<a id="retrieve_from_faiss_by_prompt#inputs_outputs"></a> | ||
## Inputs / outputs | ||
|
||
<a id="retrieve_from_faiss_by_prompt#consumes"></a> | ||
### Consumes | ||
**This component consumes:** | ||
|
||
- prompt: string | ||
|
||
|
||
|
||
|
||
<a id="retrieve_from_faiss_by_prompt#produces"></a> | ||
### Produces | ||
**This component produces:** | ||
|
||
- image_url: string | ||
- prompt: string | ||
|
||
|
||
|
||
<a id="retrieve_from_faiss_by_prompt#arguments"></a> | ||
## Arguments | ||
|
||
The component takes the following arguments to alter its behavior: | ||
|
||
| argument | type | description | default | | ||
| -------- | ---- | ----------- | ------- | | ||
| url_mapping_path | str | Url of the image mapping dataset | / | | ||
| faiss_index_path | str | Url of the dataset | / | | ||
| clip_model | str | Clip model name to use for the retrieval | laion/CLIP-ViT-B-32-laion2B-s34B-b79K | | ||
| num_images | int | Number of images that will be retrieved for each prompt | 2 | | ||
|
||
<a id="retrieve_from_faiss_by_prompt#usage"></a> | ||
## Usage | ||
|
||
You can add this component to your pipeline using the following code: | ||
|
||
```python | ||
from fondant.pipeline import Pipeline | ||
|
||
|
||
pipeline = Pipeline(...) | ||
|
||
dataset = pipeline.read(...) | ||
|
||
dataset = dataset.apply( | ||
"retrieve_from_faiss_by_prompt", | ||
arguments={ | ||
# Add arguments | ||
# "url_mapping_path": , | ||
# "faiss_index_path": , | ||
# "clip_model": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", | ||
# "num_images": 2, | ||
}, | ||
) | ||
``` | ||
|
||
<a id="retrieve_from_faiss_by_prompt#testing"></a> | ||
## Testing | ||
|
||
You can run the tests using docker with BuildKit. From this directory, run: | ||
``` | ||
docker build . --target test | ||
``` |
38 changes: 38 additions & 0 deletions
38
src/fondant/components/retrieve_from_faiss_by_prompt/fondant_component.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: Retrieve images from FAISS index | ||
description: | | ||
Retrieve images from a Faiss index. The component should reference a Faiss image dataset, | ||
which includes both the Faiss index and a dataset of image URLs. The input dataset consists | ||
of a list of prompts. These prompts will be embedded using a CLIP model, and similar | ||
images will be retrieved from the index. | ||
image: fndnt/retrieve_from_faiss_by_prompt:dev | ||
tags: | ||
- Data retrieval | ||
|
||
consumes: | ||
prompt: | ||
type: string | ||
|
||
produces: | ||
image_url: | ||
type: string | ||
prompt: | ||
type: string | ||
|
||
previous_index: prompt_id | ||
|
||
args: | ||
url_mapping_path: | ||
description: Url of the image mapping dataset | ||
type: str | ||
faiss_index_path: | ||
description: Url of the dataset | ||
type: str | ||
clip_model: | ||
description: Clip model name to use for the retrieval | ||
type: str | ||
default: laion/CLIP-ViT-B-32-laion2B-s34B-b79K | ||
num_images: | ||
description: Number of images that will be retrieved for each prompt | ||
type: int | ||
default: 2 |
4 changes: 4 additions & 0 deletions
4
src/fondant/components/retrieve_from_faiss_by_prompt/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
transformers==4.38.1 | ||
torch==2.2.1 | ||
faiss-cpu==1.7.4 | ||
huggingface_hub==0.21.3 |
Empty file.
106 changes: 106 additions & 0 deletions
106
src/fondant/components/retrieve_from_faiss_by_prompt/src/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import logging | ||
import os | ||
import typing as t | ||
|
||
import dask.dataframe as dd | ||
import faiss | ||
import fsspec | ||
import pandas as pd | ||
import torch | ||
from dask.distributed import Client, get_worker | ||
from dask_cuda import LocalCUDACluster | ||
from fondant.component import PandasTransformComponent | ||
from transformers import AutoTokenizer, CLIPTextModelWithProjection | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class RetrieveFromFaissByPrompt(PandasTransformComponent): | ||
"""Retrieve images from a faiss index using CLIP embeddings.""" | ||
|
||
def __init__( # PLR0913 | ||
self, | ||
url_mapping_path: str, | ||
faiss_index_path: str, | ||
clip_model: str = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", | ||
num_images: int = 2, | ||
): | ||
self.model_id = clip_model | ||
self.number_of_images = num_images | ||
|
||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
# Download faiss index to local machine | ||
if not os.path.exists("faiss_index"): | ||
logger.info(f"Downloading faiss index from {faiss_index_path}") | ||
with fsspec.open(faiss_index_path, "rb") as f: | ||
file_contents = f.read() | ||
|
||
with open("faiss_index", "wb") as out: | ||
out.write(file_contents) | ||
|
||
dataset = dd.read_parquet(url_mapping_path) | ||
if "url" not in dataset.columns: | ||
msg = "Dataset does not contain column 'url'" | ||
raise ValueError(msg) | ||
self.image_urls = dataset["url"].compute().to_list() | ||
|
||
def setup(self) -> Client: | ||
"""Setup LocalCudaCluster if gpu is available.""" | ||
if self.device == "cuda": | ||
cluster = LocalCUDACluster() | ||
return Client(cluster) | ||
|
||
return super().setup() | ||
|
||
def embed_prompt(self, prompt: str): | ||
"""Embed prompt using CLIP model.""" | ||
worker = get_worker() | ||
if worker and hasattr(worker, "model"): | ||
tokenizer = worker.tokenizer | ||
model = worker.model | ||
|
||
else: | ||
logger.info("Initializing model '%s' on worker '%s", self.model_id, worker) | ||
tokenizer = AutoTokenizer.from_pretrained(self.model_id) | ||
model = CLIPTextModelWithProjection.from_pretrained(self.model_id).to( | ||
self.device, | ||
) | ||
|
||
worker.model = model | ||
worker.tokenizer = tokenizer | ||
|
||
inputs = tokenizer([prompt], padding=True, return_tensors="pt") | ||
inputs = inputs.to(self.device) | ||
outputs = model(**inputs) | ||
return outputs.text_embeds.cpu().detach().numpy().astype("float64") | ||
|
||
def retrieve_from_index( | ||
self, | ||
query: float, | ||
number_of_images: int = 2, | ||
) -> t.List[str]: | ||
"""Retrieve images from faiss index.""" | ||
search_index = faiss.read_index("faiss_index") | ||
_, indices = search_index.search(query, number_of_images) | ||
return indices.tolist()[0] | ||
|
||
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: | ||
"""Transform partition of dataframe.""" | ||
results = [] | ||
prompts = dataframe["prompt"] | ||
|
||
for prompt in prompts: | ||
query = self.embed_prompt(prompt) | ||
indices = self.retrieve_from_index(query, self.number_of_images) | ||
for i, idx in enumerate(indices): | ||
url = self.image_urls[idx] | ||
row_to_add = (f"{prompt}_{i}", prompt, url) | ||
results.append(row_to_add) | ||
|
||
results_df = pd.DataFrame( | ||
results, | ||
columns=["id", "prompt", "image_url"], | ||
) | ||
results_df = results_df.set_index("id") | ||
return results_df |
2 changes: 2 additions & 0 deletions
2
src/fondant/components/retrieve_from_faiss_by_prompt/tests/pytest.ini
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[pytest] | ||
pythonpath = ../src |
1 change: 1 addition & 0 deletions
1
src/fondant/components/retrieve_from_faiss_by_prompt/tests/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pytest==7.4.2 |
28 changes: 28 additions & 0 deletions
28
src/fondant/components/retrieve_from_faiss_by_prompt/tests/test_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import pandas as pd | ||
|
||
from src.main import RetrieveFromFaissByPrompt | ||
|
||
|
||
def test_component(): | ||
input_dataframe = pd.DataFrame.from_dict( | ||
{ | ||
"id": [1, 2], | ||
"prompt": ["country style kitchen", "cozy living room"], | ||
}, | ||
) | ||
|
||
input_dataframe = input_dataframe.set_index("id") | ||
input_dataframe["prompt"] = input_dataframe["prompt"].astype(str) | ||
|
||
# Run component | ||
component = RetrieveFromFaissByPrompt( | ||
url_mapping_path="gs://soy-audio-379412-embed-datacomp/12M/id_mapping", | ||
faiss_index_path="gs://soy-audio-379412-embed-datacomp/12M/faiss", | ||
) | ||
|
||
component.setup() | ||
output_dataframe = component.transform(input_dataframe) | ||
assert output_dataframe.columns.tolist() == [ | ||
"prompt", | ||
"image_url", | ||
] |