Skip to content

Commit

Permalink
feat: move from .case() to .cases()
Browse files Browse the repository at this point in the history
Fixes #7280
  • Loading branch information
NickCrews committed May 8, 2024
1 parent 4d33830 commit ae70dc6
Show file tree
Hide file tree
Showing 28 changed files with 517 additions and 527 deletions.
14 changes: 6 additions & 8 deletions docs/posts/ci-analysis/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,12 @@ Let's also give them some names that'll look nice on our plots.
stats = stats.mutate(
raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int")
).mutate(
improvements=(
_.raw_improvements.case()
.when(0, "None")
.when(1, "Poetry")
.when(2, "Poetry + Team Plan")
.else_("NA")
.end()
),
improvements=_.raw_improvements.cases(
(0, "None"),
(1, "Poetry"),
(2, "Poetry + Team Plan"),
else_="NA",
)
team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"),
)
stats
Expand Down
24 changes: 11 additions & 13 deletions docs/tutorials/ibis-for-sql-users.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,11 @@ semantics:
case = (
t.one.cast("timestamp")
.year()
.case()
.when(2015, "This year")
.when(2014, "Last year")
.else_("Earlier")
.end()
.cases(
(2015, "This year"),
(2014, "Last year"),
else_="Earlier",
)
)
expr = t.mutate(year_group=case)
Expand All @@ -496,18 +496,16 @@ CASE
END
```

To do this, use `ibis.case`:
To do this, use `ibis.cases`:

```{python}
case = (
ibis.case()
.when(t.two < 0, t.three * 2)
.when(t.two > 1, t.three)
.else_(t.two)
.end()
cases = ibis.cases(
(t.two < 0, t.three * 2),
(t.two > 1, t.three),
else_=t.two,
)
expr = t.mutate(cond_value=case)
expr = t.mutate(cond_value=cases)
ibis.to_sql(expr)
```

Expand Down
14 changes: 5 additions & 9 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,18 @@ def test_ifelse(alltypes, df, op, pandas_op):

def test_simple_case(con, alltypes, assert_sql):
t = alltypes
expr = (
t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end()
)
expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default")

assert_sql(expr)
assert len(con.execute(expr))


def test_search_case(con, alltypes, assert_sql):
t = alltypes
expr = (
ibis.case()
.when(t.float_col > 0, t.int_col * 2)
.when(t.float_col < 0, t.int_col)
.else_(0)
.end()
expr = ibis.cases(
(t.float_col > 0, t.int_col * 2),
(t.float_col < 0, t.int_col),
else_=0,
)

assert_sql(expr)
Expand Down
58 changes: 0 additions & 58 deletions ibis/backends/dask/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,64 +773,6 @@ def q_fun(x, quantile):
tm.assert_series_equal(result, expected, check_index=False)


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_searched_case_column(batting, batting_pandas_df):
t = batting
df = batting_pandas_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI < 5, df.teamID == "PH1"],
["really bad team", "ph1 team"],
df.teamID,
)
)
tm.assert_series_equal(result, expected, check_names=False)


def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_simple_case_column(batting, batting_pandas_df):
t = batting
df = batting_pandas_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI == 5, df.RBI == 4, df.RBI == 3],
["five", "four", "three"],
"could be good?",
)
)
tm.assert_series_equal(result, expected, check_names=False)


def test_table_distinct(t, df):
expr = t[["dup_strings"]].distinct()
result = expr.compile()
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/impala/tests/test_case_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def table(mockcon):

@pytest.fixture
def simple_case(table):
return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture
def search_case(table):
t = table
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def visit(cls, op: ops.IsNan, arg):
def visit(
cls, op: ops.SearchedCase | ops.SimpleCase, cases, results, default, base=None
):
if not cases:
return default
if base is not None:
cases = tuple(base == case for case in cases)
cases, _ = cls.asframe(cases, concat=False)
Expand Down
66 changes: 1 addition & 65 deletions ibis/backends/pandas/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,73 +683,9 @@ def test_summary_non_numeric(batting, batting_df):
assert dict(result.iloc[0]) == expected


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_searched_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI < 5, df.teamID == "PH1"],
["really bad team", "ph1 team"],
df.teamID,
)
)
tm.assert_series_equal(result, expected)


def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_simple_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI == 5, df.RBI == 4, df.RBI == 3],
["five", "four", "three"],
"could be good?",
)
)
tm.assert_series_equal(result, expected)


def test_non_range_index():
def do_replace(col):
return col.cases(
(
(1, "one"),
(2, "two"),
),
default="unk",
)
return col.cases((1, "one"), (2, "two"), else_="unk")

df = pd.DataFrame(
{
Expand Down
38 changes: 12 additions & 26 deletions ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis import udf

Expand Down Expand Up @@ -115,36 +114,23 @@ def predict_price(
df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"]
return model.predict(df)

def cases(value, mapping):
"""This should really be a top-level function or method."""
expr = ibis.case()
for k, v in mapping.items():
expr = expr.when(value == k, v)
return expr.end()

diamonds = con.tables.DIAMONDS
expr = diamonds.mutate(
predicted_price=predict_price(
(_.carat - _.carat.mean()) / _.carat.std(),
cases(
_.cut,
{
c: i
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
},
_.cut.cases(
(c, i)
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
),
cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}),
cases(
_.clarity,
{
c: i
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
},
_.color.cases((c, i) for i, c in enumerate("DEFGHIJ", start=1)),
_.clarity.cases(
(c, i)
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
),
)
)
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,8 @@ def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw):
)

def visit_SimpleCase(self, op, *, base=None, cases, results, default):
if not cases:
return default
return sge.Case(
this=base, ifs=list(map(self.if_, cases, results)), default=default
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ def difference(con):
@pytest.fixture(scope="module")
def simple_case(con):
t = con.table("alltypes")
return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture(scope="module")
def search_case(con):
t = con.table("alltypes")
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@
lit2 = ibis.literal("bar")

result = alltypes.select(
alltypes.g.case()
.when(lit, lit2)
.when(lit1, ibis.literal("qux"))
.else_(ibis.literal("default"))
.end()
.name("col1"),
ibis.case()
.when(alltypes.g == lit, lit2)
.when(alltypes.g == lit1, alltypes.g)
.else_(ibis.literal(None).cast("string"))
.end()
.name("col2"),
alltypes.g.cases(
(lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default")
).name("col1"),
ibis.cases(
(alltypes.g == lit, lit2),
(alltypes.g == lit1, alltypes.g),
else_=ibis.literal(None).cast("string"),
).name("col2"),
alltypes.a,
alltypes.b,
alltypes.c,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ def test_bool_bool(snapshot):

def test_case_in_projection(alltypes, snapshot):
t = alltypes
expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end()
expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default"))
expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g))
expr = t[expr.name("col1"), expr2.name("col2"), t]

snapshot.assert_match(to_sql(expr), "out.sql")
Expand Down
Loading

0 comments on commit ae70dc6

Please sign in to comment.