diff --git a/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_dataflow.py b/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_dataflow.py index fe2a7f35c1..e9d6a756b2 100644 --- a/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_dataflow.py +++ b/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_dataflow.py @@ -1,3 +1,4 @@ +import os from typing import List import pyarrow as pa @@ -11,6 +12,8 @@ from feast import FeatureStore, FeatureView, RepoConfig from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping +DEFAULT_BATCH_SIZE = 1000 + class BytewaxMaterializationDataflow: def __init__( @@ -44,6 +47,11 @@ def input_builder(self, worker_index, worker_count, _state): return def output_builder(self, worker_index, worker_count): + def yield_batch(iterable, batch_size): + """Yield mini-batches from an iterable.""" + for i in range(0, len(iterable), batch_size): + yield iterable[i : i + batch_size] + def output_fn(batch): table = pa.Table.from_batches([batch]) @@ -62,12 +70,17 @@ def output_fn(batch): ) provider = self.feature_store._get_provider() with tqdm(total=len(rows_to_write)) as progress: - provider.online_write_batch( - config=self.config, - table=self.feature_view, - data=rows_to_write, - progress=progress.update, + # break rows_to_write to mini-batches + batch_size = int( + os.getenv("BYTEWAX_MINI_BATCH_SIZE", DEFAULT_BATCH_SIZE) ) + for mini_batch in yield_batch(rows_to_write, batch_size): + provider.online_write_batch( + config=self.config, + table=self.feature_view, + data=mini_batch, + progress=progress.update, + ) return output_fn diff --git a/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_engine.py index 21b7a5da1f..787dd585ff 100644 --- a/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_engine.py @@ -67,6 +67,9 @@ class BytewaxMaterializationEngineConfig(FeastConfigBaseModel): max_parallelism: int = 10 """ (optional) Maximum number of pods (default 10) allowed to run in parallel per job""" + mini_batch_size: int = 1000 + """ (optional) Number of rows to process per write operation (default 1000)""" + class BytewaxMaterializationEngine(BatchMaterializationEngine): def __init__( @@ -254,6 +257,10 @@ def _create_job_definition(self, job_id, namespace, pods, env): "name": "BYTEWAX_STATEFULSET_NAME", "value": f"dataflow-{job_id}", }, + { + "name": "BYTEWAX_MINI_BATCH_SIZE", + "value": str(self.batch_engine_config.mini_batch_size), + }, ] # Add any Feast configured environment variables job_env.extend(env)