Skip to content

Commit

Permalink
fix(structs): ensure that isin works with struct membership (#8978)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Apr 18, 2024
1 parent 98ef69c commit c0c508e
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 3 deletions.
10 changes: 10 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def _aggregate(self, funcname: str, *args, where):
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

def visit_InSubquery(self, op, *, rel, needle):
if op.needle.dtype.is_struct():
# construct the outer struct for pyspark
ident = sge.to_identifier(op.rel.schema.names[0], quoted=self.quoted)
needle = sge.Struct.from_arg_list(
[sge.PropertyEQ(this=ident, expression=needle)]
)

return super().visit_InSubquery(op, rel=rel, needle=needle)

def visit_NonNullLiteral(self, op, *, value, dtype):
if dtype.is_floating():
result = super().visit_NonNullLiteral(op, value=value, dtype=dtype)
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,10 @@ def visit_ExistsSubquery(self, op, *, rel):
return self.f.exists(select)

def visit_InSubquery(self, op, *, rel, needle):
return needle.isin(query=rel.this)
query = rel.this
if not isinstance(query, sge.Select):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@
try:
from polars import ComputeError as PolarsComputeError
from polars import PanicException as PolarsPanicException
from polars.exceptions import ColumnNotFoundError as PolarsColumnNotFoundError
from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError
from polars.exceptions import SchemaError as PolarsSchemaError
except ImportError:
PolarsComputeError = PolarsPanicException = PolarsInvalidOperationError = (
PolarsSchemaError
) = None
) = PolarsColumnNotFoundError = None

try:
from pyarrow import ArrowInvalid, ArrowNotImplementedError
Expand Down
30 changes: 29 additions & 1 deletion ibis/backends/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from ibis import util
from ibis.backends.tests.errors import (
ClickHouseDatabaseError,
PolarsColumnNotFoundError,
PsycoPg2InternalError,
PsycoPg2SyntaxError,
Py4JJavaError,
)
from ibis.common.exceptions import IbisError
from ibis.common.exceptions import IbisError, OperationNotDefinedError

pytestmark = [
pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"),
Expand Down Expand Up @@ -234,3 +235,30 @@ def test_keyword_fields(con, nullable):
finally:
with contextlib.suppress(NotImplementedError):
con.drop_table(name, force=True)


@pytest.mark.notyet(
["postgres"],
raises=PsycoPg2SyntaxError,
reason="sqlglot doesn't implement structs for postgres correctly",
)
@pytest.mark.notyet(
["risingwave"],
raises=PsycoPg2InternalError,
reason="sqlglot doesn't implement structs for postgres correctly",
)
@pytest.mark.notyet(
["polars"],
raises=PolarsColumnNotFoundError,
reason="doesn't seem to support IN-style subqueries on structs",
)
@pytest.mark.notimpl(["pandas", "dask"], raises=OperationNotDefinedError)
def test_isin_struct(con):
needle1 = ibis.struct({"x": 1, "y": 2})
needle2 = ibis.struct({"x": 2, "y": 3})
haystack_t = ibis.memtable({"xs": [1, 2, 3], "ys": [2, 3, 4]})
haystack = ibis.struct({"x": haystack_t.xs, "y": haystack_t.ys})
both = needle1.isin(haystack) | needle2.isin(haystack)
result = con.execute(both)
# TODO(cpcloud): ensure the type is consistent
assert result is True or result is np.bool_(True)
10 changes: 10 additions & 0 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
raise com.UnsupportedOperationError(f"{unit!r} unit is not supported")
return self.cast(res, op.dtype)

def visit_InSubquery(self, op, *, rel, needle):
# cast the needle to the same type as the column being queried, since
# trino is very strict about structs
if op.needle.dtype.is_struct():
needle = self.cast(
sge.Struct.from_arg_list([needle]), op.rel.schema.as_struct()
)

return super().visit_InSubquery(op, rel=rel, needle=needle)

def visit_StructColumn(self, op, *, names, values):
return self.cast(sge.Struct(expressions=list(values)), op.dtype)

Expand Down

0 comments on commit c0c508e

Please sign in to comment.