Skip to content

Commit

Permalink
Update test_sort.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 4, 2024
1 parent 96332b1 commit 1c94dc2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
4 changes: 3 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use eyre::{bail, Context};
use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;

use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range, sort::sort};
use crate::translation::logical_plan::{
aggregate::aggregate, project::project, range::range, sort::sort,
};

mod aggregate;
mod project;
Expand Down
17 changes: 5 additions & 12 deletions src/daft-connect/src/translation/logical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,24 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result<LogicalPlanBuilder> {
let null_ordering = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid null ordering: {null_ordering:?}"))?;

// todo(correctness): is this correct?
let is_descending = match direction {
SortDirection::Unspecified => {
bail!("Unspecified sort direction is not yet supported")
// default to ascending order
false
}
SortDirection::Ascending => false,
SortDirection::Descending => true,
};

// todo(correctness): is this correct?
let tentative_sort_nulls_first = match null_ordering {
let sort_nulls_first = match null_ordering {
NullOrdering::SortNullsUnspecified => {
bail!("Unspecified null ordering is not yet supported")
// default: match is_descending
is_descending
}
NullOrdering::SortNullsFirst => true,
NullOrdering::SortNullsLast => false,
};

// https://github.com/Eventual-Inc/Daft/blob/7922d2d810ff92b00008d877aa9a6553bc0dedab/src/daft-core/src/utils/mod.rs#L10-L19
let sort_nulls_first = is_descending;

if sort_nulls_first != tentative_sort_nulls_first {
warn!("Ignoring nulls_first {sort_nulls_first}; not yet implemented");
}

sort_by.push(child);
descending.push(is_descending);
nulls_first.push(sort_nulls_first);
Expand Down
29 changes: 21 additions & 8 deletions tests/connect/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,27 @@
from pyspark.sql.functions import col


def test_sort(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)
def test_sort_multiple_columns(spark_session):
# Create DataFrame with two columns using range
df = spark_session.range(4).select(col("id").alias("num"), col("id").alias("letter"))

# Sort the DataFrame by 'id' column in descending order
df_sorted = df.sort(col("id").desc())
# Sort by multiple columns
df_sorted = df.sort(col("num").asc(), col("letter").desc())

# Verify the DataFrame is sorted correctly
df_pandas = df.toPandas()
df_sorted_pandas = df_sorted.toPandas()
assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order"
actual = df_sorted.collect()
expected = [(0, 0), (1, 1), (2, 2), (3, 3)]
assert [(row.num, row.letter) for row in actual] == expected


def test_sort_mixed_order(spark_session):
# Create DataFrame with two columns using range
df = spark_session.range(4).select(col("id").alias("num"), col("id").alias("letter"))

# Sort with mixed ascending/descending order
df_sorted = df.sort(col("num").desc(), col("letter").asc())

# Verify the DataFrame is sorted correctly
actual = df_sorted.collect()
expected = [(3, 3), (2, 2), (1, 1), (0, 0)]
assert [(row.num, row.letter) for row in actual] == expected

0 comments on commit 1c94dc2

Please sign in to comment.