diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 626e9dc5511f..fe1cd0eb9233 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -756,6 +756,12 @@ def mean_and_std(v): id="is_in", marks=[mark.broken(["datafusion"], raises=AssertionError)], ), + param( + lambda _: ibis._.string_col.isin(["1", "7"]), + lambda t: t.string_col.isin(["1", "7"]), + id="is_in_deferred", + marks=[mark.broken(["datafusion"], raises=AssertionError)], + ), ], ) def test_reduction_ops( diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 2a97a99c5f4c..3430a5b689f9 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -704,7 +704,7 @@ def _make_any( predicates=predicates, ) else: - op = any_op_class(expr, where=where) + op = any_op_class(expr, where=expr._bind_reduction_filter(where)) return op.to_expr() diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index f0f5d422cdaf..1c6e0e1407c5 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -7,6 +7,7 @@ import ibis import ibis.common.exceptions as com import ibis.expr.datatypes as dt +from ibis.expr.deferred import Deferred import ibis.expr.operations as ops from ibis.common.grounds import Singleton from ibis.expr.types.core import Expr, _binop, _FixedTextJupyterMixin @@ -1030,7 +1031,9 @@ def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: │ b │ [4, 5] │ └────────┴──────────────────────┘ """ - return ops.ArrayCollect(self, where=where).to_expr() + return ops.ArrayCollect( + self, where=self._bind_reduction_filter(where) + ).to_expr() def identical_to(self, other: Value) -> ir.BooleanValue: """Return whether this expression is identical to other. @@ -1106,7 +1109,9 @@ def group_concat( >>> t.bill_length_mm.group_concat(sep=": ", where=t.bill_depth_mm > 18) '39.1: 36.7' """ - return ops.GroupConcat(self, sep=sep, where=where).to_expr() + return ops.GroupConcat( + self, sep=sep, where=self._bind_reduction_filter(where) + ).to_expr() def __hash__(self) -> int: return super().__hash__() @@ -1286,6 +1291,14 @@ def __pandas_result__(self, df: pd.DataFrame) -> pd.Series: (column,) = df.columns return PandasData.convert_column(df.loc[:, column], self.type()) + def _bind_reduction_filter(self, where): + import ibis.expr.analysis as an + + if where is None or not isinstance(where, Deferred): + return where + + return where.resolve(an.find_first_base_table(self.op()).to_expr()) + def approx_nunique( self, where: ir.BooleanValue | None = None, @@ -1323,7 +1336,9 @@ def approx_nunique( >>> t.body_mass_g.approx_nunique(where=t.species == "Adelie") 55 """ - return ops.ApproxCountDistinct(self, where).to_expr() + return ops.ApproxCountDistinct( + self, where=self._bind_reduction_filter(where) + ).to_expr() def approx_median( self, @@ -1362,7 +1377,9 @@ def approx_median( >>> t.body_mass_g.approx_median(where=t.species == "Chinstrap") 3700 """ - return ops.ApproxMedian(self, where).to_expr() + return ops.ApproxMedian( + self, where=self._bind_reduction_filter(where) + ).to_expr() def mode(self, where: ir.BooleanValue | None = None) -> Scalar: """Return the mode of a column. @@ -1387,7 +1404,7 @@ def mode(self, where: ir.BooleanValue | None = None) -> Scalar: >>> t.body_mass_g.mode(where=(t.species == "Gentoo") & (t.sex == "male")) 5550 """ - return ops.Mode(self, where).to_expr() + return ops.Mode(self, where=self._bind_reduction_filter(where)).to_expr() def max(self, where: ir.BooleanValue | None = None) -> Scalar: """Return the maximum of a column. @@ -1412,7 +1429,7 @@ def max(self, where: ir.BooleanValue | None = None) -> Scalar: >>> t.body_mass_g.max(where=t.species == "Chinstrap") 4800 """ - return ops.Max(self, where).to_expr() + return ops.Max(self, where=self._bind_reduction_filter(where)).to_expr() def min(self, where: ir.BooleanValue | None = None) -> Scalar: """Return the minimum of a column. @@ -1437,7 +1454,7 @@ def min(self, where: ir.BooleanValue | None = None) -> Scalar: >>> t.body_mass_g.min(where=t.species == "Adelie") 2850 """ - return ops.Min(self, where).to_expr() + return ops.Min(self, where=self._bind_reduction_filter(where)).to_expr() def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: """Return the value of `self` that maximizes `key`. @@ -1462,7 +1479,9 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: >>> t.species.argmax(t.body_mass_g, where=t.island == "Dream") 'Chinstrap' """ - return ops.ArgMax(self, key=key, where=where).to_expr() + return ops.ArgMax( + self, key=key, where=self._bind_reduction_filter(where) + ).to_expr() def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: """Return the value of `self` that minimizes `key`. @@ -1488,7 +1507,9 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: >>> t.species.argmin(t.body_mass_g, where=t.island == "Biscoe") 'Adelie' """ - return ops.ArgMin(self, key=key, where=where).to_expr() + return ops.ArgMin( + self, key=key, where=self._bind_reduction_filter(where) + ).to_expr() def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of distinct rows in an expression. @@ -1513,7 +1534,9 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: >>> t.body_mass_g.nunique(where=t.species == "Adelie") 55 """ - return ops.CountDistinct(self, where).to_expr() + return ops.CountDistinct( + self, where=self._bind_reduction_filter(where) + ).to_expr() def topk( self, @@ -1586,7 +1609,9 @@ def arbitrary( Scalar An expression """ - return ops.Arbitrary(self, how=how, where=where).to_expr() + return ops.Arbitrary( + self, how=how, where=self._bind_reduction_filter(where) + ).to_expr() def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of rows in an expression. @@ -1601,7 +1626,7 @@ def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: IntegerScalar Number of elements in an expression """ - return ops.Count(self, where).to_expr() + return ops.Count(self, where=self._bind_reduction_filter(where)).to_expr() def value_counts(self) -> ir.Table: """Compute a frequency table. @@ -1677,7 +1702,7 @@ def first(self, where: ir.BooleanValue | None = None) -> Value: >>> t.chars.first(where=t.chars != "a") 'b' """ - return ops.First(self, where=where).to_expr() + return ops.First(self, where=self._bind_reduction_filter(where)).to_expr() def last(self, where: ir.BooleanValue | None = None) -> Value: """Return the last value of a column. @@ -1703,7 +1728,7 @@ def last(self, where: ir.BooleanValue | None = None) -> Value: >>> t.chars.last(where=t.chars != "d") 'c' """ - return ops.Last(self, where=where).to_expr() + return ops.Last(self, where=self._bind_reduction_filter(where)).to_expr() def rank(self) -> ir.IntegerColumn: """Compute position of first element within each equal-value group in sorted order. diff --git a/ibis/expr/types/logical.py b/ibis/expr/types/logical.py index b559746d51c8..4f79ca210d36 100644 --- a/ibis/expr/types/logical.py +++ b/ibis/expr/types/logical.py @@ -329,7 +329,7 @@ def all(self, where: BooleanValue | None = None) -> BooleanScalar: False """ - return ops.All(self, where=where).to_expr() + return ops.All(self, where=self._bind_reduction_filter(where)).to_expr() def notall(self, where: BooleanValue | None = None) -> BooleanScalar: """Return whether not all elements are `True`. @@ -358,7 +358,7 @@ def notall(self, where: BooleanValue | None = None) -> BooleanScalar: >>> (t.arr == 2).notall(where=t.arr >= 2) True """ - return ops.NotAll(self, where=where).to_expr() + return ops.NotAll(self, where=self._bind_reduction_filter(where)).to_expr() def cumany(self, *, where=None, group_by=None, order_by=None) -> BooleanColumn: """Accumulate the `any` aggregate. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index b6ad3e4de129..0fbb012ac5c5 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -779,7 +779,7 @@ def median(self, where: ir.BooleanValue | None = None) -> NumericScalar: NumericScalar Median of the column """ - return ops.Median(self, where=where).to_expr() + return ops.Median(self, where=self._bind_reduction_filter(where)).to_expr() def quantile( self, @@ -804,7 +804,7 @@ def quantile( op = ops.MultiQuantile else: op = ops.Quantile - return op(self, quantile, where=where).to_expr() + return op(self, quantile, where=self._bind_reduction_filter(where)).to_expr() def std( self, @@ -825,7 +825,9 @@ def std( NumericScalar Standard deviation of `arg` """ - return ops.StandardDev(self, how=how, where=where).to_expr() + return ops.StandardDev( + self, how=how, where=self._bind_reduction_filter(where) + ).to_expr() def var( self, @@ -846,7 +848,9 @@ def var( NumericScalar Standard deviation of `arg` """ - return ops.Variance(self, how=how, where=where).to_expr() + return ops.Variance( + self, how=how, where=self._bind_reduction_filter(where) + ).to_expr() def corr( self, @@ -870,7 +874,9 @@ def corr( NumericScalar The correlation of `left` and `right` """ - return ops.Correlation(self, right, how=how, where=where).to_expr() + return ops.Correlation( + self, right, how=how, where=self._bind_reduction_filter(where) + ).to_expr() def cov( self, @@ -894,7 +900,9 @@ def cov( NumericScalar The covariance of `self` and `right` """ - return ops.Covariance(self, right, how=how, where=where).to_expr() + return ops.Covariance( + self, right, how=how, where=self._bind_reduction_filter(where) + ).to_expr() def mean( self, @@ -914,7 +922,7 @@ def mean( """ # TODO(kszucs): remove the alias from the reduction method in favor # of default name generated by ops.Value operations - return ops.Mean(self, where=where).to_expr() + return ops.Mean(self, where=self._bind_reduction_filter(where)).to_expr() def cummean(self, *, where=None, group_by=None, order_by=None) -> NumericColumn: """Return the cumulative mean of the input.""" @@ -938,7 +946,7 @@ def sum( NumericScalar The sum of the input expression """ - return ops.Sum(self, where=where).to_expr() + return ops.Sum(self, where=self._bind_reduction_filter(where)).to_expr() def cumsum(self, *, where=None, group_by=None, order_by=None) -> NumericColumn: """Return the cumulative sum of the input.""" @@ -1192,15 +1200,15 @@ class IntegerScalar(NumericScalar, IntegerValue): class IntegerColumn(NumericColumn, IntegerValue): def bit_and(self, where: ir.BooleanValue | None = None) -> IntegerScalar: """Aggregate the column using the bitwise and operator.""" - return ops.BitAnd(self, where).to_expr() + return ops.BitAnd(self, where=self._bind_reduction_filter(where)).to_expr() def bit_or(self, where: ir.BooleanValue | None = None) -> IntegerScalar: """Aggregate the column using the bitwise or operator.""" - return ops.BitOr(self, where).to_expr() + return ops.BitOr(self, where=self._bind_reduction_filter(where)).to_expr() def bit_xor(self, where: ir.BooleanValue | None = None) -> IntegerScalar: """Aggregate the column using the bitwise exclusive or operator.""" - return ops.BitXor(self, where).to_expr() + return ops.BitXor(self, where=self._bind_reduction_filter(where)).to_expr() @public diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index da40a802965c..f8ecd9b4204f 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -160,6 +160,12 @@ def __pandas_result__(self, df: pd.DataFrame) -> pd.DataFrame: return PandasData.convert_table(df, self.schema()) + def _bind_reduction_filter(self, where): + if where is None or not isinstance(where, Deferred): + return where + + return where.resolve(self) + def as_table(self) -> Table: """Promote the expression to a table. @@ -2313,7 +2319,9 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: >>> t.nunique(t.a != "foo") 1 """ - return ops.CountDistinctStar(self, where=where).to_expr() + return ops.CountDistinctStar( + self, where=self._bind_reduction_filter(where) + ).to_expr() def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of rows in the table. @@ -2350,7 +2358,7 @@ def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: >>> type(t.count()) """ - return ops.CountStar(self, where).to_expr() + return ops.CountStar(self, where=self._bind_reduction_filter(where)).to_expr() def dropna( self,