From be8c67be8cb150023b9da868bb3f1683d67b40ec Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Mon, 27 Nov 2023 10:21:30 +0100 Subject: [PATCH] Implement `previous_index` field (#668) #656 We might want to validate this by checking that the field mentioned in `previous_index` is also defined in the `consumes` section. --- .../download_images/fondant_component.yaml | 1 - .../Dockerfile | 15 +++-- .../embedding_based_laion_retrieval/README.md | 7 ++ .../fondant_component.yaml | 5 +- .../src/main.py | 20 +++--- .../test_requirements.txt | 1 + .../tests/pytest.ini | 2 + .../tests/test_component.py | 66 +++++++++++++++++++ .../index_qdrant/fondant_component.yaml | 14 ++-- .../prompt_based_laion_retrieval/Dockerfile | 15 +++-- .../prompt_based_laion_retrieval/README.md | 7 ++ .../fondant_component.yaml | 5 +- .../prompt_based_laion_retrieval/src/main.py | 17 +++-- .../test_requirements.txt | 1 + .../tests/pytest.ini | 2 + .../tests/test_component.py | 66 +++++++++++++++++++ src/fondant/component/executor.py | 19 +----- src/fondant/core/component_spec.py | 4 ++ src/fondant/core/manifest.py | 5 +- src/fondant/core/schemas/component_spec.json | 3 + .../examples/component_specs/component.yaml | 2 - .../evolution_examples/2/component.yaml | 6 +- .../evolution_examples/2/output_manifest.json | 16 ----- 23 files changed, 224 insertions(+), 75 deletions(-) create mode 100644 components/embedding_based_laion_retrieval/test_requirements.txt create mode 100644 components/embedding_based_laion_retrieval/tests/pytest.ini create mode 100644 components/embedding_based_laion_retrieval/tests/test_component.py create mode 100644 components/prompt_based_laion_retrieval/test_requirements.txt create mode 100644 components/prompt_based_laion_retrieval/tests/pytest.ini create mode 100644 components/prompt_based_laion_retrieval/tests/test_component.py diff --git a/components/download_images/fondant_component.yaml b/components/download_images/fondant_component.yaml index abe19c653..91efeca15 100644 --- a/components/download_images/fondant_component.yaml +++ b/components/download_images/fondant_component.yaml @@ -23,7 +23,6 @@ produces: type: int32 images_height: type: int32 -# additionalFields: false args: timeout: diff --git a/components/embedding_based_laion_retrieval/Dockerfile b/components/embedding_based_laion_retrieval/Dockerfile index 72525d884..0cdcde81a 100644 --- a/components/embedding_based_laion_retrieval/Dockerfile +++ b/components/embedding_based_laion_retrieval/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=linux/amd64 python:3.8-slim +FROM --platform=linux/amd64 python:3.8-slim as base # System dependencies RUN apt-get update && \ @@ -16,8 +16,15 @@ RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team # Set the working directory to the component folder WORKDIR /component/src +COPY src/ src/ +ENV PYTHONPATH "${PYTHONPATH}:./src" -# Copy over src-files -COPY src/ . +FROM base as test +COPY test_requirements.txt . +RUN pip3 install --no-cache-dir -r test_requirements.txt +COPY tests/ tests/ +RUN python -m pytest tests -ENTRYPOINT ["fondant", "execute", "main"] \ No newline at end of file +FROM base +WORKDIR /component/src +ENTRYPOINT ["fondant", "execute", "main"] diff --git a/components/embedding_based_laion_retrieval/README.md b/components/embedding_based_laion_retrieval/README.md index f19d55b03..97e0866a5 100644 --- a/components/embedding_based_laion_retrieval/README.md +++ b/components/embedding_based_laion_retrieval/README.md @@ -14,6 +14,7 @@ used to find images similar to the embedded images / captions. **This component produces:** - images_url: string +- embedding_id: string ### Arguments @@ -45,3 +46,9 @@ embedding_based_laion_retrieval_op = ComponentOp.from_registry( pipeline.add_op(embedding_based_laion_retrieval_op, dependencies=[...]) #Add previous component as dependency ``` +### Testing + +You can run the tests using docker with BuildKit. From this directory, run: +``` +docker build . --target test +``` diff --git a/components/embedding_based_laion_retrieval/fondant_component.yaml b/components/embedding_based_laion_retrieval/fondant_component.yaml index af147c158..d7616cfbd 100644 --- a/components/embedding_based_laion_retrieval/fondant_component.yaml +++ b/components/embedding_based_laion_retrieval/fondant_component.yaml @@ -15,7 +15,10 @@ consumes: produces: images_url: type: string -# additionalFields: false + embedding_id: + type: string + +previous_index: embedding_id args: num_images: diff --git a/components/embedding_based_laion_retrieval/src/main.py b/components/embedding_based_laion_retrieval/src/main.py index 0f7697dc3..4d730f24c 100644 --- a/components/embedding_based_laion_retrieval/src/main.py +++ b/components/embedding_based_laion_retrieval/src/main.py @@ -1,7 +1,6 @@ """This component retrieves image URLs from LAION-5B based on a set of CLIP embeddings.""" import asyncio import concurrent.futures -import functools import logging import typing as t @@ -40,6 +39,10 @@ def __init__( modality=Modality.IMAGE, ) + def query(self, id_: t.Any, embedding: t.List[float]) -> t.List[t.Dict]: + results = self.client.query(embedding_input=embedding) + return [dict(d, embedding_id=id_) for d in results] + def transform( self, dataframe: pd.DataFrame, @@ -53,23 +56,20 @@ async def async_query(): futures = [ loop.run_in_executor( executor, - functools.partial( - self.client.query, - embedding_input=embedding.tolist(), - ), + self.query, + row.id, + row.embeddings_data.tolist(), ) - for embedding in dataframe["embeddings_data"] + for row in dataframe.itertuples() ] for response in await asyncio.gather(*futures): results.extend(response) loop.run_until_complete(async_query()) - results_df = pd.DataFrame(results)["id", "url"] + results_df = pd.DataFrame(results)[["id", "url", "embedding_id"]] results_df = results_df.set_index("id") - # Cast the index to string - results_df.index = results_df.index.astype(str) - results_df.columns = ["images_url"] + results_df.rename(columns={"url": "images_url"}) return results_df diff --git a/components/embedding_based_laion_retrieval/test_requirements.txt b/components/embedding_based_laion_retrieval/test_requirements.txt new file mode 100644 index 000000000..2a929edcc --- /dev/null +++ b/components/embedding_based_laion_retrieval/test_requirements.txt @@ -0,0 +1 @@ +pytest==7.4.2 diff --git a/components/embedding_based_laion_retrieval/tests/pytest.ini b/components/embedding_based_laion_retrieval/tests/pytest.ini new file mode 100644 index 000000000..bf6a8a517 --- /dev/null +++ b/components/embedding_based_laion_retrieval/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = ../src \ No newline at end of file diff --git a/components/embedding_based_laion_retrieval/tests/test_component.py b/components/embedding_based_laion_retrieval/tests/test_component.py new file mode 100644 index 000000000..ba59028bf --- /dev/null +++ b/components/embedding_based_laion_retrieval/tests/test_component.py @@ -0,0 +1,66 @@ +import typing as t + +import numpy as np +import pandas as pd + +from src.main import LAIONRetrievalComponent + + +def test_component(monkeypatch): + def mocked_client_query(embedding_input: t.List[float]) -> t.List[dict]: + if embedding_input == [1, 2]: + return [ + { + "id": "a", + "url": "http://a", + }, + { + "id": "b", + "url": "http://b", + }, + ] + if embedding_input == [2, 3]: + return [ + { + "id": "c", + "url": "http://c", + }, + { + "id": "d", + "url": "http://d", + }, + ] + msg = f"Unexpected value: `embeddings_input` was {embedding_input}" + raise ValueError(msg) + + input_dataframe = pd.DataFrame.from_dict( + { + "id": ["1", "2"], + "embeddings_data": [np.array([1, 2]), np.array([2, 3])], + }, + ) + + expected_output_dataframe = pd.DataFrame.from_dict( + { + "id": ["a", "b", "c", "d"], + "url": ["http://a", "http://b", "http://c", "http://d"], + "embedding_id": ["1", "1", "2", "2"], + }, + ) + expected_output_dataframe = expected_output_dataframe.set_index("id") + + component = LAIONRetrievalComponent( + num_images=2, + aesthetic_score=9, + aesthetic_weight=0.5, + ) + + monkeypatch.setattr(component.client, "query", mocked_client_query) + + output_dataframe = component.transform(input_dataframe) + + pd.testing.assert_frame_equal( + left=expected_output_dataframe, + right=output_dataframe, + check_dtype=False, + ) diff --git a/components/index_qdrant/fondant_component.yaml b/components/index_qdrant/fondant_component.yaml index 6feb3b257..68ea33847 100644 --- a/components/index_qdrant/fondant_component.yaml +++ b/components/index_qdrant/fondant_component.yaml @@ -7,14 +7,12 @@ image: 'fndnt/index_qdrant:dev' tags: - Data writing consumes: - text: - fields: - data: - type: string - embedding: - type: array - items: - type: float32 + text_data: + type: string + embeddings_data: + type: array + items: + type: float32 args: collection_name: description: The name of the Qdrant collection to upsert data into. diff --git a/components/prompt_based_laion_retrieval/Dockerfile b/components/prompt_based_laion_retrieval/Dockerfile index 72525d884..0cdcde81a 100644 --- a/components/prompt_based_laion_retrieval/Dockerfile +++ b/components/prompt_based_laion_retrieval/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=linux/amd64 python:3.8-slim +FROM --platform=linux/amd64 python:3.8-slim as base # System dependencies RUN apt-get update && \ @@ -16,8 +16,15 @@ RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team # Set the working directory to the component folder WORKDIR /component/src +COPY src/ src/ +ENV PYTHONPATH "${PYTHONPATH}:./src" -# Copy over src-files -COPY src/ . +FROM base as test +COPY test_requirements.txt . +RUN pip3 install --no-cache-dir -r test_requirements.txt +COPY tests/ tests/ +RUN python -m pytest tests -ENTRYPOINT ["fondant", "execute", "main"] \ No newline at end of file +FROM base +WORKDIR /component/src +ENTRYPOINT ["fondant", "execute", "main"] diff --git a/components/prompt_based_laion_retrieval/README.md b/components/prompt_based_laion_retrieval/README.md index 8d7ffcf70..0551730d9 100644 --- a/components/prompt_based_laion_retrieval/README.md +++ b/components/prompt_based_laion_retrieval/README.md @@ -17,6 +17,7 @@ This component doesn’t return the actual images, only URLs. **This component produces:** - images_url: string +- prompt_id: string ### Arguments @@ -50,3 +51,9 @@ prompt_based_laion_retrieval_op = ComponentOp.from_registry( pipeline.add_op(prompt_based_laion_retrieval_op, dependencies=[...]) #Add previous component as dependency ``` +### Testing + +You can run the tests using docker with BuildKit. From this directory, run: +``` +docker build . --target test +``` diff --git a/components/prompt_based_laion_retrieval/fondant_component.yaml b/components/prompt_based_laion_retrieval/fondant_component.yaml index 02ea08349..3ac3604ac 100644 --- a/components/prompt_based_laion_retrieval/fondant_component.yaml +++ b/components/prompt_based_laion_retrieval/fondant_component.yaml @@ -16,7 +16,10 @@ consumes: produces: images_url: type: string -# additionalFields: false + prompt_id: + type: string + +previous_index: prompt_id args: num_images: diff --git a/components/prompt_based_laion_retrieval/src/main.py b/components/prompt_based_laion_retrieval/src/main.py index 2168f5ef0..bd3cee783 100644 --- a/components/prompt_based_laion_retrieval/src/main.py +++ b/components/prompt_based_laion_retrieval/src/main.py @@ -41,6 +41,10 @@ def __init__( modality=Modality.IMAGE, ) + def query(self, id_: t.Any, prompt: str) -> t.List[t.Dict]: + results = self.client.query(text=prompt) + return [dict(d, prompt_id=id_) for d in results] + def transform( self, dataframe: pd.DataFrame, @@ -53,21 +57,20 @@ async def async_query(): futures = [ loop.run_in_executor( executor, - self.client.query, - prompt, + self.query, + row.id, + row.prompts_text, ) - for prompt in dataframe["prompts_text"] + for row in dataframe.itertuples() ] for response in await asyncio.gather(*futures): results.extend(response) loop.run_until_complete(async_query()) - results_df = pd.DataFrame(results)["id", "url"] + results_df = pd.DataFrame(results)[["id", "url", "prompt_id"]] results_df = results_df.set_index("id") - # Cast the index to string - results_df.index = results_df.index.astype(str) - results_df.columns = ["images_url"] + results_df.rename(columns={"url": "images_url"}) return results_df diff --git a/components/prompt_based_laion_retrieval/test_requirements.txt b/components/prompt_based_laion_retrieval/test_requirements.txt new file mode 100644 index 000000000..2a929edcc --- /dev/null +++ b/components/prompt_based_laion_retrieval/test_requirements.txt @@ -0,0 +1 @@ +pytest==7.4.2 diff --git a/components/prompt_based_laion_retrieval/tests/pytest.ini b/components/prompt_based_laion_retrieval/tests/pytest.ini new file mode 100644 index 000000000..bf6a8a517 --- /dev/null +++ b/components/prompt_based_laion_retrieval/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = ../src \ No newline at end of file diff --git a/components/prompt_based_laion_retrieval/tests/test_component.py b/components/prompt_based_laion_retrieval/tests/test_component.py new file mode 100644 index 000000000..7a3a268e6 --- /dev/null +++ b/components/prompt_based_laion_retrieval/tests/test_component.py @@ -0,0 +1,66 @@ +import typing as t + +import pandas as pd + +from src.main import LAIONRetrievalComponent + + +def test_component(monkeypatch): + def mocked_client_query(text: str) -> t.List[dict]: + if text == "first prompt": + return [ + { + "id": "a", + "url": "http://a", + }, + { + "id": "b", + "url": "http://b", + }, + ] + if text == "second prompt": + return [ + { + "id": "c", + "url": "http://c", + }, + { + "id": "d", + "url": "http://d", + }, + ] + msg = f"Unexpected value: `text` was {text}" + raise ValueError(msg) + + input_dataframe = pd.DataFrame.from_dict( + { + "id": ["1", "2"], + "prompts_text": ["first prompt", "second prompt"], + }, + ) + + expected_output_dataframe = pd.DataFrame.from_dict( + { + "id": ["a", "b", "c", "d"], + "url": ["http://a", "http://b", "http://c", "http://d"], + "prompt_id": ["1", "1", "2", "2"], + }, + ) + expected_output_dataframe = expected_output_dataframe.set_index("id") + + component = LAIONRetrievalComponent( + num_images=2, + aesthetic_score=9, + aesthetic_weight=0.5, + url="", + ) + + monkeypatch.setattr(component.client, "query", mocked_client_query) + + output_dataframe = component.transform(input_dataframe) + + pd.testing.assert_frame_equal( + left=expected_output_dataframe, + right=output_dataframe, + check_dtype=False, + ) diff --git a/src/fondant/component/executor.py b/src/fondant/component/executor.py index d77200da8..571bc60bb 100644 --- a/src/fondant/component/executor.py +++ b/src/fondant/component/executor.py @@ -548,28 +548,11 @@ def _execute_component( ) # Clear divisions if component spec indicates that the index is changed - if self._infer_index_change(): + if self.spec.previous_index is not None: dataframe.clear_divisions() return dataframe - # TODO: fix in #244 - def _infer_index_change(self) -> bool: - """Infer if this component changes the index based on its component spec.""" - """ - if not self.spec.accepts_additional_subsets: - return True - if not self.spec.outputs_additional_subsets: - return True - for subset in self.spec.consumes.values(): - if not subset.additional_fields: - return True - return any( - not subset.additional_fields for subset in self.spec.produces.values() - ) - """ - return False - class DaskWriteExecutor(Executor[DaskWriteComponent]): """Base class for a Fondant write component.""" diff --git a/src/fondant/core/component_spec.py b/src/fondant/core/component_spec.py index 4dd945568..1700e10a1 100644 --- a/src/fondant/core/component_spec.py +++ b/src/fondant/core/component_spec.py @@ -181,6 +181,10 @@ def produces(self) -> t.Mapping[str, Field]: }, ) + @property + def previous_index(self) -> t.Optional[str]: + return self._specification.get("previous_index") + @property def args(self) -> t.Mapping[str, Argument]: args = self.default_arguments diff --git a/src/fondant/core/manifest.py b/src/fondant/core/manifest.py index 58c8ab045..4f0aab480 100644 --- a/src/fondant/core/manifest.py +++ b/src/fondant/core/manifest.py @@ -267,7 +267,10 @@ def evolve( # : PLR0912 (too many branches) Field(name="index", location=component_spec.component_folder_name), ) - # TODO handle additionalFields + # Remove all previous fields if the component changes the index + if component_spec.previous_index: + for field_name in evolved_manifest.fields: + evolved_manifest.remove_field(field_name) # Add or update all produced fields defined in the component spec for name, field in component_spec.produces.items(): diff --git a/src/fondant/core/schemas/component_spec.json b/src/fondant/core/schemas/component_spec.json index 064ea027d..dfa6bf68c 100644 --- a/src/fondant/core/schemas/component_spec.json +++ b/src/fondant/core/schemas/component_spec.json @@ -33,6 +33,9 @@ "produces": { "$ref": "common.json#/definitions/fields" }, + "previous_index": { + "type": "string" + }, "args": { "$ref": "#/definitions/args" } diff --git a/tests/component/examples/component_specs/component.yaml b/tests/component/examples/component_specs/component.yaml index 973cc3e6b..d1f28b76e 100644 --- a/tests/component/examples/component_specs/component.yaml +++ b/tests/component/examples/component_specs/component.yaml @@ -11,8 +11,6 @@ produces: type: array items: type: float32 -additionalFields: false - args: flag: diff --git a/tests/core/examples/evolution_examples/2/component.yaml b/tests/core/examples/evolution_examples/2/component.yaml index 2352adcb5..95d9300d1 100644 --- a/tests/core/examples/evolution_examples/2/component.yaml +++ b/tests/core/examples/evolution_examples/2/component.yaml @@ -7,8 +7,10 @@ consumes: type: binary produces: - images_encoding: - type: string + images_data: + type: binary + +previous_index: "true" # Only used to remove old fields for now args: storage_args: diff --git a/tests/core/examples/evolution_examples/2/output_manifest.json b/tests/core/examples/evolution_examples/2/output_manifest.json index ca1f6f361..db62fda15 100644 --- a/tests/core/examples/evolution_examples/2/output_manifest.json +++ b/tests/core/examples/evolution_examples/2/output_manifest.json @@ -9,25 +9,9 @@ "location":"/example_component" }, "fields": { - "images_width": { - "type": "int32", - "location":"/example_component" - }, - "images_height": { - "type": "int32", - "location":"/example_component" - }, "images_data": { "type": "binary", "location":"/example_component" - }, - "captions_data": { - "type": "binary", - "location":"/example_component" - }, - "images_encoding": { - "type": "string", - "location":"/example_component" } } } \ No newline at end of file