Skip to content

Commit

Permalink
feat(api): move from .case() to .cases() (#9096)
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews authored Oct 9, 2024
1 parent 22dcce1 commit 54889db
Show file tree
Hide file tree
Showing 26 changed files with 384 additions and 481 deletions.

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ quartodoc:
- name: ifelse
dynamic: true
signature_name: full
- name: case
- name: cases
dynamic: true
signature_name: full

Expand Down
12 changes: 5 additions & 7 deletions docs/posts/ci-analysis/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,11 @@ Let's also give them some names that'll look nice on our plots.
stats = stats.mutate(
raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int")
).mutate(
improvements=(
_.raw_improvements.case()
.when(0, "None")
.when(1, "Poetry")
.when(2, "Poetry + Team Plan")
.else_("NA")
.end()
improvements=_.raw_improvements.cases(
(0, "None"),
(1, "Poetry"),
(2, "Poetry + Team Plan"),
else_="NA",
),
team_plan=ibis.ifelse(_.raw_improvements > 1, "Poetry + Team Plan", "None"),
)
Expand Down
24 changes: 11 additions & 13 deletions docs/tutorials/ibis-for-sql-users.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,11 @@ semantics:
case = (
t.one.cast("timestamp")
.year()
.case()
.when(2015, "This year")
.when(2014, "Last year")
.else_("Earlier")
.end()
.cases(
(2015, "This year"),
(2014, "Last year"),
else_="Earlier",
)
)
expr = t.mutate(year_group=case)
Expand All @@ -489,18 +489,16 @@ CASE
END
```

To do this, use `ibis.case`:
To do this, use `ibis.cases`:

```{python}
case = (
ibis.case()
.when(t.two < 0, t.three * 2)
.when(t.two > 1, t.three)
.else_(t.two)
.end()
cases = ibis.cases(
(t.two < 0, t.three * 2),
(t.two > 1, t.three),
else_=t.two,
)
expr = t.mutate(cond_value=case)
expr = t.mutate(cond_value=cases)
ibis.to_sql(expr)
```

Expand Down
14 changes: 5 additions & 9 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,18 @@ def test_ifelse(alltypes, df, op, pandas_op):

def test_simple_case(con, alltypes, assert_sql):
t = alltypes
expr = (
t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end()
)
expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default")

assert_sql(expr)
assert len(con.execute(expr))


def test_search_case(con, alltypes, assert_sql):
t = alltypes
expr = (
ibis.case()
.when(t.float_col > 0, t.int_col * 2)
.when(t.float_col < 0, t.int_col)
.else_(0)
.end()
expr = ibis.cases(
(t.float_col > 0, t.int_col * 2),
(t.float_col < 0, t.int_col),
else_=0,
)

assert_sql(expr)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/impala/tests/test_case_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def table(mockcon):

@pytest.fixture
def simple_case(table):
return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture
def search_case(table):
t = table
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture
Expand Down
38 changes: 12 additions & 26 deletions ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis import udf

Expand Down Expand Up @@ -122,36 +121,23 @@ def predict_price(
df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"]
return model.predict(df)

def cases(value, mapping):
"""This should really be a top-level function or method."""
expr = ibis.case()
for k, v in mapping.items():
expr = expr.when(value == k, v)
return expr.end()

diamonds = con.tables.DIAMONDS
expr = diamonds.mutate(
predicted_price=predict_price(
(_.carat - _.carat.mean()) / _.carat.std(),
cases(
_.cut,
{
c: i
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
},
_.cut.cases(
(c, i)
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
),
cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}),
cases(
_.clarity,
{
c: i
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
},
_.color.cases((c, i) for i, c in enumerate("DEFGHIJ", start=1)),
_.clarity.cases(
(c, i)
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
),
)
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def difference(con):
@pytest.fixture(scope="module")
def simple_case(con):
t = con.table("alltypes")
return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture(scope="module")
def search_case(con):
t = con.table("alltypes")
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@
lit2 = ibis.literal("bar")

result = alltypes.select(
alltypes.g.case()
.when(lit, lit2)
.when(lit1, ibis.literal("qux"))
.else_(ibis.literal("default"))
.end()
.name("col1"),
ibis.case()
.when((alltypes.g == lit), lit2)
.when((alltypes.g == lit1), alltypes.g)
.else_(ibis.literal(None))
.end()
.name("col2"),
alltypes.g.cases(
(lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default")
).name("col1"),
ibis.cases(
((alltypes.g == lit), lit2),
((alltypes.g == lit1), alltypes.g),
else_=ibis.literal(None),
).name("col2"),
alltypes.a,
alltypes.b,
alltypes.c,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ def test_bool_bool(snapshot):

def test_case_in_projection(alltypes, snapshot):
t = alltypes
expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end()
expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default"))
expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g))
expr = t.select(expr.name("col1"), expr2.name("col2"), t)

snapshot.assert_match(to_sql(expr), "out.sql")
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def test_first_last(alltypes, method, filtered, include_null):
# To sanely test this we create a column that is a mix of nulls and a
# single value (or a single value after filtering is applied).
if filtered:
new = alltypes.int_col.cases([(3, 30), (4, 40)])
new = alltypes.int_col.cases((3, 30), (4, 40))
where = _.int_col == 3
else:
new = (alltypes.int_col == 3).ifelse(30, None)
Expand Down Expand Up @@ -738,7 +738,7 @@ def test_arbitrary(alltypes, filtered):
# _something_ we create a column that is a mix of nulls and a single value
# (or a single value after filtering is applied).
if filtered:
new = alltypes.int_col.cases([(3, 30), (4, 40)])
new = alltypes.int_col.cases((3, 30), (4, 40))
where = _.int_col == 3
else:
new = (alltypes.int_col == 3).ifelse(30, None)
Expand Down Expand Up @@ -1571,9 +1571,7 @@ def collect_udf(v):

def test_binds_are_cast(alltypes):
expr = alltypes.aggregate(
high_line_count=(
alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum()
)
high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum()
)

expr.execute()
Expand Down Expand Up @@ -1616,7 +1614,7 @@ def test_agg_name_in_output_column(alltypes):
def test_grouped_case(backend, con):
table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]})

case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end()
case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null())

expr = (
table.group_by(k="key")
Expand Down
75 changes: 51 additions & 24 deletions ibis/backends/tests/test_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import Counter

import pytest
from pytest import param

import ibis

Expand Down Expand Up @@ -62,18 +63,13 @@ def test_substitute(backend):
@pytest.mark.parametrize(
"inp, exp",
[
pytest.param(
lambda: ibis.literal(1)
.case()
.when(1, "one")
.when(2, "two")
.else_("other")
.end(),
param(
lambda: ibis.literal(1).cases((1, "one"), (2, "two"), else_="other"),
"one",
id="one_kwarg",
),
pytest.param(
lambda: ibis.literal(5).case().when(1, "one").when(2, "two").end(),
param(
lambda: ibis.literal(5).cases((1, "one"), (2, "two")),
None,
id="fallthrough",
),
Expand All @@ -94,13 +90,8 @@ def test_value_cases_column(batting):
np = pytest.importorskip("numpy")

df = batting.to_pandas()
expr = (
batting.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
expr = batting.RBI.cases(
(5, "five"), (4, "four"), (3, "three"), else_="could be good?"
)
result = expr.execute()
expected = np.select(
Expand All @@ -113,7 +104,7 @@ def test_value_cases_column(batting):


def test_ibis_cases_scalar():
expr = ibis.literal(5).case().when(5, "five").when(4, "four").end()
expr = ibis.literal(5).cases((5, "five"), (4, "four"))
result = expr.execute()
assert result == "five"

Expand All @@ -128,12 +119,8 @@ def test_ibis_cases_column(batting):

t = batting
df = batting.to_pandas()
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
expr = ibis.cases(
(t.RBI < 5, "really bad team"), (t.teamID == "PH1", "ph1 team"), else_=t.teamID
)
result = expr.execute()
expected = np.select(
Expand All @@ -148,5 +135,45 @@ def test_ibis_cases_column(batting):
@pytest.mark.notimpl("clickhouse", reason="special case this and returns 'oops'")
def test_value_cases_null(con):
"""CASE x WHEN NULL never gets hit"""
e = ibis.literal(5).nullif(5).case().when(None, "oops").else_("expected").end()
e = ibis.literal(5).nullif(5).cases((None, "oops"), else_="expected")
assert con.execute(e) == "expected"


@pytest.mark.parametrize(
("example", "expected"),
[
param(lambda: ibis.case().when(True, "yes").end(), "yes", id="top-level-true"),
param(lambda: ibis.case().when(False, "yes").end(), None, id="top-level-false"),
param(
lambda: ibis.case().when(False, "yes").else_("no").end(),
"no",
id="top-level-false-value",
),
param(
lambda: ibis.literal("a").case().when("a", "yes").end(),
"yes",
id="method-true",
),
param(
lambda: ibis.literal("a").case().when("b", "yes").end(),
None,
id="method-false",
),
param(
lambda: ibis.literal("a").case().when("b", "yes").else_("no").end(),
"no",
id="method-false-value",
),
],
)
def test_ibis_case_still_works(con, example, expected):
# test that the soft-deprecated .case() method still works
# https://github.com/ibis-project/ibis/pull/9096
pd = pytest.importorskip("pandas")

with pytest.warns(FutureWarning):
expr = example()

result = con.execute(expr)

assert (expected is None and pd.isna(result)) or result == expected
Loading

0 comments on commit 54889db

Please sign in to comment.