Skip to content

Commit

Permalink
feat: move from .case() to .cases()
Browse files Browse the repository at this point in the history
Fixes #7280
  • Loading branch information
NickCrews committed Sep 11, 2024
1 parent bac76ff commit de1ebc9
Show file tree
Hide file tree
Showing 25 changed files with 386 additions and 472 deletions.
2 changes: 1 addition & 1 deletion docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ quartodoc:
- name: ifelse
dynamic: true
signature_name: full
- name: case
- name: cases
dynamic: true
signature_name: full

Expand Down
14 changes: 6 additions & 8 deletions docs/posts/ci-analysis/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,12 @@ Let's also give them some names that'll look nice on our plots.
stats = stats.mutate(
raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int")
).mutate(
improvements=(
_.raw_improvements.case()
.when(0, "None")
.when(1, "Poetry")
.when(2, "Poetry + Team Plan")
.else_("NA")
.end()
),
improvements=_.raw_improvements.cases(
(0, "None"),
(1, "Poetry"),
(2, "Poetry + Team Plan"),
else_="NA",
)
team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"),
)
stats
Expand Down
24 changes: 11 additions & 13 deletions docs/tutorials/ibis-for-sql-users.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,11 @@ semantics:
case = (
t.one.cast("timestamp")
.year()
.case()
.when(2015, "This year")
.when(2014, "Last year")
.else_("Earlier")
.end()
.cases(
(2015, "This year"),
(2014, "Last year"),
else_="Earlier",
)
)
expr = t.mutate(year_group=case)
Expand All @@ -489,18 +489,16 @@ CASE
END
```

To do this, use `ibis.case`:
To do this, use `ibis.cases`:

```{python}
case = (
ibis.case()
.when(t.two < 0, t.three * 2)
.when(t.two > 1, t.three)
.else_(t.two)
.end()
cases = ibis.cases(
(t.two < 0, t.three * 2),
(t.two > 1, t.three),
else_=t.two,
)
expr = t.mutate(cond_value=case)
expr = t.mutate(cond_value=cases)
ibis.to_sql(expr)
```

Expand Down
14 changes: 5 additions & 9 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,18 @@ def test_ifelse(alltypes, df, op, pandas_op):

def test_simple_case(con, alltypes, assert_sql):
t = alltypes
expr = (
t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end()
)
expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default")

assert_sql(expr)
assert len(con.execute(expr))


def test_search_case(con, alltypes, assert_sql):
t = alltypes
expr = (
ibis.case()
.when(t.float_col > 0, t.int_col * 2)
.when(t.float_col < 0, t.int_col)
.else_(0)
.end()
expr = ibis.cases(
(t.float_col > 0, t.int_col * 2),
(t.float_col < 0, t.int_col),
else_=0,
)

assert_sql(expr)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/impala/tests/test_case_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def table(mockcon):

@pytest.fixture
def simple_case(table):
return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture
def search_case(table):
t = table
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture
Expand Down
8 changes: 1 addition & 7 deletions ibis/backends/pandas/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,13 +685,7 @@ def test_summary_non_numeric(batting, batting_df):

def test_non_range_index():
def do_replace(col):
return col.cases(
(
(1, "one"),
(2, "two"),
),
default="unk",
)
return col.cases((1, "one"), (2, "two"), else_="unk")

df = pd.DataFrame(
{
Expand Down
38 changes: 12 additions & 26 deletions ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis import udf

Expand Down Expand Up @@ -122,36 +121,23 @@ def predict_price(
df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"]
return model.predict(df)

def cases(value, mapping):
"""This should really be a top-level function or method."""
expr = ibis.case()
for k, v in mapping.items():
expr = expr.when(value == k, v)
return expr.end()

diamonds = con.tables.DIAMONDS
expr = diamonds.mutate(
predicted_price=predict_price(
(_.carat - _.carat.mean()) / _.carat.std(),
cases(
_.cut,
{
c: i
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
},
_.cut.cases(
(c, i)
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
),
cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}),
cases(
_.clarity,
{
c: i
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
},
_.color.cases((c, i) for i, c in enumerate("DEFGHIJ", start=1)),
_.clarity.cases(
(c, i)
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
),
)
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def difference(con):
@pytest.fixture(scope="module")
def simple_case(con):
t = con.table("alltypes")
return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture(scope="module")
def search_case(con):
t = con.table("alltypes")
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@
lit2 = ibis.literal("bar")

result = alltypes.select(
alltypes.g.case()
.when(lit, lit2)
.when(lit1, ibis.literal("qux"))
.else_(ibis.literal("default"))
.end()
.name("col1"),
ibis.case()
.when((alltypes.g == lit), lit2)
.when((alltypes.g == lit1), alltypes.g)
.else_(ibis.literal(None))
.end()
.name("col2"),
alltypes.g.cases(
(lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default")
).name("col1"),
ibis.cases(
(alltypes.g == lit, lit2),
(alltypes.g == lit1, alltypes.g),
else_=ibis.literal(None),
).name("col2"),
alltypes.a,
alltypes.b,
alltypes.c,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ def test_bool_bool(snapshot):

def test_case_in_projection(alltypes, snapshot):
t = alltypes
expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end()
expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default"))
expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g))
expr = t.select(expr.name("col1"), expr2.name("col2"), t)

snapshot.assert_match(to_sql(expr), "out.sql")
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def test_first_last(backend, alltypes, method, filtered, include_null):
# To sanely test this we create a column that is a mix of nulls and a
# single value (or a single value after filtering is applied).
if filtered:
new = alltypes.int_col.cases([(3, 30), (4, 40)])
new = alltypes.int_col.cases((3, 30), (4, 40))
where = _.int_col == 3
else:
new = (alltypes.int_col == 3).ifelse(30, None)
Expand Down Expand Up @@ -712,7 +712,7 @@ def test_arbitrary(backend, alltypes, df, filtered):
# _something_ we create a column that is a mix of nulls and a single value
# (or a single value after filtering is applied).
if filtered:
new = alltypes.int_col.cases([(3, 30), (4, 40)])
new = alltypes.int_col.cases((3, 30), (4, 40))
where = _.int_col == 3
else:
new = (alltypes.int_col == 3).ifelse(30, None)
Expand Down Expand Up @@ -1580,9 +1580,7 @@ def collect_udf(v):

def test_binds_are_cast(alltypes):
expr = alltypes.aggregate(
high_line_count=(
alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum()
)
high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum()
)

expr.execute()
Expand Down Expand Up @@ -1625,7 +1623,7 @@ def test_agg_name_in_output_column(alltypes):
def test_grouped_case(backend, con):
table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]})

case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end()
case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null())

expr = (
table.group_by(k="key")
Expand Down
Loading

0 comments on commit de1ebc9

Please sign in to comment.