From 76c95b22b083d6f0f43d5e0119dc164d57557aef Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 5 Aug 2023 04:49:09 -0400 Subject: [PATCH] feat(api): add `ArrayIntersect` operation and corresponding `ArrayValue.intersect` API --- ibis/backends/clickhouse/compiler/values.py | 7 ++++ ibis/backends/duckdb/registry.py | 7 ++++ ibis/backends/postgres/registry.py | 9 +++++ ibis/backends/pyspark/compiler.py | 7 ++++ ibis/backends/snowflake/registry.py | 1 + ibis/backends/tests/test_array.py | 22 +++++++++++ ibis/backends/trino/registry.py | 7 ++++ ibis/expr/operations/arrays.py | 9 +++++ ibis/expr/types/arrays.py | 43 ++++++++++++++++++++- 9 files changed, 111 insertions(+), 1 deletion(-) diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index 44933eca4a32..39cc37f39f36 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -1414,6 +1414,13 @@ def _array_union(op, **kw): return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw) +@translate_val.register(ops.ArrayIntersect) +def _array_intersect(op, **kw): + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + return f"arrayIntersect({left}, {right})" + + @translate_val.register(ops.ArrayZip) def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: arglist = [] diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index ea756bd46544..6b7c123015b1 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -270,6 +270,12 @@ def _array_filter(t, op): ) +def _array_intersect(t, op): + return t.translate( + ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x)) + ) + + def _map_keys(t, op): m = t.translate(op.arg) return sa.cast( @@ -466,6 +472,7 @@ def _try_cast(t, op): ops.Median: reduction(sa.func.median), ops.First: reduction(sa.func.first), ops.Last: reduction(sa.func.last), + ops.ArrayIntersect: _array_intersect, } ) diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index fd8f542459ce..009564c29224 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -722,5 +722,14 @@ def _array_sort(arg): ), ops.Levenshtein: fixed_arity(sa.func.levenshtein, 2), ops.ArraySort: fixed_arity(_array_sort, 1), + ops.ArrayIntersect: fixed_arity( + lambda left, right: sa.func.array( + sa.intersect( + sa.select(sa.func.unnest(left).column_valued()), + sa.select(sa.func.unnest(right).column_valued()), + ).scalar_subquery() + ), + 2, + ), } ) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 282291efc760..ad29f26219c3 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -2051,6 +2051,13 @@ def compile_array_union(t, op, **kwargs): return F.array_union(left, right) +@compiles(ops.ArrayIntersect) +def compile_array_intersect(t, op, **kwargs): + left = t.translate(op.left, **kwargs) + right = t.translate(op.right, **kwargs) + return F.array_intersect(left, right) + + @compiles(ops.Hash) def compile_hash_column(t, op, **kwargs): return F.hash(t.translate(op.arg, **kwargs)) diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index 8ca1d436ed24..fb2d6384a5be 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -400,6 +400,7 @@ def _map_get(t, op): 2, ), ops.ArrayRemove: fixed_arity(sa.func.array_remove, 2), + ops.ArrayIntersect: fixed_arity(sa.func.array_intersection, 2), ops.StringSplit: fixed_arity(sa.func.split, 2), # snowflake typeof only accepts VARIANT, so we cast ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.func.to_variant(arg))), diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 42c6f00ae04c..4099a5d5be7b 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -697,6 +697,28 @@ def test_array_union(con): assert lhs == rhs, f"row {i:d} differs" +@pytest.mark.notimpl( + ["bigquery", "dask", "datafusion", "impala", "mssql", "pandas", "polars"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.notimpl( + ["sqlite", "mysql"], + raises=com.IbisTypeError, + reason="argument passes none of the following rules:....", +) +def test_array_intersect(con): + t = ibis.memtable( + {"a": [[3, 2], [], []], "b": [[1, 3], [None], [5]], "c": range(3)} + ) + expr = t.select("c", d=t.a.intersect(t.b)).order_by("c").drop("c").d + result = con.execute(expr).map(set, na_action="ignore") + expected = pd.Series([{3}, set(), set()], dtype="object") + assert len(result) == len(expected) + + for i, (lhs, rhs) in enumerate(zip(result, expected)): + assert lhs == rhs, f"row {i:d} differs" + + @unnest @pytest.mark.notimpl( ["clickhouse"], diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index ce14dae75fa8..9bb3301c23b3 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -289,6 +289,12 @@ def _try_cast(t, op): return try_cast(arg, type_=to) +def _array_intersect(t, op): + return t.translate( + ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x)) + ) + + operation_registry.update( { # conditional expressions @@ -474,6 +480,7 @@ def _try_cast(t, op): 1, ), ops.Levenshtein: fixed_arity(sa.func.levenshtein_distance, 2), + ops.ArrayIntersect: _array_intersect, } ) diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index ea7e30b7a801..628fb32a7cd2 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -173,6 +173,15 @@ class ArrayUnion(Value): output_shape = rlz.shape_like("args") +@public +class ArrayIntersect(Value): + left = rlz.array + right = rlz.array + + output_dtype = rlz.dtype_like("args") + output_shape = rlz.shape_like("args") + + @public class ArrayZip(Value): arg = rlz.tuple_of(rlz.array, min_length=2) diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index f09690fa753e..225f31fd41b3 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -725,7 +725,48 @@ def union(self, other: ir.ArrayValue) -> ir.ArrayValue: """ return ops.ArrayUnion(self, other).to_expr() - def zip(self, other: ir.Array, *others: ir.Array) -> ir.Array: + def intersect(self, other: ArrayValue) -> ArrayValue: + """Intersect two arrays. + + Parameters + ---------- + other + Another array to intersect with `self` + + Returns + ------- + ArrayValue + Intersected arrays + + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + >>> t = ibis.memtable({"arr1": [[3, 2], [], None], "arr2": [[1, 3], [None], [5]]}) + >>> t + ┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ arr1 ┃ arr2 ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ array │ + ├──────────────────────┼──────────────────────┤ + │ [3, 2] │ [1, 3] │ + │ [] │ [None] │ + │ NULL │ [5] │ + └──────────────────────┴──────────────────────┘ + >>> t.arr1.intersect(t.arr2) + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayIntersect(arr1, arr2) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├────────────────────────────┤ + │ [3] │ + │ [] │ + │ NULL │ + └────────────────────────────┘ + """ + return ops.ArrayIntersect(self, other).to_expr() + + def zip(self, other: ArrayValue, *others: ArrayValue) -> ArrayValue: """Zip two or more arrays together. Parameters