Skip to content

Commit

Permalink
fix(ir): support converting limit(1) inputs to scalar subqueries
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Feb 5, 2024
1 parent e68000c commit 1e28de8
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/posts/ibis-duckdb-geospatial/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Notice that the last column has a `geometry` type, and in this case it contains
each subway station. Let's grab the entry for the Broad St subway station.

```{python}
broad_station = subway_stations.filter(subway_stations.NAME == "Broad St")
broad_station = subway_stations.filter(subway_stations.NAME == "Broad St").limit(1)
broad_station
```

Expand Down
Empty file.
8 changes: 8 additions & 0 deletions ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def schema(self):
def values(self):
return self.parent.values

@attribute
def singlerow(self):
return self.parent.singlerow


@public
class Select(ops.Relation):
Expand All @@ -62,6 +66,10 @@ def values(self):
def schema(self):
return Schema({k: v.dtype for k, v in self.selections.items()})

@attribute
def singlerow(self):
return self.parent.singlerow


@public
class Window(ops.Value):
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/pandas/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def schema(self):
{self.mapping[name]: dtype for name, dtype in self.parent.schema.items()}
)

@attribute
def singlerow(self):
return self.parent.singlerow


@public
class PandasResetIndex(PandasRelation):
Expand All @@ -58,6 +62,10 @@ def values(self):
def schema(self):
return self.parent.schema

@attribute
def singlerow(self):
return self.parent.singlerow


@public
class PandasJoin(PandasRelation):
Expand Down Expand Up @@ -97,6 +105,10 @@ def values(self):
def schema(self):
return Schema({k: v.dtype for k, v in self.values.items()})

@attribute
def singlerow(self):
return not self.groups


@public
class PandasLimit(PandasRelation):
Expand Down
40 changes: 33 additions & 7 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@

@public
class Relation(Node, Coercible):
"""Whether the relation is guaranteed to have a single row."""

singlerow = False

@classmethod
def __coerce__(cls, value):
from ibis.expr.types import TableExpr
Expand Down Expand Up @@ -127,14 +131,15 @@ def dtype(self):

@public
class ScalarSubquery(Subquery):
def __init__(self, rel):
from ibis.expr.operations import Reduction
shape = ds.scalar
singlerow = True

def __init__(self, rel):
super().__init__(rel=rel)
if not isinstance(self.value, Reduction):
if not rel.singlerow:
raise IntegrityError(
f"Subquery {self.value!r} is not a reduction, only "
"reductions can be used as scalar subqueries"
"Scalar subquery must have a single row. Either use a reduction "
"or limit the number of rows in the subquery using `.limit(1)`"
)


Expand Down Expand Up @@ -182,6 +187,11 @@ def __init__(self, parent, values):
def schema(self):
return Schema({k: v.dtype for k, v in self.values.items()})

@attribute
def singlerow(self):
# TODO(kszucs): also check that values doesn't contain Unnest
return self.parent.singlerow


class Simple(Relation):
parent: Relation
Expand All @@ -194,6 +204,10 @@ def values(self):
def schema(self):
return self.parent.schema

@attribute
def singlerow(self):
return self.parent.singlerow


# TODO(kszucs): remove in favor of View
@public
Expand Down Expand Up @@ -282,10 +296,10 @@ class Filter(Simple):
predicates: VarTuple[Value[dt.Boolean]]

def __init__(self, parent, predicates):
from ibis.expr.rewrites import ReductionLike
from ibis.expr.rewrites import ScalarLike

