Skip to content

Commit

Permalink
refactor(api): treat integer inputs as literals instead of column ref…
Browse files Browse the repository at this point in the history
…erences
  • Loading branch information
kszucs committed Apr 8, 2024
1 parent 356e459 commit 423a733
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 60 deletions.
67 changes: 29 additions & 38 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ibis import util
from ibis.common.deferred import Deferred, Resolver
from ibis.expr.types.core import Expr, _FixedTextJupyterMixin
from ibis.expr.types.generic import ValueExpr, literal
from ibis.expr.types.generic import Value, literal
from ibis.selectors import Selector
from ibis.util import deprecated

Expand Down Expand Up @@ -95,15 +95,23 @@ def f( # noqa: D417

# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting
# nested inputs
def bind(table: Table, value: Any) -> Iterator[ir.Value]:
def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]:
"""Bind a value to a table expression."""
if type(value) in (str, int):
yield table._get_column(value)
elif isinstance(value, ValueExpr):
if isinstance(value, str):
# TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg
yield ops.Field(table, value).to_expr()
elif isinstance(value, bool):
yield literal(value)
elif int_as_column and isinstance(value, int):
name = table.columns[value]
yield ops.Field(table, name).to_expr()
elif isinstance(value, ops.Value):
yield value.to_expr()
elif isinstance(value, Value):
yield value
elif isinstance(value, Table):
for name in value.columns:
yield value._get_column(name)
yield ops.Field(table, name).to_expr()
elif isinstance(value, Deferred):
yield value.resolve(table)
elif isinstance(value, Resolver):
Expand All @@ -112,17 +120,11 @@ def bind(table: Table, value: Any) -> Iterator[ir.Value]:
yield from value.expand(table)
elif isinstance(value, Mapping):
for k, v in value.items():
for val in bind(table, v):
for val in bind(table, v, int_as_column=int_as_column):
yield val.name(k)
elif util.is_iterable(value):
for v in value:
yield from bind(table, v)
elif isinstance(value, ops.Value):
# TODO(kszucs): from certain builders, like ir.GroupedTable we pass
# operation nodes instead of expressions to table methods, it would
# be better to convert them to expressions before passing them to
# this function
yield value.to_expr()
yield from bind(table, v, int_as_column=int_as_column)
elif callable(value):
yield value(table)
else:
Expand Down Expand Up @@ -533,13 +535,6 @@ def __interactive_rich_console__(self, console, options):
raise e
return console.render(table, options=options)

# TODO(kszucs): expose this method in the public API
def _get_column(self, name: str | int) -> ir.Column:
"""Get a column from the table."""
if isinstance(name, int):
name = self.schema().name_at_position(name)
return ops.Field(self, name).to_expr()

def __getitem__(self, what):
"""Select items from a table expression.
Expand Down Expand Up @@ -786,22 +781,18 @@ def __getitem__(self, what):
"""
from ibis.expr.types.logical import BooleanValue

if isinstance(what, (str, int)):
return self._get_column(what)
elif isinstance(what, slice):
if isinstance(what, slice):
limit, offset = util.slice_to_limit_offset(what, self.count())
return self.limit(limit, offset=offset)
elif isinstance(what, (list, tuple, Table)):
# Projection case
return self.select(what)

items = tuple(bind(self, what))
if util.all_of(items, BooleanValue):
# TODO(kszucs): this branch should be removed, .filter should be
# used instead
return self.filter(items)

values = tuple(bind(self, what, int_as_column=True))
if isinstance(what, (str, int)):
assert len(values) == 1
return values[0]
elif util.all_of(values, BooleanValue):
return self.filter(values)
else:
return self.select(items)
return self.select(values)

