Skip to content

Commit

Permalink
[FEAT] connect: add repartition support
Browse files Browse the repository at this point in the history
- [ ] the test is not great but idk how to do it better since rdd does
  not work with spark connect (I think)
- [ ] do we want to support non-shuffle repartitioning?
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent c30f6a8 commit 7d2d9d6
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ use futures::TryStreamExt;
use spark_connect::{relation::RelType, Limit, Relation, ShowString};
use tracing::warn;

use crate::translation::logical_plan::repartition::repartition;

mod aggregate;
mod drop;
mod filter;
mod local_relation;
mod project;
mod range;
mod read;
mod repartition;
mod to_df;
mod with_columns;

Expand Down Expand Up @@ -99,6 +102,9 @@ impl SparkAnalyzer<'_> {
.with_columns(*w)
.await
.wrap_err("Failed to apply with_columns to logical plan"),
RelType::Repartition(r) => repartition(*r)
.await
.wrap_err("Failed to apply repartition to logical plan"),
RelType::ToDf(t) => self
.to_df(*t)
.await
Expand Down
40 changes: 40 additions & 0 deletions src/daft-connect/src/translation/logical_plan/repartition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use eyre::{bail, ensure, WrapErr};

use crate::translation::{to_logical_plan, Plan};

pub async fn repartition(repartition: spark_connect::Repartition) -> eyre::Result<Plan> {
let spark_connect::Repartition {
input,
num_partitions,
shuffle,
} = repartition;

let Some(input) = input else {
bail!("Input is required");
};

let num_partitions = usize::try_from(num_partitions).map_err(|_| {
eyre::eyre!("Num partitions must be a positive integer, got {num_partitions}")
})?;

ensure!(
num_partitions > 0,
"Num partitions must be greater than 0, got {num_partitions}"
);

let mut plan = Box::pin(to_logical_plan(*input)).await?;

// let's make true is default
let shuffle = shuffle.unwrap_or(true);

if !shuffle {
bail!("Repartitioning without shuffling is not yet supported");
}

plan.builder = plan
.builder
.random_shuffle(Some(num_partitions))
.wrap_err("Failed to apply random_shuffle to logical plan")?;

Ok(plan)
}
13 changes: 13 additions & 0 deletions tests/connect/test_repartition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

def test_repartition(spark_session):
# Create a simple DataFrame
df = spark_session.range(10)

# Test repartitioning to 2 partitions
repartitioned = df.repartition(2)

# Verify data is preserved after repartitioning
original_data = sorted(df.collect())
repartitioned_data = sorted(repartitioned.collect())
assert repartitioned_data == original_data, "Data should be preserved after repartitioning"

0 comments on commit 7d2d9d6

Please sign in to comment.