Skip to content

Commit

Permalink
feat(datafusion): add StringLength, FindInSet, ArrayStringJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and gforsyth committed Nov 6, 2023
1 parent d87bd8f commit fd03831
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
13 changes: 12 additions & 1 deletion ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def translate_val(op, **_):
ops.ArrayContains: "array_contains",
ops.ArrayLength: "array_length",
ops.ArrayRemove: "array_remove_all",
ops.StringLength: "length",
}

for _op, _name in _simple_ops.items():
Expand Down Expand Up @@ -303,7 +304,7 @@ def round(op, *, arg, digits, **_):

@translate_val.register(ops.Substring)
def substring(op, *, arg, start, length, **_):
start += 1
start = if_(start < 0, F.length(arg) + start + 1, start + 1)
if length is not None:
return F.substr(arg, start, length)
return F.substr(arg, start)
Expand Down Expand Up @@ -776,3 +777,13 @@ def is_null(op, *, arg, **_):
@translate_val.register(ops.IsNan)
def is_nan(op, *, arg, **_):
return F.isnan(F.coalesce(arg, sg.exp.Literal.number("'NaN'::double")))


@translate_val.register(ops.ArrayStringJoin)
def array_string_join(op, *, sep, arg):
return F.array_join(arg, sep)


@translate_val.register(ops.FindInSet)
def array_string_find(op, *, needle, values):
return F.coalesce(F.array_position(F.make_array(*values), needle), 0)
8 changes: 8 additions & 0 deletions ibis/backends/datafusion/tests/test_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

import ibis


def test_string_length(con):
t = ibis.memtable({"s": ["aaa", "a", "aa"]})
assert con.execute(t.s.length()).gt(0).all()
8 changes: 2 additions & 6 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ def uses_java_re(t):
pytest.mark.notimpl(
[
"bigquery",
"datafusion",
"pyspark",
"sqlite",
"snowflake",
Expand All @@ -493,7 +492,6 @@ def uses_java_re(t):
pytest.mark.notimpl(
[
"bigquery",
"datafusion",
"pyspark",
"sqlite",
"snowflake",
Expand Down Expand Up @@ -703,9 +701,7 @@ def uses_java_re(t):
id="negative-index",
marks=[
pytest.mark.broken(["druid"], raises=sa.exc.ProgrammingError),
pytest.mark.broken(
["datafusion", "impala", "flink"], raises=AssertionError
),
pytest.mark.broken(["impala", "flink"], raises=AssertionError),
pytest.mark.notimpl(["pyspark"], raises=NotImplementedError),
],
),
Expand Down Expand Up @@ -1012,7 +1008,7 @@ def test_capitalize(con):


@pytest.mark.notimpl(
["dask", "datafusion", "pandas", "polars", "druid", "oracle", "flink"],
["dask", "pandas", "polars", "druid", "oracle", "flink"],
raises=OperationNotDefinedError,
)
@pytest.mark.notyet(
Expand Down

0 comments on commit fd03831

Please sign in to comment.