Skip to content

Commit

Permalink
[FEAT]: connect: df.sort
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 4, 2024
1 parent caf0626 commit 96332b1
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use eyre::{bail, Context};
use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;

use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range};
use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range, sort::sort};

mod aggregate;
mod project;
mod range;
mod sort;

pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
if let Some(common) = relation.common {
Expand All @@ -25,6 +26,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
RelType::Aggregate(a) => {
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
}
RelType::Sort(s) => sort(*s).wrap_err("Failed to apply sort to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
83 changes: 83 additions & 0 deletions src/daft-connect/src/translation/logical_plan/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use eyre::{bail, WrapErr};
use spark_connect::expression::{
sort_order::{NullOrdering, SortDirection},
SortOrder,
};
use tracing::warn;

use crate::translation::{logical_plan::LogicalPlanBuilder, to_daft_expr, to_logical_plan};

pub fn sort(sort: spark_connect::Sort) -> eyre::Result<LogicalPlanBuilder> {
let spark_connect::Sort {
input,
order,
is_global,
} = sort;

if let Some(is_global) = is_global {
warn!("Ignoring is_global {is_global}; not yet implemented");
}

let Some(input) = input else {
bail!("Input is required");
};

let plan = to_logical_plan(*input)?;

let mut sort_by = Vec::new();
let mut descending = Vec::new();
let mut nulls_first = Vec::new();

for o in &order {
let SortOrder {
child,
direction,
null_ordering,
} = o;

let Some(child) = child else {
bail!("Child is required");
};

let child = to_daft_expr(child)?;

let direction = SortDirection::try_from(*direction)
.wrap_err_with(|| format!("Invalid sort direction: {direction:?}"))?;

let null_ordering = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid null ordering: {null_ordering:?}"))?;

// todo(correctness): is this correct?
let is_descending = match direction {
SortDirection::Unspecified => {
bail!("Unspecified sort direction is not yet supported")
}
SortDirection::Ascending => false,
SortDirection::Descending => true,
};

// todo(correctness): is this correct?
let tentative_sort_nulls_first = match null_ordering {
NullOrdering::SortNullsUnspecified => {
bail!("Unspecified null ordering is not yet supported")
}
NullOrdering::SortNullsFirst => true,
NullOrdering::SortNullsLast => false,
};

// https://github.com/Eventual-Inc/Daft/blob/7922d2d810ff92b00008d877aa9a6553bc0dedab/src/daft-core/src/utils/mod.rs#L10-L19
let sort_nulls_first = is_descending;

if sort_nulls_first != tentative_sort_nulls_first {
warn!("Ignoring nulls_first {sort_nulls_first}; not yet implemented");
}

sort_by.push(child);
descending.push(is_descending);
nulls_first.push(sort_nulls_first);
}

let plan = plan.sort(sort_by, descending, nulls_first)?;

Ok(plan)
}
16 changes: 16 additions & 0 deletions tests/connect/test_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_sort(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Sort the DataFrame by 'id' column in descending order
df_sorted = df.sort(col("id").desc())

# Verify the DataFrame is sorted correctly
df_pandas = df.toPandas()
df_sorted_pandas = df_sorted.toPandas()
assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order"

0 comments on commit 96332b1

Please sign in to comment.