diff --git a/polars/polars-sql/src/sql_expr.rs b/polars/polars-sql/src/sql_expr.rs index 1ab8757806d7..1700d2e7b9ab 100644 --- a/polars/polars-sql/src/sql_expr.rs +++ b/polars/polars-sql/src/sql_expr.rs @@ -75,6 +75,12 @@ impl SqlExprVisitor<'_> { list, negated, } => self.visit_is_in(expr, list, *negated), + SqlExpr::IsDistinctFrom(e1, e2) => { + Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?)) + } + SqlExpr::IsNotDistinctFrom(e1, e2) => { + Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?)) + } SqlExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))), SqlExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()), SqlExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()), @@ -171,6 +177,7 @@ impl SqlExprVisitor<'_> { SQLBinaryOperator::NotEq => left.eq(right).not(), SQLBinaryOperator::Or => left.or(right), SQLBinaryOperator::Plus => left + right, + SQLBinaryOperator::Spaceship => left.eq_missing(right), SQLBinaryOperator::StringConcat => { left.cast(DataType::Utf8) + right.cast(DataType::Utf8) } diff --git a/py-polars/tests/unit/test_sql.py b/py-polars/tests/unit/test_sql.py index 7d199e517d4d..00f0051d11e8 100644 --- a/py-polars/tests/unit/test_sql.py +++ b/py-polars/tests/unit/test_sql.py @@ -7,6 +7,7 @@ import pytest import polars as pl +import polars.selectors as cs from polars.testing import assert_frame_equal @@ -79,6 +80,39 @@ def test_sql_div() -> None: ) +def test_sql_equal_not_equal() -> None: + # validate null-aware/unaware equality comparisons + df = pl.DataFrame({"a": [1, None, 3, 6, 5], "b": [1, None, 3, 4, None]}) + + with pl.SQLContext(frame_data=df) as ctx: + out = ctx.execute( + """ + SELECT + -- not null-aware + (a = b) as "1_eq_unaware", + (a <> b) as "2_neq_unaware", + (a != b) as "3_neq_unaware", + -- null-aware + (a <=> b) as "4_eq_aware", + (a IS NOT DISTINCT FROM b) as "5_eq_aware", + (a IS DISTINCT FROM b) as "6_neq_aware", + FROM frame_data + """ + ).collect() + + assert out.select(cs.contains("_aware").null_count().sum()).row(0) == (0, 0, 0) + assert out.select(cs.contains("_unaware").null_count().sum()).row(0) == (2, 2, 2) + + assert out.to_dict(False) == { + "1_eq_unaware": [True, None, True, False, None], + "2_neq_unaware": [False, None, False, True, None], + "3_neq_unaware": [False, None, False, True, None], + "4_eq_aware": [True, True, True, False, False], + "5_eq_aware": [True, True, True, False, False], + "6_neq_aware": [False, False, False, True, True], + } + + def test_sql_groupby(foods_ipc_path: Path) -> None: lf = pl.scan_ipc(foods_ipc_path)