diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 173c34dbef0a..9c5c131118ff 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -420,6 +420,13 @@ def visit_IsInf(self, op, *, arg): def visit_ArrayIndex(self, op, *, arg, index): return self.f.array_element(arg, index + self.cast(index >= 0, op.index.dtype)) + @visit_node.register(ops.StringConcat) + def visit_StringConcat(self, op, *, arg): + any_args_null = (a.is_(NULL) for a in arg) + return self.if_( + sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg) + ) + @visit_node.register(ops.Arbitrary) @visit_node.register(ops.ArgMax) @visit_node.register(ops.ArgMin) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index 598018823aab..ee106e12001d 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -359,6 +359,10 @@ def visit_HexDigest(self, op, *, arg, how): else: raise NotImplementedError(f"No available hashing function for {how}") + @visit_node.register(ops.StringConcat) + def visit_StringConcat(self, op, *, arg): + return reduce(lambda x, y: sge.DPipe(this=x, expression=y), arg) + _SIMPLE_OPS = { ops.ArrayPosition: "list_indexof", diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 0be420b225b0..20fedd4dd193 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -2,6 +2,7 @@ from functools import singledispatchmethod +import sqlglot as sg import sqlglot.expressions as sge import ibis.common.exceptions as com @@ -101,6 +102,11 @@ def visit_StringContains(self, op, *, haystack, needle): def visit_ExtractSecond(self, op, *, arg): return self.f.floor(self.cast(self.f.extract(self.v.second, arg), op.dtype)) + @visit_node.register(ops.StringConcat) + def visit_StringConcat(self, op, *, arg): + any_args_null = (a.is_(NULL) for a in arg) + return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg)) + @visit_node.register(ops.AnalyticVectorizedUDF) @visit_node.register(ops.ApproxMedian) @visit_node.register(ops.Arbitrary) diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index cb74911200d7..388bbb5dea32 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -390,6 +390,11 @@ def visit_HexDigest(self, op, *, arg, how): ) ) + @visit_node.register(ops.StringConcat) + def visit_StringConcat(self, op, *, arg): + any_args_null = (a.is_(NULL) for a in arg) + return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg)) + @visit_node.register(ops.Any) @visit_node.register(ops.All) @visit_node.register(ops.ApproxMedian) diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py index 0f3e0472fac5..edab44f726a7 100644 --- a/ibis/backends/oracle/compiler.py +++ b/ibis/backends/oracle/compiler.py @@ -441,6 +441,11 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by): return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) + @visit_node.register(ops.StringConcat) + def visit_StringConcat(self, op, *, arg): + any_args_null = (a.is_(NULL) for a in arg) + return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg)) + @visit_node.register(ops.Arbitrary) @visit_node.register(ops.ArgMax) @visit_node.register(ops.ArgMin) diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index b421ea51050a..0c9155dc3f20 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -170,6 +170,10 @@ def _sign(value, dtype): self.cast(self.f.array(), op.dtype), ) + @visit_node.register(ops.StringConcat) + def visit_StringConcat(self, op, *, arg): + return reduce(lambda x, y: sge.DPipe(this=x, expression=y), arg) + @visit_node.register(ops.ArrayConcat) def visit_ArrayConcat(self, op, *, arg): return reduce(self.f.array_cat, map(partial(self.cast, to=op.dtype), arg)) diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_no_cart_join/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_no_cart_join/out.sql index 699792bed260..dbe05db44905 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_no_cart_join/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_no_cart_join/out.sql @@ -20,16 +20,13 @@ FROM ( "t1"."ancestor_level_number", "t1"."ancestor_node_sort_order", "t1"."descendant_node_natural_key", - CONCAT( - LPAD( - '-', - ( - "t1"."ancestor_level_number" - CAST(1 AS TINYINT) - ) * CAST(7 AS TINYINT), - '-' - ), - "t1"."ancestor_level_name" - ) AS "product_level_name" + LPAD( + '-', + ( + "t1"."ancestor_level_number" - CAST(1 AS TINYINT) + ) * CAST(7 AS TINYINT), + '-' + ) || "t1"."ancestor_level_name" AS "product_level_name" FROM "products" AS "t1" ) AS "t4" ON "t2"."product_id" = "t4"."descendant_node_natural_key" diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_order_by_expr/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_order_by_expr/out.sql index 1cd433cac8ee..076f3029f1a7 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_order_by_expr/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_order_by_expr/out.sql @@ -5,4 +5,4 @@ FROM "t" AS "t0" WHERE "t0"."a" = CAST(1 AS TINYINT) ORDER BY - CONCAT("t0"."b", 'a') ASC \ No newline at end of file + "t0"."b" || 'a' ASC \ No newline at end of file diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index ea954243f505..989d6dc393db 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -1,6 +1,8 @@ from __future__ import annotations import contextlib +from functools import reduce +from operator import add import numpy as np import pandas as pd @@ -1033,3 +1035,36 @@ def test_re_split_column_multiple_patterns(alltypes): ) result = expr.execute() assert all(not any(element) for element in result) + + +@pytest.mark.parametrize( + "fn", + [lambda n: n + "a", lambda n: n + n, lambda n: "a" + n], + ids=["null-a", "null-null", "a-null"], +) +@pytest.mark.notimpl(["pandas", "dask"], raises=TypeError) +def test_concat_with_null(con, fn): + null = ibis.literal(None, type="string") + expr = fn(null) + result = con.execute(expr) + assert pd.isna(result) + + +@pytest.mark.parametrize( + "args", + [ + param((ibis.literal(None, str), None), id="null-null"), + param((ibis.literal("abc"), None), id="abc-null"), + param((ibis.literal("abc"), ibis.literal(None, str)), id="abc-typed-null"), + param((ibis.literal("abc"), "def", None), id="abc-def-null"), + ], +) +@pytest.mark.parametrize( + "method", + [lambda args: args[0].concat(*args[1:]), lambda args: reduce(add, args)], + ids=["concat", "add"], +) +@pytest.mark.notimpl(["pandas", "dask"], raises=TypeError) +def test_concat(con, args, method): + expr = method(args) + assert pd.isna(con.execute(expr)) diff --git a/ibis/expr/types/strings.py b/ibis/expr/types/strings.py index 99d9959adc1e..10d3d6331e77 100644 --- a/ibis/expr/types/strings.py +++ b/ibis/expr/types/strings.py @@ -1490,6 +1490,8 @@ def split(self, delimiter: str | StringValue) -> ir.ArrayValue: def concat(self, other: str | StringValue, *args: str | StringValue) -> StringValue: """Concatenate strings. + NULLs are propagated. This methods is equivalent to using the `+` operator. + Parameters ---------- other @@ -1506,16 +1508,24 @@ def concat(self, other: str | StringValue, *args: str | StringValue) -> StringVa -------- >>> import ibis >>> ibis.options.interactive = True - >>> t = ibis.memtable({"s": ["abc", "bac", "bca"]}) - >>> t.s.concat("xyz") + >>> t = ibis.memtable({"s": ["abc", None]}) + >>> t.s.concat("xyz", "123") + ┏━━━━━━━━━━━━━━━━┓ + ┃ StringConcat() ┃ + ┡━━━━━━━━━━━━━━━━┩ + │ string │ + ├────────────────┤ + │ abcxyz123 │ + │ NULL │ + └────────────────┘ + >>> t.s + "xyz" ┏━━━━━━━━━━━━━━━━┓ ┃ StringConcat() ┃ ┡━━━━━━━━━━━━━━━━┩ │ string │ ├────────────────┤ │ abcxyz │ - │ bacxyz │ - │ bcaxyz │ + │ NULL │ └────────────────┘ """ return ops.StringConcat((self, other, *args)).to_expr()