diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 7ba22b9a1b80..d837b2bf4e42 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1256,6 +1256,56 @@ def execute_union(op, **kw): return result +@translate.register(ops.Intersection) +def execute_intersection(op, *, ctx, **kw): + left = gen_name("polars_intersect_left") + right = gen_name("polars_intersect_right") + + ctx.register_many( + frames={ + left: translate(op.left, ctx=ctx, **kw), + right: translate(op.right, ctx=ctx, **kw), + } + ) + + sql = ( + sg.select(STAR) + .from_(sg.to_identifier(left, quoted=True)) + .intersect(sg.select(STAR).from_(sg.to_identifier(right, quoted=True))) + ) + + result = ctx.execute(sql.sql(Polars), eager=False) + + if op.distinct is True: + return result.unique() + return result + + +@translate.register(ops.Difference) +def execute_difference(op, *, ctx, **kw): + left = gen_name("polars_diff_left") + right = gen_name("polars_diff_right") + + ctx.register_many( + frames={ + left: translate(op.left, ctx=ctx, **kw), + right: translate(op.right, ctx=ctx, **kw), + } + ) + + sql = ( + sg.select(STAR) + .from_(sg.to_identifier(left, quoted=True)) + .except_(sg.select(STAR).from_(sg.to_identifier(right, quoted=True))) + ) + + result = ctx.execute(sql.sql(Polars), eager=False) + + if op.distinct is True: + return result.unique() + return result + + @translate.register(ops.Hash) def execute_hash(op, **kw): # polars' hash() returns a uint64, but we want to return an int64 diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 64467b067012..c459a4055346 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -6,7 +6,6 @@ from pytest import param import ibis -import ibis.common.exceptions as com import ibis.expr.types as ir from ibis import _ from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError @@ -84,7 +83,6 @@ def test_union_mixed_distinct(backend, union_subsets): param(True, id="distinct"), ], ) -@pytest.mark.notimpl(["polars"]) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_intersect(backend, alltypes, df, distinct): a = alltypes.filter((_.id >= 5200) & (_.id <= 5210)) @@ -129,7 +127,6 @@ def test_intersect(backend, alltypes, df, distinct): param(True, id="distinct"), ], ) -@pytest.mark.notimpl(["polars"]) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_difference(backend, alltypes, df, distinct): a = alltypes.filter((_.id >= 5200) & (_.id <= 5210)) @@ -238,7 +235,6 @@ def test_top_level_union(backend, con, alltypes, distinct, ordered): ), ], ) -@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_top_level_intersect_difference( backend, con, alltypes, distinct, opname, expected, ordered