Skip to content

Commit

Permalink
feat(python,rust,cli): add SQL support for null-aware equality checks (
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored and c-peters committed Jul 14, 2023
1 parent 43453f0 commit 41b3947
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
7 changes: 7 additions & 0 deletions polars/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ impl SqlExprVisitor<'_> {
list,
negated,
} => self.visit_is_in(expr, list, *negated),
SqlExpr::IsDistinctFrom(e1, e2) => {
Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
}
SqlExpr::IsNotDistinctFrom(e1, e2) => {
Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
}
SqlExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
SqlExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
SqlExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
Expand Down Expand Up @@ -171,6 +177,7 @@ impl SqlExprVisitor<'_> {
SQLBinaryOperator::NotEq => left.eq(right).not(),
SQLBinaryOperator::Or => left.or(right),
SQLBinaryOperator::Plus => left + right,
SQLBinaryOperator::Spaceship => left.eq_missing(right),
SQLBinaryOperator::StringConcat => {
left.cast(DataType::Utf8) + right.cast(DataType::Utf8)
}
Expand Down
34 changes: 34 additions & 0 deletions py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import polars as pl
import polars.selectors as cs
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -79,6 +80,39 @@ def test_sql_div() -> None:
)


def test_sql_equal_not_equal() -> None:
# validate null-aware/unaware equality comparisons
df = pl.DataFrame({"a": [1, None, 3, 6, 5], "b": [1, None, 3, 4, None]})

with pl.SQLContext(frame_data=df) as ctx:
out = ctx.execute(
"""
SELECT
-- not null-aware
(a = b) as "1_eq_unaware",
(a <> b) as "2_neq_unaware",
(a != b) as "3_neq_unaware",
-- null-aware
(a <=> b) as "4_eq_aware",
(a IS NOT DISTINCT FROM b) as "5_eq_aware",
(a IS DISTINCT FROM b) as "6_neq_aware",
FROM frame_data
"""
).collect()

assert out.select(cs.contains("_aware").null_count().sum()).row(0) == (0, 0, 0)
assert out.select(cs.contains("_unaware").null_count().sum()).row(0) == (2, 2, 2)

assert out.to_dict(False) == {
"1_eq_unaware": [True, None, True, False, None],
"2_neq_unaware": [False, None, False, True, None],
"3_neq_unaware": [False, None, False, True, None],
"4_eq_aware": [True, True, True, False, False],
"5_eq_aware": [True, True, True, False, False],
"6_neq_aware": [False, False, False, True, True],
}


def test_sql_groupby(foods_ipc_path: Path) -> None:
lf = pl.scan_ipc(foods_ipc_path)

Expand Down

0 comments on commit 41b3947

Please sign in to comment.