From 29e27c8496a9e5ceace5928b012c8883d9a6fa3c Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Mon, 9 Oct 2023 23:12:27 +0200 Subject: [PATCH] Change input_partition_rows to accept -1 as default --- src/fondant/data_io.py | 68 ++++++++++++++++++++--------------------- src/fondant/executor.py | 7 ++--- src/fondant/pipeline.py | 7 +---- src/fondant/schema.py | 10 ------ 4 files changed, 38 insertions(+), 54 deletions(-) diff --git a/src/fondant/data_io.py b/src/fondant/data_io.py index 75e4a6bf5..d324b19af 100644 --- a/src/fondant/data_io.py +++ b/src/fondant/data_io.py @@ -24,7 +24,7 @@ def __init__( *, manifest: Manifest, component_spec: ComponentSpec, - input_partition_rows: t.Optional[t.Union[int, str]] = None, + input_partition_rows: int, ): super().__init__(manifest=manifest, component_spec=component_spec) self.input_partition_rows = input_partition_rows @@ -36,42 +36,42 @@ def partition_loaded_dataframe(self, dataframe: dd.DataFrame) -> dd.DataFrame: Returns: The partitioned dataframe. """ - if self.input_partition_rows != "disable": - if isinstance(self.input_partition_rows, int): - # Only load the index column to trigger a faster compute of the rows - total_rows = len(dataframe.index) - # +1 to handle any remainder rows - n_partitions = (total_rows // self.input_partition_rows) + 1 - dataframe = dataframe.repartition(npartitions=n_partitions) - logger.info( - f"Total number of rows is {total_rows}.\n" - f"Repartitioning the data from {dataframe.partitions} partitions to have" - f" {n_partitions} such that the number of partitions per row is approximately" - f"{self.input_partition_rows}", - ) + if self.input_partition_rows > 1: + # Only load the index column to trigger a faster compute of the rows + total_rows = len(dataframe.index) + # +1 to handle any remainder rows + n_partitions = (total_rows // self.input_partition_rows) + 1 + dataframe = dataframe.repartition(npartitions=n_partitions) + logger.info( + f"Total number of rows is {total_rows}.\n" + f"Repartitioning the data from {dataframe.partitions} partitions to have" + f" {n_partitions} such that the number of partitions per row is approximately" + f"{self.input_partition_rows}", + ) - elif self.input_partition_rows is None: - n_partitions = dataframe.npartitions - n_workers = os.cpu_count() - if n_partitions < n_workers: # type: ignore - logger.info( - f"The number of partitions of the input dataframe is {n_partitions}. The " - f"available number of workers is {n_workers}.", - ) - dataframe = dataframe.repartition(npartitions=n_workers) - logger.info( - f"Repartitioning the data to {n_workers} partitions before processing" - f" to maximize worker usage", - ) - else: - msg = ( - f"{self.input_partition_rows} is not a valid argument. Choose either " - f"the number of partitions or set to 'disable' to disable automated " - f"partitioning" + elif self.input_partition_rows == -1: + n_partitions = dataframe.npartitions + n_workers = os.cpu_count() + if n_partitions < n_workers: # type: ignore + logger.info( + f"The number of partitions of the input dataframe is {n_partitions}. The " + f"available number of workers is {n_workers}.", ) - raise ValueError( - msg, + dataframe = dataframe.repartition(npartitions=n_workers) + logger.info( + f"Repartitioning the data to {n_workers} partitions before processing" + f" to maximize worker usage", ) + else: + msg = ( + f"{self.input_partition_rows} is not a valid value for the 'input_partition_rows' " + f"parameter. It should be a number larger than 0 to indicate the number of " + f"expected rows per partition, or '-1' to let Fondant optimize the number of " + f"partitions based on the number of available workers." + ) + raise ValueError( + msg, + ) return dataframe diff --git a/src/fondant/executor.py b/src/fondant/executor.py index 24ac10039..41e97829f 100644 --- a/src/fondant/executor.py +++ b/src/fondant/executor.py @@ -29,7 +29,6 @@ from fondant.component_spec import Argument, ComponentSpec from fondant.data_io import DaskDataLoader, DaskDataWriter from fondant.manifest import Manifest, Metadata -from fondant.schema import validate_partition_number dask.config.set({"dataframe.convert-string": False}) logger = logging.getLogger(__name__) @@ -65,7 +64,7 @@ def __init__( output_manifest_path: t.Union[str, Path], metadata: t.Dict[str, t.Any], user_arguments: t.Dict[str, t.Any], - input_partition_rows: t.Optional[t.Union[str, int]] = None, + input_partition_rows: int, cluster_type: t.Optional[str] = None, client_kwargs: t.Optional[dict] = None, ) -> None: @@ -111,7 +110,7 @@ def from_args(cls) -> "Executor": parser = argparse.ArgumentParser() parser.add_argument("--component_spec", type=json.loads) parser.add_argument("--cache", type=lambda x: bool(strtobool(x))) - parser.add_argument("--input_partition_rows", type=validate_partition_number) + parser.add_argument("--input_partition_rows", type=int) parser.add_argument("--cluster_type", type=str) parser.add_argument("--client_kwargs", type=json.loads) args, _ = parser.parse_known_args() @@ -140,7 +139,7 @@ def from_spec( component_spec: ComponentSpec, *, cache: bool, - input_partition_rows: t.Optional[t.Union[str, int]], + input_partition_rows: int, cluster_type: t.Optional[str], client_kwargs: t.Optional[dict], ) -> "Executor": diff --git a/src/fondant/pipeline.py b/src/fondant/pipeline.py index e95bd833e..0e2a296cf 100644 --- a/src/fondant/pipeline.py +++ b/src/fondant/pipeline.py @@ -16,7 +16,6 @@ from fondant.component_spec import ComponentSpec from fondant.exceptions import InvalidPipelineDefinition from fondant.manifest import Manifest -from fondant.schema import validate_partition_number logger = logging.getLogger(__name__) @@ -107,11 +106,7 @@ def __init__( self.client_kwargs = client_kwargs self.arguments = arguments or {} - self._add_component_argument( - "input_partition_rows", - input_partition_rows, - validate_partition_number, - ) + self._add_component_argument("input_partition_rows", input_partition_rows) self._add_component_argument("cache", self.cache) self._add_component_argument("cluster_type", cluster_type) self._add_component_argument("client_kwargs", client_kwargs) diff --git a/src/fondant/schema.py b/src/fondant/schema.py index 46a73b1ef..fea6ebe37 100644 --- a/src/fondant/schema.py +++ b/src/fondant/schema.py @@ -157,16 +157,6 @@ class Field(t.NamedTuple): type: Type -def validate_partition_number(arg_value): - if arg_value in ["disable", None, "None"]: - return arg_value if arg_value != "None" else None - try: - return int(arg_value) - except ValueError: - msg = f"Invalid format for '{arg_value}'. The value must be an integer or set to 'disable'" - raise InvalidTypeSchema(msg) - - def validate_partition_size(arg_value): if arg_value in ["disable", None, "None"]: return arg_value if arg_value != "None" else None