Skip to content

Commit

Permalink
feat(api): support deferred in reduction filters
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Oct 10, 2023
1 parent 90b4bf7 commit 349f475
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 30 deletions.
6 changes: 6 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,12 @@ def mean_and_std(v):
id="is_in",
marks=[mark.broken(["datafusion"], raises=AssertionError)],
),
param(
lambda _: ibis._.string_col.isin(["1", "7"]),
lambda t: t.string_col.isin(["1", "7"]),
id="is_in_deferred",
marks=[mark.broken(["datafusion"], raises=AssertionError)],
),
],
)
def test_reduction_ops(
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def _make_any(
predicates=predicates,
)
else:
op = any_op_class(expr, where=where)
op = any_op_class(expr, where=expr._bind_reduction_filter(where))
return op.to_expr()


Expand Down
53 changes: 39 additions & 14 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
from ibis.expr.deferred import Deferred
import ibis.expr.operations as ops
from ibis.common.grounds import Singleton
from ibis.expr.types.core import Expr, _binop, _FixedTextJupyterMixin
Expand Down Expand Up @@ -1030,7 +1031,9 @@ def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar:
│ b │ [4, 5] │
└────────┴──────────────────────┘
"""
return ops.ArrayCollect(self, where=where).to_expr()
return ops.ArrayCollect(
self, where=self._bind_reduction_filter(where)
).to_expr()

def identical_to(self, other: Value) -> ir.BooleanValue:
"""Return whether this expression is identical to other.
Expand Down Expand Up @@ -1106,7 +1109,9 @@ def group_concat(
>>> t.bill_length_mm.group_concat(sep=": ", where=t.bill_depth_mm > 18)
'39.1: 36.7'
"""
return ops.GroupConcat(self, sep=sep, where=where).to_expr()
return ops.GroupConcat(
self, sep=sep, where=self._bind_reduction_filter(where)
).to_expr()

def __hash__(self) -> int:
return super().__hash__()
Expand Down Expand Up @@ -1286,6 +1291,14 @@ def __pandas_result__(self, df: pd.DataFrame) -> pd.Series:
(column,) = df.columns
return PandasData.convert_column(df.loc[:, column], self.type())

def _bind_reduction_filter(self, where):
import ibis.expr.analysis as an

if where is None or not isinstance(where, Deferred):
return where

return where.resolve(an.find_first_base_table(self.op()).to_expr())

def approx_nunique(
self,
where: ir.BooleanValue | None = None,
Expand Down Expand Up @@ -1323,7 +1336,9 @@ def approx_nunique(
>>> t.body_mass_g.approx_nunique(where=t.species == "Adelie")
55
"""
return ops.ApproxCountDistinct(self, where).to_expr()
return ops.ApproxCountDistinct(
self, where=self._bind_reduction_filter(where)
).to_expr()

def approx_median(
self,
Expand Down Expand Up @@ -1362,7 +1377,9 @@ def approx_median(
>>> t.body_mass_g.approx_median(where=t.species == "Chinstrap")
3700
"""
return ops.ApproxMedian(self, where).to_expr()
return ops.ApproxMedian(
self, where=self._bind_reduction_filter(where)
).to_expr()

def mode(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the mode of a column.
Expand All @@ -1387,7 +1404,7 @@ def mode(self, where: ir.BooleanValue | None = None) -> Scalar:
>>> t.body_mass_g.mode(where=(t.species == "Gentoo") & (t.sex == "male"))
5550
"""
return ops.Mode(self, where).to_expr()
return ops.Mode(self, where=self._bind_reduction_filter(where)).to_expr()

def max(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the maximum of a column.
Expand All @@ -1412,7 +1429,7 @@ def max(self, where: ir.BooleanValue | None = None) -> Scalar:
>>> t.body_mass_g.max(where=t.species == "Chinstrap")
4800
"""
return ops.Max(self, where).to_expr()
return ops.Max(self, where=self._bind_reduction_filter(where)).to_expr()

def min(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the minimum of a column.
Expand All @@ -1437,7 +1454,7 @@ def min(self, where: ir.BooleanValue | None = None) -> Scalar:
>>> t.body_mass_g.min(where=t.species == "Adelie")
2850
"""
return ops.Min(self, where).to_expr()
return ops.Min(self, where=self._bind_reduction_filter(where)).to_expr()

def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the value of `self` that maximizes `key`.
Expand All @@ -1462,7 +1479,9 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
>>> t.species.argmax(t.body_mass_g, where=t.island == "Dream")
'Chinstrap'
"""
return ops.ArgMax(self, key=key, where=where).to_expr()
return ops.ArgMax(
self, key=key, where=self._bind_reduction_filter(where)
).to_expr()

def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the value of `self` that minimizes `key`.
Expand All @@ -1488,7 +1507,9 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
>>> t.species.argmin(t.body_mass_g, where=t.island == "Biscoe")
'Adelie'
"""
return ops.ArgMin(self, key=key, where=where).to_expr()
return ops.ArgMin(
self, key=key, where=self._bind_reduction_filter(where)
).to_expr()

def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of distinct rows in an expression.
Expand All @@ -1513,7 +1534,9 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
>>> t.body_mass_g.nunique(where=t.species == "Adelie")
55
"""
return ops.CountDistinct(self, where).to_expr()
return ops.CountDistinct(
self, where=self._bind_reduction_filter(where)
).to_expr()

def topk(
self,
Expand Down Expand Up @@ -1586,7 +1609,9 @@ def arbitrary(
Scalar
An expression
"""
return ops.Arbitrary(self, how=how, where=where).to_expr()
return ops.Arbitrary(
self, how=how, where=self._bind_reduction_filter(where)
).to_expr()

def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of rows in an expression.
Expand All @@ -1601,7 +1626,7 @@ def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
IntegerScalar
Number of elements in an expression
"""
return ops.Count(self, where).to_expr()
return ops.Count(self, where=self._bind_reduction_filter(where)).to_expr()

def value_counts(self) -> ir.Table:
"""Compute a frequency table.
Expand Down Expand Up @@ -1677,7 +1702,7 @@ def first(self, where: ir.BooleanValue | None = None) -> Value:
>>> t.chars.first(where=t.chars != "a")
'b'
"""
return ops.First(self, where=where).to_expr()
return ops.First(self, where=self._bind_reduction_filter(where)).to_expr()

def last(self, where: ir.BooleanValue | None = None) -> Value:
"""Return the last value of a column.
Expand All @@ -1703,7 +1728,7 @@ def last(self, where: ir.BooleanValue | None = None) -> Value:
>>> t.chars.last(where=t.chars != "d")
'c'
"""
return ops.Last(self, where=where).to_expr()
return ops.Last(self, where=self._bind_reduction_filter(where)).to_expr()

def rank(self) -> ir.IntegerColumn:
"""Compute position of first element within each equal-value group in sorted order.
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def all(self, where: BooleanValue | None = None) -> BooleanScalar:
False
"""
return ops.All(self, where=where).to_expr()
return ops.All(self, where=self._bind_reduction_filter(where)).to_expr()

def notall(self, where: BooleanValue | None = None) -> BooleanScalar:
"""Return whether not all elements are `True`.
Expand Down Expand Up @@ -358,7 +358,7 @@ def notall(self, where: BooleanValue | None = None) -> BooleanScalar:
>>> (t.arr == 2).notall(where=t.arr >= 2)
True
"""
return ops.NotAll(self, where=where).to_expr()
return ops.NotAll(self, where=self._bind_reduction_filter(where)).to_expr()

def cumany(self, *, where=None, group_by=None, order_by=None) -> BooleanColumn:
"""Accumulate the `any` aggregate.
Expand Down
30 changes: 19 additions & 11 deletions ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def median(self, where: ir.BooleanValue | None = None) -> NumericScalar:
NumericScalar
Median of the column
"""
return ops.Median(self, where=where).to_expr()
return ops.Median(self, where=self._bind_reduction_filter(where)).to_expr()

def quantile(
self,
Expand All @@ -804,7 +804,7 @@ def quantile(
op = ops.MultiQuantile
else:
op = ops.Quantile
return op(self, quantile, where=where).to_expr()
return op(self, quantile, where=self._bind_reduction_filter(where)).to_expr()

def std(
self,
Expand All @@ -825,7 +825,9 @@ def std(
NumericScalar
Standard deviation of `arg`
"""
return ops.StandardDev(self, how=how, where=where).to_expr()
return ops.StandardDev(
self, how=how, where=self._bind_reduction_filter(where)
).to_expr()

def var(
self,
Expand All @@ -846,7 +848,9 @@ def var(
NumericScalar
Standard deviation of `arg`
"""
return ops.Variance(self, how=how, where=where).to_expr()
return ops.Variance(
self, how=how, where=self._bind_reduction_filter(where)
).to_expr()

def corr(
self,
Expand All @@ -870,7 +874,9 @@ def corr(
NumericScalar
The correlation of `left` and `right`
"""
return ops.Correlation(self, right, how=how, where=where).to_expr()
return ops.Correlation(
self, right, how=how, where=self._bind_reduction_filter(where)
).to_expr()

def cov(
self,
Expand All @@ -894,7 +900,9 @@ def cov(
NumericScalar
The covariance of `self` and `right`
"""
return ops.Covariance(self, right, how=how, where=where).to_expr()
return ops.Covariance(
self, right, how=how, where=self._bind_reduction_filter(where)
).to_expr()

def mean(
self,
Expand All @@ -914,7 +922,7 @@ def mean(
"""
# TODO(kszucs): remove the alias from the reduction method in favor
# of default name generated by ops.Value operations
return ops.Mean(self, where=where).to_expr()
return ops.Mean(self, where=self._bind_reduction_filter(where)).to_expr()

def cummean(self, *, where=None, group_by=None, order_by=None) -> NumericColumn:
"""Return the cumulative mean of the input."""
Expand All @@ -938,7 +946,7 @@ def sum(
NumericScalar
The sum of the input expression
"""
return ops.Sum(self, where=where).to_expr()
return ops.Sum(self, where=self._bind_reduction_filter(where)).to_expr()

def cumsum(self, *, where=None, group_by=None, order_by=None) -> NumericColumn:
"""Return the cumulative sum of the input."""
Expand Down Expand Up @@ -1192,15 +1200,15 @@ class IntegerScalar(NumericScalar, IntegerValue):
class IntegerColumn(NumericColumn, IntegerValue):
def bit_and(self, where: ir.BooleanValue | None = None) -> IntegerScalar:
"""Aggregate the column using the bitwise and operator."""
return ops.BitAnd(self, where).to_expr()
return ops.BitAnd(self, where=self._bind_reduction_filter(where)).to_expr()

def bit_or(self, where: ir.BooleanValue | None = None) -> IntegerScalar:
"""Aggregate the column using the bitwise or operator."""
return ops.BitOr(self, where).to_expr()
return ops.BitOr(self, where=self._bind_reduction_filter(where)).to_expr()

def bit_xor(self, where: ir.BooleanValue | None = None) -> IntegerScalar:
"""Aggregate the column using the bitwise exclusive or operator."""
return ops.BitXor(self, where).to_expr()
return ops.BitXor(self, where=self._bind_reduction_filter(where)).to_expr()


@public
Expand Down
12 changes: 10 additions & 2 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ def __pandas_result__(self, df: pd.DataFrame) -> pd.DataFrame:

return PandasData.convert_table(df, self.schema())

def _bind_reduction_filter(self, where):
if where is None or not isinstance(where, Deferred):
return where

return where.resolve(self)

def as_table(self) -> Table:
"""Promote the expression to a table.
Expand Down Expand Up @@ -2313,7 +2319,9 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
>>> t.nunique(t.a != "foo")
1
"""
return ops.CountDistinctStar(self, where=where).to_expr()
return ops.CountDistinctStar(
self, where=self._bind_reduction_filter(where)
).to_expr()

def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of rows in the table.
Expand Down Expand Up @@ -2350,7 +2358,7 @@ def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
>>> type(t.count())
<class 'ibis.expr.types.numeric.IntegerScalar'>
"""
return ops.CountStar(self, where).to_expr()
return ops.CountStar(self, where=self._bind_reduction_filter(where)).to_expr()

def dropna(
self,
Expand Down

0 comments on commit 349f475

Please sign in to comment.