Skip to content

Commit

Permalink
Update test_sort.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 25, 2024
1 parent e6833f7 commit b919cd7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
6 changes: 4 additions & 2 deletions src/daft-connect/src/translation/logical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result<LogicalPlanBuilder> {
// todo(correctness): is this correct?
let is_descending = match direction {
SortDirection::Unspecified => {
bail!("Unspecified sort direction is not yet supported")
// default to ascending order
false
}
SortDirection::Ascending => false,
SortDirection::Descending => true,
Expand All @@ -59,7 +60,8 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result<LogicalPlanBuilder> {
// todo(correctness): is this correct?
let tentative_sort_nulls_first = match null_ordering {
NullOrdering::SortNullsUnspecified => {
bail!("Unspecified null ordering is not yet supported")
// default: match is_descending
is_descending
}
NullOrdering::SortNullsFirst => true,
NullOrdering::SortNullsLast => false,
Expand Down
35 changes: 27 additions & 8 deletions tests/connect/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,33 @@
from pyspark.sql.functions import col


def test_sort(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)
def test_sort_multiple_columns(spark_session):
# Create DataFrame with two columns using range
df = spark_session.range(4).select(
(col("id") % 2).alias("num"),
(col("id") % 2).cast("string").alias("letter")
)

# Sort the DataFrame by 'id' column in descending order
df_sorted = df.sort(col("id").desc())
# Sort by multiple columns
df_sorted = df.sort(col("num").asc(), col("letter").desc())

# Verify the DataFrame is sorted correctly
df_pandas = df.toPandas()
df_sorted_pandas = df_sorted.toPandas()
assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order"
actual = df_sorted.collect()
expected = [(0, "0"), (0, "0"), (1, "1"), (1, "1")]
assert [(row.num, row.letter) for row in actual] == expected


def test_sort_mixed_order(spark_session):
# Create DataFrame with two columns using range
df = spark_session.range(4).select(
(col("id") % 2).alias("num"),
(col("id") % 2).cast("string").alias("letter")
)

# Sort with mixed ascending/descending order
df_sorted = df.sort(col("num").desc(), col("letter").asc())

# Verify the DataFrame is sorted correctly
actual = df_sorted.collect()
expected = [(1, "1"), (1, "1"), (0, "0"), (0, "0")]
assert [(row.num, row.letter) for row in actual] == expected

0 comments on commit b919cd7

Please sign in to comment.