diff --git a/components/download_images/fondant_component.yaml b/components/download_images/fondant_component.yaml index 6f4262e29..1efaa48d4 100644 --- a/components/download_images/fondant_component.yaml +++ b/components/download_images/fondant_component.yaml @@ -23,21 +23,28 @@ args: timeout: description: Maximum time (in seconds) to wait when trying to download an image type: int + default: 10 retries: description: Number of times to retry downloading an image if it fails. type: int + default: 0 image_size: description: Size of the images after resizing. type: int + default: 256 resize_mode: description: Resize mode to use. One of "no", "keep_ratio", "center_crop", "border". type: str + default: 'border' resize_only_if_bigger: description: If True, resize only if image is bigger than image_size. type: bool + default: 'False' min_image_size: description: Minimum size of the images. type: int + default: 0 max_aspect_ratio: description: Maximum aspect ratio of the images. - type: float \ No newline at end of file + type: float + default: 'inf' \ No newline at end of file diff --git a/components/download_images/src/main.py b/components/download_images/src/main.py index 9b222e3b0..017001e0d 100644 --- a/components/download_images/src/main.py +++ b/components/download_images/src/main.py @@ -10,10 +10,10 @@ import traceback import urllib -import pandas as pd +import dask.dataframe as dd from resizer import Resizer -from fondant.component import PandasTransformComponent +from fondant.component import DaskTransformComponent logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives): else None ) if (ua_token is None or ua_token == user_agent_token) and any( - x in disallowed_header_directives for x in directives + x in disallowed_header_directives for x in directives ): return True except Exception as err: @@ -53,9 +53,9 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives) ) with urllib.request.urlopen(request, timeout=timeout) as r: if disallowed_header_directives and is_disallowed( - r.headers, - user_agent_token, - disallowed_header_directives, + r.headers, + user_agent_token, + disallowed_header_directives, ): return None img_stream = io.BytesIO(r.read()) @@ -67,13 +67,13 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives) def download_image_with_retry( - url, - *, - timeout, - retries, - resizer, - user_agent_token=None, - disallowed_header_directives=None, + url, + *, + timeout, + retries, + resizer, + user_agent_token=None, + disallowed_header_directives=None, ): for _ in range(retries + 1): img_stream = download_image( @@ -81,50 +81,71 @@ def download_image_with_retry( ) if img_stream is not None: # resize the image - return resizer(img_stream) + img_str, width, height = resizer(img_stream) + return img_str, width, height return None, None, None -class DownloadImagesComponent(PandasTransformComponent): +class DownloadImagesComponent(DaskTransformComponent): """Component that downloads images based on URLs.""" - def setup( - self, - *, - timeout: int = 10, - retries: int = 0, - image_size: int = 256, - resize_mode: str = "border", - resize_only_if_bigger: bool = False, - min_image_size: int = 0, - max_aspect_ratio: float = float("inf"), - ): + def transform( + self, + dataframe: dd.DataFrame, + *, + timeout: int, + retries: int, + image_size: int, + resize_mode: str, + resize_only_if_bigger: bool, + min_image_size: int, + max_aspect_ratio: float, + ) -> dd.DataFrame: + """Function that downloads images from a list of URLs and executes filtering and resizing + Args: + dataframe: Dask dataframe + timeout: Maximum time (in seconds) to wait when trying to download an image. + retries: Number of times to retry downloading an image if it fails. + image_size: Size of the images after resizing. + resize_mode: Resize mode to use. One of "no", "keep_ratio", "center_crop", "border". + resize_only_if_bigger: If True, resize only if image is bigger than image_size. + min_image_size: Minimum size of the images. + max_aspect_ratio: Maximum aspect ratio of the images. + + Returns: + Dask dataframe + """ logger.info("Instantiating resizer...") - self.resizer = Resizer( + resizer = Resizer( image_size=image_size, resize_mode=resize_mode, resize_only_if_bigger=resize_only_if_bigger, min_image_size=min_image_size, max_aspect_ratio=max_aspect_ratio, ) - self.timeout = timeout - self.retries = retries - - def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: - dataframe[[ - ("images", "data"), - ("images", "width"), - ("images", "height"), - ]] = dataframe.apply( + + # Remove duplicates from laion retrieval + dataframe = dataframe.drop_duplicates() + + dataframe = dataframe.apply( lambda example: download_image_with_retry( - url=example["images"]["url"], - timeout=self.timeout, - retries=self.retries, - resizer=self.resizer, + url=example.images_url, + timeout=timeout, + retries=retries, + resizer=resizer, ), axis=1, result_type="expand", + meta={0: bytes, 1: int, 2: int}, ) + dataframe.columns = [ + "images_data", + "images_width", + "images_height", + ] + + # Remove images that could not be fetched + dataframe = dataframe.dropna() return dataframe diff --git a/components/download_images/src/resizer.py b/components/download_images/src/resizer.py index 386a71d3f..f545a0bf1 100644 --- a/components/download_images/src/resizer.py +++ b/components/download_images/src/resizer.py @@ -174,20 +174,20 @@ def __call__(self, img_stream, blurring_bbox_list=None): original_height, original_width = img.shape[:2] # check if image is too small if min(original_height, original_width) < self.min_image_size: - return None, None, None, None, None, "image too small" + return None, None, None if original_height * original_width > self.max_image_area: - return None, None, None, None, None, "image area too large" + return None, None, None # check if wrong aspect ratio if ( max(original_height, original_width) / min(original_height, original_width) > self.max_aspect_ratio ): - return None, None, None, None, None, "aspect ratio too large" + return None, None, None # check if resizer was defined during init if needed if blurring_bbox_list is not None and self.blurrer is None: - return None, None, None, None, None, "blurrer not defined" + return None, None, None # Flag to check if blurring is still needed. maybe_blur_still_needed = True diff --git a/components/prompt_based_laion_retrieval/fondant_component.yaml b/components/prompt_based_laion_retrieval/fondant_component.yaml index 09cdb630d..5fa3bf331 100644 --- a/components/prompt_based_laion_retrieval/fondant_component.yaml +++ b/components/prompt_based_laion_retrieval/fondant_component.yaml @@ -25,3 +25,7 @@ args: aesthetic_weight: description: Weight of the aesthetic embedding when added to the query, between 0 and 1 type: float + url: + description: The url of the backend clip retrieval service, defaults to the public service + type: str + default: https://knn.laion.ai/knn-service \ No newline at end of file diff --git a/components/prompt_based_laion_retrieval/src/main.py b/components/prompt_based_laion_retrieval/src/main.py index 5109e94e5..6dbc39a57 100644 --- a/components/prompt_based_laion_retrieval/src/main.py +++ b/components/prompt_based_laion_retrieval/src/main.py @@ -21,6 +21,7 @@ def setup( num_images: int, aesthetic_score: int, aesthetic_weight: float, + url: str, ) -> None: """ @@ -30,10 +31,11 @@ def setup( between 0 and 9. aesthetic_weight: weight of the aesthetic embedding to add to the query, between 0 and 1. + url: The url of the backend clip retrieval service, defaults to the public clip url. """ self.client = ClipClient( - url="https://knn.laion.ai/knn-service", - indice_name="laion5B-L-14", + url=url, + indice_name="laion5B", #TODO:revert back to laion5b-L-14 after backend correction num_images=num_images, aesthetic_score=aesthetic_score, aesthetic_weight=aesthetic_weight, diff --git a/components/segment_images/src/main.py b/components/segment_images/src/main.py index b666127b3..89e9193ef 100644 --- a/components/segment_images/src/main.py +++ b/components/segment_images/src/main.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -def convert_to_rgb(seg: np.array): +def convert_to_rgb(seg: np.array) -> bytes: """ Converts a 2D segmentation to a RGB one which makes it possible to visualize it. @@ -23,7 +23,7 @@ def convert_to_rgb(seg: np.array): seg: 2D segmentation map as a NumPy array. Returns: - color_seg: 3D segmentation map contain RGB values for each pixel. + color_seg: the RGB segmentation map as a binary string """ color_seg = np.zeros( (seg.shape[0], seg.shape[1], 3), dtype=np.uint8, @@ -32,9 +32,13 @@ def convert_to_rgb(seg: np.array): for label, color in enumerate(palette): color_seg[seg == label, :] = color - color_seg = color_seg.astype(np.uint8).tobytes() + color_seg = color_seg.astype(np.uint8) + image = Image.fromarray(color_seg).convert('RGB') - return color_seg + crop_bytes = io.BytesIO() + image.save(crop_bytes, format="JPEG") + + return crop_bytes.getvalue() def process_image(image: bytes, *, processor: SegformerImageProcessor, device: str) -> torch.Tensor: @@ -46,6 +50,7 @@ def process_image(image: bytes, *, processor: SegformerImageProcessor, device: s processor: The processor object for transforming the image. device: The device to move the transformed image to. """ + def load(img: bytes) -> Image: """Load the bytestring as an image.""" bytes_ = io.BytesIO(img) diff --git a/components/write_to_hf_hub/src/main.py b/components/write_to_hf_hub/src/main.py index a81bcb5c9..c3022b234 100644 --- a/components/write_to_hf_hub/src/main.py +++ b/components/write_to_hf_hub/src/main.py @@ -8,6 +8,7 @@ # Define the schema for the struct using PyArrow import huggingface_hub +from datasets.features.features import generate_from_arrow_type from PIL import Image from fondant.component import WriteComponent @@ -71,7 +72,7 @@ def write( if image_column_names and column_name in image_column_names: schema_dict[column_name] = datasets.Image() else: - schema_dict[column_name] = datasets.Value(str(field.type.value)) + schema_dict[column_name] = generate_from_arrow_type(field.type.value) schema = datasets.Features(schema_dict).arrow_schema dataframe = dataframe[write_columns] diff --git a/examples/pipelines/controlnet-interior-design/components/write_to_hub_controlnet/fondant_component.yaml b/examples/pipelines/controlnet-interior-design/components/write_to_hub_controlnet/fondant_component.yaml index efb253159..4915810f0 100644 --- a/examples/pipelines/controlnet-interior-design/components/write_to_hub_controlnet/fondant_component.yaml +++ b/examples/pipelines/controlnet-interior-design/components/write_to_hub_controlnet/fondant_component.yaml @@ -16,9 +16,7 @@ consumes: segmentations: fields: data: - type: array - items: - type: binary + type: binary args: hf_token: diff --git a/examples/pipelines/controlnet-interior-design/pipeline.py b/examples/pipelines/controlnet-interior-design/pipeline.py index 33b4d9054..ecb1c2e06 100644 --- a/examples/pipelines/controlnet-interior-design/pipeline.py +++ b/examples/pipelines/controlnet-interior-design/pipeline.py @@ -20,12 +20,17 @@ ) laion_retrieval_op = ComponentOp.from_registry( name="prompt_based_laion_retrieval", - arguments={"num_images": 2, "aesthetic_score": 9, "aesthetic_weight": 0.5}, + arguments={ + "num_images": 2, + "aesthetic_score": 9, + "aesthetic_weight": 0.5, + "url": None, + }, ) download_images_op = ComponentOp.from_registry( name="download_images", arguments={ - "timeout": 10, + "timeout": 1, "retries": 0, "image_size": 512, "resize_mode": "center_crop", @@ -63,8 +68,6 @@ "hf_token": "hf_token", "image_column_names": ["images_data"], }, - number_of_gpus=1, - node_pool_name="model-inference-pool", ) pipeline = Pipeline(pipeline_name=pipeline_name, base_path=PipelineConfigs.BASE_PATH)