Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image retrieval from FAISS index #876

Merged
merged 18 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/fondant/components/retrieve_images_from_faiss_index/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
FROM --platform=linux/amd64 python:3.10-slim as base

# System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# Install requirements
COPY requirements.txt /
RUN pip3 install --no-cache-dir -r requirements.txt

# Install Fondant
# This is split from other requirements to leverage caching
ARG FONDANT_VERSION=main
RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION}

# Set the working directory to the component folder
WORKDIR /component
COPY src/ src/

FROM base as test
COPY tests/ tests/
RUN pip3 install --no-cache-dir -r tests/requirements.txt
RUN python -m pytest tests

FROM base
WORKDIR /component/src
ENTRYPOINT ["fondant", "execute", "main"]
77 changes: 77 additions & 0 deletions src/fondant/components/retrieve_images_from_faiss_index/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Retrieve images from FAISS index

<a id="retrieve_images_from_faiss_index#description"></a>
## Description
Retrieve images from a Faiss index. The component should reference a Faiss image dataset,
which includes both the Faiss index and a dataset of image URLs. The input dataset consists
of a list of prompts. These prompts will be embedded using a CLIP model, and similar
images will be retrieved from the index.


<a id="retrieve_images_from_faiss_index#inputs_outputs"></a>
## Inputs / outputs

<a id="retrieve_images_from_faiss_index#consumes"></a>
### Consumes
**This component consumes:**

- prompt: string




<a id="retrieve_images_from_faiss_index#produces"></a>
### Produces
**This component produces:**

- image_url: string
- prompt_id: string



<a id="retrieve_images_from_faiss_index#arguments"></a>
## Arguments

The component takes the following arguments to alter its behavior:

| argument | type | description | default |
| -------- | ---- | ----------- | ------- |
| dataset_url | str | Url of the dataset | / |
| faiss_index_path | str | Url of the dataset | / |
| image_index_column_name | str | Name of the column in the dataset that contains the image index | / |
| clip_model | str | Clip model name to use for the retrieval | laion/CLIP-ViT-B-32-laion2B-s34B-b79K |
| num_images | int | Number of images that will be retrieved for each prompt | 2 |

<a id="retrieve_images_from_faiss_index#usage"></a>
## Usage

You can add this component to your pipeline using the following code:

```python
from fondant.pipeline import Pipeline


pipeline = Pipeline(...)

dataset = pipeline.read(...)

dataset = dataset.apply(
"retrieve_images_from_faiss_index",
arguments={
# Add arguments
# "dataset_url": ,
# "faiss_index_path": ,
# "image_index_column_name": ,
# "clip_model": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
# "num_images": 2,
},
)
```

<a id="retrieve_images_from_faiss_index#testing"></a>
## Testing

