diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index e23673dffac2..d8e5323093b9 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -90,6 +90,16 @@ def _aggregate(self, funcname: str, *args, where): args = tuple(self.if_(where, arg, NULL) for arg in args) return func(*args) + def visit_InSubquery(self, op, *, rel, needle): + if op.needle.dtype.is_struct(): + # construct the outer struct for pyspark + ident = sge.to_identifier(op.rel.schema.names[0], quoted=self.quoted) + needle = sge.Struct.from_arg_list( + [sge.PropertyEQ(this=ident, expression=needle)] + ) + + return super().visit_InSubquery(op, rel=rel, needle=needle) + def visit_NonNullLiteral(self, op, *, value, dtype): if dtype.is_floating(): result = super().visit_NonNullLiteral(op, value=value, dtype=dtype) diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index a7c8fc7b5b6c..71d57b7481aa 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -940,7 +940,10 @@ def visit_ExistsSubquery(self, op, *, rel): return self.f.exists(select) def visit_InSubquery(self, op, *, rel, needle): - return needle.isin(query=rel.this) + query = rel.this + if not isinstance(query, sge.Select): + query = sg.select(STAR).from_(query) + return needle.isin(query=query) def visit_Array(self, op, *, exprs): return self.f.array(*exprs) diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index 29023d861795..b11cdf66a9c5 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -61,12 +61,13 @@ try: from polars import ComputeError as PolarsComputeError from polars import PanicException as PolarsPanicException + from polars.exceptions import ColumnNotFoundError as PolarsColumnNotFoundError from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError from polars.exceptions import SchemaError as PolarsSchemaError except ImportError: PolarsComputeError = PolarsPanicException = PolarsInvalidOperationError = ( PolarsSchemaError - ) = None + ) = PolarsColumnNotFoundError = None try: from pyarrow import ArrowInvalid, ArrowNotImplementedError diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 2ba5d1e95c89..f996e9cb6c2e 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -15,11 +15,12 @@ from ibis import util from ibis.backends.tests.errors import ( ClickHouseDatabaseError, + PolarsColumnNotFoundError, PsycoPg2InternalError, PsycoPg2SyntaxError, Py4JJavaError, ) -from ibis.common.exceptions import IbisError +from ibis.common.exceptions import IbisError, OperationNotDefinedError pytestmark = [ pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"), @@ -234,3 +235,30 @@ def test_keyword_fields(con, nullable): finally: with contextlib.suppress(NotImplementedError): con.drop_table(name, force=True) + + +@pytest.mark.notyet( + ["postgres"], + raises=PsycoPg2SyntaxError, + reason="sqlglot doesn't implement structs for postgres correctly", +) +@pytest.mark.notyet( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="sqlglot doesn't implement structs for postgres correctly", +) +@pytest.mark.notyet( + ["polars"], + raises=PolarsColumnNotFoundError, + reason="doesn't seem to support IN-style subqueries on structs", +) +@pytest.mark.notimpl(["pandas", "dask"], raises=OperationNotDefinedError) +def test_isin_struct(con): + needle1 = ibis.struct({"x": 1, "y": 2}) + needle2 = ibis.struct({"x": 2, "y": 3}) + haystack_t = ibis.memtable({"xs": [1, 2, 3], "ys": [2, 3, 4]}) + haystack = ibis.struct({"x": haystack_t.xs, "y": haystack_t.ys}) + both = needle1.isin(haystack) | needle2.isin(haystack) + result = con.execute(both) + # TODO(cpcloud): ensure the type is consistent + assert result is True or result is np.bool_(True) diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index ed2136574e39..d2511246fada 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -282,6 +282,16 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): raise com.UnsupportedOperationError(f"{unit!r} unit is not supported") return self.cast(res, op.dtype) + def visit_InSubquery(self, op, *, rel, needle): + # cast the needle to the same type as the column being queried, since + # trino is very strict about structs + if op.needle.dtype.is_struct(): + needle = self.cast( + sge.Struct.from_arg_list([needle]), op.rel.schema.as_struct() + ) + + return super().visit_InSubquery(op, rel=rel, needle=needle) + def visit_StructColumn(self, op, *, names, values): return self.cast(sge.Struct(expressions=list(values)), op.dtype)