Skip to content

Commit

Permalink
feat: make Table.bind() a public API
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Apr 19, 2024
1 parent 08a33e9 commit 04b3033
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 70 deletions.
8 changes: 3 additions & 5 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,15 @@ def group_by(self, expr) -> Self:
def order_by(self, expr) -> Self:
return self.copy(orderings=self.orderings + util.promote_tuple(expr))

def bind(self, table):
from ibis.expr.types.relations import bind

def bind(self, table: ir.Table):
if table is None:
if self._table is None:
raise IbisInputError("Cannot bind window frame without a table")
else:
table = self._table.to_expr()

grouping = bind(table, self.groupings)
orderings = bind(table, self.orderings)
grouping = table.bind(self.groupings)
orderings = table.bind(self.orderings)
return self.copy(groupings=grouping, orderings=orderings)


Expand Down
5 changes: 1 addition & 4 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,8 +1890,6 @@ def topk(
Table
A top-k expression
"""
from ibis.expr.types.relations import bind

try:
(table,) = self.op().relations
except ValueError:
Expand All @@ -1902,8 +1900,7 @@ def topk(
if by is None:
by = lambda t: t.count()

(metric,) = bind(table, by)

(metric,) = table.bind(by)
return table.aggregate(metric, by=[self]).order_by(metric.desc()).limit(k)

def arbitrary(
Expand Down
7 changes: 3 additions & 4 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import rewrite_window_input
from ibis.expr.types.relations import bind

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -89,7 +88,7 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable:
A grouped table expression
"""
table = self.table.to_expr()
havings = tuple(bind(table, expr))
havings = table.bind(expr)
return self.copy(havings=self.havings + havings)

def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
Expand All @@ -110,7 +109,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
A sorted grouped GroupedTable
"""
table = self.table.to_expr()
orderings = tuple(bind(table, expr))
orderings = table.bind(expr)
return self.copy(orderings=self.orderings + orderings)

def mutate(
Expand Down Expand Up @@ -201,7 +200,7 @@ def _selectables(self, *exprs, **kwexprs):
[`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate)
"""
table = self.table.to_expr()
values = bind(table, (exprs, kwexprs))
values = table.bind((exprs, kwexprs))
window = ibis.window(group_by=self.groupings, order_by=self.orderings)
return [rewrite_window_input(expr.op(), window).to_expr() for expr in values]

Expand Down
7 changes: 3 additions & 4 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ibis.expr.types.generic import Value
from ibis.expr.types.relations import (
Table,
bind,
dereference_mapping,
unwrap_aliases,
)
Expand Down Expand Up @@ -221,8 +220,8 @@ def prepare_predicates(
lk = rk = pred

# bind the predicates to the join chain
(left_value,) = bind(left, lk)
(right_value,) = bind(right, rk)
(left_value,) = left.bind(lk)
(right_value,) = right.bind(rk)

# dereference the left value to one of the relations in the join chain
left_value, right_value = dereference_sides(
Expand Down Expand Up @@ -425,7 +424,7 @@ def cross_join(
@functools.wraps(Table.select)
def select(self, *args, **kwargs):
chain = self.op()
values = bind(self, (args, kwargs))
values = self.bind((args, kwargs))
values = unwrap_aliases(values)

# if there are values referencing fields from the join chain constructed
Expand Down
205 changes: 157 additions & 48 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,44 +95,6 @@ def f( # noqa: D417
return f


# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting
# nested inputs
def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]:
"""Bind a value to a table expression."""
if isinstance(value, str):
# TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg
yield ops.Field(table, value).to_expr()
elif isinstance(value, bool):
yield literal(value)
elif int_as_column and isinstance(value, int):
name = table.columns[value]
yield ops.Field(table, name).to_expr()
elif isinstance(value, ops.Value):
yield value.to_expr()
elif isinstance(value, Value):
yield value
elif isinstance(value, Table):
for name in value.columns:
yield ops.Field(value, name).to_expr()
elif isinstance(value, Deferred):
yield value.resolve(table)
elif isinstance(value, Resolver):
yield value.resolve({"_": table})
elif isinstance(value, Selector):
yield from value.expand(table)
elif isinstance(value, Mapping):
for k, v in value.items():
for val in bind(table, v, int_as_column=int_as_column):
yield val.name(k)
elif util.is_iterable(value):
for v in value:
yield from bind(table, v, int_as_column=int_as_column)
elif callable(value):
yield value(table)
else:
yield literal(value)


def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]:
"""Unwrap aliases into a mapping of {name: expression}."""
result = {}
Expand Down Expand Up @@ -569,6 +531,153 @@ def preview(
console_width=console_width,
)

# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting
# nested inputs
def bind(self, value: Any, *, int_as_column: bool = False) -> tuple[ir.Value]:
"""Resolve value(s) to a tuple of Columns or Scalars.
This is one of the most flexible way or selecting values from a Table,
accepting a wide range of inputs types.
This is similar to .select(), but this returns a tuple of Columns or Scalars,
while .select() returns a new Table.
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.examples.diamonds.fetch().head(3)
>>> t
┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓
┃ carat ┃ cut ┃ color ┃ clarity ┃ depth ┃ table ┃ price ┃ x ┃ y ┃ z ┃
┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩
│ float64 │ string │ string │ string │ float64 │ float64 │ int64 │ float64 │ float64 │ float64 │
├─────────┼───────────┼────────┼─────────┼─────────┼─────────┼───────┼─────────┼─────────┼─────────┤
│ 0.23 │ Ideal │ E │ SI2 │ 61.5 │ 55.0 │ 326 │ 3.95 │ 3.98 │ 2.43 │
│ 0.21 │ Premium │ E │ SI1 │ 59.8 │ 61.0 │ 326 │ 3.89 │ 3.84 │ 2.31 │
│ 0.23 │ Good │ E │ VS1 │ 56.9 │ 65.0 │ 327 │ 4.05 │ 4.07 │ 2.31 │
└─────────┴───────────┴────────┴─────────┴─────────┴─────────┴───────┴─────────┴─────────┴─────────┘
The simplest use is to retrieve columns:
>>> t.bind("cut")
(┏━━━━━━━━━━━┓
┃ cut ┃
┡━━━━━━━━━━━┩
│ string │
├───────────┤
│ Ideal │
│ Premium │
│ Good │
└───────────┘,)
>>> t.bind(["carat", t.cut, ibis._.price + 1])
(┏━━━━━━━━━┓
┃ carat ┃
┡━━━━━━━━━┩
│ float64 │
├─────────┤
│ 0.23 │
│ 0.21 │
│ 0.23 │
└─────────┘,
┏━━━━━━━━━┓
┃ cut ┃
┡━━━━━━━━━┩
│ string │
├─────────┤
│ Ideal │
│ Premium │
│ Good │
└─────────┘,
┏━━━━━━━━━━━━━━━┓
┃ Add(price, 1) ┃
┡━━━━━━━━━━━━━━━┩
│ int64 │
├───────────────┤
│ 327 │
│ 327 │
│ 328 │
└───────────────┘)
Args that look like literals are returned as literals...
>>> t.bind(False)
(False,)
>>> t.bind(3)
(3,)
...unless you explicitly want to retrieve a column by index:
>>> t.bind(3, int_as_column=True)
(┏━━━━━━━━━┓
┃ clarity ┃
┡━━━━━━━━━┩
│ string │
├─────────┤
│ SI2 │
│ SI1 │
│ VS1 │
└─────────┘,)
Use a mapping to rename values in one go:
>>> t.bind({"x": "price", "y": ibis._.clarity})
(┏━━━━━━━┓
┃ x ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 326 │
│ 326 │
│ 327 │
└───────┘,
┏━━━━━━━━┓
┃ y ┃
┡━━━━━━━━┩
│ string │
├────────┤
│ SI2 │
│ SI1 │
│ VS1 │
└────────┘)
"""

def inner() -> Iterator[ir.Value]:
if isinstance(value, str):
# TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg
yield ops.Field(self, value).to_expr()
elif isinstance(value, bool):
yield literal(value)
elif int_as_column and isinstance(value, int):
name = self.columns[value]
yield ops.Field(self, name).to_expr()
elif isinstance(value, ops.Value):
yield value.to_expr()
elif isinstance(value, Value):
yield value
elif isinstance(value, Table):
for name in value.columns:
yield ops.Field(value, name).to_expr()
elif isinstance(value, Deferred):
yield value.resolve(self)
elif isinstance(value, Resolver):
yield value.resolve({"_": self})
elif isinstance(value, Selector):
yield from value.expand(self)
elif isinstance(value, Mapping):
for k, v in value.items():
for val in self.bind(v, int_as_column=int_as_column):
yield val.name(k)
elif util.is_iterable(value):
for v in value:
yield from self.bind(v, int_as_column=int_as_column)
elif callable(value):
yield value(self)
else:
yield literal(value)

return tuple(inner())

def __getitem__(self, what):
"""Select items from a table expression.
Expand Down Expand Up @@ -819,7 +928,7 @@ def __getitem__(self, what):
limit, offset = util.slice_to_limit_offset(what, self.count())
return self.limit(limit, offset=offset)

values = tuple(bind(self, what, int_as_column=True))
values = self.bind(what, int_as_column=True)
if isinstance(what, (str, int)):
assert len(values) == 1
return values[0]
Expand Down Expand Up @@ -1004,7 +1113,7 @@ def group_by(
from ibis.expr.types.groupby import GroupedTable

by = tuple(v for v in by if v is not None)
groups = bind(self, (by, key_exprs))
groups = self.bind((by, key_exprs))
return GroupedTable(self, groups)

# TODO(kszucs): shouldn't this be ibis.rowid() instead not bound to a specific table?
Expand Down Expand Up @@ -1183,9 +1292,9 @@ def aggregate(

node = self.op()

groups = bind(self, by)
metrics = bind(self, (metrics, kwargs))
having = bind(self, having)
groups = self.bind(by)
metrics = self.bind((metrics, kwargs))
having = self.bind(having)

groups = unwrap_aliases(groups)
metrics = unwrap_aliases(metrics)
Expand Down Expand Up @@ -1724,7 +1833,7 @@ def order_by(
│ 2 │ B │ 6 │
└───────┴────────┴───────┘
"""
keys = bind(self, by)
keys = self.bind(by)
keys = unwrap_aliases(keys)
keys = dereference_values(self.op(), keys)
if not keys:
Expand Down Expand Up @@ -1973,7 +2082,7 @@ def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Tab
# string and integer inputs are going to be coerced to literals instead
# of interpreted as column references like in select
node = self.op()
values = bind(self, (exprs, mutations))
values = self.bind((exprs, mutations))
values = unwrap_aliases(values)
# allow overriding of fields, hence the mutation behavior
values = {**node.fields, **values}
Expand Down Expand Up @@ -2158,7 +2267,7 @@ def select(
"""
from ibis.expr.rewrites import rewrite_project_input

values = bind(self, (exprs, named_exprs))
values = self.bind((exprs, named_exprs))
values = unwrap_aliases(values)
values = dereference_values(self.op(), values)
if not values:
Expand Down Expand Up @@ -2535,7 +2644,7 @@ def filter(
from ibis.expr.analysis import flatten_predicates
from ibis.expr.rewrites import rewrite_filter_input

preds = bind(self, predicates)
preds = self.bind(predicates)
preds = unwrap_aliases(preds)
preds = dereference_values(self.op(), preds)
preds = flatten_predicates(list(preds.values()))
Expand Down Expand Up @@ -2671,7 +2780,7 @@ def dropna(
344
"""
if subset is not None:
subset = bind(self, subset)
subset = self.bind(subset)
return ops.DropNa(self, how, subset).to_expr()

def fillna(
Expand Down
Loading

0 comments on commit 04b3033

Please sign in to comment.