From 24d8c23e8131d35d1abb539cd23d0ec1005af033 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Tue, 18 Jul 2023 15:06:30 +0200 Subject: [PATCH] Split component implementation and execution (#302) This PR follows up on the PoC presented in #268 --- Fixes #257 It splits the implementation and execution of components, this has some advantages: - Pandas components can use `__init__` instead of setup, which is probably more familiar to users - Other components can use `__init__` as well instead of receiving all arguments to their transform or equivalent method, aligning implementation of different component types - Component implementation and execution should be easier to test separately I borrowed the executor terminology from KfP. --- Fixes #203 Since I had to update all the components, I also switched some of them to subclass `PandasTransformComponent` instead of `DaskTransformComponent`. --- These changes open some opportunities for additional improvements, but I propose to tackle those as separate PRs as this PR is already quite huge due to all the changes to the components. - [ ] #300 - [ ] #301 --- README.md | 33 +- components/caption_images/src/main.py | 7 +- components/download_images/src/main.py | 48 ++- .../src/main.py | 7 +- components/filter_comments/src/main.py | 48 +-- .../filter_image_resolution/src/main.py | 21 +- components/filter_line_length/src/main.py | 37 +- components/image_cropping/src/main.py | 45 +- components/image_embedding/src/main.py | 9 +- .../image_resolution_extraction/src/main.py | 5 +- components/language_filter/src/main.py | 7 +- components/load_from_hf_hub/src/main.py | 39 +- components/minhash_generator/src/main.py | 7 +- components/pii_redaction/src/main.py | 33 +- .../prompt_based_laion_retrieval/src/main.py | 9 +- components/segment_images/src/main.py | 7 +- components/text_length_filter/src/main.py | 17 +- components/text_normalization/src/main.py | 7 +- components/write_to_hf_hub/src/main.py | 43 +- docs/component_spec.md | 36 +- docs/custom_component.md | 34 +- .../load_from_commoncrawl/src/main.py | 28 +- .../components/generate_prompts/src/main.py | 21 +- .../cluster_image_embeddings/src/main.py | 22 +- .../filter_text_complexity/src/main.py | 9 +- src/fondant/component.py | 391 ++---------------- src/fondant/executor.py | 377 +++++++++++++++++ src/fondant/pipeline.py | 3 +- tests/test_component.py | 118 ++++-- 29 files changed, 799 insertions(+), 669 deletions(-) create mode 100644 src/fondant/executor.py diff --git a/README.md b/README.md index 3b061a4c9..03ced8bf3 100644 --- a/README.md +++ b/README.md @@ -222,24 +222,39 @@ args: type: str ``` -Once you have your component specification, all you need to do is implement a single `.transform` -method and Fondant will do the rest. You will get the data defined in your specification as a -[Dask](https://www.dask.org/) dataframe, which is evaluated lazily. +Once you have your component specification, all you need to do is implement a constructor +and a single `.transform` method and Fondant will do the rest. You will get the data defined in +your specification partition by partition as a Pandas dataframe. ```python -from fondant.component import TransformComponent +import pandas as pd +from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor -class ExampleComponent(TransformComponent): - def transform(self, dataframe, *, argument1, argument2): - """Implement your custom logic in this single method - +class ExampleComponent(PandasTransformComponent): + + def __init__(self, *args, argument1, argument2) -> None: + """ Args: - dataframe: A Dask dataframe containing the data argumentX: An argument passed to the component """ + # Initialize your component here based on the arguments + + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + """Implement your custom logic in this single method + Args: + dataframe: A Pandas dataframe containing the data + Returns: + A pandas dataframe containing the transformed data + """ + +if __name__ == "__main__": + executor = PandasTransformExecutor.from_args() + executor.execute(ExampleComponent) ``` +For more advanced use cases, you can use the `DaskTransformComponent` instead. ### Running your pipeline diff --git a/components/caption_images/src/main.py b/components/caption_images/src/main.py index 72f4a5fe4..669fc4bd3 100644 --- a/components/caption_images/src/main.py +++ b/components/caption_images/src/main.py @@ -7,6 +7,7 @@ import pandas as pd import torch from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor from PIL import Image from transformers import BatchEncoding, BlipForConditionalGeneration, BlipProcessor @@ -52,7 +53,7 @@ def caption_image_batch( class CaptionImagesComponent(PandasTransformComponent): """Component that captions images using a model from the Hugging Face hub.""" - def setup(self, *, model_id: str, batch_size: int, max_new_tokens: int) -> None: + def __init__(self, *args, model_id: str, batch_size: int, max_new_tokens: int) -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Device: {self.device}") @@ -85,5 +86,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = CaptionImagesComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(CaptionImagesComponent) diff --git a/components/download_images/src/main.py b/components/download_images/src/main.py index c5dd07376..f94f8bb57 100644 --- a/components/download_images/src/main.py +++ b/components/download_images/src/main.py @@ -12,6 +12,7 @@ import dask.dataframe as dd from fondant.component import DaskTransformComponent +from fondant.executor import DaskTransformExecutor from resizer import Resizer logger = logging.getLogger(__name__) @@ -88,21 +89,19 @@ def download_image_with_retry( class DownloadImagesComponent(DaskTransformComponent): """Component that downloads images based on URLs.""" - def transform( - self, - dataframe: dd.DataFrame, - *, - timeout: int, - retries: int, - image_size: int, - resize_mode: str, - resize_only_if_bigger: bool, - min_image_size: int, - max_aspect_ratio: float, - ) -> dd.DataFrame: - """Function that downloads images from a list of URLs and executes filtering and resizing + def __init__(self, + *_, + timeout: int, + retries: int, + image_size: int, + resize_mode: str, + resize_only_if_bigger: bool, + min_image_size: int, + max_aspect_ratio: float, + ): + """Component that downloads images from a list of URLs and executes filtering and resizing. + Args: - dataframe: Dask dataframe timeout: Maximum time (in seconds) to wait when trying to download an image. retries: Number of times to retry downloading an image if it fails. image_size: Size of the images after resizing. @@ -114,8 +113,9 @@ def transform( Returns: Dask dataframe """ - logger.info("Instantiating resizer...") - resizer = Resizer( + self.timeout = timeout + self.retries = retries + self.resizer = Resizer( image_size=image_size, resize_mode=resize_mode, resize_only_if_bigger=resize_only_if_bigger, @@ -123,15 +123,21 @@ def transform( max_aspect_ratio=max_aspect_ratio, ) + def transform( + self, + dataframe: dd.DataFrame, + ) -> dd.DataFrame: + logger.info("Instantiating resizer...") + # Remove duplicates from laion retrieval dataframe = dataframe.drop_duplicates() dataframe = dataframe.apply( lambda example: download_image_with_retry( url=example.images_url, - timeout=timeout, - retries=retries, - resizer=resizer, + timeout=self.timeout, + retries=self.retries, + resizer=self.resizer, ), axis=1, result_type="expand", @@ -150,5 +156,5 @@ def transform( if __name__ == "__main__": - component = DownloadImagesComponent.from_args() - component.run() + executor = DaskTransformExecutor.from_args() + executor.execute(DownloadImagesComponent) diff --git a/components/embedding_based_laion_retrieval/src/main.py b/components/embedding_based_laion_retrieval/src/main.py index 4a4e55204..3c870372f 100644 --- a/components/embedding_based_laion_retrieval/src/main.py +++ b/components/embedding_based_laion_retrieval/src/main.py @@ -8,6 +8,7 @@ import pandas as pd from clip_client import ClipClient, Modality from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) @@ -15,7 +16,7 @@ class LAIONRetrievalComponent(PandasTransformComponent): """Component that retrieves image URLs from LAION-5B based on a set of CLIP embeddings.""" - def setup( + def __init__( self, *, num_images: int, @@ -70,5 +71,5 @@ async def async_query(): if __name__ == "__main__": - component = LAIONRetrievalComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(LAIONRetrievalComponent) diff --git a/components/filter_comments/src/main.py b/components/filter_comments/src/main.py index 6925afe66..57f6f490c 100644 --- a/components/filter_comments/src/main.py +++ b/components/filter_comments/src/main.py @@ -4,41 +4,35 @@ """ import logging -import dask.dataframe as dd -from fondant.component import DaskTransformComponent +import pandas as pd +from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor from utils.text_extraction import get_comments_to_code_ratio logger = logging.getLogger(__name__) -class FilterCommentsComponent(DaskTransformComponent): - """Component that filters instances based on code to comments ratio.""" +class FilterCommentsComponent(PandasTransformComponent): + """Component that filters instances based on code to comments ratio. + + Args: + min_comments_ratio: The minimum code to comment ratio + max_comments_ratio: The maximum code to comment ratio + """ + + def __init__(self, *args, min_comments_ratio: float, max_comments_ratio: float) -> None: + self.min_comments_ratio = min_comments_ratio + self.max_comments_ratio = max_comments_ratio def transform( self, - *, - dataframe: dd.DataFrame, - min_comments_ratio: float, - max_comments_ratio: float, - ) -> dd.DataFrame: - """ - Args: - dataframe: Dask dataframe - min_comments_ratio: The minimum code to comment ratio - max_comments_ratio: The maximum code to comment ratio - Returns: - Filtered dask dataframe. - """ - # Apply the function to the desired column and filter the DataFrame - return dataframe[ - dataframe["code_content"].map_partitions( - lambda example: example.map(get_comments_to_code_ratio).between( - min_comments_ratio, max_comments_ratio, - ), - ) - ] + dataframe: pd.DataFrame, + ) -> pd.DataFrame: + comment_to_code_ratio = dataframe["code"]["content"].apply(get_comments_to_code_ratio) + mask = comment_to_code_ratio.between(self.min_comments_ratio, self.max_comments_ratio) + return dataframe[mask] if __name__ == "__main__": - component = FilterCommentsComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(FilterCommentsComponent) diff --git a/components/filter_image_resolution/src/main.py b/components/filter_image_resolution/src/main.py index c59ca5a60..c6e0276c0 100644 --- a/components/filter_image_resolution/src/main.py +++ b/components/filter_image_resolution/src/main.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) @@ -13,20 +14,16 @@ class FilterImageResolutionComponent(PandasTransformComponent): """Component that filters images based on height and width.""" - def setup(self, *, min_image_dim: int, max_aspect_ratio: float) -> None: - self.min_image_dim = min_image_dim - self.max_aspect_ratio = max_aspect_ratio - - def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + def __init__(self, *_, min_image_dim: int, max_aspect_ratio: float) -> None: """ Args: - dataframe: Pandas dataframe min_image_dim: minimum image dimension. - min_aspect_ratio: minimum aspect ratio. - - Returns: - filtered Pandas dataframe + max_aspect_ratio: maximum aspect ratio. """ + self.min_image_dim = min_image_dim + self.max_aspect_ratio = max_aspect_ratio + + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: width = dataframe["image"]["width"] height = dataframe["image"]["height"] min_image_dim = np.minimum(width, height) @@ -38,5 +35,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = FilterImageResolutionComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(FilterImageResolutionComponent) diff --git a/components/filter_line_length/src/main.py b/components/filter_line_length/src/main.py index a2792b38c..10ded7111 100644 --- a/components/filter_line_length/src/main.py +++ b/components/filter_line_length/src/main.py @@ -1,42 +1,45 @@ """This component filters code based on a set of metadata associated with it.""" import logging -import dask.dataframe as dd -from fondant.component import DaskTransformComponent +import pandas as pd +from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) -class FilterLineLengthComponent(DaskTransformComponent): +class FilterLineLengthComponent(PandasTransformComponent): """ This component filters code based on a set of metadata associated with it: average line length, maximum line length and alphanum fraction. """ - def transform( - self, - *, - dataframe: dd.DataFrame, + def __init__(self, *_, avg_line_length_threshold: int, max_line_length_threshold: int, alphanum_fraction_threshold: float, - ) -> dd.DataFrame: + ) -> None: """ Args: - dataframe: Dask dataframe avg_line_length_threshold: Threshold for average line length to filter on max_line_length_threshold: Threshold for max line length to filter on - alphanum_fraction_threshold: Alphanum fraction to filter on - Returns: - Filtered dask dataframe. + alphanum_fraction_threshold: Alphanum fraction to filter on. """ + self.avg_line_length_threshold = avg_line_length_threshold + self.max_line_length_threshold = max_line_length_threshold + self.alphanum_fraction_threshold = alphanum_fraction_threshold + + def transform( + self, + dataframe: pd.DataFrame, + ) -> pd.DataFrame: return dataframe[ - (dataframe["code_avg_line_length"] > avg_line_length_threshold) - & (dataframe["code_max_line_length"] > max_line_length_threshold) - & (dataframe["code_alphanum_fraction"] > alphanum_fraction_threshold) + (dataframe["code_avg_line_length"] > self.avg_line_length_threshold) + & (dataframe["code_max_line_length"] > self.max_line_length_threshold) + & (dataframe["code_alphanum_fraction"] > self.alphanum_fraction_threshold) ] if __name__ == "__main__": - component = FilterLineLengthComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(FilterLineLengthComponent) diff --git a/components/image_cropping/src/main.py b/components/image_cropping/src/main.py index 2be401db8..4215d2f92 100644 --- a/components/image_cropping/src/main.py +++ b/components/image_cropping/src/main.py @@ -3,65 +3,62 @@ import logging import typing as t -import dask.dataframe as dd import numpy as np -from fondant.component import DaskTransformComponent +import pandas as pd +from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor from image_crop import remove_borders from PIL import Image logger = logging.getLogger(__name__) -def extract_dimensions(image_df: dd.DataFrame) -> t.Tuple[np.int16, np.int16]: +def extract_dimensions(image_bytes: bytes) -> t.Tuple[np.int16, np.int16]: """Extract the width and height of an image. Args: - image_df (dd.DataFrame): input dataframe with images_data column + image_bytes: input image as bytes Returns: np.int16: width of the image np.int16: height of the image """ - image = Image.open(io.BytesIO(image_df["images_data"])) + image = Image.open(io.BytesIO(image_bytes)) return np.int16(image.size[0]), np.int16(image.size[1]) -class ImageCroppingComponent(DaskTransformComponent): +class ImageCroppingComponent(PandasTransformComponent): """Component that crops images.""" - def transform( + def __init__( self, - *, - dataframe: dd.DataFrame, + *_, cropping_threshold: int, padding: int, - ) -> dd.DataFrame: + ) -> None: """ Args: - dataframe (dd.DataFrame): Dask dataframe cropping_threshold (int): threshold parameter used for detecting borders padding (int): padding for the image cropping. - - Returns: - dd.DataFrame: Dask dataframe with cropped images """ + self.cropping_threshold = cropping_threshold + self.padding = padding + + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: # crop images - dataframe["images_data"] = dataframe["images_data"].map( - lambda x: remove_borders(x, cropping_threshold, padding), - meta=("images_data", "bytes"), + dataframe["images"]["data"] = dataframe["images"]["data"].apply( + lambda image: remove_borders(image, self.cropping_threshold, self.padding), ) # extract width and height - dataframe[["images_width", "images_height"]] = dataframe[ - [ - "images_data", - ] - ].apply(extract_dimensions, axis=1, result_type="expand", meta={0: int, 1: int}) + dataframe["images"][["width", "height"]] = dataframe["images"]["data"].apply( + extract_dimensions, axis=1, result_type="expand", meta={0: int, 1: int}, + ) return dataframe if __name__ == "__main__": - component = ImageCroppingComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(ImageCroppingComponent) diff --git a/components/image_embedding/src/main.py b/components/image_embedding/src/main.py index 05b71f7cb..2ae6698c3 100644 --- a/components/image_embedding/src/main.py +++ b/components/image_embedding/src/main.py @@ -7,6 +7,7 @@ import pandas as pd import torch from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor from PIL import Image from transformers import CLIPProcessor, CLIPVisionModelWithProjection @@ -49,9 +50,9 @@ def embed_image_batch(image_batch: pd.DataFrame, *, model: CLIPVisionModelWithPr class EmbedImagesComponent(PandasTransformComponent): """Component that captions images using a model from the Hugging Face hub.""" - def setup( + def __init__( self, - *, + *_, model_id: str, batch_size: int, ): @@ -85,5 +86,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = EmbedImagesComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(EmbedImagesComponent) diff --git a/components/image_resolution_extraction/src/main.py b/components/image_resolution_extraction/src/main.py index 89e230348..681840da3 100644 --- a/components/image_resolution_extraction/src/main.py +++ b/components/image_resolution_extraction/src/main.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) @@ -46,5 +47,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = ImageResolutionExtractionComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(ImageResolutionExtractionComponent) diff --git a/components/language_filter/src/main.py b/components/language_filter/src/main.py index 540f40163..415996f83 100644 --- a/components/language_filter/src/main.py +++ b/components/language_filter/src/main.py @@ -4,6 +4,7 @@ import fasttext import pandas as pd from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ def is_language(self, row): class LanguageFilterComponent(PandasTransformComponent): """Component that filter columns based on provided language.""" - def setup(self, *, language): + def __init__(self, *_, language): """Setup language filter component. Args: @@ -67,5 +68,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = LanguageFilterComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(LanguageFilterComponent) diff --git a/components/load_from_hf_hub/src/main.py b/components/load_from_hf_hub/src/main.py index 03a1e4c8b..a55daf0c9 100644 --- a/components/load_from_hf_hub/src/main.py +++ b/components/load_from_hf_hub/src/main.py @@ -3,18 +3,20 @@ import typing as t import dask.dataframe as dd -from fondant.component import LoadComponent +from fondant.component import DaskLoadComponent +from fondant.executor import DaskLoadExecutor logger = logging.getLogger(__name__) -class LoadFromHubComponent(LoadComponent): - def load(self, - *, +class LoadFromHubComponent(DaskLoadComponent): + + def __init__(self, *_, dataset_name: str, column_name_mapping: dict, image_column_names: t.Optional[list], - n_rows_to_load: t.Optional[int]) -> dd.DataFrame: + n_rows_to_load: t.Optional[int], + ) -> None: """ Args: dataset_name: name of the dataset to load. @@ -22,33 +24,36 @@ def load(self, image_column_names: A list containing the original hub image column names. Used to format the image from HF hub format to a byte string n_rows_to_load: optional argument that defines the number of rows to load. Useful for - testing pipeline runs on a small scale - Returns: - Dataset: HF dataset. + testing pipeline runs on a small scale. """ + self.dataset_name = dataset_name + self.column_name_mapping = column_name_mapping + self.image_column_names = image_column_names + self.n_rows_to_load = n_rows_to_load + + def load(self) -> dd.DataFrame: # 1) Load data, read as Dask dataframe logger.info("Loading dataset from the hub...") - dask_df = dd.read_parquet(f"hf://datasets/{dataset_name}") + dask_df = dd.read_parquet(f"hf://datasets/{self.dataset_name}") # 2) Make sure images are bytes instead of dicts - if image_column_names is not None: - for image_column_name in image_column_names: + if self.image_column_names is not None: + for image_column_name in self.image_column_names: dask_df[image_column_name] = dask_df[image_column_name].map( lambda x: x["bytes"], meta=("bytes", bytes), ) # 3) Rename columns - dask_df = dask_df.rename(columns=column_name_mapping) + dask_df = dask_df.rename(columns=self.column_name_mapping) # 4) Optional: only return specific amount of rows - - if n_rows_to_load: - dask_df = dask_df.head(n_rows_to_load) + if self.n_rows_to_load: + dask_df = dask_df.head(self.n_rows_to_load) dask_df = dd.from_pandas(dask_df, npartitions=1) return dask_df if __name__ == "__main__": - component = LoadFromHubComponent.from_args() - component.run() + executor = DaskLoadExecutor.from_args() + executor.execute(LoadFromHubComponent) diff --git a/components/minhash_generator/src/main.py b/components/minhash_generator/src/main.py index 2135ec3a2..6da4e3f61 100644 --- a/components/minhash_generator/src/main.py +++ b/components/minhash_generator/src/main.py @@ -5,6 +5,7 @@ import pandas as pd from datasketch import MinHash from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor from nltk.util import ngrams logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ def compute_minhash(shingles: list) -> np.ndarray: class MinHashGeneratorComponent(PandasTransformComponent): """Component generates minhashes of text.""" - def setup(self, *, shingle_ngram_size: int): + def __init__(self, *_, shingle_ngram_size: int): """Setup component. Args: @@ -60,5 +61,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = MinHashGeneratorComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(MinHashGeneratorComponent) diff --git a/components/pii_redaction/src/main.py b/components/pii_redaction/src/main.py index 98237dc05..803149bfd 100644 --- a/components/pii_redaction/src/main.py +++ b/components/pii_redaction/src/main.py @@ -3,29 +3,22 @@ import json import logging -import dask.dataframe as dd -from fondant.component import DaskTransformComponent +import pandas as pd +from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor from pii_detection import scan_pii from pii_redaction import redact_pii logger = logging.getLogger(__name__) -class RemovePIIComponent(DaskTransformComponent): +class RemovePIIComponent(PandasTransformComponent): """Component that detects and redacts PII from code.""" def transform( self, - *, - dataframe: dd.DataFrame, - ) -> dd.DataFrame: - """ - Args: - dataframe: Dask dataframe. - - Returns: - Dask dataframe - """ + dataframe: pd.DataFrame, + ) -> pd.DataFrame: # detect PII result = dataframe.apply( lambda example: scan_pii(text=example.code_content), @@ -33,7 +26,7 @@ def transform( result_type="expand", meta={0: object, 1: bool, 2: int}, ) - result.columns = ["code_secrets", "code_has_secrets", "code_number_secrets"] + result.columns = [("code", "secrets"), ("code", "has_secrets"), ("code", "number_secrets")] dataframe = dataframe.merge(result, left_index=True, right_index=True) @@ -42,7 +35,7 @@ def transform( with open("replacements.json") as f: replacements = json.load(f) - dataframe["code_content"] = dataframe.apply( + dataframe["code"]["content"] = dataframe.apply( lambda example: redact_pii( text=example.code_content, secrets=example.code_secrets, @@ -52,13 +45,11 @@ def transform( axis=1, meta=(None, "str"), ) - dataframe = dataframe.drop( - ["code_secrets", "code_has_secrets", "code_number_secrets"], axis=1, + return dataframe.drop( + [("code", "secrets"), ("code", "has_secrets"), ("code", "number_secrets")], axis=1, ) - return dataframe - if __name__ == "__main__": - component = RemovePIIComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(RemovePIIComponent) diff --git a/components/prompt_based_laion_retrieval/src/main.py b/components/prompt_based_laion_retrieval/src/main.py index d1e057b0b..49024dc52 100644 --- a/components/prompt_based_laion_retrieval/src/main.py +++ b/components/prompt_based_laion_retrieval/src/main.py @@ -7,6 +7,7 @@ import pandas as pd from clip_client import ClipClient, Modality from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) @@ -14,9 +15,9 @@ class LAIONRetrievalComponent(PandasTransformComponent): """Component that retrieves image URLs from LAION-5B based on a set of prompts.""" - def setup( + def __init__( self, - *, + *_, num_images: int, aesthetic_score: int, aesthetic_weight: float, @@ -71,5 +72,5 @@ async def async_query(): if __name__ == "__main__": - component = LAIONRetrievalComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(LAIONRetrievalComponent) diff --git a/components/segment_images/src/main.py b/components/segment_images/src/main.py index 434b6fb50..2fcc12026 100644 --- a/components/segment_images/src/main.py +++ b/components/segment_images/src/main.py @@ -7,6 +7,7 @@ import pandas as pd import torch from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor from palette import palette from PIL import Image from transformers import AutoModelForSemanticSegmentation, BatchFeature, SegformerImageProcessor @@ -85,7 +86,7 @@ def segment_image_batch(image_batch: pd.DataFrame, class SegmentImagesComponent(PandasTransformComponent): """Component that segments images using a model from the Hugging Face hub.""" - def setup(self, model_id: str, batch_size: int) -> None: + def __init__(self, *_, model_id: str, batch_size: int) -> None: """ Args: model_id: id of the model on the Hugging Face hub @@ -121,5 +122,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = SegmentImagesComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(SegmentImagesComponent) diff --git a/components/text_length_filter/src/main.py b/components/text_length_filter/src/main.py index 58e154311..8acce86fc 100644 --- a/components/text_length_filter/src/main.py +++ b/components/text_length_filter/src/main.py @@ -4,6 +4,7 @@ import fasttext import pandas as pd from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) @@ -11,7 +12,7 @@ class TextLengthFilterComponent(PandasTransformComponent): """A component that filters out text based on their length.""" - def setup(self, *, min_characters_length: int, min_words_length: int): + def __init__(self, *_, min_characters_length: int, min_words_length: int): """Setup component. Args: @@ -22,15 +23,7 @@ def setup(self, *, min_characters_length: int, min_words_length: int): self.min_words_length = min_words_length def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: - """ - Filter out text based on their length. - - Args: - dataframe: Pandas dataframe. - - Returns: - Pandas dataframe. - """ + """Filter out text based on their length.""" caption_num_words = dataframe["text"]["data"].apply(lambda x: len(fasttext.tokenize(x))) caption_num_chars = dataframe["text"]["data"].apply(len) @@ -41,5 +34,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = TextLengthFilterComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(TextLengthFilterComponent) diff --git a/components/text_normalization/src/main.py b/components/text_normalization/src/main.py index 716273410..a3c415717 100644 --- a/components/text_normalization/src/main.py +++ b/components/text_normalization/src/main.py @@ -6,6 +6,7 @@ import pandas as pd from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) @@ -13,7 +14,7 @@ class TextNormalizationComponent(PandasTransformComponent): """Component that normalizes text.""" - def setup(self, *, apply_nfc: bool, do_lowercase: bool, characters_to_remove: List[str]): + def __init__(self, *args, apply_nfc: bool, do_lowercase: bool, characters_to_remove: List[str]): self.apply_nfc = apply_nfc self.do_lowercase = do_lowercase self.characters_to_remove = characters_to_remove @@ -60,5 +61,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = TextNormalizationComponent.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(TextNormalizationComponent) diff --git a/components/write_to_hf_hub/src/main.py b/components/write_to_hf_hub/src/main.py index bbe35d883..46ebcad73 100644 --- a/components/write_to_hf_hub/src/main.py +++ b/components/write_to_hf_hub/src/main.py @@ -9,7 +9,9 @@ # Define the schema for the struct using PyArrow import huggingface_hub from datasets.features.features import generate_from_arrow_type -from fondant.component import WriteComponent +from fondant.component import DaskWriteComponent +from fondant.component_spec import ComponentSpec +from fondant.executor import DaskWriteExecutor from PIL import Image logger = logging.getLogger(__name__) @@ -30,10 +32,10 @@ def convert_bytes_to_image(image_bytes: bytes, feature_encoder: datasets.Image) return image -class WriteToHubComponent(WriteComponent): - def write( - self, - dataframe: dd.DataFrame, +class WriteToHubComponent(DaskWriteComponent): + + def __init__(self, + spec: ComponentSpec, *, hf_token: str, username: str, @@ -43,7 +45,7 @@ def write( ): """ Args: - dataframe: Dask dataframe + spec: Dynamic component specification describing the dataset to write hf_token: The hugging face token used to write to the hub username: The username under which to upload the dataset dataset_name: The name of the dataset to upload @@ -52,15 +54,22 @@ def write( column_name_mapping: Mapping of the consumed fondant column names to the written hub column names. """ - # login huggingface_hub.login(token=hf_token) - # Create HF dataset repository repo_id = f"{username}/{dataset_name}" - repo_path = f"hf://datasets/{repo_id}" + self.repo_path = f"hf://datasets/{repo_id}" + logger.info(f"Creating HF dataset repository under ID: '{repo_id}'") huggingface_hub.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True) + self.spec = spec + self.image_column_names = image_column_names + self.column_name_mapping = column_name_mapping + + def write( + self, + dataframe: dd.DataFrame, + ): # Get columns to write and schema write_columns = [] schema_dict = {} @@ -68,7 +77,7 @@ def write( for field in subset.fields.values(): column_name = f"{subset_name}_{field.name}" write_columns.append(column_name) - if image_column_names and column_name in image_column_names: + if self.image_column_names and column_name in self.image_column_names: schema_dict[column_name] = datasets.Image() else: schema_dict[column_name] = generate_from_arrow_type(field.type.value) @@ -79,21 +88,21 @@ def write( # Map image column to hf data format feature_encoder = datasets.Image(decode=True) - if image_column_names is not None: - for image_column_name in image_column_names: + if self.image_column_names is not None: + for image_column_name in self.image_column_names: dataframe[image_column_name] = dataframe[image_column_name].map( lambda x: convert_bytes_to_image(x, feature_encoder), meta=(image_column_name, "object"), ) # Map column names to hf data format - if column_name_mapping: - dataframe = dataframe.rename(columns=column_name_mapping) + if self.column_name_mapping: + dataframe = dataframe.rename(columns=self.column_name_mapping) # Write dataset to the hub - dd.to_parquet(dataframe, path=f"{repo_path}/data", schema=schema) + dd.to_parquet(dataframe, path=f"{self.repo_path}/data", schema=schema) if __name__ == "__main__": - component = WriteToHubComponent.from_args() - component.run() + executor = DaskWriteExecutor.from_args() + executor.execute(WriteToHubComponent) diff --git a/docs/component_spec.md b/docs/component_spec.md index e42efddab..3445fd093 100644 --- a/docs/component_spec.md +++ b/docs/component_spec.md @@ -150,23 +150,39 @@ custom_op = ComponentOp( ) ``` -Afterwards, we pass all keyword arguments to the `transform()` method of the component. +Afterwards, we pass all keyword arguments to the `__init__()` method of the component. ```python -from fondant.component import TransformComponent +import pandas as pd +from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor -class ExampleComponent(TransformComponent): - def transform(self, dataframe, *, custom_argument, default_argument): - """Implement your custom logic in this single method +class ExampleComponent(PandasTransformComponent): + + def __init__(self, *args, custom_argument, default_argument) -> None: + """ + Args: + x_argument: An argument passed to the component + """ + # Initialize your component here based on the arguments + + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + """Implement your custom logic in this single method + + Args: + dataframe: A Pandas dataframe containing the data - Args: - dataframe: A Dask dataframe containing the data - custom_argument: An argument passed to the component - default_argument: A default argument passed to the components - """ + Returns: + A pandas dataframe containing the transformed data + """ + +if __name__ == "__main__": + executor = PandasTransformExecutor.from_args() + executor.execute(ExampleComponent) ``` + ## Examples Each component specification defines how the input manifest will be transformed into the output diff --git a/docs/custom_component.md b/docs/custom_component.md index a415036be..2be91564a 100644 --- a/docs/custom_component.md +++ b/docs/custom_component.md @@ -39,35 +39,39 @@ The easiest way to implement a `TransformComponent` is to subclass the provided chunks as pandas dataframes. ```python +import pandas as pd from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor + class ExampleComponent(PandasTransformComponent): - - def setup(self, argument1, argument2): - """This method is called once per component with any custom arguments it received. Use - it for component wide setup or to store your arguments as instance attributes to access - them in the `transform` method. - + + def __init__(self, *args, argument1, argument2) -> None: + """ Args: - argumentX: A custom argument passed to the component - """ + argumentX: An argument passed to the component + """ + # Initialize your component here based on the arguments - def transform(self, dataframe): - """Implement your custom transformation logic in this single method - + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + """Implement your custom logic in this single method Args: dataframe: A Pandas dataframe containing one partition of your data - Returns: - A pandas dataframe with the transformed data + A pandas dataframe containing the transformed data """ + +if __name__ == "__main__": + executor = PandasTransformExecutor.from_args() + executor.execute(ExampleComponent) ``` -The `setup` method is called once for each component class with custom arguments defined in the + +The `__init__` method is called once for each component class with custom arguments defined in the `args` section of the [component specification](component_spec).) The `transform` method is called multiple times, each time containing a pandas `dataframe` -loaded in memory. +with a partition of your data loaded in memory. The `dataframes` passed to the `transform` method contains the data specified in the `produces` section of the component specification. If a component defines that it consumes an `images` subset diff --git a/examples/pipelines/commoncrawl/components/load_from_commoncrawl/src/main.py b/examples/pipelines/commoncrawl/components/load_from_commoncrawl/src/main.py index 99755f1fa..6241d88d6 100644 --- a/examples/pipelines/commoncrawl/components/load_from_commoncrawl/src/main.py +++ b/examples/pipelines/commoncrawl/components/load_from_commoncrawl/src/main.py @@ -9,7 +9,8 @@ import dask.dataframe as dd import pandas as pd -from fondant.component import LoadComponent +from fondant.component import DaskLoadComponent +from fondant.executor import DaskLoadExecutor logger = logging.getLogger(__name__) @@ -60,30 +61,37 @@ def read_warc_paths_file( return dask_df -class LoadFromCommonCrawl(LoadComponent): - def load( - self, index_name: str, n_segments_to_load: t.Optional[int] = None - ) -> dd.DataFrame: +class LoadFromCommonCrawlComponent(DaskLoadComponent): + def __init__( + self, *args, index_name: str, n_segments_to_load: t.Optional[int] = None + ) -> None: + self.index_name = index_name + self.n_segments_to_load = n_segments_to_load """Loads a dataset of segment file paths from CommonCrawl based on a given index. + Args: index_name: The name of the CommonCrawl index to load. n_segments_to_load: The number of segments to load from the index. + """ + + def load(self) -> dd.DataFrame: + """ Returns: A Dask DataFrame containing the segment file paths. """ - logger.info(f"Loading CommonCrawl index {index_name}...") - warc_paths_file_key = f"crawl-data/{index_name}/warc.paths.gz" + logger.info(f"Loading CommonCrawl index {self.index_name}...") + warc_paths_file_key = f"crawl-data/{self.index_name}/warc.paths.gz" warc_paths_file_content = fetch_warc_file_from_s3( S3_COMMONCRAWL_BUCKET, warc_paths_file_key ) warc_paths_df = read_warc_paths_file( - warc_paths_file_content, n_segments_to_load + warc_paths_file_content, self.n_segments_to_load ) return warc_paths_df if __name__ == "__main__": - component = LoadFromCommonCrawl.from_args() - component.run() + executor = DaskLoadExecutor.from_args() + executor.execute(LoadFromCommonCrawlComponent) diff --git a/examples/pipelines/controlnet-interior-design/components/generate_prompts/src/main.py b/examples/pipelines/controlnet-interior-design/components/generate_prompts/src/main.py index 7b0369098..28bc055b8 100644 --- a/examples/pipelines/controlnet-interior-design/components/generate_prompts/src/main.py +++ b/examples/pipelines/controlnet-interior-design/components/generate_prompts/src/main.py @@ -8,7 +8,8 @@ import dask.dataframe as dd import pandas as pd -from fondant.component import LoadComponent +from fondant.component import DaskLoadComponent +from fondant.executor import DaskLoadExecutor logger = logging.getLogger(__name__) @@ -95,24 +96,26 @@ def make_interior_prompt(room: str, prefix: str, style: str) -> str: return f"{prefix.lower()} {room.lower()}, {style.lower()} interior design" -class GeneratePromptsComponent(LoadComponent): - def load(self, n_rows_to_load: t.Optional[int]) -> dd.DataFrame: +class GeneratePromptsComponent(DaskLoadComponent): + def __init__(self, *args, n_rows_to_load: t.Optional[int]) -> None: """ Generate a set of initial prompts that will be used to retrieve images from the LAION-5B dataset. + Args: n_rows_to_load: Optional argument that defines the number of rows to load. Useful for testing pipeline runs on a small scale - Returns: - Dask dataframe """ + self.n_rows_to_load = n_rows_to_load + + def load(self) -> dd.DataFrame: room_tuples = itertools.product(rooms, interior_prefix, interior_styles) prompts = map(lambda x: make_interior_prompt(*x), room_tuples) pandas_df = pd.DataFrame(prompts, columns=["prompts_text"]) - if n_rows_to_load: - pandas_df = pandas_df.head(n_rows_to_load) + if self.n_rows_to_load: + pandas_df = pandas_df.head(self.n_rows_to_load) df = dd.from_pandas(pandas_df, npartitions=1) @@ -120,5 +123,5 @@ def load(self, n_rows_to_load: t.Optional[int]) -> dd.DataFrame: if __name__ == "__main__": - component = GeneratePromptsComponent.from_args() - component.run() + executor = DaskLoadExecutor.from_args() + executor.execute(GeneratePromptsComponent) diff --git a/examples/pipelines/datacomp/components/cluster_image_embeddings/src/main.py b/examples/pipelines/datacomp/components/cluster_image_embeddings/src/main.py index d84320108..c2ee12c51 100644 --- a/examples/pipelines/datacomp/components/cluster_image_embeddings/src/main.py +++ b/examples/pipelines/datacomp/components/cluster_image_embeddings/src/main.py @@ -8,6 +8,7 @@ from sklearn.cluster import KMeans from fondant.component import DaskTransformComponent +from fondant.executor import DaskTransformExecutor logger = logging.getLogger(__name__) @@ -15,22 +16,17 @@ class ClusterImageEmbeddingsComponent(DaskTransformComponent): """Component that clusters images based on embeddings.""" - def transform( - self, dataframe: dd.DataFrame, sample_ratio: float, num_clusters: int - ) -> dd.DataFrame: - """ - Args: - dataframe: Dask dataframe + def __init__(self, sample_ratio: float, num_clusters: int) -> None: + self.sample_ratio = sample_ratio + self.num_clusters = num_clusters - Returns: - Dask dataframe - """ + def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame: embeddings = dataframe["image_embedding"].sample( - frac=sample_ratio, random_state=1 + frac=self.sample_ratio, random_state=1 ) embeddings = np.vstack(list(embeddings)) - kmeans = KMeans(n_clusters=num_clusters, random_state=0, n_init="auto") + kmeans = KMeans(n_clusters=self.num_clusters, random_state=0, n_init="auto") kmeans = kmeans.fit(embeddings) # call predict per row @@ -44,5 +40,5 @@ def transform( if __name__ == "__main__": - component = ClusterImageEmbeddingsComponent.from_args() - component.run() + executor = DaskTransformExecutor.from_args() + executor.execute(ClusterImageEmbeddingsComponent) diff --git a/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py b/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py index 2e7f6616a..b5685be0e 100644 --- a/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py +++ b/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py @@ -12,6 +12,7 @@ from spacy.symbols import nsubj, VERB from fondant.component import PandasTransformComponent +from fondant.executor import PandasTransformExecutor logger = logging.getLogger(__name__) @@ -41,9 +42,9 @@ class FilterTextComplexity(PandasTransformComponent): - complexity of the dependency parse tree - number of actions""" - def setup( + def __init__( self, - *, + *args, spacy_pipeline, batch_size: int, min_complexity: int, @@ -72,5 +73,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: if __name__ == "__main__": - component = FilterTextComplexity.from_args() - component.run() + executor = PandasTransformExecutor.from_args() + executor.execute(FilterTextComplexity) diff --git a/src/fondant/component.py b/src/fondant/component.py index 325393db8..724f4ea8b 100644 --- a/src/fondant/component.py +++ b/src/fondant/component.py @@ -1,260 +1,51 @@ -""" -This Python module defines abstract base class for components in the Fondant data processing -framework, providing a standardized interface for extending loading and transforming components. -The loading component is the first component that loads the initial dataset and the transform -components take care of processing, filtering and extending the data. -""" +"""This module defines interfaces which components should implement to be executed by fondant.""" -import argparse -import json -import logging import typing as t -from abc import ABC, abstractmethod -from pathlib import Path import dask.dataframe as dd import pandas as pd -from fondant.component_spec import Argument, ComponentSpec, kubeflow2python_type -from fondant.data_io import DaskDataLoader, DaskDataWriter -from fondant.manifest import Manifest +from fondant.component_spec import ComponentSpec -logger = logging.getLogger(__name__) +class BaseComponent: + """Base interface for each component, specifying only the constructor. -class Component(ABC): - """Abstract base class for a Fondant component.""" + Args: + spec: The specification of the component + **kwargs: The provided user arguments are passed in as keyword arguments + """ - def __init__( - self, - spec: ComponentSpec, - *, - input_manifest_path: t.Union[str, Path], - output_manifest_path: t.Union[str, Path], - metadata: t.Dict[str, t.Any], - user_arguments: t.Dict[str, Argument], - ) -> None: - self.spec = spec - self.input_manifest_path = input_manifest_path - self.output_manifest_path = output_manifest_path - self.metadata = metadata - self.user_arguments = user_arguments - - @classmethod - def from_file( - cls, - path: t.Union[str, Path] = "../fondant_component.yaml", - ) -> "Component": - """Create a component from a component spec file. - - Args: - path: Path to the component spec file - """ - component_spec = ComponentSpec.from_file(path) - return cls.from_spec(component_spec) - - @classmethod - def from_args(cls) -> "Component": - """Create a component from a passed argument containing the specification as a dict.""" - parser = argparse.ArgumentParser() - parser.add_argument("--component_spec", type=json.loads) - args, _ = parser.parse_known_args() - - if "component_spec" not in args: - msg = "Error: The --component_spec argument is required." - raise ValueError(msg) - - component_spec = ComponentSpec(args.component_spec) - - return cls.from_spec(component_spec) - - @classmethod - def from_spec(cls, component_spec: ComponentSpec) -> "Component": - """Create a component from a component spec.""" - args_dict = vars(cls._add_and_parse_args(component_spec)) - - if "component_spec" in args_dict: - args_dict.pop("component_spec") - input_manifest_path = args_dict.pop("input_manifest_path") - output_manifest_path = args_dict.pop("output_manifest_path") - metadata = args_dict.pop("metadata") - - metadata = json.loads(metadata) if metadata else {} - - return cls( - component_spec, - input_manifest_path=input_manifest_path, - output_manifest_path=output_manifest_path, - metadata=metadata, - user_arguments=args_dict, - ) - - @classmethod - def _add_and_parse_args(cls, spec: ComponentSpec): - parser = argparse.ArgumentParser() - component_arguments = cls._get_component_arguments(spec) - - for arg in component_arguments.values(): - if arg.name in cls.optional_fondant_arguments(): - input_required = False - default = None - elif arg.default: - input_required = False - default = arg.default - else: - input_required = True - default = None - - parser.add_argument( - f"--{arg.name}", - type=kubeflow2python_type(arg.type), # type: ignore - required=input_required, - default=default, - help=arg.description, - ) - - return parser.parse_args() - - @staticmethod - def optional_fondant_arguments() -> t.List[str]: - return [] - - @staticmethod - def _get_component_arguments(spec: ComponentSpec) -> t.Dict[str, Argument]: - """ - Get the component arguments as a dictionary representation containing both input and output - arguments of a component - Args: - spec: the component spec - Returns: - Input and output arguments of the component. - """ - component_arguments: t.Dict[str, Argument] = {} - kubeflow_component_spec = spec.kubeflow_specification - component_arguments.update(kubeflow_component_spec.input_arguments) - component_arguments.update(kubeflow_component_spec.output_arguments) - return component_arguments - - @abstractmethod - def _load_or_create_manifest(self) -> Manifest: - """Abstract method that returns the dataset manifest.""" - - @abstractmethod - def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]: - """Abstract method that processes the manifest and - returns another dataframe. - """ - - def _write_data(self, dataframe: dd.DataFrame, *, manifest: Manifest): - """Create a data writer given a manifest and writes out the index and subsets.""" - data_writer = DaskDataWriter(manifest=manifest, component_spec=self.spec) - data_writer.write_dataframe(dataframe) - - def run(self): - """Runs the component.""" - input_manifest = self._load_or_create_manifest() - - output_df = self._process_dataset(manifest=input_manifest) - - output_manifest = input_manifest.evolve(component_spec=self.spec) - - self._write_data(dataframe=output_df, manifest=output_manifest) - - self.upload_manifest(output_manifest, save_path=self.output_manifest_path) - - def upload_manifest(self, manifest: Manifest, save_path: str): - Path(save_path).parent.mkdir(parents=True, exist_ok=True) - manifest.to_file(save_path) - - -class LoadComponent(Component): - """Base class for a Fondant load component.""" - - @staticmethod - def optional_fondant_arguments() -> t.List[str]: - return ["input_manifest_path"] - - def _load_or_create_manifest(self) -> Manifest: - component_id = self.spec.name.lower().replace(" ", "_") - return Manifest.create( - base_path=self.metadata["base_path"], - run_id=self.metadata["run_id"], - component_id=component_id, - ) - - @abstractmethod - def load(self, *args, **kwargs) -> dd.DataFrame: - """Abstract method that loads the initial dataframe.""" + def __init__(self, spec: ComponentSpec, **kwargs): + pass - def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: - """This function loads the initial dataframe sing the user-provided `load` method. - Returns: - A `dd.DataFrame` instance with initial data'. - """ - # Load the dataframe according to the custom function provided to the user - return self.load(**self.user_arguments) +class DaskLoadComponent(BaseComponent): + """Component that loads data and returns a Dask DataFrame.""" + def load(self) -> dd.DataFrame: + raise NotImplementedError -class TransformComponent(Component): - """Base class for a Fondant transform component.""" - def _load_or_create_manifest(self) -> Manifest: - return Manifest.from_file(self.input_manifest_path) +class DaskTransformComponent(BaseComponent): + """Component that transforms an incoming Dask DataFrame.""" - @abstractmethod - def transform(self, *args, **kwargs) -> dd.DataFrame: + def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame: """ Abstract method for applying data transformations to the input dataframe. Args: - args: The dataframe will be passed in as a positional argument - kwargs: Arguments provided to the component are passed as keyword arguments + dataframe: A Dask dataframe containing the data specified in the `consumes` section + of the component specification """ + raise NotImplementedError - def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: - """ - Load the data based on the manifest using a DaskDataloader and call the transform method to - process it. - - Returns: - A `dd.DataFrame` instance with updated data based on the applied data transformations. - """ +class PandasTransformComponent(BaseComponent): + """Component that transforms the incoming dataset partition per partition as a pandas + DataFrame. + """ -class DaskTransformComponent(TransformComponent): - @abstractmethod - def transform(self, *args, **kwargs) -> dd.DataFrame: - """ - Abstract method for applying data transformations to the input dataframe. - - Args: - args: A Dask dataframe will be passed in as a positional argument - kwargs: Arguments provided to the component are passed as keyword arguments - """ - - def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: - """ - Load the data based on the manifest using a DaskDataloader and call the transform method to - process it. - - Returns: - A `dd.DataFrame` instance with updated data based on the applied data transformations. - """ - data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) - dataframe = data_loader.load_dataframe() - dataframe = self.transform(dataframe, **self.user_arguments) - return dataframe - - -class PandasTransformComponent(TransformComponent): - def setup(self, *args, **kwargs): - """Called once for each instance of the Component class. Use this to set up resources - such as database connections. - """ - return - - @abstractmethod def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: """ Abstract method for applying data transformations to the input dataframe. @@ -263,135 +54,15 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: Args: dataframe: A Pandas dataframe containing a partition of the data """ + raise NotImplementedError - @staticmethod - def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable: - """Factory that creates a function to wrap the component transform function. The wrapper: - - Converts the columns to hierarchical format before passing the dataframe to the - transform function - - Removes extra columns from the returned dataframe which are not defined in the component - spec `produces` section - - Sorts the columns from the returned dataframe according to the order in the component - spec `produces` section to match the order in the `meta` argument passed to Dask's - `map_partitions`. - - Flattens the returned dataframe columns. - - Args: - transform: Transform method to wrap - spec: Component specification to base behavior on - """ - - def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame: - # Switch to hierarchical columns - dataframe.columns = pd.MultiIndex.from_tuples( - tuple(column.split("_")) for column in dataframe.columns - ) - - # Call transform method - dataframe = transform(dataframe) - - # Drop columns not in specification - columns = [ - (subset_name, field) - for subset_name, subset in spec.produces.items() - for field in subset.fields - ] - dataframe = dataframe[columns] - - # Switch to flattened columns - dataframe.columns = [ - "_".join(column) for column in dataframe.columns.to_flat_index() - ] - return dataframe - - return wrapped_transform - - def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: - """ - Load the data based on the manifest using a DaskDataloader and call the transform method to - process it. - - Returns: - A `dd.DataFrame` instance with updated data based on the applied data transformations. - """ - data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) - dataframe = data_loader.load_dataframe() - - # Call the component setup method with user provided argument - self.setup(**self.user_arguments) - - # Create meta dataframe with expected format - meta_dict = {"id": pd.Series(dtype="object")} - for subset_name, subset in self.spec.produces.items(): - for field_name, field in subset.fields.items(): - meta_dict[f"{subset_name}_{field_name}"] = pd.Series( - dtype=pd.ArrowDtype(field.type.value), - ) - meta_df = pd.DataFrame(meta_dict).set_index("id") - - wrapped_transform = self.wrap_transform(self.transform, spec=self.spec) - - # Call the component transform method for each partition - dataframe = dataframe.map_partitions( - wrapped_transform, - meta=meta_df, - ) - # Clear divisions if component spec indicates that the index is changed - if self._infer_index_change(): - dataframe.clear_divisions() +class DaskWriteComponent(BaseComponent): + """Component that accepts a Dask DataFrame and writes its contents.""" - return dataframe + def write(self, dataframe: dd.DataFrame) -> None: + raise NotImplementedError - def _infer_index_change(self) -> bool: - """Infer if this component changes the index based on its component spec.""" - if not self.spec.accepts_additional_subsets: - return True - if not self.spec.outputs_additional_subsets: - return True - for subset in self.spec.consumes.values(): - if not subset.additional_fields: - return True - return any( - not subset.additional_fields for subset in self.spec.produces.values() - ) - -class WriteComponent(Component): - """Base class for a Fondant write component.""" - - @staticmethod - def optional_fondant_arguments() -> t.List[str]: - return ["output_manifest_path"] - - def _load_or_create_manifest(self) -> Manifest: - return Manifest.from_file(self.input_manifest_path) - - @abstractmethod - def write(self, *args, **kwargs): - """ - Abstract method to write a dataframe to a final custom location. - - Args: - args: The dataframe will be passed in as a positional argument - kwargs: Arguments provided to the component are passed as keyword arguments - """ - - def _process_dataset(self, manifest: Manifest) -> None: - """ - Creates a DataLoader using the provided manifest and loads the input dataframe using the - `load_dataframe` instance, and applies data transformations to it using the `transform` - method implemented by the derived class. Returns a single dataframe. - - Returns: - A `dd.DataFrame` instance with updated data based on the applied data transformations. - """ - data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) - dataframe = data_loader.load_dataframe() - self.write(dataframe, **self.user_arguments) - - def _write_data(self, dataframe: dd.DataFrame, *, manifest: Manifest): - """Create a data writer given a manifest and writes out the index and subsets.""" - - def upload_manifest(self, manifest: Manifest, save_path: str): - pass +Component = t.TypeVar("Component", bound=BaseComponent) +"""Component type which can represents any of the subclasses of BaseComponent""" diff --git a/src/fondant/executor.py b/src/fondant/executor.py new file mode 100644 index 000000000..aca8a235d --- /dev/null +++ b/src/fondant/executor.py @@ -0,0 +1,377 @@ +""" +This Python module defines abstract base class for components in the Fondant data processing +framework, providing a standardized interface for extending loading and transforming components. +The loading component is the first component that loads the initial dataset and the transform +components take care of processing, filtering and extending the data. +""" + +import argparse +import json +import logging +import typing as t +from abc import abstractmethod +from pathlib import Path + +import dask.dataframe as dd +import pandas as pd + +from fondant.component import ( + Component, + DaskLoadComponent, + DaskTransformComponent, + DaskWriteComponent, + PandasTransformComponent, +) +from fondant.component_spec import Argument, ComponentSpec, kubeflow2python_type +from fondant.data_io import DaskDataLoader, DaskDataWriter +from fondant.manifest import Manifest + +logger = logging.getLogger(__name__) + + +class Executor(t.Generic[Component]): + """An executor executes a Component.""" + + def __init__( + self, + spec: ComponentSpec, + *, + input_manifest_path: t.Union[str, Path], + output_manifest_path: t.Union[str, Path], + metadata: t.Dict[str, t.Any], + user_arguments: t.Dict[str, Argument], + ) -> None: + self.spec = spec + self.input_manifest_path = input_manifest_path + self.output_manifest_path = output_manifest_path + self.metadata = metadata + self.user_arguments = user_arguments + + @classmethod + def from_file( + cls, + path: t.Union[str, Path] = "../fondant_component.yaml", + ) -> "Executor": + """Create an executor from a component spec file. + + Args: + path: Path to the component spec file + """ + component_spec = ComponentSpec.from_file(path) + return cls.from_spec(component_spec) + + @classmethod + def from_args(cls) -> "Executor": + """Create an executor from a passed argument containing the specification as a dict.""" + parser = argparse.ArgumentParser() + parser.add_argument("--component_spec", type=json.loads) + args, _ = parser.parse_known_args() + + if "component_spec" not in args: + msg = "Error: The --component_spec argument is required." + raise ValueError(msg) + + component_spec = ComponentSpec(args.component_spec) + + return cls.from_spec(component_spec) + + @classmethod + def from_spec(cls, component_spec: ComponentSpec) -> "Executor": + """Create an executor from a component spec.""" + args_dict = vars(cls._add_and_parse_args(component_spec)) + + if "component_spec" in args_dict: + args_dict.pop("component_spec") + input_manifest_path = args_dict.pop("input_manifest_path") + output_manifest_path = args_dict.pop("output_manifest_path") + metadata = args_dict.pop("metadata") + + metadata = json.loads(metadata) if metadata else {} + + return cls( + component_spec, + input_manifest_path=input_manifest_path, + output_manifest_path=output_manifest_path, + metadata=metadata, + user_arguments=args_dict, + ) + + @classmethod + def _add_and_parse_args(cls, spec: ComponentSpec): + parser = argparse.ArgumentParser() + component_arguments = cls._get_component_arguments(spec) + + for arg in component_arguments.values(): + if arg.name in cls.optional_fondant_arguments(): + input_required = False + default = None + elif arg.default: + input_required = False + default = arg.default + else: + input_required = True + default = None + + parser.add_argument( + f"--{arg.name}", + type=kubeflow2python_type(arg.type), # type: ignore + required=input_required, + default=default, + help=arg.description, + ) + + return parser.parse_args() + + @staticmethod + def optional_fondant_arguments() -> t.List[str]: + return [] + + @staticmethod + def _get_component_arguments(spec: ComponentSpec) -> t.Dict[str, Argument]: + """ + Get the component arguments as a dictionary representation containing both input and output + arguments of a component + Args: + spec: the component spec + Returns: + Input and output arguments of the component. + """ + component_arguments: t.Dict[str, Argument] = {} + kubeflow_component_spec = spec.kubeflow_specification + component_arguments.update(kubeflow_component_spec.input_arguments) + component_arguments.update(kubeflow_component_spec.output_arguments) + return component_arguments + + @abstractmethod + def _load_or_create_manifest(self) -> Manifest: + """Abstract method that returns the dataset manifest.""" + + @abstractmethod + def _execute_component( + self, + component: Component, + *, + manifest: Manifest, + ) -> t.Union[None, dd.DataFrame]: + """ + Abstract method to execute a component with the provided manifest. + + Args: + component: Component instance to execute + manifest: Manifest describing the input data + + Returns: + A Dask DataFrame containing the output data + """ + + def _write_data(self, dataframe: dd.DataFrame, *, manifest: Manifest): + """Create a data writer given a manifest and writes out the index and subsets.""" + data_writer = DaskDataWriter(manifest=manifest, component_spec=self.spec) + data_writer.write_dataframe(dataframe) + + def execute(self, component_cls: t.Type[Component]) -> None: + """Execute a component. + + Args: + component_cls: The class of the component to execute + """ + input_manifest = self._load_or_create_manifest() + + component = component_cls(self.spec, **self.user_arguments) + output_df = self._execute_component(component, manifest=input_manifest) + + output_manifest = input_manifest.evolve(component_spec=self.spec) + + self._write_data(dataframe=output_df, manifest=output_manifest) + + self.upload_manifest(output_manifest, save_path=self.output_manifest_path) + + def upload_manifest(self, manifest: Manifest, save_path: t.Union[str, Path]): + Path(save_path).parent.mkdir(parents=True, exist_ok=True) + manifest.to_file(save_path) + + +class DaskLoadExecutor(Executor[DaskLoadComponent]): + """Base class for a Fondant load component.""" + + @staticmethod + def optional_fondant_arguments() -> t.List[str]: + return ["input_manifest_path"] + + def _load_or_create_manifest(self) -> Manifest: + component_id = self.spec.name.lower().replace(" ", "_") + return Manifest.create( + base_path=self.metadata["base_path"], + run_id=self.metadata["run_id"], + component_id=component_id, + ) + + def _execute_component( + self, + component: DaskLoadComponent, + *, + manifest: Manifest, + ) -> dd.DataFrame: + """This function loads the initial dataframe using the user-provided `load` method. + + Returns: + A `dd.DataFrame` instance with initial data. + """ + return component.load() + + +class TransformExecutor(Executor[Component]): + """Base class for a Fondant transform component.""" + + def _load_or_create_manifest(self) -> Manifest: + return Manifest.from_file(self.input_manifest_path) + + def _execute_component( + self, + component: Component, + *, + manifest: Manifest, + ) -> dd.DataFrame: + raise NotImplementedError + + +class DaskTransformExecutor(TransformExecutor[DaskTransformComponent]): + def _execute_component( + self, + component: DaskTransformComponent, + *, + manifest: Manifest, + ) -> dd.DataFrame: + """ + Load the data based on the manifest using a DaskDataloader and call the transform method to + process it. + + Returns: + A `dd.DataFrame` instance with updated data based on the applied data transformations. + """ + data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) + dataframe = data_loader.load_dataframe() + return component.transform(dataframe) + + +class PandasTransformExecutor(TransformExecutor[PandasTransformComponent]): + @staticmethod + def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable: + """Factory that creates a function to wrap the component transform function. The wrapper: + - Converts the columns to hierarchical format before passing the dataframe to the + transform function + - Removes extra columns from the returned dataframe which are not defined in the component + spec `produces` section + - Sorts the columns from the returned dataframe according to the order in the component + spec `produces` section to match the order in the `meta` argument passed to Dask's + `map_partitions`. + - Flattens the returned dataframe columns. + + Args: + transform: Transform method to wrap + spec: Component specification to base behavior on + """ + + def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame: + # Switch to hierarchical columns + dataframe.columns = pd.MultiIndex.from_tuples( + tuple(column.split("_")) for column in dataframe.columns + ) + + # Call transform method + dataframe = transform(dataframe) + + # Drop columns not in specification + columns = [ + (subset_name, field) + for subset_name, subset in spec.produces.items() + for field in subset.fields + ] + dataframe = dataframe[columns] + + # Switch to flattened columns + dataframe.columns = [ + "_".join(column) for column in dataframe.columns.to_flat_index() + ] + return dataframe + + return wrapped_transform + + def _execute_component( + self, + component: PandasTransformComponent, + *, + manifest: Manifest, + ) -> dd.DataFrame: + """ + Load the data based on the manifest using a DaskDataloader and call the component's + transform method for each partition of the data. + + Returns: + A `dd.DataFrame` instance with updated data based on the applied data transformations. + """ + data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) + dataframe = data_loader.load_dataframe() + + # Create meta dataframe with expected format + meta_dict = {"id": pd.Series(dtype="object")} + for subset_name, subset in self.spec.produces.items(): + for field_name, field in subset.fields.items(): + meta_dict[f"{subset_name}_{field_name}"] = pd.Series( + dtype=pd.ArrowDtype(field.type.value), + ) + meta_df = pd.DataFrame(meta_dict).set_index("id") + + wrapped_transform = self.wrap_transform(component.transform, spec=self.spec) + + # Call the component transform method for each partition + dataframe = dataframe.map_partitions( + wrapped_transform, + meta=meta_df, + ) + + # Clear divisions if component spec indicates that the index is changed + if self._infer_index_change(): + dataframe.clear_divisions() + + return dataframe + + def _infer_index_change(self) -> bool: + """Infer if this component changes the index based on its component spec.""" + if not self.spec.accepts_additional_subsets: + return True + if not self.spec.outputs_additional_subsets: + return True + for subset in self.spec.consumes.values(): + if not subset.additional_fields: + return True + return any( + not subset.additional_fields for subset in self.spec.produces.values() + ) + + +class DaskWriteExecutor(Executor[DaskWriteComponent]): + """Base class for a Fondant write component.""" + + @staticmethod + def optional_fondant_arguments() -> t.List[str]: + return ["output_manifest_path"] + + def _load_or_create_manifest(self) -> Manifest: + return Manifest.from_file(self.input_manifest_path) + + def _execute_component( + self, + component: DaskWriteComponent, + *, + manifest: Manifest, + ) -> None: + data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) + dataframe = data_loader.load_dataframe() + component.write(dataframe) + + def _write_data(self, dataframe: dd.DataFrame, *, manifest: Manifest): + """Create a data writer given a manifest and writes out the index and subsets.""" + + def upload_manifest(self, manifest: Manifest, save_path: t.Union[str, Path]): + pass diff --git a/src/fondant/pipeline.py b/src/fondant/pipeline.py index f0b9a31de..9693af05f 100644 --- a/src/fondant/pipeline.py +++ b/src/fondant/pipeline.py @@ -11,9 +11,10 @@ except ImportError: from importlib_resources import files # type: ignore -from fondant.component import ComponentSpec, Manifest +from fondant.component_spec import ComponentSpec from fondant.exceptions import InvalidPipelineDefinition from fondant.import_utils import is_kfp_available +from fondant.manifest import Manifest if is_kfp_available(): import kfp diff --git a/tests/test_component.py b/tests/test_component.py index e6b0f2576..5fa862d8f 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -9,14 +9,20 @@ import pytest import yaml from fondant.component import ( - Component, + DaskLoadComponent, DaskTransformComponent, - LoadComponent, + DaskWriteComponent, PandasTransformComponent, - WriteComponent, ) from fondant.component_spec import ComponentSpec from fondant.data_io import DaskDataLoader, DaskDataWriter +from fondant.executor import ( + DaskLoadExecutor, + DaskTransformExecutor, + DaskWriteExecutor, + Executor, + PandasTransformExecutor, +) from fondant.manifest import Manifest components_path = Path(__file__).parent / "example_specs/components" @@ -48,12 +54,26 @@ def mocked_write_dataframe(self, dataframe): monkeypatch.setattr(DaskDataWriter, "write_dataframe", mocked_write_dataframe) monkeypatch.setattr( - Component, + Executor, "upload_manifest", lambda self, manifest, save_path: None, ) +def patch_method_class(method): + """Patch a method on a class instead of an instance. The returned method can be passed to + `mock.patch.object` as the `wraps` argument. + """ + m = mock.MagicMock() + + def wrapper(self, *args, **kwargs): + m(*args, **kwargs) + return method(self, *args, **kwargs) + + wrapper.mock = m + return wrapper + + def test_component_arguments(): # Mock CLI arguments sys.argv = [ @@ -74,7 +94,7 @@ def test_component_arguments(): "None", ] - class MyComponent(Component): + class MyExecutor(Executor): """Base component with dummy methods so it can be instantiated.""" def _load_or_create_manifest(self) -> Manifest: @@ -83,8 +103,8 @@ def _load_or_create_manifest(self) -> Manifest: def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]: pass - component = MyComponent.from_args() - assert component.user_arguments == { + executor = MyExecutor.from_args() + assert executor.user_arguments == { "string_default_arg": "foo", "integer_default_arg": 1, "float_default_arg": 3.14, @@ -121,21 +141,25 @@ def test_load_component(): yaml_file_to_json_string(components_path / "component.yaml"), ] - class MyLoadComponent(LoadComponent): - def load(self, *, flag, value): - assert flag == "success" - assert value == 1 + class MyLoadComponent(DaskLoadComponent): + def __init__(self, *args, flag, value): + self.flag = flag + self.value = value + def load(self): + assert self.flag == "success" + assert self.value == 1 data = { "id": [0, 1], "captions_data": ["hello world", "this is another caption"], } return dd.DataFrame.from_dict(data, npartitions=N_PARTITIONS) - component = MyLoadComponent.from_args() - with mock.patch.object(MyLoadComponent, "load", wraps=component.load) as load: - component.run() - load.assert_called_once() + executor = DaskLoadExecutor.from_args() + load = patch_method_class(MyLoadComponent.load) + with mock.patch.object(MyLoadComponent, "load", load): + executor.execute(MyLoadComponent) + load.mock.assert_called_once() @pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing") @@ -158,20 +182,25 @@ def test_dask_transform_component(): ] class MyDaskComponent(DaskTransformComponent): - def transform(self, dataframe, *, flag, value): - assert flag == "success" - assert value == 1 + def __init__(self, *args, flag, value): + self.flag = flag + self.value = value + + def transform(self, dataframe): + assert self.flag == "success" + assert self.value == 1 assert isinstance(dataframe, dd.DataFrame) return dataframe - component = MyDaskComponent.from_args() + executor = DaskTransformExecutor.from_args() + transform = patch_method_class(MyDaskComponent.transform) with mock.patch.object( MyDaskComponent, "transform", - wraps=component.transform, - ) as transform: - component.run() - transform.assert_called_once() + transform, + ): + executor.execute(MyDaskComponent) + transform.mock.assert_called_once() @pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing") @@ -194,7 +223,7 @@ def test_pandas_transform_component(): ] class MyPandasComponent(PandasTransformComponent): - def setup(self, *, flag, value): + def __init__(self, *args, flag, value): assert flag == "success" assert value == 1 @@ -202,17 +231,17 @@ def transform(self, dataframe): assert isinstance(dataframe, pd.DataFrame) return dataframe.rename(columns={"images": "embeddings"}) - component = MyPandasComponent.from_args() - setup = mock.patch.object(MyPandasComponent, "setup", wraps=component.setup) - transform = mock.patch.object( + executor = PandasTransformExecutor.from_args() + init = patch_method_class(MyPandasComponent.__init__) + transform = patch_method_class(MyPandasComponent.transform) + with mock.patch.object(MyPandasComponent, "__init__", init), mock.patch.object( MyPandasComponent, "transform", - wraps=component.transform, - ) - with setup as setup, transform as transform: - component.run() - setup.assert_called_once() - assert transform.call_count == N_PARTITIONS + transform, + ): + executor.execute(MyPandasComponent) + init.mock.assert_called_once() + assert transform.mock.call_count == N_PARTITIONS def test_wrap_transform(): @@ -282,7 +311,7 @@ def transform(dataframe: pd.DataFrame) -> pd.DataFrame: ] return dataframe - wrapped_transform = PandasTransformComponent.wrap_transform(transform, spec=spec) + wrapped_transform = PandasTransformExecutor.wrap_transform(transform, spec=spec) output_df = wrapped_transform(input_df) # Check column flattening, trimming, and ordering @@ -306,13 +335,18 @@ def test_write_component(): yaml_file_to_json_string(components_path / "component.yaml"), ] - class MyWriteComponent(WriteComponent): - def write(self, dataframe, *, flag, value): - assert flag == "success" - assert value == 1 + class MyWriteComponent(DaskWriteComponent): + def __init__(self, *args, flag, value): + self.flag = flag + self.value = value + + def write(self, dataframe): + assert self.flag == "success" + assert self.value == 1 assert isinstance(dataframe, dd.DataFrame) - component = MyWriteComponent.from_args() - with mock.patch.object(MyWriteComponent, "write", wraps=component.write) as write: - component.run() - write.assert_called_once() + executor = DaskWriteExecutor.from_args() + write = patch_method_class(MyWriteComponent.write) + with mock.patch.object(MyWriteComponent, "write", write): + executor.execute(MyWriteComponent) + write.mock.assert_called_once()