def __len__(self):
raise com.ExpressionError("Use .count() instead")
Expand Down Expand Up @@ -844,7 +835,7 @@ def __getattr__(self, key: str) -> ir.Column:
└───────────┘
"""
try:
return self._get_column(key)
return ops.Field(self, key).to_expr()
except com.IbisTypeError:
pass

Expand Down Expand Up @@ -2039,7 +2030,7 @@ def select(
Projection by zero-indexed column position
>>> t.select(0, 4).head()
>>> t.select(t[0], t[4]).head()
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ species ┃ flipper_length_mm ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
Expand Down Expand Up @@ -4375,10 +4366,10 @@ def relocate(
where = 0

# all columns that should come BEFORE the matched selectors
front = [left for left in range(where) if left not in sels]
front = [self[left] for left in range(where) if left not in sels]

# all columns that should come AFTER the matched selectors
back = [right for right in range(where, ncols) if right not in sels]
back = [self[right] for right in range(where, ncols) if right not in sels]

# selected columns
middle = [self[i].name(name) for i, name in sels.items()]
Expand Down
7 changes: 3 additions & 4 deletions ibis/tests/expr/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def test_if_any(penguins):


def test_negate_range(penguins):
assert penguins.select(~s.r[3:]).equals(penguins.select(0, 1, 2))
assert penguins.select(~s.r[3:]).equals(penguins[[0, 1, 2]])


def test_string_range_start(penguins):
Expand Down Expand Up @@ -378,16 +378,15 @@ def test_all(penguins):
@pytest.mark.parametrize(
("seq", "expected"),
[
param([0, 1, 2], (0, 1, 2), id="int_tuple"),
param(~s.r[[3, 4, 5]], sorted(set(range(8)) - {3, 4, 5}), id="neg_int_list"),
param(~s.r[3, 4, 5], sorted(set(range(8)) - {3, 4, 5}), id="neg_int_tuple"),
param(s.r["island", "year"], ("island", "year"), id="string_tuple"),
param(s.r[["island", "year"]], ("island", "year"), id="string_list"),
param(iter(["island", 4, "year"]), ("island", 4, "year"), id="mixed_iterable"),
param(iter(["island", "year"]), ("island", "year"), id="mixed_iterable"),
],
)
def test_sequence(penguins, seq, expected):
assert penguins.select(seq).equals(penguins.select(*expected))
assert penguins.select(seq).equals(penguins[expected])


def test_names_callable(penguins):
Expand Down
17 changes: 10 additions & 7 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from ibis import _
from ibis.common.annotations import ValidationError
from ibis.common.deferred import Deferred
from ibis.common.exceptions import ExpressionError, IntegrityError, RelationError
from ibis.common.exceptions import (
ExpressionError,
IbisTypeError,
IntegrityError,
RelationError,
)
from ibis.expr import api
from ibis.expr.rewrites import simplify
from ibis.expr.tests.test_newrels import join_tables
Expand Down Expand Up @@ -230,7 +235,7 @@ def test_projection_with_star_expr(table):

# cannot pass an invalid table expression
t2 = t.aggregate([t["a"].sum().name("sum(a)")], by=["g"])
with pytest.raises(IntegrityError):
with pytest.raises(IbisTypeError):
t[[t2]]
# TODO: there may be some ways this can be invalid

Expand Down Expand Up @@ -581,10 +586,8 @@ def test_order_by_scalar(table, key, expected):
("key", "exc_type"),
[
("bogus", com.IbisTypeError),
# (("bogus", False), com.IbisTypeError),
(("bogus", False), com.IbisTypeError),
(ibis.desc("bogus"), com.IbisTypeError),
(1000, IndexError),
# ((1000, False), IndexError),
(_.bogus, AttributeError),
(_.bogus.desc(), AttributeError),
],
Expand Down Expand Up @@ -746,7 +749,7 @@ def test_aggregate_keywords(table):
def test_select_on_literals(table):
# literal ints and strings are column indices, everything else is a value
expr1 = table.select(col1=True, col2=1, col3="a")
expr2 = table.select(col1=ibis.literal(True), col2=table.b, col3=table.a)
expr2 = table.select(col1=ibis.literal(True), col2=ibis.literal(1), col3=table.a)
assert expr1.equals(expr2)


Expand Down Expand Up @@ -1280,7 +1283,7 @@ def test_inner_join_overlapping_column_names():
lambda t1, t2: [(t1.foo_id, t2.foo_id)],
lambda t1, t2: [(_.foo_id, _.foo_id)],
lambda t1, t2: [(t1.foo_id, _.foo_id)],
lambda t1, t2: [(2, 0)], # foo_id is 2nd in t1, 0th in t2
lambda t1, t2: [(t1[2], t2[0])], # foo_id is 2nd in t1, 0th in t2
lambda t1, t2: [(lambda t: t.foo_id, lambda t: t.foo_id)],
],
)
Expand Down
14 changes: 3 additions & 11 deletions ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,18 +882,10 @@ def test_bitwise_exprs(fn, expected_op):
([1, 0], ["bar", "foo"]),
],
)
@pytest.mark.parametrize(
"expr_func",
[
lambda t, args: t[args],
lambda t, args: t.order_by(args),
lambda t, args: t.group_by(args).aggregate(bar_avg=t.bar.mean()),
],
)
def test_table_operations_with_integer_column(position, names, expr_func):
def test_table_operations_with_integer_column(position, names):
t = ibis.table([("foo", "string"), ("bar", "double")])
result = expr_func(t, position)
expected = expr_func(t, names)
result = t[position]
expected = t[names]
assert result.equals(expected)


Expand Down

0 comments on commit 423a733

Please sign in to comment.