for pred in predicates:
if pred.find(ReductionLike, filter=Value):
if pred.find(ScalarLike, filter=Value):
raise IntegrityError(
f"Cannot add {pred!r} to filter, it is a reduction which "
"must be converted to a scalar subquery first"
Expand All @@ -304,6 +318,10 @@ class Limit(Simple):
n: typing.Union[int, Scalar[dt.Integer], None] = None
offset: typing.Union[int, Scalar[dt.Integer]] = 0

@attribute
def singlerow(self):
return self.n == 1


@public
class Aggregate(Relation):
Expand All @@ -328,6 +346,10 @@ def values(self):
def schema(self):
return Schema({k: v.dtype for k, v in self.values.items()})

@attribute
def singlerow(self):
return not self.groups


@public
class Set(Relation):
Expand Down Expand Up @@ -438,6 +460,10 @@ class DummyTable(Relation):
def schema(self):
return Schema({k: v.dtype for k, v in self.values.items()})

@attribute
def singlerow(self):
return all(value.shape.is_scalar() for value in self.values.values())


@public
class FillNa(Simple):
Expand Down
30 changes: 22 additions & 8 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.deferred import Item, _, deferred, var
from ibis.common.exceptions import ExpressionError
from ibis.common.exceptions import ExpressionError, RelationError
from ibis.common.patterns import Check, pattern, replace
from ibis.util import Namespace

Expand Down Expand Up @@ -98,8 +98,12 @@ def rewrite_sample(_):

@replace(p.Analytic)
def project_wrap_analytic(_, rel):
# Wrap analytic functions in a window function
return ops.WindowFunction(_, ops.RowsWindowFrame(rel))
if _.relations == {rel}:
# Wrap analytic functions in a window function
return ops.WindowFunction(_, ops.RowsWindowFrame(rel))
else:
# TODO(kszucs): cover this with tests
raise RelationError("Analytic function depends on multiple tables")


@replace(p.Reduction)
Expand All @@ -110,30 +114,40 @@ def project_wrap_reduction(_, rel):
# it into a window function of `rel`
return ops.WindowFunction(_, ops.RowsWindowFrame(rel))
else:
# 1. The reduction doesn't depend on any table, constructed from
# scalar values, so turn it into a scalar subquery.
# 1. The reduction doesn't depend only on `rel` but either constructed
# from scalar values or depends on other relations, so turn it into
# a scalar subquery.
# 2. The reduction is originating from `rel` and other tables,
# so this is a correlated scalar subquery.
# 3. The reduction is originating entirely from other tables,
# so this is an uncorrelated scalar subquery.
return ops.ScalarSubquery(_.to_expr().as_table())


@replace(p.Field(p.Relation(singlerow=True)))
def project_wrap_scalar_field(_, rel):
if _.relations == {rel}:
return _
else:
return ops.ScalarSubquery(_.to_expr().as_table())


def rewrite_project_input(value, relation):
# we need to detect reductions which are either turned into window functions
# or scalar subqueries depending on whether they are originating from the
# relation
return value.replace(
project_wrap_analytic | project_wrap_reduction,
project_wrap_analytic | project_wrap_reduction | project_wrap_scalar_field,
filter=p.Value & ~p.WindowFunction,
context={"rel": relation},
)


ReductionLike = p.Reduction | p.Field(p.Aggregate(groups={}))
ScalarLike = p.Reduction | p.Field(p.Relation(singlerow=True))
# ReductionLike = p.Reduction | p.Field(p.Aggregate(groups={}))


@replace(ReductionLike)
@replace(ScalarLike)
def filter_wrap_reduction(_):
# Wrap reductions or fields referencing an aggregation without a group by -
# which are scalar fields - in a scalar subquery. In the latter case we
Expand Down
49 changes: 45 additions & 4 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,25 @@ def test_subquery_integrity_check():
with pytest.raises(IntegrityError, match=msg):
ops.ScalarSubquery(t)

agg = t.agg(t.a.sum() + 1)
msg = "is not a reduction"
col = t.select(t.a)
msg = "Scalar subquery must have a single row"
with pytest.raises(IntegrityError, match=msg):
ops.ScalarSubquery(agg)
ops.ScalarSubquery(col)

agg = t.agg(t.a.sum() + 1)
sub = ops.ScalarSubquery(agg)
assert isinstance(sub, ops.ScalarSubquery)
assert sub.shape.is_scalar()
assert sub.dtype.is_int64()


# TODO(kszucs): raise a warning about deprecating the use of `to_array`
def test_value_to_array_creates_subquery():
expr = t.int_col.sum().as_table().to_array()
op = expr.op()
assert op.shape.is_scalar()
assert op.dtype.is_int64()
assert isinstance(op, ops.ScalarSubquery)


def test_select_turns_scalar_reduction_into_subquery():
Expand All @@ -162,7 +177,7 @@ def test_select_turns_scalar_reduction_into_subquery():
assert t1.op() == expected


def test_select_scalar_foreign_scalar_reduction_into_subquery():
def test_select_turns_foreign_field_reduction_into_subquery():
t1 = t.filter(t.bool_col)
t2 = t.select(summary=t1.int_col.sum())
subquery = ops.ScalarSubquery(t1.int_col.sum().as_table())
Expand All @@ -180,6 +195,32 @@ def test_select_turns_value_with_multiple_parents_into_subquery():
assert t1.op() == expected


def test_select_turns_singlerow_relation_field_into_scalar_subquery():
v = ibis.table(name="v", schema={"a": "int64", "b": "string"})

# other is from the same relation
expr = t.select(t.int_col, other=v.limit(1).a)
expected = Project(
parent=t,
values={
"int_col": t.int_col,
"other": ops.ScalarSubquery(v.limit(1).a.as_table()),
},
)
assert expr.op() == expected

# other is from a different relation
expr = t.select(t.int_col, other=t.limit(1).int_col)
expected = Project(
parent=t,
values={
"int_col": t.int_col,
"other": ops.ScalarSubquery(t.limit(1).int_col.as_table()),
},
)
assert expr.op() == expected


def test_mutate():
proj = t.select(t, other=t.int_col + 1)
expected = Project(
Expand Down
1 change: 0 additions & 1 deletion ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import ibis.expr.operations as ops
from ibis.expr.types.generic import Column, Scalar, Value, literal
from ibis.expr.types.typing import V
from ibis.common.deferred import deferrable

if TYPE_CHECKING:
Expand Down
22 changes: 19 additions & 3 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,9 +1137,9 @@ def aggregate(
metrics = unwrap_aliases(metrics)
having = unwrap_aliases(having)

groups = dereference_values(self.op(), groups)
metrics = dereference_values(self.op(), metrics)
having = dereference_values(self.op(), having)
groups = dereference_values(node, groups)
metrics = dereference_values(node, metrics)
having = dereference_values(node, having)

# the user doesn't need to specify the metrics used in the having clause
# explicitly, we implicitly add them to the metrics list by looking for
Expand Down Expand Up @@ -1816,6 +1816,22 @@ def intersect(self, table: Table, *rest: Table, distinct: bool = True) -> Table:
node = ops.Intersection(node, table, distinct=distinct)
return node.to_expr().select(self.columns)

def to_array(self) -> ir.Column:
"""View a single column table as an array.
Returns
-------
Value
A single column view of a table
"""
schema = self.schema()
if len(schema) != 1:
raise com.ExpressionError(
"Table must have exactly one column when viewed as array"
)

return ops.ScalarSubquery(self).to_expr()

def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Table:
"""Add columns to a table expression.
Expand Down

0 comments on commit 1e28de8

Please sign in to comment.