Skip to content

Commit

Permalink
feat(api): add ArrayIntersect operation and corresponding `ArrayVal…
Browse files Browse the repository at this point in the history
…ue.intersect` API
  • Loading branch information
cpcloud committed Aug 5, 2023
1 parent 9b0e6c8 commit 76c95b2
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 1 deletion.
7 changes: 7 additions & 0 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,13 @@ def _array_union(op, **kw):
return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw)


@translate_val.register(ops.ArrayIntersect)
def _array_intersect(op, **kw):
left = translate_val(op.left, **kw)
right = translate_val(op.right, **kw)
return f"arrayIntersect({left}, {right})"


@translate_val.register(ops.ArrayZip)
def _array_zip(op: ops.ArrayZip, **kw: Any) -> str:
arglist = []
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ def _array_filter(t, op):
)


def _array_intersect(t, op):
return t.translate(
ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x))
)


def _map_keys(t, op):
m = t.translate(op.arg)
return sa.cast(
Expand Down Expand Up @@ -466,6 +472,7 @@ def _try_cast(t, op):
ops.Median: reduction(sa.func.median),
ops.First: reduction(sa.func.first),
ops.Last: reduction(sa.func.last),
ops.ArrayIntersect: _array_intersect,
}
)

Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,5 +722,14 @@ def _array_sort(arg):
),
ops.Levenshtein: fixed_arity(sa.func.levenshtein, 2),
ops.ArraySort: fixed_arity(_array_sort, 1),
ops.ArrayIntersect: fixed_arity(
lambda left, right: sa.func.array(
sa.intersect(
sa.select(sa.func.unnest(left).column_valued()),
sa.select(sa.func.unnest(right).column_valued()),
).scalar_subquery()
),
2,
),
}
)
7 changes: 7 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,13 @@ def compile_array_union(t, op, **kwargs):
return F.array_union(left, right)


@compiles(ops.ArrayIntersect)
def compile_array_intersect(t, op, **kwargs):
left = t.translate(op.left, **kwargs)
right = t.translate(op.right, **kwargs)
return F.array_intersect(left, right)


@compiles(ops.Hash)
def compile_hash_column(t, op, **kwargs):
return F.hash(t.translate(op.arg, **kwargs))
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def _map_get(t, op):
2,
),
ops.ArrayRemove: fixed_arity(sa.func.array_remove, 2),
ops.ArrayIntersect: fixed_arity(sa.func.array_intersection, 2),
ops.StringSplit: fixed_arity(sa.func.split, 2),
# snowflake typeof only accepts VARIANT, so we cast
ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.func.to_variant(arg))),
Expand Down
22 changes: 22 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,28 @@ def test_array_union(con):
assert lhs == rhs, f"row {i:d} differs"


@pytest.mark.notimpl(
["bigquery", "dask", "datafusion", "impala", "mssql", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
)
def test_array_intersect(con):
t = ibis.memtable(
{"a": [[3, 2], [], []], "b": [[1, 3], [None], [5]], "c": range(3)}
)
expr = t.select("c", d=t.a.intersect(t.b)).order_by("c").drop("c").d
result = con.execute(expr).map(set, na_action="ignore")
expected = pd.Series([{3}, set(), set()], dtype="object")
assert len(result) == len(expected)

for i, (lhs, rhs) in enumerate(zip(result, expected)):
assert lhs == rhs, f"row {i:d} differs"


@unnest
@pytest.mark.notimpl(
["clickhouse"],
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ def _try_cast(t, op):
return try_cast(arg, type_=to)


def _array_intersect(t, op):
return t.translate(
ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x))
)


operation_registry.update(
{
# conditional expressions
Expand Down Expand Up @@ -474,6 +480,7 @@ def _try_cast(t, op):
1,
),
ops.Levenshtein: fixed_arity(sa.func.levenshtein_distance, 2),
ops.ArrayIntersect: _array_intersect,
}
)

Expand Down
9 changes: 9 additions & 0 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ class ArrayUnion(Value):
output_shape = rlz.shape_like("args")


@public
class ArrayIntersect(Value):
left = rlz.array
right = rlz.array

output_dtype = rlz.dtype_like("args")
output_shape = rlz.shape_like("args")


@public
class ArrayZip(Value):
arg = rlz.tuple_of(rlz.array, min_length=2)
Expand Down
43 changes: 42 additions & 1 deletion ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,48 @@ def union(self, other: ir.ArrayValue) -> ir.ArrayValue:
"""
return ops.ArrayUnion(self, other).to_expr()

def zip(self, other: ir.Array, *others: ir.Array) -> ir.Array:
def intersect(self, other: ArrayValue) -> ArrayValue:
"""Intersect two arrays.
Parameters
----------
other
Another array to intersect with `self`
Returns
-------
ArrayValue
Intersected arrays
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"arr1": [[3, 2], [], None], "arr2": [[1, 3], [None], [5]]})
>>> t
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ arr1 ┃ arr2 ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │ array<int64> │
├──────────────────────┼──────────────────────┤
│ [3, 2] │ [1, 3] │
│ [] │ [None] │
│ NULL │ [5] │
└──────────────────────┴──────────────────────┘
>>> t.arr1.intersect(t.arr2)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayIntersect(arr1, arr2) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├────────────────────────────┤
│ [3] │
│ [] │
│ NULL │
└────────────────────────────┘
"""
return ops.ArrayIntersect(self, other).to_expr()

def zip(self, other: ArrayValue, *others: ArrayValue) -> ArrayValue:
"""Zip two or more arrays together.
Parameters
Expand Down

0 comments on commit 76c95b2

Please sign in to comment.