Skip to content

Commit

Permalink
Add image retrieval from FAISS index (#876)
Browse files Browse the repository at this point in the history
The component expects a dataset in the form of:
- A Dask dataset containing id mapping (image id to image url)
- The faiss index itself

The component will load the dataset from the remote storage into the
component, initialise a CLIP model,
and retrieve similar images based on given prompts.
  • Loading branch information
mrchtr authored Mar 6, 2024
1 parent f2da61b commit 6f78d09
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/fondant/components/retrieve_from_faiss_by_prompt/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime as base

# System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# Install requirements
COPY requirements.txt ./
RUN pip3 install --no-cache-dir -r requirements.txt

# Install Fondant
# This is split from other requirements to leverage caching
ARG FONDANT_VERSION=main
RUN pip3 install fondant[component,aws,azure,gcp,gpu]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION}

# Set the working directory to the component folder
WORKDIR /component
COPY src/ src/

FROM base as test
COPY tests/ tests/
RUN pip3 install --no-cache-dir -r tests/requirements.txt
RUN python -m pytest tests

FROM base
WORKDIR /component/src
ENTRYPOINT ["fondant", "execute", "main"]
75 changes: 75 additions & 0 deletions src/fondant/components/retrieve_from_faiss_by_prompt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Retrieve images from FAISS index

<a id="retrieve_from_faiss_by_prompt#description"></a>
## Description
Retrieve images from a Faiss index. The component should reference a Faiss image dataset,
which includes both the Faiss index and a dataset of image URLs. The input dataset consists
of a list of prompts. These prompts will be embedded using a CLIP model, and similar
images will be retrieved from the index.


<a id="retrieve_from_faiss_by_prompt#inputs_outputs"></a>
## Inputs / outputs

<a id="retrieve_from_faiss_by_prompt#consumes"></a>
### Consumes
**This component consumes:**

- prompt: string




<a id="retrieve_from_faiss_by_prompt#produces"></a>
### Produces
**This component produces:**

- image_url: string
- prompt: string



<a id="retrieve_from_faiss_by_prompt#arguments"></a>
## Arguments

The component takes the following arguments to alter its behavior:

| argument | type | description | default |
| -------- | ---- | ----------- | ------- |
| url_mapping_path | str | Url of the image mapping dataset | / |
| faiss_index_path | str | Url of the dataset | / |
| clip_model | str | Clip model name to use for the retrieval | laion/CLIP-ViT-B-32-laion2B-s34B-b79K |
| num_images | int | Number of images that will be retrieved for each prompt | 2 |

<a id="retrieve_from_faiss_by_prompt#usage"></a>
## Usage

You can add this component to your pipeline using the following code:

```python
from fondant.pipeline import Pipeline


pipeline = Pipeline(...)

dataset = pipeline.read(...)

dataset = dataset.apply(
"retrieve_from_faiss_by_prompt",
arguments={
# Add arguments
# "url_mapping_path": ,
# "faiss_index_path": ,
# "clip_model": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
# "num_images": 2,
},
)
```

<a id="retrieve_from_faiss_by_prompt#testing"></a>
## Testing

