Skip to content

Commit

Permalink
Update test_sort.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 11, 2024
1 parent 8a9c54a commit e5b129a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
17 changes: 5 additions & 12 deletions src/daft-connect/src/translation/logical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,24 @@ pub async fn sort(sort: spark_connect::Sort) -> eyre::Result<Plan> {
let null_ordering = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid null ordering: {null_ordering:?}"))?;

// todo(correctness): is this correct?
let is_descending = match direction {
SortDirection::Unspecified => {
bail!("Unspecified sort direction is not yet supported")
// default to ascending order
false
}
SortDirection::Ascending => false,
SortDirection::Descending => true,
};

// todo(correctness): is this correct?
let tentative_sort_nulls_first = match null_ordering {
let sort_nulls_first = match null_ordering {
NullOrdering::SortNullsUnspecified => {
bail!("Unspecified null ordering is not yet supported")
// default: match is_descending
is_descending
}
NullOrdering::SortNullsFirst => true,
NullOrdering::SortNullsLast => false,
};

// https://github.com/Eventual-Inc/Daft/blob/7922d2d810ff92b00008d877aa9a6553bc0dedab/src/daft-core/src/utils/mod.rs#L10-L19
let sort_nulls_first = is_descending;

if sort_nulls_first != tentative_sort_nulls_first {
warn!("Ignoring nulls_first {sort_nulls_first}; not yet implemented");
}

sort_by.push(child);
descending.push(is_descending);
nulls_first.push(sort_nulls_first);
Expand Down
29 changes: 21 additions & 8 deletions tests/connect/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,27 @@
from pyspark.sql.functions import col


def test_sort(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)
def test_sort_multiple_columns(spark_session):
# Create DataFrame with two columns using range
df = spark_session.range(4).select(col("id").alias("num"), col("id").alias("letter"))

# Sort the DataFrame by 'id' column in descending order
df_sorted = df.sort(col("id").desc())
# Sort by multiple columns
df_sorted = df.sort(col("num").asc(), col("letter").desc())

# Verify the DataFrame is sorted correctly
df_pandas = df.toPandas()
df_sorted_pandas = df_sorted.toPandas()
assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order"
actual = df_sorted.collect()
expected = [(0, 0), (1, 1), (2, 2), (3, 3)]
assert [(row.num, row.letter) for row in actual] == expected


def test_sort_mixed_order(spark_session):
# Create DataFrame with two columns using range
df = spark_session.range(4).select(col("id").alias("num"), col("id").alias("letter"))

# Sort with mixed ascending/descending order
df_sorted = df.sort(col("num").desc(), col("letter").asc())

# Verify the DataFrame is sorted correctly
actual = df_sorted.collect()
expected = [(3, 3), (2, 2), (1, 1), (0, 0)]
assert [(row.num, row.letter) for row in actual] == expected

0 comments on commit e5b129a

Please sign in to comment.