Skip to content

Commit

Permalink
fix(backends): make string concat-with-null behavior consistent acros…
Browse files Browse the repository at this point in the history
…s backends (ibis-project#8305)

Fixes ibis-project#8302.
  • Loading branch information
cpcloud committed Feb 12, 2024
1 parent 0a4ba95 commit cc2a395
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 15 deletions.
7 changes: 7 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,13 @@ def visit_IsInf(self, op, *, arg):
def visit_ArrayIndex(self, op, *, arg, index):
return self.f.array_element(arg, index + self.cast(index >= 0, op.index.dtype))

@visit_node.register(ops.StringConcat)
def visit_StringConcat(self, op, *, arg):
any_args_null = (a.is_(NULL) for a in arg)
return self.if_(
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

@visit_node.register(ops.Arbitrary)
@visit_node.register(ops.ArgMax)
@visit_node.register(ops.ArgMin)
Expand Down
4 changes: 4 additions & 0 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def visit_HexDigest(self, op, *, arg, how):
else:
raise NotImplementedError(f"No available hashing function for {how}")

@visit_node.register(ops.StringConcat)
def visit_StringConcat(self, op, *, arg):
return reduce(lambda x, y: sge.DPipe(this=x, expression=y), arg)


_SIMPLE_OPS = {
ops.ArrayPosition: "list_indexof",
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from functools import singledispatchmethod

import sqlglot as sg
import sqlglot.expressions as sge

import ibis.common.exceptions as com
Expand Down Expand Up @@ -101,6 +102,11 @@ def visit_StringContains(self, op, *, haystack, needle):
def visit_ExtractSecond(self, op, *, arg):
return self.f.floor(self.cast(self.f.extract(self.v.second, arg), op.dtype))

@visit_node.register(ops.StringConcat)
def visit_StringConcat(self, op, *, arg):
any_args_null = (a.is_(NULL) for a in arg)
return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg))

@visit_node.register(ops.AnalyticVectorizedUDF)
@visit_node.register(ops.ApproxMedian)
@visit_node.register(ops.Arbitrary)
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,11 @@ def visit_HexDigest(self, op, *, arg, how):
)
)

@visit_node.register(ops.StringConcat)
def visit_StringConcat(self, op, *, arg):
any_args_null = (a.is_(NULL) for a in arg)
return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg))

@visit_node.register(ops.Any)
@visit_node.register(ops.All)
@visit_node.register(ops.ApproxMedian)
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by):

return sge.Window(this=func, partition_by=group_by, order=order, spec=spec)

@visit_node.register(ops.StringConcat)
def visit_StringConcat(self, op, *, arg):
any_args_null = (a.is_(NULL) for a in arg)
return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg))

@visit_node.register(ops.Arbitrary)
@visit_node.register(ops.ArgMax)
@visit_node.register(ops.ArgMin)
Expand Down
4 changes: 4 additions & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def _sign(value, dtype):
self.cast(self.f.array(), op.dtype),
)

@visit_node.register(ops.StringConcat)
def visit_StringConcat(self, op, *, arg):
return reduce(lambda x, y: sge.DPipe(this=x, expression=y), arg)

@visit_node.register(ops.ArrayConcat)
def visit_ArrayConcat(self, op, *, arg):
return reduce(self.f.array_cat, map(partial(self.cast, to=op.dtype), arg))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@ FROM (
"t1"."ancestor_level_number",
"t1"."ancestor_node_sort_order",
"t1"."descendant_node_natural_key",
CONCAT(
LPAD(
'-',
(
"t1"."ancestor_level_number" - CAST(1 AS TINYINT)
) * CAST(7 AS TINYINT),
'-'
),
"t1"."ancestor_level_name"
) AS "product_level_name"
LPAD(
'-',
(
"t1"."ancestor_level_number" - CAST(1 AS TINYINT)
) * CAST(7 AS TINYINT),
'-'
) || "t1"."ancestor_level_name" AS "product_level_name"
FROM "products" AS "t1"
) AS "t4"
ON "t2"."product_id" = "t4"."descendant_node_natural_key"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ FROM "t" AS "t0"
WHERE
"t0"."a" = CAST(1 AS TINYINT)
ORDER BY
CONCAT("t0"."b", 'a') ASC
"t0"."b" || 'a' ASC
35 changes: 35 additions & 0 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import contextlib
from functools import reduce
from operator import add

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1033,3 +1035,36 @@ def test_re_split_column_multiple_patterns(alltypes):
)
result = expr.execute()
assert all(not any(element) for element in result)


@pytest.mark.parametrize(
"fn",
[lambda n: n + "a", lambda n: n + n, lambda n: "a" + n],
ids=["null-a", "null-null", "a-null"],
)
@pytest.mark.notimpl(["pandas", "dask"], raises=TypeError)
def test_concat_with_null(con, fn):
null = ibis.literal(None, type="string")
expr = fn(null)
result = con.execute(expr)
assert pd.isna(result)


@pytest.mark.parametrize(
"args",
[
param((ibis.literal(None, str), None), id="null-null"),
param((ibis.literal("abc"), None), id="abc-null"),
param((ibis.literal("abc"), ibis.literal(None, str)), id="abc-typed-null"),
param((ibis.literal("abc"), "def", None), id="abc-def-null"),
],
)
@pytest.mark.parametrize(
"method",
[lambda args: args[0].concat(*args[1:]), lambda args: reduce(add, args)],
ids=["concat", "add"],
)
@pytest.mark.notimpl(["pandas", "dask"], raises=TypeError)
def test_concat(con, args, method):
expr = method(args)
assert pd.isna(con.execute(expr))
18 changes: 14 additions & 4 deletions ibis/expr/types/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,8 @@ def split(self, delimiter: str | StringValue) -> ir.ArrayValue:
def concat(self, other: str | StringValue, *args: str | StringValue) -> StringValue:
"""Concatenate strings.
NULLs are propagated. This methods is equivalent to using the `+` operator.
Parameters
----------
other
Expand All @@ -1506,16 +1508,24 @@ def concat(self, other: str | StringValue, *args: str | StringValue) -> StringVa
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"s": ["abc", "bac", "bca"]})
>>> t.s.concat("xyz")
>>> t = ibis.memtable({"s": ["abc", None]})
>>> t.s.concat("xyz", "123")
┏━━━━━━━━━━━━━━━━┓
┃ StringConcat() ┃
┡━━━━━━━━━━━━━━━━┩
│ string │
├────────────────┤
│ abcxyz123 │
│ NULL │
└────────────────┘
>>> t.s + "xyz"
┏━━━━━━━━━━━━━━━━┓
┃ StringConcat() ┃
┡━━━━━━━━━━━━━━━━┩
│ string │
├────────────────┤
│ abcxyz │
│ bacxyz │
│ bcaxyz │
│ NULL │
└────────────────┘
"""
return ops.StringConcat((self, other, *args)).to_expr()
Expand Down

0 comments on commit cc2a395

Please sign in to comment.