diff --git a/components/caption_images/Dockerfile b/components/caption_images/Dockerfile index 77e9ccdbe..9713a5785 100644 --- a/components/caption_images/Dockerfile +++ b/components/caption_images/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime +FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime as base # System dependencies RUN apt-get update && \ @@ -15,9 +15,16 @@ ARG FONDANT_VERSION=main RUN pip3 install fondant[aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} # Set the working directory to the component folder -WORKDIR /component/src +WORKDIR /component +COPY src/ src/ +ENV PYTHONPATH "${PYTHONPATH}:./src" -# Copy over src-files -COPY src/ . +FROM base as test +COPY test_requirements.txt . +RUN pip3 install --no-cache-dir -r test_requirements.txt +COPY tests/ tests/ +RUN python -m pytest tests +FROM base +WORKDIR /component/src ENTRYPOINT ["fondant", "execute", "main"] diff --git a/components/caption_images/test_requirements.txt b/components/caption_images/test_requirements.txt new file mode 100644 index 000000000..de1887bec --- /dev/null +++ b/components/caption_images/test_requirements.txt @@ -0,0 +1 @@ +pytest==7.4.2 \ No newline at end of file diff --git a/components/caption_images/tests/test_caption_images.py b/components/caption_images/tests/test_caption_images.py index 08e34c319..46399f4e3 100644 --- a/components/caption_images/tests/test_caption_images.py +++ b/components/caption_images/tests/test_caption_images.py @@ -1,27 +1,31 @@ import pandas as pd import requests -from caption_images.src.main import CaptionImagesComponent -from fondant.abstract_component_test import AbstractComponentTest +from src.main import CaptionImagesComponent -class TestCaptionImagesComponent(AbstractComponentTest): - def create_component(self): - return CaptionImagesComponent( + +def test_image_caption_component(): + image_urls = [ + "https://cdn.pixabay.com/photo/2023/06/29/09/52/angkor-thom-8096092_1280.jpg", + "https://cdn.pixabay.com/photo/2023/07/19/18/56/japanese-beetle-8137606_1280.png", + ] + input_dataframe = pd.DataFrame( + {"images": {"data": [requests.get(url).content for url in image_urls]}}) + + expected_output_dataframe = pd.DataFrame( + data={("captions", "text"): {0: "a motorcycle", 1: "a beetle"}}, + ) + + component = CaptionImagesComponent( model_id="Salesforce/blip-image-captioning-base", batch_size=4, max_new_tokens=2, ) - def create_input_data(self): - image_urls = [ - "https://cdn.pixabay.com/photo/2023/06/29/09/52/angkor-thom-8096092_1280.jpg", - "https://cdn.pixabay.com/photo/2023/07/19/18/56/japanese-beetle-8137606_1280.png", - ] - return pd.DataFrame( - {"images": {"data": [requests.get(url).content for url in image_urls]}}, - ) + output_dataframe = component.transform(input_dataframe) - def create_output_data(self): - return pd.DataFrame( - data={("captions", "text"): {0: "a motorcycle", 1: "a beetle"}}, - ) + pd.testing.assert_frame_equal( + left=expected_output_dataframe, + right=output_dataframe, + check_dtype=False, + ) diff --git a/src/fondant/abstract_component_test.py b/src/fondant/abstract_component_test.py deleted file mode 100644 index 92a1dd69f..000000000 --- a/src/fondant/abstract_component_test.py +++ /dev/null @@ -1,47 +0,0 @@ -from abc import ABC, abstractmethod - -import pandas as pd -import pytest - - -class AbstractComponentTest(ABC): - @abstractmethod - def create_component(self): - """ - This method should be implemented by concrete test classes - to create the specific component - that needs to be tested. - """ - raise NotImplementedError - - @abstractmethod - def create_input_data(self): - """This method should be implemented by concrete test classes - to create the specific input data. - """ - raise NotImplementedError - - @abstractmethod - def create_output_data(self): - """This method should be implemented by concrete test classes - to create the specific output data. - """ - raise NotImplementedError - - @pytest.fixture(autouse=True) - def __setUp(self): - """ - This method will be run before each test method. - Add any common setup steps for your components here. - """ - self.component = self.create_component() - self.input_data = self.create_input_data() - self.expected_output_data = self.create_output_data() - - def test_transform(self): - """ - Default test for the transform method. - Tests if the transform method executes without errors. - """ - output = self.component.transform(self.input_data) - pd.testing.assert_frame_equal(output, self.expected_output_data)