diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 1817659c9f7ef..898ccd67bdd73 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -332,7 +332,7 @@ quartodoc: - name: ifelse dynamic: true signature_name: full - - name: case + - name: cases dynamic: true signature_name: full diff --git a/docs/posts/ci-analysis/index.qmd b/docs/posts/ci-analysis/index.qmd index 5babc2c6d0c61..65159d5bbe9c7 100644 --- a/docs/posts/ci-analysis/index.qmd +++ b/docs/posts/ci-analysis/index.qmd @@ -203,14 +203,12 @@ Let's also give them some names that'll look nice on our plots. stats = stats.mutate( raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int") ).mutate( - improvements=( - _.raw_improvements.case() - .when(0, "None") - .when(1, "Poetry") - .when(2, "Poetry + Team Plan") - .else_("NA") - .end() - ), + improvements=_.raw_improvements.cases( + (0, "None"), + (1, "Poetry"), + (2, "Poetry + Team Plan"), + else_="NA", + ) team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"), ) stats diff --git a/docs/tutorials/ibis-for-sql-users.qmd b/docs/tutorials/ibis-for-sql-users.qmd index cbb9b4974d706..c18dba6fd4991 100644 --- a/docs/tutorials/ibis-for-sql-users.qmd +++ b/docs/tutorials/ibis-for-sql-users.qmd @@ -466,11 +466,11 @@ semantics: case = ( t.one.cast("timestamp") .year() - .case() - .when(2015, "This year") - .when(2014, "Last year") - .else_("Earlier") - .end() + .cases( + (2015, "This year"), + (2014, "Last year"), + else_="Earlier", + ) ) expr = t.mutate(year_group=case) @@ -489,18 +489,16 @@ CASE END ``` -To do this, use `ibis.case`: +To do this, use `ibis.cases`: ```{python} -case = ( - ibis.case() - .when(t.two < 0, t.three * 2) - .when(t.two > 1, t.three) - .else_(t.two) - .end() +cases = ibis.cases( + (t.two < 0, t.three * 2), + (t.two > 1, t.three), + else_=t.two, ) -expr = t.mutate(cond_value=case) +expr = t.mutate(cond_value=cases) ibis.to_sql(expr) ``` diff --git a/ibis/backends/clickhouse/tests/test_operators.py b/ibis/backends/clickhouse/tests/test_operators.py index fbfbd4efffb08..72b22e8cb2fb3 100644 --- a/ibis/backends/clickhouse/tests/test_operators.py +++ b/ibis/backends/clickhouse/tests/test_operators.py @@ -201,9 +201,7 @@ def test_ifelse(alltypes, df, op, pandas_op): def test_simple_case(con, alltypes, assert_sql): t = alltypes - expr = ( - t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end() - ) + expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default") assert_sql(expr) assert len(con.execute(expr)) @@ -211,12 +209,10 @@ def test_simple_case(con, alltypes, assert_sql): def test_search_case(con, alltypes, assert_sql): t = alltypes - expr = ( - ibis.case() - .when(t.float_col > 0, t.int_col * 2) - .when(t.float_col < 0, t.int_col) - .else_(0) - .end() + expr = ibis.cases( + (t.float_col > 0, t.int_col * 2), + (t.float_col < 0, t.int_col), + else_=0, ) assert_sql(expr) diff --git a/ibis/backends/impala/tests/test_case_exprs.py b/ibis/backends/impala/tests/test_case_exprs.py index a195928b12214..360fbf9522c8b 100644 --- a/ibis/backends/impala/tests/test_case_exprs.py +++ b/ibis/backends/impala/tests/test_case_exprs.py @@ -14,13 +14,13 @@ def table(mockcon): @pytest.fixture def simple_case(table): - return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture def search_case(table): t = table - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture diff --git a/ibis/backends/snowflake/tests/test_udf.py b/ibis/backends/snowflake/tests/test_udf.py index f73f79c173803..57e551b150b12 100644 --- a/ibis/backends/snowflake/tests/test_udf.py +++ b/ibis/backends/snowflake/tests/test_udf.py @@ -8,7 +8,6 @@ import pytest from pytest import param -import ibis import ibis.expr.datatypes as dt from ibis import udf @@ -122,36 +121,23 @@ def predict_price( df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"] return model.predict(df) - def cases(value, mapping): - """This should really be a top-level function or method.""" - expr = ibis.case() - for k, v in mapping.items(): - expr = expr.when(value == k, v) - return expr.end() - diamonds = con.tables.DIAMONDS expr = diamonds.mutate( predicted_price=predict_price( (_.carat - _.carat.mean()) / _.carat.std(), - cases( - _.cut, - { - c: i - for i, c in enumerate( - ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 - ) - }, + _.cut.cases( + (c, i) + for i, c in enumerate( + ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 + ) ), - cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}), - cases( - _.clarity, - { - c: i - for i, c in enumerate( - ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), - start=1, - ) - }, + _.color.cases((c, i) for i, c in enumerate("DEFGHIJ", start=1)), + _.clarity.cases( + (c, i) + for i, c in enumerate( + ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), + start=1, + ) ), ) ) diff --git a/ibis/backends/tests/sql/conftest.py b/ibis/backends/tests/sql/conftest.py index a552cec35a4b8..7d785e247b6d3 100644 --- a/ibis/backends/tests/sql/conftest.py +++ b/ibis/backends/tests/sql/conftest.py @@ -159,13 +159,13 @@ def difference(con): @pytest.fixture(scope="module") def simple_case(con): t = con.table("alltypes") - return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture(scope="module") def search_case(con): t = con.table("alltypes") - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture(scope="module") diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py index 7e00b4a40109b..5f4322007cb6a 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py @@ -22,18 +22,14 @@ lit2 = ibis.literal("bar") result = alltypes.select( - alltypes.g.case() - .when(lit, lit2) - .when(lit1, ibis.literal("qux")) - .else_(ibis.literal("default")) - .end() - .name("col1"), - ibis.case() - .when((alltypes.g == lit), lit2) - .when((alltypes.g == lit1), alltypes.g) - .else_(ibis.literal(None)) - .end() - .name("col2"), + alltypes.g.cases( + (lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default") + ).name("col1"), + ibis.cases( + ((alltypes.g == lit), lit2), + ((alltypes.g == lit1), alltypes.g), + else_=ibis.literal(None), + ).name("col2"), alltypes.a, alltypes.b, alltypes.c, diff --git a/ibis/backends/tests/sql/test_select_sql.py b/ibis/backends/tests/sql/test_select_sql.py index 5f4b63df8e7b3..346d6f1c76208 100644 --- a/ibis/backends/tests/sql/test_select_sql.py +++ b/ibis/backends/tests/sql/test_select_sql.py @@ -461,8 +461,8 @@ def test_bool_bool(snapshot): def test_case_in_projection(alltypes, snapshot): t = alltypes - expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() - expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end() + expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default")) + expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g)) expr = t.select(expr.name("col1"), expr2.name("col2"), t) snapshot.assert_match(to_sql(expr), "out.sql") diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 2ff92c14f3619..cfa27b4ab94b6 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -611,7 +611,7 @@ def test_first_last(alltypes, method, filtered, include_null): # To sanely test this we create a column that is a mix of nulls and a # single value (or a single value after filtering is applied). if filtered: - new = alltypes.int_col.cases([(3, 30), (4, 40)]) + new = alltypes.int_col.cases((3, 30), (4, 40)) where = _.int_col == 3 else: new = (alltypes.int_col == 3).ifelse(30, None) @@ -738,7 +738,7 @@ def test_arbitrary(alltypes, filtered): # _something_ we create a column that is a mix of nulls and a single value # (or a single value after filtering is applied). if filtered: - new = alltypes.int_col.cases([(3, 30), (4, 40)]) + new = alltypes.int_col.cases((3, 30), (4, 40)) where = _.int_col == 3 else: new = (alltypes.int_col == 3).ifelse(30, None) @@ -1571,9 +1571,7 @@ def collect_udf(v): def test_binds_are_cast(alltypes): expr = alltypes.aggregate( - high_line_count=( - alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum() - ) + high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum() ) expr.execute() @@ -1616,7 +1614,7 @@ def test_agg_name_in_output_column(alltypes): def test_grouped_case(backend, con): table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]}) - case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end() + case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null()) expr = ( table.group_by(k="key") diff --git a/ibis/backends/tests/test_conditionals.py b/ibis/backends/tests/test_conditionals.py index 90bd76dc44418..0c91b0b33acc2 100644 --- a/ibis/backends/tests/test_conditionals.py +++ b/ibis/backends/tests/test_conditionals.py @@ -63,17 +63,12 @@ def test_substitute(backend): "inp, exp", [ pytest.param( - lambda: ibis.literal(1) - .case() - .when(1, "one") - .when(2, "two") - .else_("other") - .end(), + lambda: ibis.literal(1).cases((1, "one"), (2, "two"), else_="other"), "one", id="one_kwarg", ), pytest.param( - lambda: ibis.literal(5).case().when(1, "one").when(2, "two").end(), + lambda: ibis.literal(5).cases((1, "one"), (2, "two")), None, id="fallthrough", ), @@ -94,13 +89,8 @@ def test_value_cases_column(batting): np = pytest.importorskip("numpy") df = batting.to_pandas() - expr = ( - batting.RBI.case() - .when(5, "five") - .when(4, "four") - .when(3, "three") - .else_("could be good?") - .end() + expr = batting.RBI.cases( + (5, "five"), (4, "four"), (3, "three"), else_="could be good?" ) result = expr.execute() expected = np.select( @@ -113,7 +103,7 @@ def test_value_cases_column(batting): def test_ibis_cases_scalar(): - expr = ibis.literal(5).case().when(5, "five").when(4, "four").end() + expr = ibis.literal(5).cases((5, "five"), (4, "four")) result = expr.execute() assert result == "five" @@ -128,12 +118,8 @@ def test_ibis_cases_column(batting): t = batting df = batting.to_pandas() - expr = ( - ibis.case() - .when(t.RBI < 5, "really bad team") - .when(t.teamID == "PH1", "ph1 team") - .else_(t.teamID) - .end() + expr = ibis.cases( + (t.RBI < 5, "really bad team"), (t.teamID == "PH1", "ph1 team"), else_=t.teamID ) result = expr.execute() expected = np.select( @@ -148,5 +134,19 @@ def test_ibis_cases_column(batting): @pytest.mark.notimpl("clickhouse", reason="special case this and returns 'oops'") def test_value_cases_null(con): """CASE x WHEN NULL never gets hit""" - e = ibis.literal(5).nullif(5).case().when(None, "oops").else_("expected").end() + e = ibis.literal(5).nullif(5).cases((None, "oops"), else_="expected") assert con.execute(e) == "expected" + + +def test_ibis_case_still_works(con): + # just to make sure that the soft-deprecated .case() method still works + # https://github.com/ibis-project/ibis/pull/9096 + pd = pytest.importorskip("pandas") + assert con.execute(ibis.case().when(True, "yes").end()) == "yes" + assert pd.isna(con.execute(ibis.case().when(False, "yes").end())) + assert con.execute(ibis.case().when(False, "yes").else_("no").end()) == "no" + assert con.execute(ibis.literal("a").case().when("a", "yes").end()) == "yes" + assert pd.isna(con.execute(ibis.literal("a").case().when("b", "yes").end())) + assert ( + con.execute(ibis.literal("a").case().when("b", "yes").else_("no").end()) == "no" + ) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 9e6c90cabf0d3..0a4ade8f63f8c 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -382,12 +382,11 @@ def test_case_where(backend, alltypes, df): table = alltypes table = table.mutate( new_col=( - ibis.case() - .when(table["int_col"] == 1, 20) - .when(table["int_col"] == 0, 10) - .else_(0) - .end() - .cast("int64") + ibis.cases( + (table["int_col"] == 1, 20), + (table["int_col"] == 0, 10), + else_=0, + ).cast("int64") ) ) @@ -420,9 +419,7 @@ def test_select_filter_mutate(backend, alltypes, df): # Prepare the float_col so that filter must execute # before the cast to get the correct result. - t = t.mutate( - float_col=ibis.case().when(t["bool_col"], t["float_col"]).else_(np.nan).end() - ) + t = t.mutate(float_col=ibis.cases((t["bool_col"], t["float_col"]), else_=np.nan)) # Actual test t = t.select(t.columns) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 9f94744cc29d1..748addce30149 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -56,16 +56,16 @@ def test_group_by_has_index(backend, snapshot): ) expr = countries.group_by( cont=( - _.continent.case() - .when("NA", "North America") - .when("SA", "South America") - .when("EU", "Europe") - .when("AF", "Africa") - .when("AS", "Asia") - .when("OC", "Oceania") - .when("AN", "Antarctica") - .else_("Unknown continent") - .end() + _.continent.cases( + ("NA", "North America"), + ("SA", "South America"), + ("EU", "Europe"), + ("AF", "Africa"), + ("AS", "Asia"), + ("OC", "Oceania"), + ("AN", "Antarctica"), + else_="Unknown continent", + ) ) ).agg(total_pop=_.population.sum()) sql = str(ibis.to_sql(expr, dialect=backend.name())) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index cb51c30aa2736..29372fc3cd02a 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -508,14 +508,14 @@ def uses_java_re(t): id="length", ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").startswith( - "abc" - ), + lambda t: t.int_col.cases( + (1, "abcd"), (2, "ABCD"), else_="dabc" + ).startswith("abc"), lambda t: t.int_col == 1, id="startswith", ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").endswith( + lambda t: t.int_col.cases((1, "abcd"), (2, "ABCD"), else_="dabc").endswith( "bcd" ), lambda t: t.int_col == 1, @@ -681,11 +681,9 @@ def test_re_replace_global(con): @pytest.mark.notimpl(["druid"], raises=ValidationError) def test_substr_with_null_values(backend, alltypes, df): table = alltypes.mutate( - substr_col_null=ibis.case() - .when(alltypes["bool_col"], alltypes["string_col"]) - .else_(None) - .end() - .substr(0, 2) + substr_col_null=ibis.cases( + (alltypes["bool_col"], alltypes["string_col"]), else_=None + ).substr(0, 2) ) result = table.execute() @@ -885,7 +883,7 @@ def test_levenshtein(con, right): @pytest.mark.parametrize( "expr", [ - param(ibis.case().when(True, "%").end(), id="case"), + param(ibis.cases((True, "%")), id="case"), param(ibis.ifelse(True, "%", ibis.null()), id="ifelse"), ], ) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 3098e349baca4..cfa3cf8ff2db5 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -146,7 +146,7 @@ def test_collect_into_struct(alltypes): @pytest.mark.notimpl(["flink"], raises=Py4JJavaError, reason="not implemented in ibis") def test_field_access_after_case(con): s = ibis.struct({"a": 3}) - x = ibis.case().when(True, s).else_(ibis.struct({"a": 4})).end() + x = ibis.cases((True, s), else_=ibis.struct({"a": 4})) y = x.a assert con.to_pandas(y) == 3 diff --git a/ibis/backends/tests/tpc/h/test_queries.py b/ibis/backends/tests/tpc/h/test_queries.py index 57c1384d9338a..bad4566cb0c5f 100644 --- a/ibis/backends/tests/tpc/h/test_queries.py +++ b/ibis/backends/tests/tpc/h/test_queries.py @@ -272,9 +272,7 @@ def test_08(part, supplier, region, lineitem, orders, customer, nation): ] ) - q = q.mutate( - nation_volume=ibis.case().when(q.nation == NATION, q.volume).else_(0).end() - ) + q = q.mutate(nation_volume=ibis.cases((q.nation == NATION, q.volume), else_=0)) gq = q.group_by([q.o_year]) q = gq.aggregate(mkt_share=q.nation_volume.sum() / q.volume.sum()) q = q.order_by([q.o_year]) @@ -400,19 +398,15 @@ def test_12(orders, lineitem): gq = q.group_by([q.l_shipmode]) q = gq.aggregate( - high_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 1) - .when("2-HIGH", 1) - .else_(0) - .end() + high_line_count=q.o_orderpriority.cases( + ("1-URGENT", 1), + ("2-HIGH", 1), + else_=0, ).sum(), - low_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 0) - .when("2-HIGH", 0) - .else_(1) - .end() + low_line_count=q.o_orderpriority.cases( + ("1-URGENT", 0), + ("2-HIGH", 0), + else_=1, ).sum(), ) q = q.order_by(q.l_shipmode) diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 153dc4e765df8..53f1628e52064 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -68,6 +68,7 @@ "array", "asc", "case", + "cases", "coalesce", "connect", "cross_join", @@ -1073,55 +1074,78 @@ def interval( def case() -> bl.SearchedCaseBuilder: - """Begin constructing a case expression. + """DEPRECATED: Use `ibis.cases()` instead.""" + return bl.SearchedCaseBuilder() + + +@deferrable +def cases(*branches: tuple[Any, Any], else_: Any | None = None) -> ir.Value: + """Create a multi-branch if-else expression. - Use the `.when` method on the resulting object followed by `.end` to create a - complete case expression. + Goes through each (condition, value) pair in `branches`, finding the + first condition that evaluates to True, and returns the corresponding + value. If no condition is True, returns `else_`. + + Parameters + ---------- + branches + A sequence of (condition, value) pairs. The condition is a boolean + expression and the value is the result if the condition is True. + else_ + The value to return if no condition is True. Defaults to NULL. Returns ------- - SearchedCaseBuilder - A builder object to use for constructing a case expression. + Value + A value expression See Also -------- - [`Value.case()`](./expression-generic.qmd#ibis.expr.types.generic.Value.case) + [`Value.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) + [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) Examples -------- >>> import ibis - >>> from ibis import _ >>> ibis.options.interactive = True - >>> t = ibis.memtable( - ... { - ... "left": [1, 2, 3, 4], - ... "symbol": ["+", "-", "*", "/"], - ... "right": [5, 6, 7, 8], - ... } - ... ) - >>> t.mutate( - ... result=( - ... ibis.case() - ... .when(_.symbol == "+", _.left + _.right) - ... .when(_.symbol == "-", _.left - _.right) - ... .when(_.symbol == "*", _.left * _.right) - ... .when(_.symbol == "/", _.left / _.right) - ... .end() - ... ) - ... ) - ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ - ┃ left ┃ symbol ┃ right ┃ result ┃ - ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ - │ int64 │ string │ int64 │ float64 │ - ├───────┼────────┼───────┼─────────┤ - │ 1 │ + │ 5 │ 6.0 │ - │ 2 │ - │ 6 │ -4.0 │ - │ 3 │ * │ 7 │ 21.0 │ - │ 4 │ / │ 8 │ 0.5 │ - └───────┴────────┴───────┴─────────┘ - + >>> v = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}).values + >>> ibis.cases((v == 1, "a"), (v > 2, "b"), else_="unk").name("cases") + ┏━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━┩ + │ string │ + ├────────┤ + │ a │ + │ unk │ + │ a │ + │ unk │ + │ b │ + │ unk │ + │ b │ + └────────┘ + >>> ibis.cases( + ... (v % 2 == 0, "divisible by 2"), + ... (v % 3 == 0, "divisible by 3"), + ... (v % 4 == 0, "shadowed by the 2 case"), + ... ).name("cases") + ┏━━━━━━━━━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━━━━━━━━━┩ + │ string │ + ├────────────────┤ + │ NULL │ + │ divisible by 2 │ + │ NULL │ + │ divisible by 2 │ + │ divisible by 3 │ + │ divisible by 2 │ + │ divisible by 2 │ + └────────────────┘ """ - return bl.SearchedCaseBuilder() + if not branches: + raise ValueError("At least one branch is required") + cases, results = zip(*branches) + return ops.SearchedCase(cases=cases, results=results, default=else_).to_expr() def now() -> ir.TimestampScalar: diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index 9a913b7cfc0f2..3f4eb578e90dd 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -304,16 +304,12 @@ def ifelse(op, bool_expr, true_expr, false_null_expr): @translate.register(ops.SimpleCase) @translate.register(ops.SearchedCase) -def switch_case(op, cases, results, default, base=None): - out = f"{base}.case()" if base else "ibis.case()" - - for case, result in zip(cases, results): - out = f"{out}.when({case}, {result})" - - if default is not None: - out = f"{out}.else_({default})" - - return f"{out}.end()" +def switch_cases(op, cases, results, default, base=None): + namespace = f"{base}" if base else "ibis" + case_strs = [f"({case}, {result})" for case, result in zip(cases, results)] + cases_str = ", ".join(case_strs) + else_str = f", else_={default}" if default is not None else "" + return f"{namespace}.cases({cases_str}{else_str})" _infix_ops = { diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index bc033f66318ed..7ea03f4d70e85 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -154,7 +154,7 @@ class IfElse(Value): Equivalent to ```python - bool_expr.case().when(True, true_expr).else_(false_or_null_expr) + bool_expr.cases((True, true_expr), else_=false_or_null_expr) ``` Many backends implement this as a built-in function. diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index ac1a57f94c911..0223f39679bd4 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from public import public @@ -28,6 +28,9 @@ from ibis.formats.pyarrow import PyArrowData +_SENTINEL = object() + + @public class Value(Expr): """Base class for a data generating expression having a known type.""" @@ -404,7 +407,7 @@ def fill_null(self, fill_value: Scalar) -> Value: @deprecated(as_of="9.1", instead="use fill_null instead") def fillna(self, fill_value: Scalar) -> Value: - """Deprecated - use `fill_null` instead.""" + """DEPRECATED: use `fill_null` instead.""" return self.fill_null(fill_value) def nullif(self, null_if_expr: Value) -> Value: @@ -687,6 +690,9 @@ def substitute( Value Replaced values + [`Value.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.case) + [`ibis.cases()`](./expression-generic.qmd#ibis.cases) + Examples -------- >>> import ibis @@ -715,20 +721,25 @@ def substitute( │ torg │ 52 │ └────────┴──────────────┘ """ - if isinstance(value, dict): - expr = ibis.case() - try: - null_replacement = value.pop(None) - except KeyError: - pass - else: - expr = expr.when(self.isnull(), null_replacement) - for k, v in value.items(): - expr = expr.when(self == k, v) + try: + branches = value.items() + except AttributeError: + branches = [(value, replacement)] + + if ( + repl := next((v for k, v in branches if k is None), _SENTINEL) + ) is not _SENTINEL: + result = self.fill_null(repl) else: - expr = self.case().when(value, replacement) + result = self + + if else_ is None: + else_ = result - return expr.else_(else_ if else_ is not None else self).end() + if not (nonnulls := [(k, v) for k, v in branches if k is not None]): + return else_ + + return result.cases(*nonnulls, else_=else_) def over( self, @@ -864,100 +875,32 @@ def notnull(self) -> ir.BooleanValue: """ return ops.NotNull(self).to_expr() + @deprecated(as_of="10.0.0", instead="use value.cases() or ibis.cases()") def case(self) -> bl.SimpleCaseBuilder: - """Create a SimpleCaseBuilder to chain multiple if-else statements. - - Add new search expressions with the `.when()` method. These must be - comparable with this column expression. Conclude by calling `.end()`. - - Returns - ------- - SimpleCaseBuilder - A case builder - - See Also - -------- - [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) - - Examples - -------- - >>> import ibis - >>> ibis.options.interactive = True - >>> x = ibis.examples.penguins.fetch().head(5)["sex"] - >>> x - ┏━━━━━━━━┓ - ┃ sex ┃ - ┡━━━━━━━━┩ - │ string │ - ├────────┤ - │ male │ - │ female │ - │ female │ - │ NULL │ - │ female │ - └────────┘ - >>> x.case().when("male", "M").when("female", "F").else_("U").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male', 'female'), ('M', 'F'), 'U') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├──────────────────────────────────────────────────────┤ - │ M │ - │ F │ - │ F │ - │ U │ - │ F │ - └──────────────────────────────────────────────────────┘ - - Cases not given result in the ELSE case - - >>> x.case().when("male", "M").else_("OTHER").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male',), ('M',), 'OTHER') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├─────────────────────────────────────────────┤ - │ M │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - └─────────────────────────────────────────────┘ - - If you don't supply an ELSE, then NULL is used - - >>> x.case().when("male", "M").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male',), ('M',), Cast(None, string)) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├────────────────────────────────────────────────────────┤ - │ M │ - │ NULL │ - │ NULL │ - │ NULL │ - │ NULL │ - └────────────────────────────────────────────────────────┘ - """ - import ibis.expr.builders as bl - + """DEPRECATED: use `value.cases()` or `ibis.cases()` instead.""" return bl.SimpleCaseBuilder(self.op()) def cases( self, - case_result_pairs: Iterable[tuple[ir.BooleanValue, Value]], - default: Value | None = None, + branch: tuple[Value, Value], + *branches: tuple[Value, Value], + else_: Value | None = None, ) -> Value: - """Create a case expression in one shot. + """Create a multi-branch if-else expression. + + This is equivalent to a SQL `CASE` statement. Parameters ---------- - case_result_pairs - Conditional-result pairs - default - Value to return if none of the case conditions are true + branch + First (`condition`, `result`) pair. Required. + branches + Additional (`condition`, `result`) pairs. We look through the test + values in order and return the result corresponding to the first + test value that matches `self`. If none match, we return `else_`. + else_ + Value to return if none of the case conditions evaluate to `True`. + Defaults to `NULL`. Returns ------- @@ -967,48 +910,52 @@ def cases( See Also -------- [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) + [`ibis.cases()`](./expression-generic.qmd#ibis.cases) Examples -------- >>> import ibis >>> ibis.options.interactive = True - >>> t = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}) - >>> t - ┏━━━━━━━━┓ - ┃ values ┃ - ┡━━━━━━━━┩ - │ int64 │ - ├────────┤ - │ 1 │ - │ 2 │ - │ 1 │ - │ 2 │ - │ 3 │ - │ 2 │ - │ 4 │ - └────────┘ - >>> number_letter_map = ((1, "a"), (2, "b"), (3, "c")) - >>> t.values.cases(number_letter_map, default="unk").name("replace") - ┏━━━━━━━━━┓ - ┃ replace ┃ - ┡━━━━━━━━━┩ - │ string │ - ├─────────┤ - │ a │ - │ b │ - │ a │ - │ b │ - │ c │ - │ b │ - │ unk │ - └─────────┘ + >>> t = ibis.memtable( + ... { + ... "left": [5, 6, 7, 8, 9, 10], + ... "symbol": ["+", "-", "*", "/", "bogus", None], + ... "right": [1, 2, 3, 4, 5, 6], + ... } + ... ) + + Note that we never hit the `None` case, because `x = NULL` is always + `NULL`, which is not truthy. If you want to replace `NULL`s, you should use + `.fill_null(some_value)` prior to `cases()`. + + >>> t.mutate( + ... result=( + ... t.symbol.cases( + ... ("+", t.left + t.right), + ... ("-", t.left - t.right), + ... ("*", t.left * t.right), + ... ("/", t.left / t.right), + ... (None, -999), + ... ) + ... ) + ... ) + ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ + ┃ left ┃ symbol ┃ right ┃ result ┃ + ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ + │ int64 │ string │ int64 │ float64 │ + ├───────┼────────┼───────┼─────────┤ + │ 5 │ + │ 1 │ 6.0 │ + │ 6 │ - │ 2 │ 4.0 │ + │ 7 │ * │ 3 │ 21.0 │ + │ 8 │ / │ 4 │ 2.0 │ + │ 9 │ bogus │ 5 │ NULL │ + │ 10 │ NULL │ 6 │ NULL │ + └───────┴────────┴───────┴─────────┘ """ - builder = self.case() - for case, result in case_result_pairs: - builder = builder.when(case, result) - return builder.else_(default).end() + cases, results = zip(branch, *branches) + return ops.SimpleCase( + base=self, cases=cases, results=results, default=else_ + ).to_expr() def collect( self, diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index bfe376ef4f117..d6836c68ad041 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools from collections.abc import Sequence from typing import TYPE_CHECKING, Literal @@ -1221,13 +1220,7 @@ def label(self, labels: Iterable[str], nulls: str | None = None) -> ir.StringVal │ 2 │ c │ └───────┴─────────┘ """ - return ( - functools.reduce( - lambda stmt, inputs: stmt.when(*inputs), enumerate(labels), self.case() - ) - .else_(nulls) - .end() - ) + return self.cases(*enumerate(labels), else_=nulls) @public diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 00450a44e837b..9c33d9dd89f9b 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -2830,9 +2830,7 @@ def info(self) -> Table: for pos, colname in enumerate(self.columns): col = self[colname] typ = col.type() - agg = self.select( - isna=ibis.case().when(col.isnull(), 1).else_(0).end() - ).agg( + agg = self.select(isna=ibis.cases((col.isnull(), 1), else_=0)).agg( name=lit(colname), type=lit(str(typ)), nullable=lit(typ.nullable), diff --git a/ibis/tests/expr/test_case.py b/ibis/tests/expr/test_case.py index 97bfcba5d6649..02054f7daf628 100644 --- a/ibis/tests/expr/test_case.py +++ b/ibis/tests/expr/test_case.py @@ -8,7 +8,7 @@ import ibis.expr.types as ir from ibis import _ from ibis.common.annotations import SignatureValidationError -from ibis.tests.util import assert_equal, assert_pickle_roundtrip +from ibis.tests.util import assert_pickle_roundtrip def test_ifelse_method(table): @@ -48,106 +48,67 @@ def test_ifelse_function_deferred(table): def test_case_dshape(table): - assert isinstance(ibis.case().when(True, "bar").when(False, "bar").end(), ir.Scalar) - assert isinstance(ibis.case().when(True, None).else_("bar").end(), ir.Scalar) - assert isinstance( - ibis.case().when(table.b == 9, None).else_("bar").end(), ir.Column - ) - assert isinstance(ibis.case().when(True, table.a).else_(42).end(), ir.Column) - assert isinstance(ibis.case().when(True, 42).else_(table.a).end(), ir.Column) - assert isinstance(ibis.case().when(True, table.a).else_(table.b).end(), ir.Column) - - assert isinstance(ibis.literal(5).case().when(9, 42).end(), ir.Scalar) - assert isinstance(ibis.literal(5).case().when(9, 42).else_(43).end(), ir.Scalar) - assert isinstance(ibis.literal(5).case().when(table.a, 42).end(), ir.Column) - assert isinstance(ibis.literal(5).case().when(9, table.a).end(), ir.Column) - assert isinstance(ibis.literal(5).case().when(table.a, table.b).end(), ir.Column) - assert isinstance( - ibis.literal(5).case().when(9, 42).else_(table.a).end(), ir.Column - ) - assert isinstance(table.a.case().when(9, 42).end(), ir.Column) - assert isinstance(table.a.case().when(table.b, 42).end(), ir.Column) - assert isinstance(table.a.case().when(9, table.b).end(), ir.Column) - assert isinstance(table.a.case().when(table.a, table.b).end(), ir.Column) + assert isinstance(ibis.cases((True, "bar"), (False, "bar")), ir.Scalar) + assert isinstance(ibis.cases((True, None), else_="bar"), ir.Scalar) + assert isinstance(ibis.cases((table.b == 9, None), else_="bar"), ir.Column) + assert isinstance(ibis.cases((True, table.a), else_=42), ir.Column) + assert isinstance(ibis.cases((True, 42), else_=table.a), ir.Column) + assert isinstance(ibis.cases((True, table.a), else_=table.b), ir.Column) + + assert isinstance(ibis.literal(5).cases((9, 42)), ir.Scalar) + assert isinstance(ibis.literal(5).cases((9, 42), else_=43), ir.Scalar) + assert isinstance(ibis.literal(5).cases((table.a, 42)), ir.Column) + assert isinstance(ibis.literal(5).cases((9, table.a)), ir.Column) + assert isinstance(ibis.literal(5).cases((table.a, table.b)), ir.Column) + assert isinstance(ibis.literal(5).cases((9, 42), else_=table.a), ir.Column) + assert isinstance(table.a.cases((9, 42)), ir.Column) + assert isinstance(table.a.cases((table.b, 42)), ir.Column) + assert isinstance(table.a.cases((9, table.b)), ir.Column) + assert isinstance(table.a.cases((table.a, table.b)), ir.Column) def test_case_dtype(): - assert isinstance( - ibis.case().when(True, "bar").when(False, "bar").end(), ir.StringValue - ) - assert isinstance(ibis.case().when(True, None).else_("bar").end(), ir.StringValue) + assert isinstance(ibis.cases((True, "bar"), (False, "bar")), ir.StringValue) + assert isinstance(ibis.cases((True, None), else_="bar"), ir.StringValue) with pytest.raises(TypeError): - ibis.case().when(True, 5).when(False, "bar").end() + ibis.cases((True, 5), (False, "bar")) with pytest.raises(TypeError): - ibis.case().when(True, 5).else_("bar").end() - - -def test_simple_case_expr(table): - case1, result1 = "foo", table.a - case2, result2 = "bar", table.c - default_result = table.b - - expr1 = table.g.lower().cases( - [(case1, result1), (case2, result2)], default=default_result - ) - - expr2 = ( - table.g.lower() - .case() - .when(case1, result1) - .when(case2, result2) - .else_(default_result) - .end() - ) - - assert_equal(expr1, expr2) - assert isinstance(expr1, ir.IntegerColumn) + ibis.cases((True, 5), else_="bar") def test_multiple_case_expr(table): - expr = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(table.d) - .end() + expr = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=table.d, ) # deferred cases - deferred = ( - ibis.case() - .when(_.a == 5, table.f) - .when(_.b == 128, table.b * 2) - .when(_.c == 1000, table.e) - .else_(table.d) - .end() + deferred = ibis.cases( + (_.a == 5, table.f), + (_.b == 128, table.b * 2), + (_.c == 1000, table.e), + else_=table.d, ) expr2 = deferred.resolve(table) # deferred results - expr3 = ( - ibis.case() - .when(table.a == 5, _.f) - .when(table.b == 128, _.b * 2) - .when(table.c == 1000, _.e) - .else_(table.d) - .end() - .resolve(table) - ) + expr3 = ibis.cases( + (table.a == 5, _.f), + (table.b == 128, _.b * 2), + (table.c == 1000, _.e), + else_=table.d, + ).resolve(table) # deferred default - expr4 = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(_.d) - .end() - .resolve(table) - ) + expr4 = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=_.d, + ).resolve(table) - assert repr(deferred) == "" assert expr.equals(expr2) assert expr.equals(expr3) assert expr.equals(expr4) @@ -168,13 +129,11 @@ def test_pickle_multiple_case_node(table): result3 = table.e default = table.d - expr = ( - ibis.case() - .when(case1, result1) - .when(case2, result2) - .when(case3, result3) - .else_(default) - .end() + expr = ibis.cases( + (case1, result1), + (case2, result2), + (case3, result3), + else_=default, ) op = expr.op() @@ -182,18 +141,16 @@ def test_pickle_multiple_case_node(table): def test_simple_case_null_else(table): - expr = table.g.case().when("foo", "bar").end() + expr = table.g.cases(("foo", "bar")) op = expr.op() assert isinstance(expr, ir.StringColumn) assert isinstance(op.default.to_expr(), ir.Value) - assert isinstance(op.default, ops.Cast) - assert op.default.to == dt.string def test_multiple_case_null_else(table): - expr = ibis.case().when(table.g == "foo", "bar").end() - expr2 = ibis.case().when(table.g == "foo", _).end().resolve("bar") + expr = ibis.cases((table.g == "foo", "bar")) + expr2 = ibis.cases((table.g == "foo", _)).resolve("bar") assert expr.equals(expr2) @@ -208,32 +165,43 @@ def test_case_mixed_type(): name="my_data", ) - expr = ( - t0.three.case().when(0, "low").when(1, "high").else_("null").end().name("label") - ) + expr = t0.three.cases((0, "low"), (1, "high"), else_="null").name("label") result = t0.select(expr) assert result["label"].type().equals(dt.string) +def test_err_on_bad_args(table): + with pytest.raises(ValueError): + ibis.cases((True,)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(TypeError): + ibis.cases((True, 3), 5) + + def test_err_on_nonbool_expr(table): with pytest.raises(SignatureValidationError): - ibis.case().when(table.a, "bar").else_("baz").end() + ibis.cases((table.a, "bar"), else_="baz") with pytest.raises(SignatureValidationError): - ibis.case().when(ibis.literal(1), "bar").else_("baz").end() + ibis.cases((ibis.literal(1), "bar"), else_=("baz")) def test_err_on_noncomparable(table): + table.a.cases((8, "bar")) + table.a.cases((-8, "bar")) # Can't compare an int to a string with pytest.raises(TypeError): - table.a.case().when("foo", "bar").end() + table.a.cases(("foo", "bar")) def test_err_on_empty_cases(table): - with pytest.raises(SignatureValidationError): - ibis.case().end() - with pytest.raises(SignatureValidationError): - ibis.case().else_(42).end() - with pytest.raises(SignatureValidationError): - table.a.case().end() - with pytest.raises(SignatureValidationError): - table.a.case().else_(42).end() + with pytest.raises(ValueError): + ibis.cases() + with pytest.raises(ValueError): + ibis.cases(else_=42) + with pytest.raises(ValueError): + table.a.cases() + with pytest.raises(ValueError): + table.a.cases(else_=42) diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index d1e1cd5e35c7a..35e99e31458e9 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -825,23 +825,11 @@ def test_substitute_dict(): subs = {"a": "one", "b": table.bar} result = table.foo.substitute(subs) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(table.foo) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=table.foo) assert_equal(result, expected) result = table.foo.substitute(subs, else_=ibis.null()) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(ibis.null()) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=ibis.null()) assert_equal(result, expected)