diff --git a/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_dataflow.py b/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_dataflow.py index 31be7a6b89..6fc53b67f2 100644 --- a/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_dataflow.py +++ b/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_dataflow.py @@ -1,3 +1,4 @@ +import logging import os from typing import List @@ -7,11 +8,11 @@ from bytewax.execution import cluster_main from bytewax.inputs import ManualInputConfig from bytewax.outputs import ManualOutputConfig -from tqdm import tqdm from feast import FeatureStore, FeatureView, RepoConfig from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping +logger = logging.getLogger(__name__) DEFAULT_BATCH_SIZE = 1000 @@ -29,14 +30,20 @@ def __init__( self.feature_view = feature_view self.worker_index = worker_index self.paths = paths + self.mini_batch_size = int( + os.getenv("BYTEWAX_MINI_BATCH_SIZE", DEFAULT_BATCH_SIZE) + ) self._run_dataflow() def process_path(self, path): + logger.info(f"Processing path {path}") dataset = pq.ParquetDataset(path, use_legacy_dataset=False) batches = [] for fragment in dataset.fragments: - for batch in fragment.to_table().to_batches(): + for batch in fragment.to_table().to_batches( + max_chunksize=self.mini_batch_size + ): batches.append(batch) return batches @@ -45,40 +52,26 @@ def input_builder(self, worker_index, worker_count, _state): return [(None, self.paths[self.worker_index])] def output_builder(self, worker_index, worker_count): - def yield_batch(iterable, batch_size): - """Yield mini-batches from an iterable.""" - for i in range(0, len(iterable), batch_size): - yield iterable[i : i + batch_size] - - def output_fn(batch): - table = pa.Table.from_batches([batch]) + def output_fn(mini_batch): + table: pa.Table = pa.Table.from_batches([mini_batch]) if self.feature_view.batch_source.field_mapping is not None: table = _run_pyarrow_field_mapping( table, self.feature_view.batch_source.field_mapping ) - join_key_to_value_type = { entity.name: entity.dtype.to_value_type() for entity in self.feature_view.entity_columns } - rows_to_write = _convert_arrow_to_proto( table, self.feature_view, join_key_to_value_type ) - provider = self.feature_store._get_provider() - with tqdm(total=len(rows_to_write)) as progress: - # break rows_to_write to mini-batches - batch_size = int( - os.getenv("BYTEWAX_MINI_BATCH_SIZE", DEFAULT_BATCH_SIZE) - ) - for mini_batch in yield_batch(rows_to_write, batch_size): - provider.online_write_batch( - config=self.config, - table=self.feature_view, - data=mini_batch, - progress=progress.update, - ) + self.feature_store._get_provider().online_write_batch( + config=self.config, + table=self.feature_view, + data=rows_to_write, + progress=None, + ) return output_fn diff --git a/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_engine.py index b2f8985f87..16742ff0f7 100644 --- a/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/bytewax/bytewax_materialization_engine.py @@ -82,6 +82,12 @@ class BytewaxMaterializationEngineConfig(FeastConfigBaseModel): mini_batch_size: int = 1000 """ (optional) Number of rows to process per write operation (default 1000)""" + bytewax_replicas: int = 5 + """ (optional) Number of process to spawn in each pods to handle a file in parallel""" + + bytewax_worker_per_process: int = 1 + """ (optional) Number of threads as worker per bytewax process""" + active_deadline_seconds: int = 86400 """ (optional) Maximum amount of time a materialization job is allowed to run""" @@ -111,7 +117,6 @@ def __init__( self.offline_store = offline_store self.online_store = online_store - # TODO: Configure k8s here k8s_config.load_config() self.k8s_client = client.api_client.ApiClient() @@ -299,6 +304,9 @@ def _create_kubernetes_job(self, job_id, paths, feature_view): len(paths), # Create a pod for each parquet file self.batch_engine_config.env, ) + logger.info( + f"Created job `dataflow-{job_id}` on namespace `{self.namespace}`" + ) except FailToCreateError as failures: return BytewaxMaterializationJob(job_id, self.namespace, error=failures) @@ -345,7 +353,7 @@ def _create_job_definition(self, job_id, namespace, pods, env, index_offset=0): {"name": "BYTEWAX_WORKDIR", "value": "/bytewax"}, { "name": "BYTEWAX_WORKERS_PER_PROCESS", - "value": "1", + "value": f"{self.batch_engine_config.bytewax_worker_per_process}", }, { "name": "BYTEWAX_POD_NAME", @@ -358,7 +366,7 @@ def _create_job_definition(self, job_id, namespace, pods, env, index_offset=0): }, { "name": "BYTEWAX_REPLICAS", - "value": f"{pods}", + "value": f"{self.batch_engine_config.bytewax_replicas}", }, { "name": "BYTEWAX_KEEP_CONTAINER_ALIVE", diff --git a/sdk/python/feast/infra/materialization/contrib/bytewax/dataflow.py b/sdk/python/feast/infra/materialization/contrib/bytewax/dataflow.py index 23cdc20ef3..9d9b328c0e 100644 --- a/sdk/python/feast/infra/materialization/contrib/bytewax/dataflow.py +++ b/sdk/python/feast/infra/materialization/contrib/bytewax/dataflow.py @@ -1,3 +1,4 @@ +import logging import os import yaml @@ -8,6 +9,8 @@ ) if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + with open("/var/feast/feature_store.yaml") as f: feast_config = yaml.safe_load(f)