From 5561e14c1277b8f72f4746b65d704a39f88a0161 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 20:48:52 -0800 Subject: [PATCH] [FEAT] connect: add repartition support - [ ] the test is not great but idk how to do it better since rdd does not work with spark connect (I think) - [ ] do we want to support non-shuffle repartitioning? --- .../src/translation/logical_plan.rs | 7 +++- .../translation/logical_plan/repartition.rs | 40 +++++++++++++++++++ tests/connect/test_repartition.py | 13 ++++++ 3 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/repartition.rs create mode 100644 tests/connect/test_repartition.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index b6097d17ad..9da7a373db 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -6,7 +6,8 @@ use tracing::warn; use crate::translation::logical_plan::{ aggregate::aggregate, drop::drop, filter::filter, local_relation::local_relation, - project::project, range::range, read::read, to_df::to_df, with_columns::with_columns, + project::project, range::range, read::read, repartition::repartition, to_df::to_df, + with_columns::with_columns, }; mod aggregate; @@ -16,6 +17,7 @@ mod local_relation; mod project; mod range; mod read; +mod repartition; mod to_df; mod with_columns; @@ -70,6 +72,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::WithColumns(w) => with_columns(*w) .await .wrap_err("Failed to apply with_columns to logical plan"), + RelType::Repartition(r) => repartition(*r) + .await + .wrap_err("Failed to apply repartition to logical plan"), RelType::ToDf(t) => to_df(*t) .await .wrap_err("Failed to apply to_df to logical plan"), diff --git a/src/daft-connect/src/translation/logical_plan/repartition.rs b/src/daft-connect/src/translation/logical_plan/repartition.rs new file mode 100644 index 0000000000..5e94051591 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/repartition.rs @@ -0,0 +1,40 @@ +use eyre::{bail, ensure, WrapErr}; + +use crate::translation::{to_logical_plan, Plan}; + +pub async fn repartition(repartition: spark_connect::Repartition) -> eyre::Result { + let spark_connect::Repartition { + input, + num_partitions, + shuffle, + } = repartition; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let num_partitions = usize::try_from(num_partitions).map_err(|_| { + eyre::eyre!("Num partitions must be a positive integer, got {num_partitions}") + })?; + + ensure!( + num_partitions > 0, + "Num partitions must be greater than 0, got {num_partitions}" + ); + + let mut plan = Box::pin(to_logical_plan(*input)).await?; + + // let's make true is default + let shuffle = shuffle.unwrap_or(true); + + if !shuffle { + bail!("Repartitioning without shuffling is not yet supported"); + } + + plan.builder = plan + .builder + .random_shuffle(Some(num_partitions)) + .wrap_err("Failed to apply random_shuffle to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_repartition.py b/tests/connect/test_repartition.py new file mode 100644 index 0000000000..7d7c8e25f6 --- /dev/null +++ b/tests/connect/test_repartition.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +def test_repartition(spark_session): + # Create a simple DataFrame + df = spark_session.range(10) + + # Test repartitioning to 2 partitions + repartitioned = df.repartition(2) + + # Verify data is preserved after repartitioning + original_data = sorted(df.collect()) + repartitioned_data = sorted(repartitioned.collect()) + assert repartitioned_data == original_data, "Data should be preserved after repartitioning" \ No newline at end of file