diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py index b984599270212..4bd99fabe6478 100644 --- a/ibis/expr/builders.py +++ b/ibis/expr/builders.py @@ -226,7 +226,7 @@ def order_by(self, expr) -> Self: return self.copy(orderings=self.orderings + util.promote_tuple(expr)) def bind(self, table): - from ibis.expr.types.relations import bind + # from ibis.expr.types.relations import bind if table is None: if self._table is None: @@ -234,9 +234,9 @@ def bind(self, table): else: table = self._table.to_expr() - grouping = bind(table, self.groupings) - orderings = bind(table, self.orderings) - return self.copy(groupings=grouping, orderings=orderings) + return self.copy( + groupings=table.bind(self.groupings), orderings=table.bind(self.orderings) + ) class LegacyWindowBuilder(WindowBuilder): diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py index 3e18946eebb92..cbcc20a97f059 100644 --- a/ibis/expr/types/groupby.py +++ b/ibis/expr/types/groupby.py @@ -28,10 +28,9 @@ from ibis.common.grounds import Concrete from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.rewrites import rewrite_window_input -from ibis.expr.types.relations import bind if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Sequence @public @@ -65,13 +64,12 @@ def __getattr__(self, attr): def aggregate(self, *metrics, **kwds) -> ir.Table: """Compute aggregates over a group by.""" - return self.table.to_expr().aggregate( - metrics, by=self.groupings, having=self.havings, **kwds - ) + metrics = self.table.to_expr().bind(*metrics, **kwds) + return self.table.to_expr().aggregate(metrics, by=self.groupings, having=self.havings) agg = aggregate - def having(self, *expr: ir.BooleanScalar) -> GroupedTable: + def having(self, *predicates: ir.BooleanScalar) -> GroupedTable: """Add a post-aggregation result filter `expr`. ::: {.callout-warning} @@ -80,8 +78,8 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable: Parameters ---------- - expr - An expression that filters based on an aggregate value. + predicates + Expressions that filters based on an aggregate value. Returns ------- @@ -89,10 +87,10 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable: A grouped table expression """ table = self.table.to_expr() - havings = tuple(bind(table, expr)) + havings = table.bind(*predicates) return self.copy(havings=self.havings + havings) - def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: + def order_by(self, *by: ir.Value) -> GroupedTable: """Sort a grouped table expression by `expr`. Notes @@ -101,7 +99,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: Parameters ---------- - expr + by Expressions to order the results by Returns @@ -110,7 +108,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: A sorted grouped GroupedTable """ table = self.table.to_expr() - orderings = tuple(bind(table, expr)) + orderings = table.bind(*by) return self.copy(orderings=self.orderings + orderings) def mutate( @@ -201,7 +199,7 @@ def _selectables(self, *exprs, **kwexprs): [`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate) """ table = self.table.to_expr() - values = bind(table, (exprs, kwexprs)) + values = table.bind(*exprs, **kwexprs) window = ibis.window(group_by=self.groupings, order_by=self.orderings) return [rewrite_window_input(expr.op(), window).to_expr() for expr in values] diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index ff081a353a510..7accd5a615b69 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -176,9 +176,11 @@ def prepare_predicates( else: lk = rk = pred + # TODO(kszucs): bind can emit multiple predicates, this would allow + # selectors to be used as join keys, use zip() # bind the predicates to the join chain - (left_value,) = bind(left, lk) - (right_value,) = bind(right, rk) + (left_value,) = left.bind(lk) + (right_value,) = right.bind(rk) # dereference the left value to one of the relations in the join chain left_value = deref_left.dereference(left_value.op()) @@ -336,7 +338,7 @@ def asof_join( filtered, predicates=[left_on == right_on] + predicates ) values = {**self.op().values, **filtered.op().values} - return result.select(values) + return result.select(**values) chain = self.op() right = right.op() @@ -383,7 +385,7 @@ def cross_join( @functools.wraps(Table.select) def select(self, *args, **kwargs): chain = self.op() - values = bind(self, (args, kwargs)) + values = self.bind(*args, **kwargs) values = unwrap_aliases(values) links = [link.table for link in chain.rest if link.how not in ("semi", "anti")] diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 6678f3e9f6900..5ebbb42071c64 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -91,18 +91,11 @@ def f( # noqa: D417 return f -# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting -# nested inputs -def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]: +def bind(table: Table, value) -> Iterator[ir.Value]: """Bind a value to a table expression.""" if isinstance(value, str): # TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg yield ops.Field(table, value).to_expr() - elif isinstance(value, bool): - yield literal(value) - elif int_as_column and isinstance(value, int): - name = table.columns[value] - yield ops.Field(table, name).to_expr() elif isinstance(value, ops.Value): yield value.to_expr() elif isinstance(value, Value): @@ -116,13 +109,6 @@ def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]: yield value.resolve({"_": table}) elif isinstance(value, Selector): yield from value.expand(table) - elif isinstance(value, Mapping): - for k, v in value.items(): - for val in bind(table, v, int_as_column=int_as_column): - yield val.name(k) - elif util.is_iterable(value): - for v in value: - yield from bind(table, v, int_as_column=int_as_column) elif callable(value): yield value(table) else: @@ -247,6 +233,28 @@ def _bind_reduction_filter(self, where): return where.resolve(self) + def bind(self, *args, **kwargs): + if len(args) == 1: + if isinstance(args[0], dict): + kwargs = {**args[0], **kwargs} + args = () + else: + args = util.promote_list(args[0]) + + values = [] + for arg in args: + values.extend(bind(self, arg)) + for key, arg in kwargs.items(): + try: + (value,) = bind(self, arg) + except ValueError: + raise com.IbisInputError( + "Keyword arguments cannot produce more than one value" + ) + values.append(value.name(key)) + + return tuple(values) + def as_scalar(self) -> ir.ScalarExpr: """Inform ibis that the table expression should be treated as a scalar. @@ -769,7 +777,12 @@ def __getitem__(self, what): limit, offset = util.slice_to_limit_offset(what, self.count()) return self.limit(limit, offset=offset) - values = tuple(bind(self, what, int_as_column=True)) + args = [ + self.columns[arg] if isinstance(arg, int) else arg + for arg in util.promote_list(what) + ] + values = self.bind(args) + if isinstance(what, (str, int)): assert len(values) == 1 return values[0] @@ -954,7 +967,7 @@ def group_by( from ibis.expr.types.groupby import GroupedTable by = tuple(v for v in by if v is not None) - groups = bind(self, (by, key_exprs)) + groups = self.bind(*by, **key_exprs) return GroupedTable(self, groups) # TODO(kszucs): shouldn't this be ibis.rowid() instead not bound to a specific table? @@ -1133,9 +1146,9 @@ def aggregate( node = self.op() - groups = bind(self, by) - metrics = bind(self, (metrics, kwargs)) - having = tuple(bind(self, having)) + groups = self.bind(by) + metrics = self.bind(metrics, **kwargs) + having = self.bind(having) groups = unwrap_aliases(groups) metrics = unwrap_aliases(metrics) @@ -1672,7 +1685,7 @@ def order_by( │ 2 │ B │ 6 │ └───────┴────────┴───────┘ """ - keys = bind(self, by) + keys = self.bind(*by) keys = unwrap_aliases(keys) keys = dereference_values(self.op(), keys) if not keys: @@ -1921,7 +1934,7 @@ def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Tab # string and integer inputs are going to be coerced to literals instead # of interpreted as column references like in select node = self.op() - values = bind(self, (exprs, mutations)) + values = self.bind(*exprs, **mutations) values = unwrap_aliases(values) # allow overriding of fields, hence the mutation behavior values = {**node.fields, **values} @@ -2106,7 +2119,7 @@ def select( """ from ibis.expr.rewrites import rewrite_project_input - values = bind(self, (exprs, named_exprs)) + values = self.bind(*exprs, **named_exprs) values = unwrap_aliases(values) values = dereference_values(self.op(), values) if not values: @@ -2483,7 +2496,7 @@ def filter( from ibis.expr.analysis import flatten_predicates from ibis.expr.rewrites import rewrite_filter_input - preds = bind(self, predicates) + preds = self.bind(*predicates) preds = unwrap_aliases(preds) preds = dereference_values(self.op(), preds) preds = flatten_predicates(list(preds.values())) @@ -2619,7 +2632,7 @@ def dropna( 344 """ if subset is not None: - subset = bind(self, subset) + subset = self.bind(subset) return ops.DropNa(self, how, subset).to_expr() def fillna( diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index c8b2ed0103ce9..0e365c9acf5a9 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -267,10 +267,10 @@ def test_projection_array_expr(table): assert_equal(result, expected) -@pytest.mark.parametrize("empty", [list(), dict()]) -def test_projection_no_expr(table, empty): - with pytest.raises(com.IbisTypeError, match="must select at least one"): - table.select(empty) +# @pytest.mark.parametrize("empty", [list(), dict()]) +# def test_projection_no_expr(table, empty): +# with pytest.raises(com.IbisTypeError, match="must select at least one"): +# table.select(empty) # FIXME(kszucs): currently bind() flattens the list of expressions, so arbitrary @@ -2077,3 +2077,67 @@ def test_unbind_with_namespace(): assert s.op() == expected.op() assert s.equals(expected) + + +def test_table_bind(): + def eq(left, right): + return all(a.equals(b) for a, b in zip(left, right)) + + t = ibis.table({"a": "int", "b": "string"}, name="t") + + # single table arg + exprs = t.bind(t) + expected = (t.a, t.b) + assert eq(exprs, expected) + + # single selector arg + exprs = t.bind(s.all()) + expected = (t.a, t.b) + assert eq(exprs, expected) + + # single tuple arg + exprs = t.bind([1, "a"]) + expected = (ibis.literal(1), t.a) + assert eq(exprs, expected) + + # single list arg + exprs = t.bind([1, 2, "b"]) + expected = (ibis.literal(1), ibis.literal(2), t.b) + assert eq(exprs, expected) + + # single list arg with kwargs + exprs = t.bind([1], b=2) + expected = (ibis.literal(1), ibis.literal(2).name("b")) + assert eq(exprs, expected) + + # single dict arg + exprs = t.bind({"c": 1, "d": 2}) + expected = (ibis.literal(1).name("c"), ibis.literal(2).name("d")) + assert eq(exprs, expected) + + # single dict arg with kwargs + exprs = t.bind({"c": 1}, d=2) + expected = (ibis.literal(1).name("c"), ibis.literal(2).name("d")) + assert eq(exprs, expected) + + # single dict arg with overlapping kwargs + exprs = t.bind({"c": 1, "d": 2}, c=2) + expected = (ibis.literal(2).name("c"), ibis.literal(2).name("d")) + assert eq(exprs, expected) + + # kwargs cannot cannot produce more than one value + with pytest.raises(com.IbisInputError): + t.bind(alias=t) + with pytest.raises(com.IbisInputError): + t.bind(alias=s.all()) + + # multiple args + exprs = t.bind(t, ["a", "b"], {"c": 1}, d=2) + expected = ( + t.a, + t.b, + ibis.literal(["a", "b"]), + ibis.literal({"c": 1}), + ibis.literal(2).name("d"), + ) + assert eq(exprs, expected)