Skip to content

Commit

Permalink
feat: remove .case(), move to .cases()
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Oct 9, 2023
1 parent 07ea22c commit eb36129
Show file tree
Hide file tree
Showing 34 changed files with 266 additions and 576 deletions.
12 changes: 5 additions & 7 deletions docs/posts/ci-analysis/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,11 @@ Let's also give them some names that'll look nice on our plots.
stats = stats.mutate(
raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int")
).mutate(
improvements=(
_.raw_improvements.case()
.when(0, "None")
.when(1, "Poetry")
.when(2, "Poetry + Team Plan")
.else_("NA")
.end()
improvements=_.raw_improvements.cases(
(0, "None"),
(1, "Poetry"),
(2, "Poetry + Team Plan"),
else_="NA",
),
team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"),
)
Expand Down
22 changes: 10 additions & 12 deletions docs/tutorials/ibis-for-sql-users.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,11 @@ semantics:
case = (
t.one.cast("timestamp")
.year()
.case()
.when(2015, "This year")
.when(2014, "Last year")
.else_("Earlier")
.end()
.cases(
(2015, "This year"),
(2014, "Last year"),
else_="Earlier",
)
)
expr = t.mutate(year_group=case)
Expand All @@ -496,15 +496,13 @@ CASE
END
```

To do this, use `ibis.case`:
To do this, use `ibis.cases`:

```{python}
case = (
ibis.case()
.when(t.two < 0, t.three * 2)
.when(t.two > 1, t.three)
.else_(t.two)
.end()
case = ibis.cases(
(t.two < 0, t.three * 2),
(t.two > 1, t.three),
else_=t.two,
)
expr = t.mutate(cond_value=case)
Expand Down
15 changes: 3 additions & 12 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,10 @@ def _floor_divide(t, op):
return sa.func.floor(left / right)


def _simple_case(t, op):
return _translate_case(t, op, value=t.translate(op.base))


def _searched_case(t, op):
return _translate_case(t, op, value=None)


def _translate_case(t, op, *, value):
def _translate_case(t, op):
return sa.case(
*zip(map(t.translate, op.cases), map(t.translate, op.results)),
value=value,
value=None,
else_=t.translate(op.default),
)

Expand Down Expand Up @@ -558,8 +550,7 @@ class array_filter(FunctionElement):
ops.Negate: _negate,
ops.Round: _round,
ops.Literal: _literal,
ops.SimpleCase: _simple_case,
ops.SearchedCase: _searched_case,
ops.SearchedCase: _translate_case,
ops.TableColumn: _table_column,
ops.TableArrayView: _table_array_view,
ops.ExistsSubquery: _exists_subquery,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def count_star(translator, op):
ops.Between: between,
ops.InValues: binary_infix.in_values,
ops.InColumn: binary_infix.in_column,
ops.SimpleCase: case.simple_case,
# ops.SimpleCase: case.simple_case,
ops.SearchedCase: case.searched_case,
ops.TableColumn: table_column,
ops.TableArrayView: table_array_view,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def _literal(op, *, value, dtype, **kw):
raise NotImplementedError(f"Unsupported type: {dtype!r}")


@translate_val.register(ops.SimpleCase)
# @translate_val.register(ops.SimpleCase)
@translate_val.register(ops.SearchedCase)
def _case(op, *, base=None, cases, results, default, **_):
return sg.exp.Case(this=base, ifs=list(map(if_, cases, results)), default=default)
Expand Down
14 changes: 5 additions & 9 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,18 @@ def test_ifelse(alltypes, df, op, pandas_op):

def test_simple_case(con, alltypes, snapshot):
t = alltypes
expr = (
t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end()
)
expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default")

snapshot.assert_match(expr.compile(), "out.sql")
assert len(con.execute(expr))


def test_search_case(con, alltypes, snapshot):
t = alltypes
expr = (
ibis.case()
.when(t.float_col > 0, t.int_col * 2)
.when(t.float_col < 0, t.int_col)
.else_(0)
.end()
expr = ibis.cases(
(t.float_col > 0, t.int_col * 2),
(t.float_col < 0, t.int_col),
else_=0,
)

snapshot.assert_match(expr.compile(), "out.sql")
Expand Down
16 changes: 8 additions & 8 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,14 +509,14 @@ def execute_searched_case_dask(op, when_nodes, then_nodes, otherwise, **kwargs):
return out


@execute_node.register(ops.SimpleCase, dd.Series, tuple, tuple, object)
def execute_simple_case_series(op, value, whens, thens, otherwise, **kwargs):
whens = [execute(arg, **kwargs) for arg in whens]
thens = [execute(arg, **kwargs) for arg in thens]
if otherwise is None:
otherwise = np.nan
raw = np.select([value == when for when in whens], thens, otherwise)
return wrap_case_result(raw, op.to_expr())
# @execute_node.register(ops.SimpleCase, dd.Series, tuple, tuple, object)
# def execute_simple_case_series(op, value, whens, thens, otherwise, **kwargs):
# whens = [execute(arg, **kwargs) for arg in whens]
# thens = [execute(arg, **kwargs) for arg in thens]
# if otherwise is None:
# otherwise = np.nan
# raw = np.select([value == when for when in whens], thens, otherwise)
# return wrap_case_result(raw, op.to_expr())


@execute_node.register(ops.Greatest, tuple)
Expand Down
26 changes: 11 additions & 15 deletions ibis/backends/dask/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def q_fun(x, quantile):


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
expr = ibis.cases((True, 1), (False, 2))
result = client.execute(expr)
expected = np.int8(1)
assert result == expected
Expand All @@ -854,12 +854,10 @@ def test_searched_case_scalar(client):
def test_searched_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
expr = ibis.cases(
(t.RBI < 5, "really bad team"),
(t.teamID == "PH1", "ph1 team"),
else_=t.teamID,
)
result = expr.compile()
expected = dd.from_array(
Expand All @@ -874,7 +872,7 @@ def test_searched_case_column(batting, batting_df):

def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
expr = x.cases((2, x - 1), (3, x + 1), (4, x + 2))
result = client.execute(expr)
expected = np.int8(1)
assert result == expected
Expand All @@ -883,13 +881,11 @@ def test_simple_case_scalar(client):
def test_simple_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
expr = t.RBI.cases(
(5, "five"),
(4, "four"),
(3, "three"),
else_=("could be good?"),
)
result = expr.compile()
expected = dd.from_array(
Expand Down
12 changes: 10 additions & 2 deletions ibis/backends/impala/tests/test_case_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@ def table(mockcon):

@pytest.fixture
def simple_case(table):
return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return table.g.cases(
("foo", "bar"),
("baz", "qux"),
else_="default",
)


@pytest.fixture
def search_case(table):
t = table
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases(
(t.f > 0, t.d * 2),
(t.c < 0, t.a * 2),
)
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture
Expand Down
48 changes: 24 additions & 24 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,30 +1443,30 @@ def execute_searched_case(op, whens, thens, otherwise, **kwargs):
return _build_select(op, whens, thens, otherwise, **kwargs)


@execute_node.register(ops.SimpleCase, object, tuple, tuple, object)
def execute_simple_case_scalar(op, value, whens, thens, otherwise, **kwargs):
value = getattr(value, "obj", value)
return _build_select(
op,
whens,
thens,
otherwise,
func=lambda whens: np.asarray(whens) == value,
**kwargs,
)


@execute_node.register(ops.SimpleCase, (pd.Series, SeriesGroupBy), tuple, tuple, object)
def execute_simple_case_series(op, value, whens, thens, otherwise, **kwargs):
value = getattr(value, "obj", value)
return _build_select(
op,
whens,
thens,
otherwise,
func=lambda whens: [value == when for when in whens],
**kwargs,
)
# @execute_node.register(ops.SimpleCase, object, tuple, tuple, object)
# def execute_simple_case_scalar(op, value, whens, thens, otherwise, **kwargs):
# value = getattr(value, "obj", value)
# return _build_select(
# op,
# whens,
# thens,
# otherwise,
# func=lambda whens: np.asarray(whens) == value,
# **kwargs,
# )


# @execute_node.register(ops.SimpleCase, (pd.Series, SeriesGroupBy), tuple, tuple, object)
# def execute_simple_case_series(op, value, whens, thens, otherwise, **kwargs):
# value = getattr(value, "obj", value)
# return _build_select(
# op,
# whens,
# thens,
# otherwise,
# func=lambda whens: [value == when for when in whens],
# **kwargs,
# )


@execute_node.register(ops.Distinct, pd.DataFrame)
Expand Down
21 changes: 7 additions & 14 deletions ibis/backends/pandas/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def test_summary_non_numeric(batting, batting_df):


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
expr = ibis.cases((True, 1), (False, 2))
result = client.execute(expr)
expected = np.int8(1)
assert result == expected
Expand All @@ -696,12 +696,10 @@ def test_searched_case_scalar(client):
def test_searched_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
expr = ibis.cases(
(t.RBI < 5, "really bad team"),
(t.teamID == "PH1", "ph1 team"),
t.teamID,
)
result = expr.execute()
expected = pd.Series(
Expand All @@ -716,7 +714,7 @@ def test_searched_case_column(batting, batting_df):

def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
expr = x.cases((2, x - 1), (3, x + 1), (4, x + 2))
result = client.execute(expr)
expected = np.int8(1)
assert result == expected
Expand All @@ -726,12 +724,7 @@ def test_simple_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
t.RBI.cases((5, "five"), (4, "four"), (3, "three"), else_="could be good?"),
)
result = expr.execute()
expected = pd.Series(
Expand Down
18 changes: 9 additions & 9 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,15 @@ def ifelse(op, **kw):
return pl.when(bool_expr).then(true_expr).otherwise(false_null_expr)


@translate.register(ops.SimpleCase)
def simple_case(op, **kw):
base = translate(op.base, **kw)
default = translate(op.default, **kw)
for case, result in reversed(list(zip(op.cases, op.results))):
case = base == translate(case, **kw)
result = translate(result, **kw)
default = pl.when(case).then(result).otherwise(default)
return default
# @translate.register(ops.SimpleCase)
# def simple_case(op, **kw):
# base = translate(op.base, **kw)
# default = translate(op.default, **kw)
# for case, result in reversed(list(zip(op.cases, op.results))):
# case = base == translate(case, **kw)
# result = translate(result, **kw)
# default = pl.when(case).then(result).otherwise(default)
# return default


@translate.register(ops.SearchedCase)
Expand Down
12 changes: 5 additions & 7 deletions ibis/backends/sqlite/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,11 @@ def translator(t, op):

def _extract_quarter(t, op):
expr_new = ops.ExtractMonth(op.arg).to_expr()
expr_new = (
ibis.case()
.when(expr_new.isin([1, 2, 3]), 1)
.when(expr_new.isin([4, 5, 6]), 2)
.when(expr_new.isin([7, 8, 9]), 3)
.else_(4)
.end()
expr_new = ibis.cases(
(expr_new.isin([1, 2, 3]), 1),
(expr_new.isin([4, 5, 6]), 2),
(expr_new.isin([7, 8, 9]), 3),
else_=4,
)
return sa.cast(t.translate(expr_new.op()), sa.Integer)

Expand Down
Loading

0 comments on commit eb36129

Please sign in to comment.