Skip to content

Commit

Permalink
refactor(api): make input value coercion of mutate() identical to `…
Browse files Browse the repository at this point in the history
…select()` (#8878)

String literals passed to select() are interpreted as 
columns whereas mutate() interpreted them as literals.

BREAKING CHANGE: strings passed to table.mutate() are now interpreted as
column references instead of literals, use `ibis.literal(string)` to
pass the string as a literal
  • Loading branch information
kszucs authored Apr 3, 2024
1 parent d7f94e5 commit 38e7e14
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/ibis-for-sql-users.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ In many other situations, you can use constants without having to use
but the number 5 like so:

```{python}
expr = t3.mutate(number5=5)
expr = t3.mutate(number5=ibis.literal(5))
ibis.to_sql(expr)
```

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def test_mutate_scalar_with_window_after_join(npartitions):

joined = left.outer_join(right, left.ints == right.group)
proj = joined[left, right.value]
expr = proj.mutate(sum=proj.value.sum(), const=1)
expr = proj.mutate(sum=proj.value.sum(), const=ibis.literal(1))
result = expr.execute()
result = result.sort_values(["ints", "value"]).reset_index(drop=True)
expected = (
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def test_mutate_after_join():
q_count=joined["q_count"].fillna(0),
p_density=joined.p_density.fillna(1e-10),
q_density=joined.q_density.fillna(1e-10),
features="Order_Priority",
features=ibis.literal("Order_Priority"),
)

expected = pd.DataFrame(
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_mutate_scalar_with_window_after_join():

joined = left.outer_join(right, left.ints == right.group)
proj = joined[left, right.value]
expr = proj.mutate(sum=proj.value.sum(), const=1)
expr = proj.mutate(sum=proj.value.sum(), const=ibis.literal(1))
result = expr.execute()
expected = pd.DataFrame(
{
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def test_mutate_filter_join_no_cross_join(snapshot):
[("person_id", "int64"), ("birth_datetime", "timestamp")],
name="person",
)
mutated = person.mutate(age=400)
mutated = person.mutate(age=ibis.literal(400))
expr = mutated.filter(mutated.age <= 40)[mutated.person_id]

snapshot.assert_match(to_sql(expr), "out.sql")
Expand Down
16 changes: 9 additions & 7 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def test_create_table_timestamp(con, temp_table):
reason="Feature is not yet implemented: CREATE TEMPORARY TABLE",
)
def test_persist_expression_ref_count(backend, con, alltypes):
non_persisted_table = alltypes.mutate(test_column="calculation")
non_persisted_table = alltypes.mutate(test_column=ibis.literal("calculation"))
persisted_table = non_persisted_table.cache()

op = non_persisted_table.op()
Expand All @@ -1239,7 +1239,9 @@ def test_persist_expression_ref_count(backend, con, alltypes):
reason="Feature is not yet implemented: CREATE TEMPORARY TABLE",
)
def test_persist_expression(backend, alltypes):
non_persisted_table = alltypes.mutate(test_column="calculation", other_calc="xyz")
non_persisted_table = alltypes.mutate(
test_column=ibis.literal("calculation"), other_calc=ibis.literal("xyz")
)
persisted_table = non_persisted_table.cache()
backend.assert_frame_equal(
non_persisted_table.to_pandas(), persisted_table.to_pandas()
Expand All @@ -1259,7 +1261,7 @@ def test_persist_expression(backend, alltypes):
)
def test_persist_expression_contextmanager(backend, alltypes):
non_cached_table = alltypes.mutate(
test_column="calculation", other_column="big calc"
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc")
)
with non_cached_table.cache() as cached_table:
backend.assert_frame_equal(
Expand All @@ -1280,7 +1282,7 @@ def test_persist_expression_contextmanager(backend, alltypes):
)
def test_persist_expression_contextmanager_ref_count(backend, con, alltypes):
non_cached_table = alltypes.mutate(
test_column="calculation", other_column="big calc 2"
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
)
op = non_cached_table.op()
with non_cached_table.cache() as cached_table:
Expand All @@ -1304,7 +1306,7 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes):
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
def test_persist_expression_multiple_refs(backend, con, alltypes):
non_cached_table = alltypes.mutate(
test_column="calculation", other_column="big calc 2"
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
)
op = non_cached_table.op()
with non_cached_table.cache() as cached_table:
Expand Down Expand Up @@ -1345,7 +1347,7 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):
)
def test_persist_expression_repeated_cache(alltypes):
non_cached_table = alltypes.mutate(
test_column="calculation", other_column="big calc 2"
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
)
with non_cached_table.cache() as cached_table:
with cached_table.cache() as nested_cached_table:
Expand Down Expand Up @@ -1374,7 +1376,7 @@ def test_persist_expression_repeated_cache(alltypes):
)
def test_persist_expression_release(con, alltypes):
non_cached_table = alltypes.mutate(
test_column="calculation", other_column="big calc 3"
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 3")
)
cached_table = non_cached_table.cache()
cached_table.release()
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def test_uncorrelated_subquery(backend, batting, batting_df):


