Skip to content

Commit

Permalink
[FEAT] connect: add repartition support
Browse files Browse the repository at this point in the history
- [ ] the test is not great but idk how to do it better since rdd does
  not work with spark connect (I think)
- [ ] do we want to support non-shuffle repartitioning?
  • Loading branch information
andrewgazelka committed Dec 18, 2024
1 parent 07752b8 commit 5561e14
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, drop::drop, filter::filter, local_relation::local_relation,
project::project, range::range, read::read, to_df::to_df, with_columns::with_columns,
project::project, range::range, read::read, repartition::repartition, to_df::to_df,
with_columns::with_columns,
};

mod aggregate;
Expand All @@ -16,6 +17,7 @@ mod local_relation;
mod project;
mod range;
mod read;
mod repartition;
mod to_df;
mod with_columns;

Expand Down Expand Up @@ -70,6 +72,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
RelType::WithColumns(w) => with_columns(*w)
.await
.wrap_err("Failed to apply with_columns to logical plan"),
RelType::Repartition(r) => repartition(*r)
.await

Check warning on line 76 in src/daft-connect/src/translation/logical_plan.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan.rs#L76

Added line #L76 was not covered by tests
.wrap_err("Failed to apply repartition to logical plan"),
RelType::ToDf(t) => to_df(*t)
.await
.wrap_err("Failed to apply to_df to logical plan"),
Expand Down
40 changes: 40 additions & 0 deletions src/daft-connect/src/translation/logical_plan/repartition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use eyre::{bail, ensure, WrapErr};

use crate::translation::{to_logical_plan, Plan};

pub async fn repartition(repartition: spark_connect::Repartition) -> eyre::Result<Plan> {
let spark_connect::Repartition {
input,
num_partitions,
shuffle,
} = repartition;

let Some(input) = input else {
bail!("Input is required");

Check warning on line 13 in src/daft-connect/src/translation/logical_plan/repartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/repartition.rs#L13

Added line #L13 was not covered by tests
};

let num_partitions = usize::try_from(num_partitions).map_err(|_| {
eyre::eyre!("Num partitions must be a positive integer, got {num_partitions}")

Check warning on line 17 in src/daft-connect/src/translation/logical_plan/repartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/repartition.rs#L17

Added line #L17 was not covered by tests
})?;

ensure!(
num_partitions > 0,
"Num partitions must be greater than 0, got {num_partitions}"

Check warning on line 22 in src/daft-connect/src/translation/logical_plan/repartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/repartition.rs#L22

Added line #L22 was not covered by tests
);

let mut plan = Box::pin(to_logical_plan(*input)).await?;

// let's make true is default
let shuffle = shuffle.unwrap_or(true);

if !shuffle {
bail!("Repartitioning without shuffling is not yet supported");

Check warning on line 31 in src/daft-connect/src/translation/logical_plan/repartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/repartition.rs#L31

Added line #L31 was not covered by tests
}

plan.builder = plan
.builder
.random_shuffle(Some(num_partitions))
.wrap_err("Failed to apply random_shuffle to logical plan")?;

Ok(plan)
}
13 changes: 13 additions & 0 deletions tests/connect/test_repartition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

def test_repartition(spark_session):
# Create a simple DataFrame
df = spark_session.range(10)

# Test repartitioning to 2 partitions
repartitioned = df.repartition(2)

# Verify data is preserved after repartitioning
original_data = sorted(df.collect())
repartitioned_data = sorted(repartitioned.collect())
assert repartitioned_data == original_data, "Data should be preserved after repartitioning"

0 comments on commit 5561e14

Please sign in to comment.