Skip to content

Commit

Permalink
feat(datafusion): add isnull and isnan operations
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Nov 5, 2023
1 parent 12118ad commit 0076c25
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
2 changes: 2 additions & 0 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sqlglot as sg
from sqlglot import exp, transforms
from sqlglot.dialects import Postgres
from sqlglot.dialects.dialect import rename_func

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -58,6 +59,7 @@ class Generator(Postgres.Generator):
transforms.eliminate_qualify,
]
),
exp.IsNan: rename_func("isnan"),
}


Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import math
import operator
from typing import Any

Expand All @@ -11,7 +12,6 @@
import ibis.expr.operations as ops
from ibis.backends.base.sqlglot import (
NULL,
STAR,
AggGen,
F,
interval,
Expand Down Expand Up @@ -202,6 +202,8 @@ def _literal(op, *, value, dtype, **kw):
elif dtype.is_string() or dtype.is_macaddr():
return sg.exp.convert(str(value))
elif dtype.is_numeric():
if isinstance(value, float) and math.isinf(value):
return sg.exp.Literal.number("'+Inf'::double")
return sg.exp.convert(value)
elif dtype.is_interval():
if dtype.unit.short in {"ms", "us", "ns"}:
Expand Down Expand Up @@ -324,7 +326,7 @@ def count_distinct(op, *, arg, where, **_):

@translate_val.register(ops.CountStar)
def count_star(op, *, where, **_):
return agg.count(STAR, where=where)
return agg.count(1, where=where)


@translate_val.register(ops.Sum)
Expand Down Expand Up @@ -764,3 +766,13 @@ def correlation(op, *, left, right, where, **_):
right = cast(right, dt.float64)

return agg["corr"](left, right, where=where)


@translate_val.register(ops.IsNull)
def is_null(op, *, arg, **_):
return arg.is_(NULL)


@translate_val.register(ops.IsNan)
def is_nan(op, *, arg, **_):
return F.isnan(F.coalesce(arg, sg.exp.Literal.number("'NaN'::double")))
17 changes: 3 additions & 14 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def test_scalar_fillna_nullif(con, expr, expected):
param(
"nan_col",
_.nan_col.isnan(),
marks=pytest.mark.notimpl(["datafusion", "mysql", "sqlite"]),
marks=pytest.mark.notimpl(["mysql", "sqlite"]),
id="nan_col",
),
param(
"none_col",
_.none_col.isnull(),
marks=[pytest.mark.notimpl(["datafusion", "mysql"])],
marks=[pytest.mark.notimpl(["mysql"])],
id="none_col",
),
],
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_case_where(backend, alltypes, df):


# TODO: some of these are notimpl (datafusion) others are probably never
@pytest.mark.notimpl(["datafusion", "mysql", "sqlite", "mssql", "druid", "oracle"])
@pytest.mark.notimpl(["mysql", "sqlite", "mssql", "druid", "oracle"])
@pytest.mark.notyet(["flink"], "NaN is not supported in Flink SQL", raises=ValueError)
def test_select_filter_mutate(backend, alltypes, df):
"""Test that select, filter and mutate are executed in right order.
Expand Down Expand Up @@ -565,7 +565,6 @@ def test_order_by_random(alltypes):
raises=sa.exc.ProgrammingError,
reason="Druid only supports trivial unions",
)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_table_info(alltypes):
expr = alltypes.info()
df = expr.execute()
Expand Down Expand Up @@ -1372,11 +1371,6 @@ def test_try_cast_func(con, from_val, to_type, func):
raises=BadRequest,
reason="bigquery doesn't support OFFSET without LIMIT",
),
pytest.mark.notyet(
["datafusion"],
raises=AssertionError,
reason="no support for offset yet",
),
pytest.mark.notyet(
["mssql"],
raises=sa.exc.CompileError,
Expand All @@ -1401,11 +1395,6 @@ def test_try_cast_func(con, from_val, to_type, func):
lambda _: 1,
id="[3:4]",
marks=[
pytest.mark.notyet(
["datafusion"],
raises=AssertionError,
reason="no support for offset yet",
),
pytest.mark.notyet(
["mssql"],
raises=sa.exc.CompileError,
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,11 +753,17 @@ def test_decimal_literal(con, backend, expr, expected_types, expected_result):
operator.methodcaller("isinf"),
np.isinf,
id="isinf",
marks=[
pytest.mark.notimpl(
["datafusion"],
raises=com.OperationNotDefinedError,
)
],
),
],
)
@pytest.mark.notimpl(
["mysql", "sqlite", "datafusion", "mssql", "oracle", "flink"],
["mysql", "sqlite", "mssql", "oracle", "flink"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.xfail(
Expand Down

0 comments on commit 0076c25

Please sign in to comment.