def test_int_column(alltypes):
expr = alltypes.mutate(x=1).x
expr = alltypes.mutate(x=ibis.literal(1)).x
result = expr.execute()
assert expr.type() == dt.int8
assert result.dtype == np.int8
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ def test_interval_add_cast_column(backend, alltypes, df):
),
param(
lambda t: (
t.mutate(suffix="%d")
t.mutate(suffix=ibis.literal("%d"))
.select(formatted=lambda t: t.timestamp_col.strftime("%Y%m" + t.suffix))
.formatted
),
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_mutate():
def test_mutate_overwrites_existing_column():
t = ibis.table(dict(a="string", b="string"))

mut = t.mutate(a=42)
mut = t.mutate(a=ibis.literal(42))
assert mut.op() == Project(parent=t, values={"a": ibis.literal(42), "b": t.b})

sel = mut.select("a")
Expand Down
15 changes: 8 additions & 7 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def f( # noqa: D417

# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting
# nested inputs
def bind(table: Table, value: Any, prefer_column=True) -> Iterator[ir.Value]:
def bind(table: Table, value: Any) -> Iterator[ir.Value]:
"""Bind a value to a table expression."""
if prefer_column and type(value) in (str, int):
if type(value) in (str, int):
yield table._get_column(value)
elif isinstance(value, ValueExpr):
yield value
Expand All @@ -110,11 +110,11 @@ def bind(table: Table, value: Any, prefer_column=True) -> Iterator[ir.Value]:
yield from value.expand(table)
elif isinstance(value, Mapping):
for k, v in value.items():
for val in bind(table, v, prefer_column=prefer_column):
for val in bind(table, v):
yield val.name(k)
elif util.is_iterable(value):
for v in value:
yield from bind(table, v, prefer_column=prefer_column)
yield from bind(table, v)
elif isinstance(value, ops.Value):
# TODO(kszucs): from certain builders, like ir.GroupedTable we pass
# operation nodes instead of expressions to table methods, it would
Expand Down Expand Up @@ -1946,7 +1946,7 @@ def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Tab
# string and integer inputs are going to be coerced to literals instead
# of interpreted as column references like in select
node = self.op()
values = bind(self, (exprs, mutations), prefer_column=False)
values = bind(self, (exprs, mutations))
values = unwrap_aliases(values)
# allow overriding of fields, hence the mutation behavior
values = {**node.fields, **values}
Expand Down Expand Up @@ -3359,7 +3359,8 @@ def cache(self) -> Table:
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.examples.penguins.fetch()
>>> cached_penguins = t.mutate(computation="Heavy Computation").cache()
>>> heavy_computation = ibis.literal("Heavy Computation")
>>> cached_penguins = t.mutate(computation=heavy_computation).cache()
>>> cached_penguins
┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓
┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃
Expand All @@ -3381,7 +3382,7 @@ def cache(self) -> Table:
Explicit cache cleanup
>>> with t.mutate(computation="Heavy Computation").cache() as cached_penguins:
>>> with t.mutate(computation=heavy_computation).cache() as cached_penguins:
... cached_penguins
┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓
┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃
Expand Down
8 changes: 4 additions & 4 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,15 @@ def test_mutate(table):
table.b.sum().name("x2"),
(_.a + 2).name("x3"),
lambda _: (_.a + 3).name("x4"),
4,
"five",
ibis.literal(4),
ibis.literal("five"),
],
kw1=(table.a + 6),
kw2=table.b.sum(),
kw3=(_.a + 7),
kw4=lambda _: (_.a + 8),
kw5=9,
kw6="ten",
kw5=ibis.literal(9),
kw6=ibis.literal("ten"),
)
expected = table[
table,
Expand Down

0 comments on commit 38e7e14

Please sign in to comment.