Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image retrieval from FAISS index #876

Merged
merged 18 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/fondant/components/retrieve_from_faiss_by_prompt/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime as base

# System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# Install requirements
COPY requirements.txt ./
RUN pip3 install --no-cache-dir -r requirements.txt

# Install Fondant
# This is split from other requirements to leverage caching
ARG FONDANT_VERSION=main
RUN pip3 install fondant[component,aws,azure,gcp,gpu]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION}

# Set the working directory to the component folder
WORKDIR /component
COPY src/ src/

FROM base as test
COPY tests/ tests/
RUN pip3 install --no-cache-dir -r tests/requirements.txt
RUN python -m pytest tests

FROM base
WORKDIR /component/src
ENTRYPOINT ["fondant", "execute", "main"]
75 changes: 75 additions & 0 deletions src/fondant/components/retrieve_from_faiss_by_prompt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Retrieve images from FAISS index

<a id="retrieve_from_faiss_by_prompt#description"></a>
## Description
Retrieve images from a Faiss index. The component should reference a Faiss image dataset,
which includes both the Faiss index and a dataset of image URLs. The input dataset consists
of a list of prompts. These prompts will be embedded using a CLIP model, and similar
images will be retrieved from the index.


<a id="retrieve_from_faiss_by_prompt#inputs_outputs"></a>
## Inputs / outputs

<a id="retrieve_from_faiss_by_prompt#consumes"></a>
### Consumes
**This component consumes:**

- prompt: string




<a id="retrieve_from_faiss_by_prompt#produces"></a>
### Produces
**This component produces:**

- image_url: string
- prompt: string



<a id="retrieve_from_faiss_by_prompt#arguments"></a>
## Arguments

The component takes the following arguments to alter its behavior:

| argument | type | description | default |
| -------- | ---- | ----------- | ------- |
| url_mapping_path | str | Url of the image mapping dataset | / |
| faiss_index_path | str | Url of the dataset | / |
| clip_model | str | Clip model name to use for the retrieval | laion/CLIP-ViT-B-32-laion2B-s34B-b79K |
| num_images | int | Number of images that will be retrieved for each prompt | 2 |

<a id="retrieve_from_faiss_by_prompt#usage"></a>
## Usage

You can add this component to your pipeline using the following code:

```python
from fondant.pipeline import Pipeline


pipeline = Pipeline(...)

dataset = pipeline.read(...)

dataset = dataset.apply(
"retrieve_from_faiss_by_prompt",
arguments={
# Add arguments
# "url_mapping_path": ,
# "faiss_index_path": ,
# "clip_model": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
# "num_images": 2,
},
)
```

<a id="retrieve_from_faiss_by_prompt#testing"></a>
## Testing

You can run the tests using docker with BuildKit. From this directory, run:
```
docker build . --target test
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Retrieve images from FAISS index
description: |
Retrieve images from a Faiss index. The component should reference a Faiss image dataset,
which includes both the Faiss index and a dataset of image URLs. The input dataset consists
of a list of prompts. These prompts will be embedded using a CLIP model, and similar
images will be retrieved from the index.

image: fndnt/retrieve_from_faiss_by_prompt:dev
tags:
- Data retrieval

consumes:
prompt:
type: string

produces:
image_url:
type: string
prompt:
type: string

previous_index: prompt_id

args:
url_mapping_path:
description: Url of the image mapping dataset
type: str
faiss_index_path:
description: Url of the dataset
type: str
clip_model:
description: Clip model name to use for the retrieval
type: str
default: laion/CLIP-ViT-B-32-laion2B-s34B-b79K
num_images:
description: Number of images that will be retrieved for each prompt
type: int
default: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
transformers==4.38.1
torch==2.2.1
faiss-cpu==1.7.4
huggingface_hub==0.21.3
Empty file.
106 changes: 106 additions & 0 deletions src/fondant/components/retrieve_from_faiss_by_prompt/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import logging
import os
import typing as t

import dask.dataframe as dd
import faiss
import fsspec
import pandas as pd
import torch
from dask.distributed import Client, get_worker
from dask_cuda import LocalCUDACluster
from fondant.component import PandasTransformComponent
from transformers import AutoTokenizer, CLIPTextModelWithProjection

logger = logging.getLogger(__name__)


class RetrieveFromFaissByPrompt(PandasTransformComponent):
"""Retrieve images from a faiss index using CLIP embeddings."""

def __init__( # PLR0913
self,
url_mapping_path: str,
faiss_index_path: str,
clip_model: str = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
num_images: int = 2,
):
self.model_id = clip_model
self.number_of_images = num_images

self.device = "cuda" if torch.cuda.is_available() else "cpu"

# Download faiss index to local machine
if not os.path.exists("faiss_index"):
logger.info(f"Downloading faiss index from {faiss_index_path}")
with fsspec.open(faiss_index_path, "rb") as f:
file_contents = f.read()

with open("faiss_index", "wb") as out:
out.write(file_contents)

dataset = dd.read_parquet(url_mapping_path)
if "url" not in dataset.columns:
msg = "Dataset does not contain column 'url'"
raise ValueError(msg)
self.image_urls = dataset["url"].compute().to_list()

def setup(self) -> Client:
"""Setup LocalCudaCluster if gpu is available."""
if self.device == "cuda":
cluster = LocalCUDACluster()
return Client(cluster)

return super().setup()

def embed_prompt(self, prompt: str):
"""Embed prompt using CLIP model."""
worker = get_worker()
if worker and hasattr(worker, "model"):
tokenizer = worker.tokenizer
model = worker.model

else:
logger.info("Initializing model '%s' on worker '%s", self.model_id, worker)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
model = CLIPTextModelWithProjection.from_pretrained(self.model_id).to(
self.device,
)

worker.model = model
worker.tokenizer = tokenizer

inputs = tokenizer([prompt], padding=True, return_tensors="pt")
inputs = inputs.to(self.device)
outputs = model(**inputs)
return outputs.text_embeds.cpu().detach().numpy().astype("float64")

def retrieve_from_index(
self,
query: float,
number_of_images: int = 2,
) -> t.List[str]:
"""Retrieve images from faiss index."""
search_index = faiss.read_index("faiss_index")
_, indices = search_index.search(query, number_of_images)
return indices.tolist()[0]

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"""Transform partition of dataframe."""
results = []
prompts = dataframe["prompt"]

for prompt in prompts:
query = self.embed_prompt(prompt)
indices = self.retrieve_from_index(query, self.number_of_images)
for i, idx in enumerate(indices):
url = self.image_urls[idx]
row_to_add = (f"{prompt}_{i}", prompt, url)
results.append(row_to_add)

results_df = pd.DataFrame(
results,
columns=["id", "prompt", "image_url"],
)
results_df = results_df.set_index("id")
return results_df
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = ../src
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.4.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd

from src.main import RetrieveFromFaissByPrompt


def test_component():
input_dataframe = pd.DataFrame.from_dict(
{
"id": [1, 2],
"prompt": ["country style kitchen", "cozy living room"],
},
)

input_dataframe = input_dataframe.set_index("id")
input_dataframe["prompt"] = input_dataframe["prompt"].astype(str)

# Run component
component = RetrieveFromFaissByPrompt(
url_mapping_path="gs://soy-audio-379412-embed-datacomp/12M/id_mapping",
faiss_index_path="gs://soy-audio-379412-embed-datacomp/12M/faiss",
)

component.setup()
output_dataframe = component.transform(input_dataframe)
assert output_dataframe.columns.tolist() == [
"prompt",
"image_url",
]
Loading