Skip to content

Commit

Permalink
feat(api): add ArrayIntersection operation and corresponding API
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 4, 2023
1 parent df14997 commit 2e956a4
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 2 deletions.
7 changes: 7 additions & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ def _array_filter(t, op):
)


def _array_intersection(t, op):
return t.translate(
ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x))
)


def _map_keys(t, op):
m = t.translate(op.arg)
return sa.cast(
Expand Down Expand Up @@ -466,6 +472,7 @@ def _try_cast(t, op):
ops.Median: reduction(sa.func.median),
ops.First: reduction(sa.func.first),
ops.Last: reduction(sa.func.last),
ops.ArrayIntersection: _array_intersection,
}
)

Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,5 +716,14 @@ def _unnest(t, op):
lambda arg: sa.extract("microsecond", arg) % 1_000_000, 1
),
ops.Levenshtein: fixed_arity(sa.func.levenshtein, 2),
ops.ArrayIntersection: fixed_arity(
lambda left, right: sa.func.array(
sa.intersect(
sa.select(sa.func.unnest(left).column_valued()),
sa.select(sa.func.unnest(right).column_valued()),
).scalar_subquery()
),
2,
),
}
)
7 changes: 7 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,13 @@ def compile_array_union(t, op, **kwargs):
return F.array_union(left, right)


@compiles(ops.ArrayIntersection)
def compile_array_intersection(t, op, **kwargs):
left = t.translate(op.left, **kwargs)
right = t.translate(op.right, **kwargs)
return F.array_intersect(left, right)


@compiles(ops.Hash)
def compile_hash_column(t, op, **kwargs):
return F.hash(t.translate(op.arg, **kwargs))
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def _map_get(t, op):
lambda left, right: sa.func.array_distinct(sa.func.array_cat(left, right)),
2,
),
ops.ArrayIntersection: fixed_arity(sa.func.array_intersection, 2),
ops.StringSplit: fixed_arity(sa.func.split, 2),
# snowflake typeof only accepts VARIANT, so we cast
ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.func.to_variant(arg))),
Expand Down
23 changes: 22 additions & 1 deletion ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,6 @@ def test_array_position(backend, con):
"pandas",
"polars",
"postgres",
"snowflake",
],
raises=com.OperationNotDefinedError,
)
Expand Down Expand Up @@ -708,6 +707,28 @@ def test_array_union(con):
assert lhs == rhs, f"row {i:d} differs"


@pytest.mark.notimpl(
["bigquery", "dask", "datafusion", "impala", "mssql", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
)
def test_array_intersection(con):
t = ibis.memtable(
{"a": [[3, 2], [], []], "b": [[1, 3], [None], [5]], "c": range(3)}
)
expr = t.select("c", d=t.a.intersection(t.b)).order_by("c").drop("c").d
result = con.execute(expr).map(set, na_action="ignore")
expected = pd.Series([{3}, set(), set()], dtype="object")
assert len(result) == len(expected)

for i, (lhs, rhs) in enumerate(zip(result, expected)):
assert lhs == rhs, f"row {i:d} differs"


@unnest
@pytest.mark.notimpl(
["clickhouse"],
Expand Down
16 changes: 16 additions & 0 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,21 @@ def _try_cast(t, op):
return try_cast(arg, type_=to)


def _array_intersection(t, op):
return array_filter(
t.translate(op.left),
sa.literal_column("(x)"),
t.translate(
ops.ArrayContains(
op.right,
ops.Argument(
name="x", shape=op.left.output_shape, dtype=op.left.output_dtype
),
)
),
)


operation_registry.update(
{
# conditional expressions
Expand Down Expand Up @@ -474,6 +489,7 @@ def _try_cast(t, op):
1,
),
ops.Levenshtein: fixed_arity(sa.func.levenshtein_distance, 2),
ops.ArrayIntersection: _array_intersection,
}
)

Expand Down
9 changes: 9 additions & 0 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ class ArrayUnion(Value):
output_shape = rlz.shape_like("args")


@public
class ArrayIntersection(Value):
left = rlz.array
right = rlz.array

output_dtype = rlz.dtype_like("args")
output_shape = rlz.shape_like("args")


@public
class ArrayZip(Value):
arg = rlz.tuple_of(rlz.array, min_length=2)
Expand Down
43 changes: 42 additions & 1 deletion ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,48 @@ def union(self, other: ir.ArrayValue) -> ir.ArrayValue:
"""
return ops.ArrayUnion(self, other).to_expr()

def zip(self, other: ir.Array, *others: ir.Array) -> ir.Array:
def intersection(self, other: ArrayValue) -> ArrayValue:
"""Intersect two arrays.
Parameters
----------
other
Another array to intersect with `self`
Returns
-------
ArrayValue
Intersected arrays
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"arr1": [[3, 2], [], None], "arr2": [[1, 3], [None], [5]]})
>>> t
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ arr1 ┃ arr2 ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │ array<int64> │
├──────────────────────┼──────────────────────┤
│ [3, 2] │ [1, 3] │
│ [] │ [None] │
│ NULL │ [5] │
└──────────────────────┴──────────────────────┘
>>> t.arr1.intersection(t.arr2)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayIntersection(arr1, arr2) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├───────────────────────────────┤
│ [3] │
│ [] │
│ NULL │
└───────────────────────────────┘
"""
return ops.ArrayIntersection(self, other).to_expr()

def zip(self, other: ArrayValue, *others: ArrayValue) -> ArrayValue:
"""Zip two or more arrays together.
Parameters
Expand Down

0 comments on commit 2e956a4

Please sign in to comment.