diff --git a/ibis/backends/datafusion/compiler/values.py b/ibis/backends/datafusion/compiler/values.py index d59822fc0abd..0df4a1ef8867 100644 --- a/ibis/backends/datafusion/compiler/values.py +++ b/ibis/backends/datafusion/compiler/values.py @@ -108,6 +108,7 @@ def translate_val(op, **_): ops.ArrayContains: "array_contains", ops.ArrayLength: "array_length", ops.ArrayRemove: "array_remove_all", + ops.StringLength: "length", } for _op, _name in _simple_ops.items(): @@ -303,7 +304,7 @@ def round(op, *, arg, digits, **_): @translate_val.register(ops.Substring) def substring(op, *, arg, start, length, **_): - start += 1 + start = if_(start < 0, F.length(arg) + start + 1, start + 1) if length is not None: return F.substr(arg, start, length) return F.substr(arg, start) @@ -776,3 +777,13 @@ def is_null(op, *, arg, **_): @translate_val.register(ops.IsNan) def is_nan(op, *, arg, **_): return F.isnan(F.coalesce(arg, sg.exp.Literal.number("'NaN'::double"))) + + +@translate_val.register(ops.ArrayStringJoin) +def array_string_join(op, *, sep, arg): + return F.array_join(arg, sep) + + +@translate_val.register(ops.FindInSet) +def array_string_find(op, *, needle, values): + return F.coalesce(F.array_position(F.make_array(*values), needle), 0) diff --git a/ibis/backends/datafusion/tests/test_string.py b/ibis/backends/datafusion/tests/test_string.py new file mode 100644 index 000000000000..1a927e3fede3 --- /dev/null +++ b/ibis/backends/datafusion/tests/test_string.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +import ibis + + +def test_string_length(con): + t = ibis.memtable({"s": ["aaa", "a", "aa"]}) + assert con.execute(t.s.length()).gt(0).all() diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index e2dea996bed1..884afe0b2cbc 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -466,7 +466,6 @@ def uses_java_re(t): pytest.mark.notimpl( [ "bigquery", - "datafusion", "pyspark", "sqlite", "snowflake", @@ -493,7 +492,6 @@ def uses_java_re(t): pytest.mark.notimpl( [ "bigquery", - "datafusion", "pyspark", "sqlite", "snowflake", @@ -703,9 +701,7 @@ def uses_java_re(t): id="negative-index", marks=[ pytest.mark.broken(["druid"], raises=sa.exc.ProgrammingError), - pytest.mark.broken( - ["datafusion", "impala", "flink"], raises=AssertionError - ), + pytest.mark.broken(["impala", "flink"], raises=AssertionError), pytest.mark.notimpl(["pyspark"], raises=NotImplementedError), ], ), @@ -1012,7 +1008,7 @@ def test_capitalize(con): @pytest.mark.notimpl( - ["dask", "datafusion", "pandas", "polars", "druid", "oracle", "flink"], + ["dask", "pandas", "polars", "druid", "oracle", "flink"], raises=OperationNotDefinedError, ) @pytest.mark.notyet(