You can run the tests using docker with BuildKit. From this directory, run:
```
docker build . --target test
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Retrieve images from FAISS index
description: |
Retrieve images from a Faiss index. The component should reference a Faiss image dataset,
which includes both the Faiss index and a dataset of image URLs. The input dataset consists
of a list of prompts. These prompts will be embedded using a CLIP model, and similar
images will be retrieved from the index.

image: fndnt/retrieve_images_from_faiss_index:dev
tags:
- Data retrieval

consumes:
prompt:
type: string

produces:
image_url:
type: string
prompt_id:
type: string

previous_index: prompt_id

args:
dataset_url:
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
description: Url of the dataset
type: str
faiss_index_path:
description: Url of the dataset
type: str
image_index_column_name:
description: Name of the column in the dataset that contains the image index
type: str
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
clip_model:
description: Clip model name to use for the retrieval
type: str
default: laion/CLIP-ViT-B-32-laion2B-s34B-b79K
num_images:
description: Number of images that will be retrieved for each prompt
type: int
default: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transformers==4.38.1
torch==2.2.1
faiss-cpu==1.7.4
Empty file.
106 changes: 106 additions & 0 deletions src/fondant/components/retrieve_images_from_faiss_index/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import logging
import typing as t

import dask.dataframe as dd
import faiss
import fsspec
import pandas as pd
import torch
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from fondant.component import DaskTransformComponent
from transformers import AutoTokenizer, CLIPTextModelWithProjection

logger = logging.getLogger(__name__)


class RetrieveImagesFromFaissIndex(DaskTransformComponent):
"""Retrieve images from a faiss index using CLIP embeddings."""

def __init__( # noqa PLR0913
self,
dataset_path: str,
faiss_index_path: str,
image_index_column_name: str,
clip_model: str = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
num_images: int = 2,
):
self.model = CLIPTextModelWithProjection.from_pretrained(clip_model)
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer = AutoTokenizer.from_pretrained(clip_model)
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
self.number_of_images = num_images
self.image_index_column_name = image_index_column_name

self.device = "cuda" if torch.cuda.is_available() else "cpu"

# Download faiss index to local machine
logger.info(f"Downloading faiss index from {faiss_index_path}")

with fsspec.open(faiss_index_path, "rb") as f:
file_contents = f.read()

with open("faiss_index", "wb") as out:
out.write(file_contents)

self.search_index = faiss.read_index("faiss_index")

self.dataset = dd.read_parquet(dataset_path)

if "url" not in self.dataset.columns:
msg = "Dataset does not contain column 'url'"
raise ValueError(msg)

def setup(self) -> Client:
if self.device == "cuda":
cluster = LocalCUDACluster()
return Client(cluster)

return super().setup()

def embed_prompt(self, prompt: str):
inputs = self.tokenizer([prompt], padding=True, return_tensors="pt")
outputs = self.model(**inputs)
return outputs.text_embeds.cpu().detach().numpy().astype("float64")

def retrieve_from_index(
self,
query: float,
number_of_images: int = 2,
) -> t.List[str]:
_, indices = self.search_index.search(query, number_of_images)
return indices.tolist()[0]

def transform_partition(self, dataframe: pd.DataFrame) -> pd.DataFrame:
results = []

for index, row in dataframe.iterrows():
if "prompt" in dataframe.columns:
prompt = row["prompt"]
query = self.embed_prompt(prompt)

elif "embedding" in dataframe.columns:
prompt = None
query = row["embedding"]
else:
msg = (
"Dataframe does not contain a prompt or embedding column. "
"Please provide one of both."
)
raise ValueError(msg)

indices = self.retrieve_from_index(query, self.number_of_images)
for i in indices:
url = self.dataset[self.dataset[self.image_index_column_name] == i][
"url"
]
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
row_to_add = (index, prompt, i, url) if prompt else (index, i, url)
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
results.append(row_to_add)

results_df = pd.DataFrame(
results,
columns=["prompt_id", "prompt", "image_index", "image_url"],
)
results_df = results_df.astype({"prompt_id": str})
return results_df

def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame:
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
return dataframe.map_partitions(self.transform_partition)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = ../src
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.4.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import dask.dataframe as dd
import pandas as pd


def create_image_index_mapping():
"""Create a mapping between image index and image url."""
data = {
"image_id": [1, 2, 3, 4, 5],
"image_url": ["url1", "url2", "url3", "url4", "url5"],
}

# Create Dask DataFrame
ddf = dd.from_pandas(
pd.DataFrame(data),
npartitions=1,
) # You can adjust the number of partitions as needed

# Store as Parquet
ddf.to_parquet("./dataset.parquet")


create_image_index_mapping()
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import dask.dataframe as dd
import pandas as pd

from src.main import RetrieveImagesFromFaissIndex


def test_component():
input_dataframe = pd.DataFrame.from_dict(
{
"id": ["1", "2"],
"prompt": ["first prompt", "second prompt"],
},
)

input_dataframe = input_dataframe.set_index("id")

pd.DataFrame.from_dict(
{
"id": ["a", "b", "c", "d"],
"image_url": ["http://a", "http://b", "http://c", "http://d"],
"prompt_id": ["1", "1", "2", "2"],
},
)

component = RetrieveImagesFromFaissIndex(
dataset_url="./tests/resources",
)
mrchtr marked this conversation as resolved.
Show resolved Hide resolved

input_dataframe = dd.from_pandas(input_dataframe, npartitions=4)
output_dataframe = component.transform(input_dataframe)
assert output_dataframe.columns.tolist() == [
"prompt_id",
"prompt",
"image_index",
"image_url",
]
Loading