You can run the tests using docker with BuildKit. From this directory, run:
```
docker build . --target test
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Retrieve images from FAISS index
description: |
Retrieve images from a Faiss index. The component should reference a Faiss image dataset,
which includes both the Faiss index and a dataset of image URLs. The input dataset consists
of a list of prompts. These prompts will be embedded using a CLIP model, and similar
images will be retrieved from the index.
image: fndnt/retrieve_from_faiss_by_prompt:dev
tags:
- Data retrieval

consumes:
prompt:
type: string

produces:
image_url:
type: string
prompt:
type: string

previous_index: prompt_id

args:
url_mapping_path:
description: Url of the image mapping dataset
type: str
faiss_index_path:
description: Url of the dataset
type: str
clip_model:
description: Clip model name to use for the retrieval
type: str
default: laion/CLIP-ViT-B-32-laion2B-s34B-b79K
num_images:
description: Number of images that will be retrieved for each prompt
type: int
default: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
transformers==4.38.1
torch==2.2.1
faiss-cpu==1.7.4
huggingface_hub==0.21.3
Empty file.
106 changes: 106 additions & 0 deletions src/fondant/components/retrieve_from_faiss_by_prompt/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import logging
import os
import typing as t

import dask.dataframe as dd
import faiss
import fsspec
import pandas as pd
import torch
from dask.distributed import Client, get_worker
from dask_cuda import LocalCUDACluster
from fondant.component import PandasTransformComponent
from transformers import AutoTokenizer, CLIPTextModelWithProjection

logger = logging.getLogger(__name__)


class RetrieveFromFaissByPrompt(PandasTransformComponent):
"""Retrieve images from a faiss index using CLIP embeddings."""

def __init__( # PLR0913
self,
url_mapping_path: str,
faiss_index_path: str,
clip_model: str = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
num_images: int = 2,
):
self.model_id = clip_model
self.number_of_images = num_images

self.device = "cuda" if torch.cuda.is_available() else "cpu"

# Download faiss index to local machine
if not os.path.exists("faiss_index"):
logger.info(f"Downloading faiss index from {faiss_index_path}")
with fsspec.open(faiss_index_path, "rb") as f:
file_contents = f.read()

with open("faiss_index", "wb") as out:
out.write(file_contents)

dataset = dd.read_parquet(url_mapping_path)
if "url" not in dataset.columns:
msg = "Dataset does not contain column 'url'"
raise ValueError(msg)
self.image_urls = dataset["url"].compute().to_list()

def setup(self) -> Client:
"""Setup LocalCudaCluster if gpu is available."""
if self.device == "cuda":
cluster = LocalCUDACluster()
return Client(cluster)

return super().setup()

def embed_prompt(self, prompt: str):
"""Embed prompt using CLIP model."""
worker = get_worker()
if worker and hasattr(worker, "model"):
tokenizer = worker.tokenizer
model = worker.model

else:
logger.info("Initializing model '%s' on worker '%s", self.model_id, worker)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
model = CLIPTextModelWithProjection.from_pretrained(self.model_id).to(
self.device,
)

worker.model = model
worker.tokenizer = tokenizer

inputs = tokenizer([prompt], padding=True, return_tensors="pt")
inputs = inputs.to(self.device)
outputs = model(**inputs)
return outputs.text_embeds.cpu().detach().numpy().astype("float64")

def retrieve_from_index(
self,
query: float,
number_of_images: int = 2,
) -> t.List[str]:
"""Retrieve images from faiss index."""
search_index = faiss.read_index("faiss_index")
_, indices = search_index.search(query, number_of_images)
return indices.tolist()[0]

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"""Transform partition of dataframe."""
results = []
prompts = dataframe["prompt"]

for prompt in prompts:
query = self.embed_prompt(prompt)
indices = self.retrieve_from_index(query, self.number_of_images)
for i, idx in enumerate(indices):
url = self.image_urls[idx]
row_to_add = (f"{prompt}_{i}", prompt, url)
results.append(row_to_add)

results_df = pd.DataFrame(
results,
columns=["id", "prompt", "image_url"],
)
results_df = results_df.set_index("id")
return results_df
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = ../src
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.4.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd

from src.main import RetrieveFromFaissByPrompt


def test_component():
input_dataframe = pd.DataFrame.from_dict(
{
"id": [1, 2],
"prompt": ["country style kitchen", "cozy living room"],
},
)

input_dataframe = input_dataframe.set_index("id")
input_dataframe["prompt"] = input_dataframe["prompt"].astype(str)

# Run component
component = RetrieveFromFaissByPrompt(
url_mapping_path="gs://soy-audio-379412-embed-datacomp/12M/id_mapping",
faiss_index_path="gs://soy-audio-379412-embed-datacomp/12M/faiss",
)

component.setup()
output_dataframe = component.transform(input_dataframe)
assert output_dataframe.columns.tolist() == [
"prompt",
"image_url",
]

0 comments on commit 6f78d09

Please sign in to comment.