From 05ac8c4a8944032653b7a5cce536c928b3f75533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 3 Apr 2024 22:29:19 +0200 Subject: [PATCH] refactor(api): treat integer inputs as literals instead of column references --- ibis/expr/types/relations.py | 68 +++++++++++++---------------- ibis/tests/expr/test_selectors.py | 7 ++- ibis/tests/expr/test_table.py | 17 +++++--- ibis/tests/expr/test_value_exprs.py | 14 ++---- 4 files changed, 46 insertions(+), 60 deletions(-) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index a2a96f53bb6e3..4f7b293994fcb 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -23,8 +23,9 @@ from ibis import util from ibis.common.deferred import Deferred, Resolver from ibis.expr.types.core import Expr, _FixedTextJupyterMixin -from ibis.expr.types.generic import ValueExpr, literal +from ibis.expr.types.generic import literal from ibis.expr.types.pretty import to_rich +from ibis.expr.types.generic import Value, literal from ibis.selectors import Selector from ibis.util import deprecated @@ -97,15 +98,23 @@ def f( # noqa: D417 # TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting # nested inputs -def bind(table: Table, value: Any) -> Iterator[ir.Value]: +def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]: """Bind a value to a table expression.""" - if type(value) in (str, int): - yield table._get_column(value) - elif isinstance(value, ValueExpr): + if isinstance(value, str): + # TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg + yield ops.Field(table, value).to_expr() + elif isinstance(value, bool): + yield literal(value) + elif int_as_column and isinstance(value, int): + name = table.columns[value] + yield ops.Field(table, name).to_expr() + elif isinstance(value, ops.Value): + yield value.to_expr() + elif isinstance(value, Value): yield value elif isinstance(value, Table): for name in value.columns: - yield value._get_column(name) + yield ops.Field(table, name).to_expr() elif isinstance(value, Deferred): yield value.resolve(table) elif isinstance(value, Resolver): @@ -114,17 +123,11 @@ def bind(table: Table, value: Any) -> Iterator[ir.Value]: yield from value.expand(table) elif isinstance(value, Mapping): for k, v in value.items(): - for val in bind(table, v): + for val in bind(table, v, int_as_column=int_as_column): yield val.name(k) elif util.is_iterable(value): for v in value: - yield from bind(table, v) - elif isinstance(value, ops.Value): - # TODO(kszucs): from certain builders, like ir.GroupedTable we pass - # operation nodes instead of expressions to table methods, it would - # be better to convert them to expressions before passing them to - # this function - yield value.to_expr() + yield from bind(table, v, int_as_column=int_as_column) elif callable(value): yield value(table) else: @@ -567,13 +570,6 @@ def preview( console_width=console_width, ) - # TODO(kszucs): expose this method in the public API - def _get_column(self, name: str | int) -> ir.Column: - """Get a column from the table.""" - if isinstance(name, int): - name = self.schema().name_at_position(name) - return ops.Field(self, name).to_expr() - def __getitem__(self, what): """Select items from a table expression. @@ -820,22 +816,18 @@ def __getitem__(self, what): """ from ibis.expr.types.logical import BooleanValue - if isinstance(what, (str, int)): - return self._get_column(what) - elif isinstance(what, slice): + if isinstance(what, slice): limit, offset = util.slice_to_limit_offset(what, self.count()) return self.limit(limit, offset=offset) - elif isinstance(what, (list, tuple, Table)): - # Projection case - return self.select(what) - - items = tuple(bind(self, what)) - if util.all_of(items, BooleanValue): - # TODO(kszucs): this branch should be removed, .filter should be - # used instead - return self.filter(items) + + values = tuple(bind(self, what, int_as_column=True)) + if isinstance(what, (str, int)): + assert len(values) == 1 + return values[0] + elif util.all_of(values, BooleanValue): + return self.filter(values) else: - return self.select(items) + return self.select(values) def __len__(self): raise com.ExpressionError("Use .count() instead") @@ -878,7 +870,7 @@ def __getattr__(self, key: str) -> ir.Column: └───────────┘ """ try: - return self._get_column(key) + return ops.Field(self, key).to_expr() except com.IbisTypeError: pass @@ -2073,7 +2065,7 @@ def select( Projection by zero-indexed column position - >>> t.select(0, 4).head() + >>> t.select(t[0], t[4]).head() ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ species ┃ flipper_length_mm ┃ ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ @@ -4409,10 +4401,10 @@ def relocate( where = 0 # all columns that should come BEFORE the matched selectors - front = [left for left in range(where) if left not in sels] + front = [self[left] for left in range(where) if left not in sels] # all columns that should come AFTER the matched selectors - back = [right for right in range(where, ncols) if right not in sels] + back = [self[right] for right in range(where, ncols) if right not in sels] # selected columns middle = [self[i].name(name) for i, name in sels.items()] diff --git a/ibis/tests/expr/test_selectors.py b/ibis/tests/expr/test_selectors.py index a39b773fcebf3..47d4051fd2a48 100644 --- a/ibis/tests/expr/test_selectors.py +++ b/ibis/tests/expr/test_selectors.py @@ -342,7 +342,7 @@ def test_if_any(penguins): def test_negate_range(penguins): - assert penguins.select(~s.r[3:]).equals(penguins.select(0, 1, 2)) + assert penguins.select(~s.r[3:]).equals(penguins[[0, 1, 2]]) def test_string_range_start(penguins): @@ -378,16 +378,15 @@ def test_all(penguins): @pytest.mark.parametrize( ("seq", "expected"), [ - param([0, 1, 2], (0, 1, 2), id="int_tuple"), param(~s.r[[3, 4, 5]], sorted(set(range(8)) - {3, 4, 5}), id="neg_int_list"), param(~s.r[3, 4, 5], sorted(set(range(8)) - {3, 4, 5}), id="neg_int_tuple"), param(s.r["island", "year"], ("island", "year"), id="string_tuple"), param(s.r[["island", "year"]], ("island", "year"), id="string_list"), - param(iter(["island", 4, "year"]), ("island", 4, "year"), id="mixed_iterable"), + param(iter(["island", "year"]), ("island", "year"), id="mixed_iterable"), ], ) def test_sequence(penguins, seq, expected): - assert penguins.select(seq).equals(penguins.select(*expected)) + assert penguins.select(seq).equals(penguins[expected]) def test_names_callable(penguins): diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 627c227cab66f..56412e9bf16fc 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -19,7 +19,12 @@ from ibis import _ from ibis.common.annotations import ValidationError from ibis.common.deferred import Deferred -from ibis.common.exceptions import ExpressionError, IntegrityError, RelationError +from ibis.common.exceptions import ( + ExpressionError, + IbisTypeError, + IntegrityError, + RelationError, +) from ibis.expr import api from ibis.expr.rewrites import simplify from ibis.expr.tests.test_newrels import join_tables @@ -230,7 +235,7 @@ def test_projection_with_star_expr(table): # cannot pass an invalid table expression t2 = t.aggregate([t["a"].sum().name("sum(a)")], by=["g"]) - with pytest.raises(IntegrityError): + with pytest.raises(IbisTypeError): t[[t2]] # TODO: there may be some ways this can be invalid @@ -581,10 +586,8 @@ def test_order_by_scalar(table, key, expected): ("key", "exc_type"), [ ("bogus", com.IbisTypeError), - # (("bogus", False), com.IbisTypeError), + (("bogus", False), com.IbisTypeError), (ibis.desc("bogus"), com.IbisTypeError), - (1000, IndexError), - # ((1000, False), IndexError), (_.bogus, AttributeError), (_.bogus.desc(), AttributeError), ], @@ -746,7 +749,7 @@ def test_aggregate_keywords(table): def test_select_on_literals(table): # literal ints and strings are column indices, everything else is a value expr1 = table.select(col1=True, col2=1, col3="a") - expr2 = table.select(col1=ibis.literal(True), col2=table.b, col3=table.a) + expr2 = table.select(col1=ibis.literal(True), col2=ibis.literal(1), col3=table.a) assert expr1.equals(expr2) @@ -1280,7 +1283,7 @@ def test_inner_join_overlapping_column_names(): lambda t1, t2: [(t1.foo_id, t2.foo_id)], lambda t1, t2: [(_.foo_id, _.foo_id)], lambda t1, t2: [(t1.foo_id, _.foo_id)], - lambda t1, t2: [(2, 0)], # foo_id is 2nd in t1, 0th in t2 + lambda t1, t2: [(t1[2], t2[0])], # foo_id is 2nd in t1, 0th in t2 lambda t1, t2: [(lambda t: t.foo_id, lambda t: t.foo_id)], ], ) diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index b962daf624b64..40e3f7bfb4b02 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -882,18 +882,10 @@ def test_bitwise_exprs(fn, expected_op): ([1, 0], ["bar", "foo"]), ], ) -@pytest.mark.parametrize( - "expr_func", - [ - lambda t, args: t[args], - lambda t, args: t.order_by(args), - lambda t, args: t.group_by(args).aggregate(bar_avg=t.bar.mean()), - ], -) -def test_table_operations_with_integer_column(position, names, expr_func): +def test_table_operations_with_integer_column(position, names): t = ibis.table([("foo", "string"), ("bar", "double")]) - result = expr_func(t, position) - expected = expr_func(t, names) + result = t[position] + expected = t[names] assert result.equals(expected)