Skip to content

Commit

Permalink
Change input_partition_rows to accept -1 as default
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Oct 9, 2023
1 parent 55fc4fe commit 29e27c8
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 54 deletions.
68 changes: 34 additions & 34 deletions src/fondant/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
*,
manifest: Manifest,
component_spec: ComponentSpec,
input_partition_rows: t.Optional[t.Union[int, str]] = None,
input_partition_rows: int,
):
super().__init__(manifest=manifest, component_spec=component_spec)
self.input_partition_rows = input_partition_rows
Expand All @@ -36,42 +36,42 @@ def partition_loaded_dataframe(self, dataframe: dd.DataFrame) -> dd.DataFrame:
Returns:
The partitioned dataframe.
"""
if self.input_partition_rows != "disable":
if isinstance(self.input_partition_rows, int):
# Only load the index column to trigger a faster compute of the rows
total_rows = len(dataframe.index)
# +1 to handle any remainder rows
n_partitions = (total_rows // self.input_partition_rows) + 1
dataframe = dataframe.repartition(npartitions=n_partitions)
logger.info(
f"Total number of rows is {total_rows}.\n"
f"Repartitioning the data from {dataframe.partitions} partitions to have"
f" {n_partitions} such that the number of partitions per row is approximately"
f"{self.input_partition_rows}",
)
if self.input_partition_rows > 1:
# Only load the index column to trigger a faster compute of the rows
total_rows = len(dataframe.index)
# +1 to handle any remainder rows
n_partitions = (total_rows // self.input_partition_rows) + 1
dataframe = dataframe.repartition(npartitions=n_partitions)
logger.info(
f"Total number of rows is {total_rows}.\n"
f"Repartitioning the data from {dataframe.partitions} partitions to have"
f" {n_partitions} such that the number of partitions per row is approximately"
f"{self.input_partition_rows}",
)

elif self.input_partition_rows is None:
n_partitions = dataframe.npartitions
n_workers = os.cpu_count()
if n_partitions < n_workers: # type: ignore
logger.info(
f"The number of partitions of the input dataframe is {n_partitions}. The "
f"available number of workers is {n_workers}.",
)
dataframe = dataframe.repartition(npartitions=n_workers)
logger.info(
f"Repartitioning the data to {n_workers} partitions before processing"
f" to maximize worker usage",
)
else:
msg = (
f"{self.input_partition_rows} is not a valid argument. Choose either "
f"the number of partitions or set to 'disable' to disable automated "
f"partitioning"
elif self.input_partition_rows == -1:
n_partitions = dataframe.npartitions
n_workers = os.cpu_count()
if n_partitions < n_workers: # type: ignore
logger.info(
f"The number of partitions of the input dataframe is {n_partitions}. The "
f"available number of workers is {n_workers}.",
)
raise ValueError(
msg,
dataframe = dataframe.repartition(npartitions=n_workers)
logger.info(
f"Repartitioning the data to {n_workers} partitions before processing"
f" to maximize worker usage",
)
else:
msg = (
f"{self.input_partition_rows} is not a valid value for the 'input_partition_rows' "
f"parameter. It should be a number larger than 0 to indicate the number of "
f"expected rows per partition, or '-1' to let Fondant optimize the number of "
f"partitions based on the number of available workers."
)
raise ValueError(
msg,
)

return dataframe

Expand Down
7 changes: 3 additions & 4 deletions src/fondant/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from fondant.component_spec import Argument, ComponentSpec
from fondant.data_io import DaskDataLoader, DaskDataWriter
from fondant.manifest import Manifest, Metadata
from fondant.schema import validate_partition_number

dask.config.set({"dataframe.convert-string": False})
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,7 +64,7 @@ def __init__(
output_manifest_path: t.Union[str, Path],
metadata: t.Dict[str, t.Any],
user_arguments: t.Dict[str, t.Any],
input_partition_rows: t.Optional[t.Union[str, int]] = None,
input_partition_rows: int,
cluster_type: t.Optional[str] = None,
client_kwargs: t.Optional[dict] = None,
) -> None:
Expand Down Expand Up @@ -111,7 +110,7 @@ def from_args(cls) -> "Executor":
parser = argparse.ArgumentParser()
parser.add_argument("--component_spec", type=json.loads)
parser.add_argument("--cache", type=lambda x: bool(strtobool(x)))
parser.add_argument("--input_partition_rows", type=validate_partition_number)
parser.add_argument("--input_partition_rows", type=int)
parser.add_argument("--cluster_type", type=str)
parser.add_argument("--client_kwargs", type=json.loads)
args, _ = parser.parse_known_args()
Expand Down Expand Up @@ -140,7 +139,7 @@ def from_spec(
component_spec: ComponentSpec,
*,
cache: bool,
input_partition_rows: t.Optional[t.Union[str, int]],
input_partition_rows: int,
cluster_type: t.Optional[str],
client_kwargs: t.Optional[dict],
) -> "Executor":
Expand Down
7 changes: 1 addition & 6 deletions src/fondant/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from fondant.component_spec import ComponentSpec
from fondant.exceptions import InvalidPipelineDefinition
from fondant.manifest import Manifest
from fondant.schema import validate_partition_number

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,11 +106,7 @@ def __init__(
self.client_kwargs = client_kwargs

self.arguments = arguments or {}
self._add_component_argument(
"input_partition_rows",
input_partition_rows,
validate_partition_number,
)
self._add_component_argument("input_partition_rows", input_partition_rows)
self._add_component_argument("cache", self.cache)
self._add_component_argument("cluster_type", cluster_type)
self._add_component_argument("client_kwargs", client_kwargs)
Expand Down
10 changes: 0 additions & 10 deletions src/fondant/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,6 @@ class Field(t.NamedTuple):
type: Type


def validate_partition_number(arg_value):
if arg_value in ["disable", None, "None"]:
return arg_value if arg_value != "None" else None
try:
return int(arg_value)
except ValueError:
msg = f"Invalid format for '{arg_value}'. The value must be an integer or set to 'disable'"
raise InvalidTypeSchema(msg)


def validate_partition_size(arg_value):
if arg_value in ["disable", None, "None"]:
return arg_value if arg_value != "None" else None
Expand Down

0 comments on commit 29e27c8

Please sign in to comment.