Skip to content

Commit

Permalink
Implement previous_index field (#668)
Browse files Browse the repository at this point in the history
#656 

We might want to validate this by checking that the field mentioned in
`previous_index` is also defined in the `consumes` section.
  • Loading branch information
RobbeSneyders committed Nov 27, 2023
1 parent 1c6cb6d commit be8c67b
Show file tree
Hide file tree
Showing 23 changed files with 224 additions and 75 deletions.
1 change: 0 additions & 1 deletion components/download_images/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ produces:
type: int32
images_height:
type: int32
# additionalFields: false

args:
timeout:
Expand Down
15 changes: 11 additions & 4 deletions components/embedding_based_laion_retrieval/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM --platform=linux/amd64 python:3.8-slim
FROM --platform=linux/amd64 python:3.8-slim as base

# System dependencies
RUN apt-get update && \
Expand All @@ -16,8 +16,15 @@ RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team

# Set the working directory to the component folder
WORKDIR /component/src
COPY src/ src/
ENV PYTHONPATH "${PYTHONPATH}:./src"

# Copy over src-files
COPY src/ .
FROM base as test
COPY test_requirements.txt .
RUN pip3 install --no-cache-dir -r test_requirements.txt
COPY tests/ tests/
RUN python -m pytest tests

ENTRYPOINT ["fondant", "execute", "main"]
FROM base
WORKDIR /component/src
ENTRYPOINT ["fondant", "execute", "main"]
7 changes: 7 additions & 0 deletions components/embedding_based_laion_retrieval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ used to find images similar to the embedded images / captions.
**This component produces:**

- images_url: string
- embedding_id: string

### Arguments

Expand Down Expand Up @@ -45,3 +46,9 @@ embedding_based_laion_retrieval_op = ComponentOp.from_registry(
pipeline.add_op(embedding_based_laion_retrieval_op, dependencies=[...]) #Add previous component as dependency
```

### Testing

You can run the tests using docker with BuildKit. From this directory, run:
```
docker build . --target test
```
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ consumes:
produces:
images_url:
type: string
# additionalFields: false
embedding_id:
type: string

previous_index: embedding_id

args:
num_images:
Expand Down
20 changes: 10 additions & 10 deletions components/embedding_based_laion_retrieval/src/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This component retrieves image URLs from LAION-5B based on a set of CLIP embeddings."""
import asyncio
import concurrent.futures
import functools
import logging
import typing as t

Expand Down Expand Up @@ -40,6 +39,10 @@ def __init__(
modality=Modality.IMAGE,
)

def query(self, id_: t.Any, embedding: t.List[float]) -> t.List[t.Dict]:
results = self.client.query(embedding_input=embedding)
return [dict(d, embedding_id=id_) for d in results]

def transform(
self,
dataframe: pd.DataFrame,
Expand All @@ -53,23 +56,20 @@ async def async_query():
futures = [
loop.run_in_executor(
executor,
functools.partial(
self.client.query,
embedding_input=embedding.tolist(),
),
self.query,
row.id,
row.embeddings_data.tolist(),
)
for embedding in dataframe["embeddings_data"]
for row in dataframe.itertuples()
]
for response in await asyncio.gather(*futures):
results.extend(response)

loop.run_until_complete(async_query())

results_df = pd.DataFrame(results)["id", "url"]
results_df = pd.DataFrame(results)[["id", "url", "embedding_id"]]
results_df = results_df.set_index("id")

# Cast the index to string
results_df.index = results_df.index.astype(str)
results_df.columns = ["images_url"]
results_df.rename(columns={"url": "images_url"})

return results_df
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.4.2
2 changes: 2 additions & 0 deletions components/embedding_based_laion_retrieval/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = ../src
66 changes: 66 additions & 0 deletions components/embedding_based_laion_retrieval/tests/test_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import typing as t

import numpy as np
import pandas as pd

from src.main import LAIONRetrievalComponent


def test_component(monkeypatch):
def mocked_client_query(embedding_input: t.List[float]) -> t.List[dict]:
if embedding_input == [1, 2]:
return [
{
"id": "a",
"url": "http://a",
},
{
"id": "b",
"url": "http://b",
},
]
if embedding_input == [2, 3]:
return [
{
"id": "c",
"url": "http://c",
},
{
"id": "d",
"url": "http://d",
},
]
msg = f"Unexpected value: `embeddings_input` was {embedding_input}"
raise ValueError(msg)

input_dataframe = pd.DataFrame.from_dict(
{
"id": ["1", "2"],
"embeddings_data": [np.array([1, 2]), np.array([2, 3])],
},
)

expected_output_dataframe = pd.DataFrame.from_dict(
{
"id": ["a", "b", "c", "d"],
"url": ["http://a", "http://b", "http://c", "http://d"],
"embedding_id": ["1", "1", "2", "2"],
},
)
expected_output_dataframe = expected_output_dataframe.set_index("id")

component = LAIONRetrievalComponent(
num_images=2,
aesthetic_score=9,
aesthetic_weight=0.5,
)

monkeypatch.setattr(component.client, "query", mocked_client_query)

output_dataframe = component.transform(input_dataframe)

pd.testing.assert_frame_equal(
left=expected_output_dataframe,
right=output_dataframe,
check_dtype=False,
)
14 changes: 6 additions & 8 deletions components/index_qdrant/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ image: 'fndnt/index_qdrant:dev'
tags:
- Data writing
consumes:
text:
fields:
data:
type: string
embedding:
type: array
items:
type: float32
text_data:
type: string
embeddings_data:
type: array
items:
type: float32
args:
collection_name:
description: The name of the Qdrant collection to upsert data into.
Expand Down
15 changes: 11 additions & 4 deletions components/prompt_based_laion_retrieval/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM --platform=linux/amd64 python:3.8-slim
FROM --platform=linux/amd64 python:3.8-slim as base

# System dependencies
RUN apt-get update && \
Expand All @@ -16,8 +16,15 @@ RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team

# Set the working directory to the component folder
WORKDIR /component/src
COPY src/ src/
ENV PYTHONPATH "${PYTHONPATH}:./src"

# Copy over src-files
COPY src/ .
FROM base as test
COPY test_requirements.txt .
RUN pip3 install --no-cache-dir -r test_requirements.txt
COPY tests/ tests/
RUN python -m pytest tests

ENTRYPOINT ["fondant", "execute", "main"]
FROM base
WORKDIR /component/src
ENTRYPOINT ["fondant", "execute", "main"]
7 changes: 7 additions & 0 deletions components/prompt_based_laion_retrieval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ This component doesn’t return the actual images, only URLs.
**This component produces:**

- images_url: string
- prompt_id: string

### Arguments

Expand Down Expand Up @@ -50,3 +51,9 @@ prompt_based_laion_retrieval_op = ComponentOp.from_registry(
pipeline.add_op(prompt_based_laion_retrieval_op, dependencies=[...]) #Add previous component as dependency
```

### Testing

You can run the tests using docker with BuildKit. From this directory, run:
```
docker build . --target test
```
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ consumes:
produces:
images_url:
type: string
# additionalFields: false
prompt_id:
type: string

previous_index: prompt_id

args:
num_images:
Expand Down
17 changes: 10 additions & 7 deletions components/prompt_based_laion_retrieval/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(
modality=Modality.IMAGE,
)

def query(self, id_: t.Any, prompt: str) -> t.List[t.Dict]:
results = self.client.query(text=prompt)
return [dict(d, prompt_id=id_) for d in results]

def transform(
self,
dataframe: pd.DataFrame,
Expand All @@ -53,21 +57,20 @@ async def async_query():
futures = [
loop.run_in_executor(
executor,
self.client.query,
prompt,
self.query,
row.id,
row.prompts_text,
)
for prompt in dataframe["prompts_text"]
for row in dataframe.itertuples()
]
for response in await asyncio.gather(*futures):
results.extend(response)

loop.run_until_complete(async_query())

results_df = pd.DataFrame(results)["id", "url"]
results_df = pd.DataFrame(results)[["id", "url", "prompt_id"]]
results_df = results_df.set_index("id")

# Cast the index to string
results_df.index = results_df.index.astype(str)
results_df.columns = ["images_url"]
results_df.rename(columns={"url": "images_url"})

return results_df
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.4.2
2 changes: 2 additions & 0 deletions components/prompt_based_laion_retrieval/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = ../src
66 changes: 66 additions & 0 deletions components/prompt_based_laion_retrieval/tests/test_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import typing as t

import pandas as pd

from src.main import LAIONRetrievalComponent


def test_component(monkeypatch):
def mocked_client_query(text: str) -> t.List[dict]:
if text == "first prompt":
return [
{
"id": "a",
"url": "http://a",
},
{
"id": "b",
"url": "http://b",
},
]
if text == "second prompt":
return [
{
"id": "c",
"url": "http://c",
},
{
"id": "d",
"url": "http://d",
},
]
msg = f"Unexpected value: `text` was {text}"
raise ValueError(msg)

input_dataframe = pd.DataFrame.from_dict(
{
"id": ["1", "2"],
"prompts_text": ["first prompt", "second prompt"],
},
)

expected_output_dataframe = pd.DataFrame.from_dict(
{
"id": ["a", "b", "c", "d"],
"url": ["http://a", "http://b", "http://c", "http://d"],
"prompt_id": ["1", "1", "2", "2"],
},
)
expected_output_dataframe = expected_output_dataframe.set_index("id")

component = LAIONRetrievalComponent(
num_images=2,
aesthetic_score=9,
aesthetic_weight=0.5,
url="",
)

monkeypatch.setattr(component.client, "query", mocked_client_query)

output_dataframe = component.transform(input_dataframe)

pd.testing.assert_frame_equal(
left=expected_output_dataframe,
right=output_dataframe,
check_dtype=False,
)
Loading

0 comments on commit be8c67b

Please sign in to comment.