From 1e28de8fe7923c9c0032e5af4958d2bb2bda60eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sun, 4 Feb 2024 17:12:22 +0100 Subject: [PATCH] fix(ir): support converting limit(1) inputs to scalar subqueries --- docs/posts/ibis-duckdb-geospatial/index.qmd | 2 +- .../ibis-duckdb-geospatial/nyc_data.db.wal | 0 ibis/backends/base/sqlglot/rewrites.py | 8 +++ ibis/backends/pandas/rewrites.py | 12 +++++ ibis/expr/operations/relations.py | 40 ++++++++++++--- ibis/expr/rewrites.py | 30 +++++++++--- ibis/expr/tests/test_newrels.py | 49 +++++++++++++++++-- ibis/expr/types/arrays.py | 1 - ibis/expr/types/relations.py | 22 +++++++-- 9 files changed, 140 insertions(+), 24 deletions(-) create mode 100644 docs/posts/ibis-duckdb-geospatial/nyc_data.db.wal diff --git a/docs/posts/ibis-duckdb-geospatial/index.qmd b/docs/posts/ibis-duckdb-geospatial/index.qmd index 575b25fdbd576..bc3f793677fae 100644 --- a/docs/posts/ibis-duckdb-geospatial/index.qmd +++ b/docs/posts/ibis-duckdb-geospatial/index.qmd @@ -83,7 +83,7 @@ Notice that the last column has a `geometry` type, and in this case it contains each subway station. Let's grab the entry for the Broad St subway station. ```{python} -broad_station = subway_stations.filter(subway_stations.NAME == "Broad St") +broad_station = subway_stations.filter(subway_stations.NAME == "Broad St").limit(1) broad_station ``` diff --git a/docs/posts/ibis-duckdb-geospatial/nyc_data.db.wal b/docs/posts/ibis-duckdb-geospatial/nyc_data.db.wal new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ibis/backends/base/sqlglot/rewrites.py b/ibis/backends/base/sqlglot/rewrites.py index 05140a76158ca..2492f4bea9f69 100644 --- a/ibis/backends/base/sqlglot/rewrites.py +++ b/ibis/backends/base/sqlglot/rewrites.py @@ -44,6 +44,10 @@ def schema(self): def values(self): return self.parent.values + @attribute + def singlerow(self): + return self.parent.singlerow + @public class Select(ops.Relation): @@ -62,6 +66,10 @@ def values(self): def schema(self): return Schema({k: v.dtype for k, v in self.selections.items()}) + @attribute + def singlerow(self): + return self.parent.singlerow + @public class Window(ops.Value): diff --git a/ibis/backends/pandas/rewrites.py b/ibis/backends/pandas/rewrites.py index 63f93c830f2c4..aa0fb83c95ba7 100644 --- a/ibis/backends/pandas/rewrites.py +++ b/ibis/backends/pandas/rewrites.py @@ -45,6 +45,10 @@ def schema(self): {self.mapping[name]: dtype for name, dtype in self.parent.schema.items()} ) + @attribute + def singlerow(self): + return self.parent.singlerow + @public class PandasResetIndex(PandasRelation): @@ -58,6 +62,10 @@ def values(self): def schema(self): return self.parent.schema + @attribute + def singlerow(self): + return self.parent.singlerow + @public class PandasJoin(PandasRelation): @@ -97,6 +105,10 @@ def values(self): def schema(self): return Schema({k: v.dtype for k, v in self.values.items()}) + @attribute + def singlerow(self): + return not self.groups + @public class PandasLimit(PandasRelation): diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index d520a45276344..e1d1570ecb093 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -29,6 +29,10 @@ @public class Relation(Node, Coercible): + """Whether the relation is guaranteed to have a single row.""" + + singlerow = False + @classmethod def __coerce__(cls, value): from ibis.expr.types import TableExpr @@ -127,14 +131,15 @@ def dtype(self): @public class ScalarSubquery(Subquery): - def __init__(self, rel): - from ibis.expr.operations import Reduction + shape = ds.scalar + singlerow = True + def __init__(self, rel): super().__init__(rel=rel) - if not isinstance(self.value, Reduction): + if not rel.singlerow: raise IntegrityError( - f"Subquery {self.value!r} is not a reduction, only " - "reductions can be used as scalar subqueries" + "Scalar subquery must have a single row. Either use a reduction " + "or limit the number of rows in the subquery using `.limit(1)`" ) @@ -182,6 +187,11 @@ def __init__(self, parent, values): def schema(self): return Schema({k: v.dtype for k, v in self.values.items()}) + @attribute + def singlerow(self): + # TODO(kszucs): also check that values doesn't contain Unnest + return self.parent.singlerow + class Simple(Relation): parent: Relation @@ -194,6 +204,10 @@ def values(self): def schema(self): return self.parent.schema + @attribute + def singlerow(self): + return self.parent.singlerow + # TODO(kszucs): remove in favor of View @public @@ -282,10 +296,10 @@ class Filter(Simple): predicates: VarTuple[Value[dt.Boolean]] def __init__(self, parent, predicates): - from ibis.expr.rewrites import ReductionLike + from ibis.expr.rewrites import ScalarLike for pred in predicates: - if pred.find(ReductionLike, filter=Value): + if pred.find(ScalarLike, filter=Value): raise IntegrityError( f"Cannot add {pred!r} to filter, it is a reduction which " "must be converted to a scalar subquery first" @@ -304,6 +318,10 @@ class Limit(Simple): n: typing.Union[int, Scalar[dt.Integer], None] = None offset: typing.Union[int, Scalar[dt.Integer]] = 0 + @attribute + def singlerow(self): + return self.n == 1 + @public class Aggregate(Relation): @@ -328,6 +346,10 @@ def values(self): def schema(self): return Schema({k: v.dtype for k, v in self.values.items()}) + @attribute + def singlerow(self): + return not self.groups + @public class Set(Relation): @@ -438,6 +460,10 @@ class DummyTable(Relation): def schema(self): return Schema({k: v.dtype for k, v in self.values.items()}) + @attribute + def singlerow(self): + return all(value.shape.is_scalar() for value in self.values.values()) + @public class FillNa(Simple): diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 8b1ea87de95d8..6f9c3172c6edf 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -10,7 +10,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.deferred import Item, _, deferred, var -from ibis.common.exceptions import ExpressionError +from ibis.common.exceptions import ExpressionError, RelationError from ibis.common.patterns import Check, pattern, replace from ibis.util import Namespace @@ -98,8 +98,12 @@ def rewrite_sample(_): @replace(p.Analytic) def project_wrap_analytic(_, rel): - # Wrap analytic functions in a window function - return ops.WindowFunction(_, ops.RowsWindowFrame(rel)) + if _.relations == {rel}: + # Wrap analytic functions in a window function + return ops.WindowFunction(_, ops.RowsWindowFrame(rel)) + else: + # TODO(kszucs): cover this with tests + raise RelationError("Analytic function depends on multiple tables") @replace(p.Reduction) @@ -110,8 +114,9 @@ def project_wrap_reduction(_, rel): # it into a window function of `rel` return ops.WindowFunction(_, ops.RowsWindowFrame(rel)) else: - # 1. The reduction doesn't depend on any table, constructed from - # scalar values, so turn it into a scalar subquery. + # 1. The reduction doesn't depend only on `rel` but either constructed + # from scalar values or depends on other relations, so turn it into + # a scalar subquery. # 2. The reduction is originating from `rel` and other tables, # so this is a correlated scalar subquery. # 3. The reduction is originating entirely from other tables, @@ -119,21 +124,30 @@ def project_wrap_reduction(_, rel): return ops.ScalarSubquery(_.to_expr().as_table()) +@replace(p.Field(p.Relation(singlerow=True))) +def project_wrap_scalar_field(_, rel): + if _.relations == {rel}: + return _ + else: + return ops.ScalarSubquery(_.to_expr().as_table()) + + def rewrite_project_input(value, relation): # we need to detect reductions which are either turned into window functions # or scalar subqueries depending on whether they are originating from the # relation return value.replace( - project_wrap_analytic | project_wrap_reduction, + project_wrap_analytic | project_wrap_reduction | project_wrap_scalar_field, filter=p.Value & ~p.WindowFunction, context={"rel": relation}, ) -ReductionLike = p.Reduction | p.Field(p.Aggregate(groups={})) +ScalarLike = p.Reduction | p.Field(p.Relation(singlerow=True)) +# ReductionLike = p.Reduction | p.Field(p.Aggregate(groups={})) -@replace(ReductionLike) +@replace(ScalarLike) def filter_wrap_reduction(_): # Wrap reductions or fields referencing an aggregation without a group by - # which are scalar fields - in a scalar subquery. In the latter case we diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 24c9661d6e6cf..1699d04756fa5 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -147,10 +147,25 @@ def test_subquery_integrity_check(): with pytest.raises(IntegrityError, match=msg): ops.ScalarSubquery(t) - agg = t.agg(t.a.sum() + 1) - msg = "is not a reduction" + col = t.select(t.a) + msg = "Scalar subquery must have a single row" with pytest.raises(IntegrityError, match=msg): - ops.ScalarSubquery(agg) + ops.ScalarSubquery(col) + + agg = t.agg(t.a.sum() + 1) + sub = ops.ScalarSubquery(agg) + assert isinstance(sub, ops.ScalarSubquery) + assert sub.shape.is_scalar() + assert sub.dtype.is_int64() + + +# TODO(kszucs): raise a warning about deprecating the use of `to_array` +def test_value_to_array_creates_subquery(): + expr = t.int_col.sum().as_table().to_array() + op = expr.op() + assert op.shape.is_scalar() + assert op.dtype.is_int64() + assert isinstance(op, ops.ScalarSubquery) def test_select_turns_scalar_reduction_into_subquery(): @@ -162,7 +177,7 @@ def test_select_turns_scalar_reduction_into_subquery(): assert t1.op() == expected -def test_select_scalar_foreign_scalar_reduction_into_subquery(): +def test_select_turns_foreign_field_reduction_into_subquery(): t1 = t.filter(t.bool_col) t2 = t.select(summary=t1.int_col.sum()) subquery = ops.ScalarSubquery(t1.int_col.sum().as_table()) @@ -180,6 +195,32 @@ def test_select_turns_value_with_multiple_parents_into_subquery(): assert t1.op() == expected +def test_select_turns_singlerow_relation_field_into_scalar_subquery(): + v = ibis.table(name="v", schema={"a": "int64", "b": "string"}) + + # other is from the same relation + expr = t.select(t.int_col, other=v.limit(1).a) + expected = Project( + parent=t, + values={ + "int_col": t.int_col, + "other": ops.ScalarSubquery(v.limit(1).a.as_table()), + }, + ) + assert expr.op() == expected + + # other is from a different relation + expr = t.select(t.int_col, other=t.limit(1).int_col) + expected = Project( + parent=t, + values={ + "int_col": t.int_col, + "other": ops.ScalarSubquery(t.limit(1).int_col.as_table()), + }, + ) + assert expr.op() == expected + + def test_mutate(): proj = t.select(t, other=t.int_col + 1) expected = Project( diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 923e01f44dccd..b8ab3e26c29a5 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -6,7 +6,6 @@ import ibis.expr.operations as ops from ibis.expr.types.generic import Column, Scalar, Value, literal -from ibis.expr.types.typing import V from ibis.common.deferred import deferrable if TYPE_CHECKING: diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 7db999d5aeca3..c8a9030665064 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -1137,9 +1137,9 @@ def aggregate( metrics = unwrap_aliases(metrics) having = unwrap_aliases(having) - groups = dereference_values(self.op(), groups) - metrics = dereference_values(self.op(), metrics) - having = dereference_values(self.op(), having) + groups = dereference_values(node, groups) + metrics = dereference_values(node, metrics) + having = dereference_values(node, having) # the user doesn't need to specify the metrics used in the having clause # explicitly, we implicitly add them to the metrics list by looking for @@ -1816,6 +1816,22 @@ def intersect(self, table: Table, *rest: Table, distinct: bool = True) -> Table: node = ops.Intersection(node, table, distinct=distinct) return node.to_expr().select(self.columns) + def to_array(self) -> ir.Column: + """View a single column table as an array. + + Returns + ------- + Value + A single column view of a table + """ + schema = self.schema() + if len(schema) != 1: + raise com.ExpressionError( + "Table must have exactly one column when viewed as array" + ) + + return ops.ScalarSubquery(self).to_expr() + def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Table: """Add columns to a table expression.