diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/Dockerfile b/src/fondant/components/retrieve_from_faiss_by_prompt/Dockerfile new file mode 100644 index 00000000..2e572a09 --- /dev/null +++ b/src/fondant/components/retrieve_from_faiss_by_prompt/Dockerfile @@ -0,0 +1,28 @@ +FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime as base + +# System dependencies +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install git -y + +# Install requirements +COPY requirements.txt ./ +RUN pip3 install --no-cache-dir -r requirements.txt + +# Install Fondant +# This is split from other requirements to leverage caching +ARG FONDANT_VERSION=main +RUN pip3 install fondant[component,aws,azure,gcp,gpu]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} + +# Set the working directory to the component folder +WORKDIR /component +COPY src/ src/ + +FROM base as test +COPY tests/ tests/ +RUN pip3 install --no-cache-dir -r tests/requirements.txt +RUN python -m pytest tests + +FROM base +WORKDIR /component/src +ENTRYPOINT ["fondant", "execute", "main"] diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/README.md b/src/fondant/components/retrieve_from_faiss_by_prompt/README.md new file mode 100644 index 00000000..d2348fb0 --- /dev/null +++ b/src/fondant/components/retrieve_from_faiss_by_prompt/README.md @@ -0,0 +1,75 @@ +# Retrieve images from FAISS index + + +## Description +Retrieve images from a Faiss index. The component should reference a Faiss image dataset, + which includes both the Faiss index and a dataset of image URLs. The input dataset consists + of a list of prompts. These prompts will be embedded using a CLIP model, and similar + images will be retrieved from the index. + + + +## Inputs / outputs + + +### Consumes +**This component consumes:** + +- prompt: string + + + + + +### Produces +**This component produces:** + +- image_url: string +- prompt: string + + + + +## Arguments + +The component takes the following arguments to alter its behavior: + +| argument | type | description | default | +| -------- | ---- | ----------- | ------- | +| url_mapping_path | str | Url of the image mapping dataset | / | +| faiss_index_path | str | Url of the dataset | / | +| clip_model | str | Clip model name to use for the retrieval | laion/CLIP-ViT-B-32-laion2B-s34B-b79K | +| num_images | int | Number of images that will be retrieved for each prompt | 2 | + + +## Usage + +You can add this component to your pipeline using the following code: + +```python +from fondant.pipeline import Pipeline + + +pipeline = Pipeline(...) + +dataset = pipeline.read(...) + +dataset = dataset.apply( + "retrieve_from_faiss_by_prompt", + arguments={ + # Add arguments + # "url_mapping_path": , + # "faiss_index_path": , + # "clip_model": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", + # "num_images": 2, + }, +) +``` + + +## Testing + +You can run the tests using docker with BuildKit. From this directory, run: +``` +docker build . --target test +``` diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/fondant_component.yaml b/src/fondant/components/retrieve_from_faiss_by_prompt/fondant_component.yaml new file mode 100644 index 00000000..b4107736 --- /dev/null +++ b/src/fondant/components/retrieve_from_faiss_by_prompt/fondant_component.yaml @@ -0,0 +1,38 @@ +name: Retrieve images from FAISS index +description: | + Retrieve images from a Faiss index. The component should reference a Faiss image dataset, + which includes both the Faiss index and a dataset of image URLs. The input dataset consists + of a list of prompts. These prompts will be embedded using a CLIP model, and similar + images will be retrieved from the index. + +image: fndnt/retrieve_from_faiss_by_prompt:dev +tags: + - Data retrieval + +consumes: + prompt: + type: string + +produces: + image_url: + type: string + prompt: + type: string + +previous_index: prompt_id + +args: + url_mapping_path: + description: Url of the image mapping dataset + type: str + faiss_index_path: + description: Url of the dataset + type: str + clip_model: + description: Clip model name to use for the retrieval + type: str + default: laion/CLIP-ViT-B-32-laion2B-s34B-b79K + num_images: + description: Number of images that will be retrieved for each prompt + type: int + default: 2 diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/requirements.txt b/src/fondant/components/retrieve_from_faiss_by_prompt/requirements.txt new file mode 100644 index 00000000..ad72cf3a --- /dev/null +++ b/src/fondant/components/retrieve_from_faiss_by_prompt/requirements.txt @@ -0,0 +1,4 @@ +transformers==4.38.1 +torch==2.2.1 +faiss-cpu==1.7.4 +huggingface_hub==0.21.3 \ No newline at end of file diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/src/__init__.py b/src/fondant/components/retrieve_from_faiss_by_prompt/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/src/main.py b/src/fondant/components/retrieve_from_faiss_by_prompt/src/main.py new file mode 100644 index 00000000..44612d67 --- /dev/null +++ b/src/fondant/components/retrieve_from_faiss_by_prompt/src/main.py @@ -0,0 +1,106 @@ +import logging +import os +import typing as t + +import dask.dataframe as dd +import faiss +import fsspec +import pandas as pd +import torch +from dask.distributed import Client, get_worker +from dask_cuda import LocalCUDACluster +from fondant.component import PandasTransformComponent +from transformers import AutoTokenizer, CLIPTextModelWithProjection + +logger = logging.getLogger(__name__) + + +class RetrieveFromFaissByPrompt(PandasTransformComponent): + """Retrieve images from a faiss index using CLIP embeddings.""" + + def __init__( # PLR0913 + self, + url_mapping_path: str, + faiss_index_path: str, + clip_model: str = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", + num_images: int = 2, + ): + self.model_id = clip_model + self.number_of_images = num_images + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + # Download faiss index to local machine + if not os.path.exists("faiss_index"): + logger.info(f"Downloading faiss index from {faiss_index_path}") + with fsspec.open(faiss_index_path, "rb") as f: + file_contents = f.read() + + with open("faiss_index", "wb") as out: + out.write(file_contents) + + dataset = dd.read_parquet(url_mapping_path) + if "url" not in dataset.columns: + msg = "Dataset does not contain column 'url'" + raise ValueError(msg) + self.image_urls = dataset["url"].compute().to_list() + + def setup(self) -> Client: + """Setup LocalCudaCluster if gpu is available.""" + if self.device == "cuda": + cluster = LocalCUDACluster() + return Client(cluster) + + return super().setup() + + def embed_prompt(self, prompt: str): + """Embed prompt using CLIP model.""" + worker = get_worker() + if worker and hasattr(worker, "model"): + tokenizer = worker.tokenizer + model = worker.model + + else: + logger.info("Initializing model '%s' on worker '%s", self.model_id, worker) + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + model = CLIPTextModelWithProjection.from_pretrained(self.model_id).to( + self.device, + ) + + worker.model = model + worker.tokenizer = tokenizer + + inputs = tokenizer([prompt], padding=True, return_tensors="pt") + inputs = inputs.to(self.device) + outputs = model(**inputs) + return outputs.text_embeds.cpu().detach().numpy().astype("float64") + + def retrieve_from_index( + self, + query: float, + number_of_images: int = 2, + ) -> t.List[str]: + """Retrieve images from faiss index.""" + search_index = faiss.read_index("faiss_index") + _, indices = search_index.search(query, number_of_images) + return indices.tolist()[0] + + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + """Transform partition of dataframe.""" + results = [] + prompts = dataframe["prompt"] + + for prompt in prompts: + query = self.embed_prompt(prompt) + indices = self.retrieve_from_index(query, self.number_of_images) + for i, idx in enumerate(indices): + url = self.image_urls[idx] + row_to_add = (f"{prompt}_{i}", prompt, url) + results.append(row_to_add) + + results_df = pd.DataFrame( + results, + columns=["id", "prompt", "image_url"], + ) + results_df = results_df.set_index("id") + return results_df diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/tests/pytest.ini b/src/fondant/components/retrieve_from_faiss_by_prompt/tests/pytest.ini new file mode 100644 index 00000000..bf6a8a51 --- /dev/null +++ b/src/fondant/components/retrieve_from_faiss_by_prompt/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = ../src \ No newline at end of file diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/tests/requirements.txt b/src/fondant/components/retrieve_from_faiss_by_prompt/tests/requirements.txt new file mode 100644 index 00000000..2a929edc --- /dev/null +++ b/src/fondant/components/retrieve_from_faiss_by_prompt/tests/requirements.txt @@ -0,0 +1 @@ +pytest==7.4.2 diff --git a/src/fondant/components/retrieve_from_faiss_by_prompt/tests/test_component.py b/src/fondant/components/retrieve_from_faiss_by_prompt/tests/test_component.py new file mode 100644 index 00000000..2d0d4816 --- /dev/null +++ b/src/fondant/components/retrieve_from_faiss_by_prompt/tests/test_component.py @@ -0,0 +1,28 @@ +import pandas as pd + +from src.main import RetrieveFromFaissByPrompt + + +def test_component(): + input_dataframe = pd.DataFrame.from_dict( + { + "id": [1, 2], + "prompt": ["country style kitchen", "cozy living room"], + }, + ) + + input_dataframe = input_dataframe.set_index("id") + input_dataframe["prompt"] = input_dataframe["prompt"].astype(str) + + # Run component + component = RetrieveFromFaissByPrompt( + url_mapping_path="gs://soy-audio-379412-embed-datacomp/12M/id_mapping", + faiss_index_path="gs://soy-audio-379412-embed-datacomp/12M/faiss", + ) + + component.setup() + output_dataframe = component.transform(input_dataframe) + assert output_dataframe.columns.tolist() == [ + "prompt", + "image_url", + ]