Skip to content

Commit

Permalink
feat(polars): add Intersection and Difference ops (#10623)
Browse files Browse the repository at this point in the history
  • Loading branch information
IndexSeek authored Dec 30, 2024
1 parent 43069bd commit 69b848a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
50 changes: 50 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,56 @@ def execute_union(op, **kw):
return result


@translate.register(ops.Intersection)
def execute_intersection(op, *, ctx, **kw):
left = gen_name("polars_intersect_left")
right = gen_name("polars_intersect_right")

ctx.register_many(
frames={
left: translate(op.left, ctx=ctx, **kw),
right: translate(op.right, ctx=ctx, **kw),
}
)

sql = (
sg.select(STAR)
.from_(sg.to_identifier(left, quoted=True))
.intersect(sg.select(STAR).from_(sg.to_identifier(right, quoted=True)))
)

result = ctx.execute(sql.sql(Polars), eager=False)

if op.distinct is True:
return result.unique()
return result


@translate.register(ops.Difference)
def execute_difference(op, *, ctx, **kw):
left = gen_name("polars_diff_left")
right = gen_name("polars_diff_right")

ctx.register_many(
frames={
left: translate(op.left, ctx=ctx, **kw),
right: translate(op.right, ctx=ctx, **kw),
}
)

sql = (
sg.select(STAR)
.from_(sg.to_identifier(left, quoted=True))
.except_(sg.select(STAR).from_(sg.to_identifier(right, quoted=True)))
)

result = ctx.execute(sql.sql(Polars), eager=False)

if op.distinct is True:
return result.unique()
return result


@translate.register(ops.Hash)
def execute_hash(op, **kw):
# polars' hash() returns a uint64, but we want to return an int64
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pytest import param

import ibis
import ibis.common.exceptions as com
import ibis.expr.types as ir
from ibis import _
from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError
Expand Down Expand Up @@ -84,7 +83,6 @@ def test_union_mixed_distinct(backend, union_subsets):
param(True, id="distinct"),
],
)
@pytest.mark.notimpl(["polars"])
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_intersect(backend, alltypes, df, distinct):
a = alltypes.filter((_.id >= 5200) & (_.id <= 5210))
Expand Down Expand Up @@ -129,7 +127,6 @@ def test_intersect(backend, alltypes, df, distinct):
param(True, id="distinct"),
],
)
@pytest.mark.notimpl(["polars"])
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_difference(backend, alltypes, df, distinct):
a = alltypes.filter((_.id >= 5200) & (_.id <= 5210))
Expand Down Expand Up @@ -238,7 +235,6 @@ def test_top_level_union(backend, con, alltypes, distinct, ordered):
),
],
)
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_top_level_intersect_difference(
backend, con, alltypes, distinct, opname, expected, ordered
Expand Down

0 comments on commit 69b848a

Please sign in to comment.