From 6ca584ae75674350ed21248b7306c854f953c02b Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Tue, 23 Apr 2024 10:59:48 -0800 Subject: [PATCH] feat: move from .case() to .cases() --- docs/posts/ci-analysis/index.qmd | 14 +- docs/tutorials/ibis-for-sql-users.qmd | 24 +- .../clickhouse/tests/test_operators.py | 14 +- ibis/backends/dask/tests/test_operations.py | 21 +- ibis/backends/impala/tests/test_case_exprs.py | 4 +- ibis/backends/pandas/executor.py | 2 + ibis/backends/pandas/tests/test_operations.py | 23 +- ibis/backends/snowflake/tests/test_udf.py | 6 +- ibis/backends/sql/compiler.py | 2 + ibis/backends/tests/sql/conftest.py | 4 +- .../test_case_in_projection/decompiled.py | 20 +- ibis/backends/tests/sql/test_select_sql.py | 4 +- ibis/backends/tests/test_aggregation.py | 8 +- ibis/backends/tests/test_generic.py | 29 ++- ibis/backends/tests/test_sql.py | 20 +- ibis/backends/tests/test_string.py | 18 +- ibis/backends/tests/test_struct.py | 2 +- ibis/backends/tests/tpch/test_h08.py | 4 +- ibis/backends/tests/tpch/test_h12.py | 12 +- ibis/expr/api.py | 93 +++++--- ibis/expr/decompile.py | 16 +- ibis/expr/operations/generic.py | 27 ++- ibis/expr/operations/logical.py | 4 +- ibis/expr/types/generic.py | 208 +++++++----------- ibis/expr/types/numeric.py | 9 +- ibis/expr/types/relations.py | 4 +- ibis/tests/expr/test_case.py | 177 +++++++++------ ibis/tests/expr/test_value_exprs.py | 16 +- 28 files changed, 374 insertions(+), 411 deletions(-) diff --git a/docs/posts/ci-analysis/index.qmd b/docs/posts/ci-analysis/index.qmd index 5babc2c6d0c61..65159d5bbe9c7 100644 --- a/docs/posts/ci-analysis/index.qmd +++ b/docs/posts/ci-analysis/index.qmd @@ -203,14 +203,12 @@ Let's also give them some names that'll look nice on our plots. stats = stats.mutate( raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int") ).mutate( - improvements=( - _.raw_improvements.case() - .when(0, "None") - .when(1, "Poetry") - .when(2, "Poetry + Team Plan") - .else_("NA") - .end() - ), + improvements=_.raw_improvements.cases( + (0, "None"), + (1, "Poetry"), + (2, "Poetry + Team Plan"), + else_="NA", + ) team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"), ) stats diff --git a/docs/tutorials/ibis-for-sql-users.qmd b/docs/tutorials/ibis-for-sql-users.qmd index 577f7b015111e..537742b876957 100644 --- a/docs/tutorials/ibis-for-sql-users.qmd +++ b/docs/tutorials/ibis-for-sql-users.qmd @@ -473,11 +473,11 @@ semantics: case = ( t.one.cast("timestamp") .year() - .case() - .when(2015, "This year") - .when(2014, "Last year") - .else_("Earlier") - .end() + .cases( + (2015, "This year"), + (2014, "Last year"), + else_="Earlier", + ) ) expr = t.mutate(year_group=case) @@ -496,18 +496,16 @@ CASE END ``` -To do this, use `ibis.case`: +To do this, use `ibis.cases`: ```{python} -case = ( - ibis.case() - .when(t.two < 0, t.three * 2) - .when(t.two > 1, t.three) - .else_(t.two) - .end() +cases = ibis.cases( + (t.two < 0, t.three * 2), + (t.two > 1, t.three), + else_=t.two, ) -expr = t.mutate(cond_value=case) +expr = t.mutate(cond_value=cases) ibis.to_sql(expr) ``` diff --git a/ibis/backends/clickhouse/tests/test_operators.py b/ibis/backends/clickhouse/tests/test_operators.py index 4ca53a3d2b9f3..3ff07ce916a45 100644 --- a/ibis/backends/clickhouse/tests/test_operators.py +++ b/ibis/backends/clickhouse/tests/test_operators.py @@ -201,9 +201,7 @@ def test_ifelse(alltypes, df, op, pandas_op): def test_simple_case(con, alltypes, assert_sql): t = alltypes - expr = ( - t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end() - ) + expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default") assert_sql(expr) assert len(con.execute(expr)) @@ -211,12 +209,10 @@ def test_simple_case(con, alltypes, assert_sql): def test_search_case(con, alltypes, assert_sql): t = alltypes - expr = ( - ibis.case() - .when(t.float_col > 0, t.int_col * 2) - .when(t.float_col < 0, t.int_col) - .else_(0) - .end() + expr = ibis.cases( + (t.float_col > 0, t.int_col * 2), + (t.float_col < 0, t.int_col), + else_=0, ) assert_sql(expr) diff --git a/ibis/backends/dask/tests/test_operations.py b/ibis/backends/dask/tests/test_operations.py index e43e5af454933..d591cb629f9c6 100644 --- a/ibis/backends/dask/tests/test_operations.py +++ b/ibis/backends/dask/tests/test_operations.py @@ -774,7 +774,7 @@ def q_fun(x, quantile): def test_searched_case_scalar(client): - expr = ibis.case().when(True, 1).when(False, 2).end() + expr = ibis.cases((True, 1), (False, 2)) result = client.execute(expr) expected = np.int8(1) assert result == expected @@ -783,12 +783,8 @@ def test_searched_case_scalar(client): def test_searched_case_column(batting, batting_pandas_df): t = batting df = batting_pandas_df - expr = ( - ibis.case() - .when(t.RBI < 5, "really bad team") - .when(t.teamID == "PH1", "ph1 team") - .else_(t.teamID) - .end() + expr = ibis.cases( + (t.RBI < 5, "really bad team"), (t.teamID == "PH1", "ph1 team"), else_=t.teamID ) result = expr.execute() expected = pd.Series( @@ -803,7 +799,7 @@ def test_searched_case_column(batting, batting_pandas_df): def test_simple_case_scalar(client): x = ibis.literal(2) - expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end() + expr = x.cases((2, x - 1), (3, x + 1), (4, x + 2)) result = client.execute(expr) expected = np.int8(1) assert result == expected @@ -812,14 +808,7 @@ def test_simple_case_scalar(client): def test_simple_case_column(batting, batting_pandas_df): t = batting df = batting_pandas_df - expr = ( - t.RBI.case() - .when(5, "five") - .when(4, "four") - .when(3, "three") - .else_("could be good?") - .end() - ) + expr = t.RBI.cases((5, "five"), (4, "four"), (3, "three"), else_="could be good?") result = expr.execute() expected = pd.Series( np.select( diff --git a/ibis/backends/impala/tests/test_case_exprs.py b/ibis/backends/impala/tests/test_case_exprs.py index e23a9436c6fb1..e1a97df344b83 100644 --- a/ibis/backends/impala/tests/test_case_exprs.py +++ b/ibis/backends/impala/tests/test_case_exprs.py @@ -14,13 +14,13 @@ def table(mockcon): @pytest.fixture def simple_case(table): - return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture def search_case(table): t = table - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index b895d892dbac9..28a6ef154f864 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -163,6 +163,8 @@ def visit(cls, op: ops.IsNan, arg): def visit( cls, op: ops.SearchedCase | ops.SimpleCase, cases, results, default, base=None ): + if not cases: + return default if base is not None: cases = tuple(base == case for case in cases) cases, _ = cls.asframe(cases, concat=False) diff --git a/ibis/backends/pandas/tests/test_operations.py b/ibis/backends/pandas/tests/test_operations.py index 3d6e78d9d2c69..5e515f8e4554c 100644 --- a/ibis/backends/pandas/tests/test_operations.py +++ b/ibis/backends/pandas/tests/test_operations.py @@ -684,7 +684,7 @@ def test_summary_non_numeric(batting, batting_df): def test_searched_case_scalar(client): - expr = ibis.case().when(True, 1).when(False, 2).end() + expr = ibis.cases((True, 1), (False, 2)) result = client.execute(expr) expected = np.int8(1) assert result == expected @@ -693,12 +693,10 @@ def test_searched_case_scalar(client): def test_searched_case_column(batting, batting_df): t = batting df = batting_df - expr = ( - ibis.case() - .when(t.RBI < 5, "really bad team") - .when(t.teamID == "PH1", "ph1 team") - .else_(t.teamID) - .end() + expr = ibis.cases( + (t.RBI < 5, "really bad team"), + (t.teamID == "PH1", "ph1 team"), + else_=t.teamID, ) result = expr.execute() expected = pd.Series( @@ -713,7 +711,7 @@ def test_searched_case_column(batting, batting_df): def test_simple_case_scalar(client): x = ibis.literal(2) - expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end() + expr = x.cases((2, x - 1), (3, x + 1), (4, x + 2)) result = client.execute(expr) expected = np.int8(1) assert result == expected @@ -722,14 +720,7 @@ def test_simple_case_scalar(client): def test_simple_case_column(batting, batting_df): t = batting df = batting_df - expr = ( - t.RBI.case() - .when(5, "five") - .when(4, "four") - .when(3, "three") - .else_("could be good?") - .end() - ) + expr = t.RBI.cases((5, "five"), (4, "four"), (3, "three"), else_="could be good?") result = expr.execute() expected = pd.Series( np.select( diff --git a/ibis/backends/snowflake/tests/test_udf.py b/ibis/backends/snowflake/tests/test_udf.py index 87696fc957888..d42c459aa6ff3 100644 --- a/ibis/backends/snowflake/tests/test_udf.py +++ b/ibis/backends/snowflake/tests/test_udf.py @@ -118,10 +118,8 @@ def predict_price( def cases(value, mapping): """This should really be a top-level function or method.""" - expr = ibis.case() - for k, v in mapping.items(): - expr = expr.when(value == k, v) - return expr.end() + pairs = [(value == k, v) for k, v in mapping.items()] + return ibis.cases(*pairs) diamonds = con.tables.DIAMONDS expr = diamonds.mutate( diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 71d57b7481aad..209df6050ed6f 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -929,6 +929,8 @@ def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw): ) def visit_SimpleCase(self, op, *, base=None, cases, results, default): + if not cases: + return default return sge.Case( this=base, ifs=list(map(self.if_, cases, results)), default=default ) diff --git a/ibis/backends/tests/sql/conftest.py b/ibis/backends/tests/sql/conftest.py index 04667e60e033b..06de1c83c8c08 100644 --- a/ibis/backends/tests/sql/conftest.py +++ b/ibis/backends/tests/sql/conftest.py @@ -164,13 +164,13 @@ def difference(con): @pytest.fixture(scope="module") def simple_case(con): t = con.table("alltypes") - return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture(scope="module") def search_case(con): t = con.table("alltypes") - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture(scope="module") diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py index 6058efaa962e6..1b1dcf62dca62 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py @@ -22,18 +22,14 @@ lit2 = ibis.literal("bar") result = alltypes.select( - alltypes.g.case() - .when(lit, lit2) - .when(lit1, ibis.literal("qux")) - .else_(ibis.literal("default")) - .end() - .name("col1"), - ibis.case() - .when(alltypes.g == lit, lit2) - .when(alltypes.g == lit1, alltypes.g) - .else_(ibis.literal(None).cast("string")) - .end() - .name("col2"), + alltypes.g.cases( + (lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default") + ).name("col1"), + ibis.cases( + (alltypes.g == lit, lit2), + (alltypes.g == lit1, alltypes.g), + else_=ibis.literal(None).cast("string"), + ).name("col2"), alltypes.a, alltypes.b, alltypes.c, diff --git a/ibis/backends/tests/sql/test_select_sql.py b/ibis/backends/tests/sql/test_select_sql.py index 94a52017f763f..24893739fb6eb 100644 --- a/ibis/backends/tests/sql/test_select_sql.py +++ b/ibis/backends/tests/sql/test_select_sql.py @@ -397,8 +397,8 @@ def test_bool_bool(snapshot): def test_case_in_projection(alltypes, snapshot): t = alltypes - expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() - expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end() + expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default")) + expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g)) expr = t[expr.name("col1"), expr2.name("col2"), t] snapshot.assert_match(to_sql(expr), "out.sql") diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index c78508d9d0e61..bef4107c0f1ea 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -680,7 +680,7 @@ def test_arbitrary(backend, alltypes, df, filtered): # _something_ we create a column that is a mix of nulls and a single value # (or a single value after filtering is applied). if filtered: - new = alltypes.int_col.cases([(3, 30), (4, 40)]) + new = alltypes.int_col.cases((3, 30), (4, 40)) where = _.int_col == 3 else: new = (alltypes.int_col == 3).ifelse(30, None) @@ -1428,9 +1428,7 @@ def collect_udf(v): def test_binds_are_cast(alltypes): expr = alltypes.aggregate( - high_line_count=( - alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum() - ) + high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum() ) expr.execute() @@ -1476,7 +1474,7 @@ def test_agg_name_in_output_column(alltypes): def test_grouped_case(backend, con): table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]}) - case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end() + case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null()) expr = ( table.group_by(k="key").aggregate(mx=case_expr.max()).dropna("k").order_by("k") diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 4d3c6126c01fe..b95fd352513de 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -356,12 +356,11 @@ def test_case_where(backend, alltypes, df): table = alltypes table = table.mutate( new_col=( - ibis.case() - .when(table["int_col"] == 1, 20) - .when(table["int_col"] == 0, 10) - .else_(0) - .end() - .cast("int64") + ibis.cases( + (table["int_col"] == 1, 20), + (table["int_col"] == 0, 10), + else_=0, + ).cast("int64") ) ) @@ -394,9 +393,7 @@ def test_select_filter_mutate(backend, alltypes, df): # Prepare the float_col so that filter must execute # before the cast to get the correct result. - t = t.mutate( - float_col=ibis.case().when(t["bool_col"], t["float_col"]).else_(np.nan).end() - ) + t = t.mutate(float_col=ibis.cases((t["bool_col"], t["float_col"]), else_=np.nan)) # Actual test t = t[t.columns] @@ -1999,6 +1996,20 @@ def test_substitute(backend): assert expr["subs_count"].execute()[0] == t.count().execute() // 10 +@pytest.mark.broken("flink", reason="can't handle SELECT NULL AS `SearchedCase(None)`") +def test_cases_empty(con): + assert pd.isnull(con.execute(ibis.cases())) + assert con.execute(ibis.cases(else_=2)) == 2 + assert pd.isnull(con.execute(ibis.literal(1).cases())) + assert con.execute(ibis.literal(1).cases(else_=2)) == 2 + + +def test_switch_cases_null(con): + """CASE x WHEN NULL never gets hit""" + e = ibis.literal(5).nullif(5).cases((None, "shouldnt get here"), else_="expected") + assert con.execute(e) == "expected" + + @pytest.mark.notimpl( ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" ) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 097bbb9cb4577..8f78d3560f899 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -65,16 +65,16 @@ def test_group_by_has_index(backend, snapshot): ) expr = countries.group_by( cont=( - _.continent.case() - .when("NA", "North America") - .when("SA", "South America") - .when("EU", "Europe") - .when("AF", "Africa") - .when("AS", "Asia") - .when("OC", "Oceania") - .when("AN", "Antarctica") - .else_("Unknown continent") - .end() + _.continent.cases( + ("NA", "North America"), + ("SA", "South America"), + ("EU", "Europe"), + ("AF", "Africa"), + ("AS", "Asia"), + ("OC", "Oceania"), + ("AN", "Antarctica"), + else_="Unknown continent", + ) ) ).agg(total_pop=_.population.sum()) sql = str(ibis.to_sql(expr, dialect=backend.name())) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 595129cb9091a..17d6cfae496d3 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -508,9 +508,9 @@ def uses_java_re(t): id="length", ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").startswith( - "abc" - ), + lambda t: t.int_col.cases( + (1, "abcd"), (2, "ABCD"), else_="dabc" + ).startswith("abc"), lambda t: t.int_col == 1, id="startswith", marks=[ @@ -518,7 +518,7 @@ def uses_java_re(t): ], ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").endswith( + lambda t: t.int_col.cases((1, "abcd"), (2, "ABCD"), else_="dabc").endswith( "bcd" ), lambda t: t.int_col == 1, @@ -694,11 +694,9 @@ def test_re_replace_global(con): @pytest.mark.notimpl(["druid"], raises=ValidationError) def test_substr_with_null_values(backend, alltypes, df): table = alltypes.mutate( - substr_col_null=ibis.case() - .when(alltypes["bool_col"], alltypes["string_col"]) - .else_(None) - .end() - .substr(0, 2) + substr_col_null=ibis.cases( + (alltypes["bool_col"], alltypes["string_col"]), else_=None + ).substr(0, 2) ) result = table.execute() @@ -919,7 +917,7 @@ def test_levenshtein(con, right): @pytest.mark.parametrize( "expr", [ - param(ibis.case().when(True, "%").end(), id="case"), + param(ibis.cases((True, "%")), id="case"), param(ibis.ifelse(True, "%", ibis.NA), id="ifelse"), ], ) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index e7f078f7591a1..2101dff274779 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -149,7 +149,7 @@ def test_collect_into_struct(alltypes): @pytest.mark.notimpl(["flink"], raises=Py4JJavaError, reason="not implemented in ibis") def test_field_access_after_case(con): s = ibis.struct({"a": 3}) - x = ibis.case().when(True, s).else_(ibis.struct({"a": 4})).end() + x = ibis.cases((True, s), else_=ibis.struct({"a": 4})) y = x.a assert con.to_pandas(y) == 3 diff --git a/ibis/backends/tests/tpch/test_h08.py b/ibis/backends/tests/tpch/test_h08.py index 651a725b1fc8a..5ffa2f2c53101 100644 --- a/ibis/backends/tests/tpch/test_h08.py +++ b/ibis/backends/tests/tpch/test_h08.py @@ -42,9 +42,7 @@ def test_tpc_h08(part, supplier, region, lineitem, orders, customer, nation): ] ) - q = q.mutate( - nation_volume=ibis.case().when(q.nation == NATION, q.volume).else_(0).end() - ) + q = q.mutate(nation_volume=ibis.cases((q.nation == NATION, q.volume), else_=0)) gq = q.group_by([q.o_year]) q = gq.aggregate(mkt_share=q.nation_volume.sum() / q.volume.sum()) q = q.order_by([q.o_year]) diff --git a/ibis/backends/tests/tpch/test_h12.py b/ibis/backends/tests/tpch/test_h12.py index 0fb092beb8e31..e56efb9507f03 100644 --- a/ibis/backends/tests/tpch/test_h12.py +++ b/ibis/backends/tests/tpch/test_h12.py @@ -32,18 +32,10 @@ def test_tpc_h12(orders, lineitem): gq = q.group_by([q.l_shipmode]) q = gq.aggregate( high_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 1) - .when("2-HIGH", 1) - .else_(0) - .end() + q.o_orderpriority.cases(("1-URGENT", 1), ("2-HIGH", 1), else_=0) ).sum(), low_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 0) - .when("2-HIGH", 0) - .else_(1) - .end() + q.o_orderpriority.cases(("1-URGENT", 0), ("2-HIGH", 0), else_=1) ).sum(), ) q = q.order_by(q.l_shipmode) diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 2dd766ac2bc8b..8114e7f206946 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -68,6 +68,7 @@ "array", "asc", "case", + "cases", "coalesce", "connect", "cross_join", @@ -1105,56 +1106,76 @@ def interval( return functools.reduce(operator.add, intervals) +@util.deprecated(instead="use ibis.cases() instead", as_of="9.0") def case() -> bl.SearchedCaseBuilder: - """Begin constructing a case expression. + """DEPRECATED: Use `ibis.cases()` instead.""" + return bl.SearchedCaseBuilder() + + +@deferrable +def cases(*branches: tuple[Any, Any], else_: Any | None = None) -> ir.Value: + """Create a multi-branch if-else expression. - Use the `.when` method on the resulting object followed by `.end` to create a - complete case expression. + Goes through each (condition, value) pair in `branches`, finding the + first condition that evaluates to True, and returns the corresponding + value. If no condition is True, returns `else_`. Returns ------- - SearchedCaseBuilder - A builder object to use for constructing a case expression. + Value + A value expression See Also -------- - [`Value.case()`](./expression-generic.qmd#ibis.expr.types.generic.Value.case) + [`Value.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) Examples -------- >>> import ibis - >>> from ibis import _ >>> ibis.options.interactive = True - >>> t = ibis.memtable( - ... { - ... "left": [1, 2, 3, 4], - ... "symbol": ["+", "-", "*", "/"], - ... "right": [5, 6, 7, 8], - ... } - ... ) - >>> t.mutate( - ... result=( - ... ibis.case() - ... .when(_.symbol == "+", _.left + _.right) - ... .when(_.symbol == "-", _.left - _.right) - ... .when(_.symbol == "*", _.left * _.right) - ... .when(_.symbol == "/", _.left / _.right) - ... .end() - ... ) - ... ) - ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ - ┃ left ┃ symbol ┃ right ┃ result ┃ - ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ - │ int64 │ string │ int64 │ float64 │ - ├───────┼────────┼───────┼─────────┤ - │ 1 │ + │ 5 │ 6.0 │ - │ 2 │ - │ 6 │ -4.0 │ - │ 3 │ * │ 7 │ 21.0 │ - │ 4 │ / │ 8 │ 0.5 │ - └───────┴────────┴───────┴─────────┘ - + >>> v = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}).values + >>> ibis.cases((v == 1, "a"), (v > 2, "b"), else_="unk").name("cases") + ┏━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━┩ + │ string │ + ├────────┤ + │ a │ + │ unk │ + │ a │ + │ unk │ + │ b │ + │ unk │ + │ b │ + └────────┘ + >>> ibis.cases( + ... (v % 2 == 0, "divisible by 2"), + ... (v % 3 == 0, "divisible by 3"), + ... (v % 4 == 0, "shadowed by the 2 case"), + ... ).name("cases") + ┏━━━━━━━━━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━━━━━━━━━┩ + │ string │ + ├────────────────┤ + │ NULL │ + │ divisible by 2 │ + │ NULL │ + │ divisible by 2 │ + │ divisible by 3 │ + │ divisible by 2 │ + │ divisible by 2 │ + └────────────────┘ """ - return bl.SearchedCaseBuilder() + for b in branches: + try: + condition, result = b + except (TypeError, ValueError) as e: + raise ValueError( + "Each branch must be a tuple of (condition, result)" + ) from e + cases, results = zip(*branches) if branches else ([], []) + return ops.SearchedCase(cases=cases, results=results, default=else_).to_expr() def now() -> ir.TimestampScalar: diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index 43452fa0c09ac..078d015cd9dc3 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -298,16 +298,12 @@ def ifelse(op, bool_expr, true_expr, false_null_expr): @translate.register(ops.SimpleCase) @translate.register(ops.SearchedCase) -def switch_case(op, cases, results, default, base=None): - out = f"{base}.case()" if base else "ibis.case()" - - for case, result in zip(cases, results): - out = f"{out}.when({case}, {result})" - - if default is not None: - out = f"{out}.else_({default})" - - return f"{out}.end()" +def switch_cases(op, cases, results, default, base=None): + namespace = f"{base}" if base else "ibis" + case_strs = [f"({case}, {result})" for case, result in zip(cases, results)] + cases_str = ", ".join(case_strs) + else_str = f", else_={default}" if default is not None else "" + return f"{namespace}.cases({cases_str}{else_str})" _infix_ops = { diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index 754029e292b52..11f6b8bfa0258 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -281,11 +281,24 @@ class SimpleCase(Value): results: VarTuple[Value] default: Value - shape = rlz.shape_like("base") - - def __init__(self, cases, results, **kwargs): + def __init__(self, base, cases, results, default): assert len(cases) == len(results) - super().__init__(cases=cases, results=results, **kwargs) + + for case in cases: + if not rlz.comparable(base, case): + raise TypeError( + f"Base expression {rlz._arg_type_error_format(base)} and " + f"case {rlz._arg_type_error_format(case)} are not comparable" + ) + + if default.dtype.is_null() and results: + default = Cast(default, rlz.highest_precedence_dtype(results)) + super().__init__(base=base, cases=cases, results=results, default=default) + + @attribute + def shape(self): + exprs = [self.base, *self.cases, *self.results, self.default] + return rlz.highest_precedence_shape(exprs) @attribute def dtype(self): @@ -301,14 +314,14 @@ class SearchedCase(Value): def __init__(self, cases, results, default): assert len(cases) == len(results) - if default.dtype.is_null(): + if default.dtype.is_null() and results: default = Cast(default, rlz.highest_precedence_dtype(results)) super().__init__(cases=cases, results=results, default=default) @attribute def shape(self): - # TODO(kszucs): can be removed after making Sequence iterable - return rlz.highest_precedence_shape(self.cases) + exprs = [*self.cases, *self.results, self.default] + return rlz.highest_precedence_shape(exprs) @attribute def dtype(self): diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index 78a33de77e1cc..af99b99668582 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -137,9 +137,9 @@ def shape(self): @public class IfElse(Value): - """Ternary case expression, equivalent to. + """Ternary case expression. - bool_expr.case().when(True, true_expr).else_(false_or_null_expr) + Equivalent to bool_expr.cases((True, true_expr), else_=false_or_null_expr) Many backends implement this as a built-in function. """ diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index b9e26756a86f1..11bc1fc511c1f 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from public import public @@ -10,6 +10,7 @@ import ibis.expr.builders as bl import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis import util from ibis.common.deferred import Deferred, _, deferrable from ibis.common.grounds import Singleton from ibis.expr.rewrites import rewrite_window_input @@ -717,19 +718,16 @@ def substitute( └────────┴──────────────┘ """ if isinstance(value, dict): - expr = ibis.case() - try: - null_replacement = value.pop(None) - except KeyError: - pass - else: - expr = expr.when(self.isnull(), null_replacement) - for k, v in value.items(): - expr = expr.when(self == k, v) + branches = sorted(value.items()) else: - expr = self.case().when(value, replacement) - - return expr.else_(else_ if else_ is not None else self).end() + branches = [(value, replacement)] + nulls = [(k, v) for k, v in branches if k is None] + nonnulls = [(k, v) for k, v in branches if k is not None] + if nulls: + null_replacement = nulls[0][1] + self = self.fillna(null_replacement) + else_ = else_ if else_ is not None else self + return self.cases(*nonnulls, else_=else_) def over( self, @@ -865,99 +863,32 @@ def notnull(self) -> ir.BooleanValue: """ return ops.NotNull(self).to_expr() + @util.deprecated(instead="use Value.cases() instead", as_of="9.0") def case(self) -> bl.SimpleCaseBuilder: - """Create a SimpleCaseBuilder to chain multiple if-else statements. - - Add new search expressions with the `.when()` method. These must be - comparable with this column expression. Conclude by calling `.end()`. - - Returns - ------- - SimpleCaseBuilder - A case builder - - See Also - -------- - [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) - - Examples - -------- - >>> import ibis - >>> ibis.options.interactive = True - >>> x = ibis.examples.penguins.fetch().head(5)["sex"] - >>> x - ┏━━━━━━━━┓ - ┃ sex ┃ - ┡━━━━━━━━┩ - │ string │ - ├────────┤ - │ male │ - │ female │ - │ female │ - │ NULL │ - │ female │ - └────────┘ - >>> x.case().when("male", "M").when("female", "F").else_("U").end() - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, 'U') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├──────────────────────┤ - │ M │ - │ F │ - │ F │ - │ U │ - │ F │ - └──────────────────────┘ - - Cases not given result in the ELSE case - - >>> x.case().when("male", "M").else_("OTHER").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, 'OTHER') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├──────────────────────────┤ - │ M │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - └──────────────────────────┘ - - If you don't supply an ELSE, then NULL is used - - >>> x.case().when("male", "M").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, Cast(None, string)) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├─────────────────────────────────────┤ - │ M │ - │ NULL │ - │ NULL │ - │ NULL │ - │ NULL │ - └─────────────────────────────────────┘ - """ - import ibis.expr.builders as bl - + """DEPRECATED: Use `self.cases()` instead.""" return bl.SimpleCaseBuilder(self.op()) def cases( self, - case_result_pairs: Iterable[tuple[ir.BooleanValue, Value]], - default: Value | None = None, + *branches: tuple[Value, Value], + else_: Value | None = None, ) -> Value: - """Create a case expression in one shot. + """Create a multi-branch if-else expression. + + This is semantically equivalent to + CASE self + WHEN test_val0 THEN result0 + WHEN test_val1 THEN result1 + ELSE else_ + END Parameters ---------- - case_result_pairs - Conditional-result pairs - default + branches + (test_val, result) pairs. We look through the test values in order + and return the result corresponding to the first test value that + matches `self`. If none match, we return `else_`. + else_ Value to return if none of the case conditions are true Returns @@ -968,48 +899,59 @@ def cases( See Also -------- [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) + [`ibis.cases()`](./expression-generic.qmd#ibis.cases) Examples -------- >>> import ibis >>> ibis.options.interactive = True - >>> t = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}) - >>> t - ┏━━━━━━━━┓ - ┃ values ┃ - ┡━━━━━━━━┩ - │ int64 │ - ├────────┤ - │ 1 │ - │ 2 │ - │ 1 │ - │ 2 │ - │ 3 │ - │ 2 │ - │ 4 │ - └────────┘ - >>> number_letter_map = ((1, "a"), (2, "b"), (3, "c")) - >>> t.values.cases(number_letter_map, default="unk").name("replace") - ┏━━━━━━━━━┓ - ┃ replace ┃ - ┡━━━━━━━━━┩ - │ string │ - ├─────────┤ - │ a │ - │ b │ - │ a │ - │ b │ - │ c │ - │ b │ - │ unk │ - └─────────┘ + >>> t = ibis.memtable( + ... { + ... "left": [5, 6, 7, 8, 9, 10], + ... "symbol": ["+", "-", "*", "/", "bogus", None], + ... "right": [1, 2, 3, 4, 5, 6], + ... } + ... ) + + Note we never hit the `None` case, because `x = NULL` is always NULL, + which is not truthy. If you want to replace NULLs, you should use + `.fillna(-999)` prior to `cases()`. + + >>> t.mutate( + ... result=( + ... t.symbol.cases( + ... ("+", t.left + t.right), + ... ("-", t.left - t.right), + ... ("*", t.left * t.right), + ... ("/", t.left / t.right), + ... (None, -999), + ... ) + ... ) + ... ) + ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ + ┃ left ┃ symbol ┃ right ┃ result ┃ + ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ + │ int64 │ string │ int64 │ float64 │ + ├───────┼────────┼───────┼─────────┤ + │ 5 │ + │ 1 │ 6.0 │ + │ 6 │ - │ 2 │ 4.0 │ + │ 7 │ * │ 3 │ 21.0 │ + │ 8 │ / │ 4 │ 2.0 │ + │ 9 │ bogus │ 5 │ NULL │ + │ 10 │ NULL │ 6 │ NULL │ + └───────┴────────┴───────┴─────────┘ """ - builder = self.case() - for case, result in case_result_pairs: - builder = builder.when(case, result) - return builder.else_(default).end() + for b in branches: + try: + test, result = b + except (TypeError, ValueError) as e: + raise ValueError( + "Each branch must be a tuple of (condition, result)" + ) from e + cases, results = zip(*branches) if branches else ([], []) + return ops.SimpleCase( + base=self, cases=cases, results=results, default=else_ + ).to_expr() def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 56d5c2dea178c..b7396c576b833 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools from typing import TYPE_CHECKING, Literal from public import public @@ -1143,13 +1142,7 @@ def label(self, labels: Iterable[str], nulls: str | None = None) -> ir.StringVal │ 2 │ c │ └───────┴─────────┘ """ - return ( - functools.reduce( - lambda stmt, inputs: stmt.when(*inputs), enumerate(labels), self.case() - ) - .else_(nulls) - .end() - ) + return self.cases(*enumerate(labels), else_=nulls) @public diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index c0cd7cb6e1d03..cb25cc80fcca1 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -2872,9 +2872,7 @@ def info(self) -> Table: for pos, colname in enumerate(self.columns): col = self[colname] typ = col.type() - agg = self.select( - isna=ibis.case().when(col.isnull(), 1).else_(0).end() - ).agg( + agg = self.select(isna=ibis.cases((col.isnull(), 1), else_=0)).agg( name=lit(colname), type=lit(str(typ)), nullable=lit(typ.nullable), diff --git a/ibis/tests/expr/test_case.py b/ibis/tests/expr/test_case.py index 89e5b3b7df7b4..c8efe09590a30 100644 --- a/ibis/tests/expr/test_case.py +++ b/ibis/tests/expr/test_case.py @@ -1,11 +1,14 @@ from __future__ import annotations +import pytest + import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import _ -from ibis.tests.util import assert_equal, assert_pickle_roundtrip +from ibis.common.annotations import SignatureValidationError +from ibis.tests.util import assert_pickle_roundtrip def test_ifelse_method(table): @@ -44,72 +47,63 @@ def test_ifelse_function_deferred(table): assert res.equals(sol) -def test_simple_case_expr(table): - case1, result1 = "foo", table.a - case2, result2 = "bar", table.c - default_result = table.b - - expr1 = table.g.lower().cases( - [(case1, result1), (case2, result2)], default=default_result - ) - - expr2 = ( - table.g.lower() - .case() - .when(case1, result1) - .when(case2, result2) - .else_(default_result) - .end() - ) +def test_err_on_bad_args(table): + with pytest.raises(ValueError): + ibis.cases((True,)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3), 5) - assert_equal(expr1, expr2) - assert isinstance(expr1, ir.IntegerColumn) + with pytest.raises(ValueError): + table.a.cases(("foo",)) + with pytest.raises(ValueError): + table.a.cases(("foo", 3, 4)) + with pytest.raises(ValueError): + table.a.cases(("foo", 3, 4)) + with pytest.raises(ValueError): + table.a.cases(("foo", 3), 5) def test_multiple_case_expr(table): - expr = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(table.d) - .end() + expr = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=table.d, ) # deferred cases - deferred = ( - ibis.case() - .when(_.a == 5, table.f) - .when(_.b == 128, table.b * 2) - .when(_.c == 1000, table.e) - .else_(table.d) - .end() + deferred = ibis.cases( + (_.a == 5, table.f), + (_.b == 128, table.b * 2), + (_.c == 1000, table.e), + else_=table.d, ) expr2 = deferred.resolve(table) # deferred results - expr3 = ( - ibis.case() - .when(table.a == 5, _.f) - .when(table.b == 128, _.b * 2) - .when(table.c == 1000, _.e) - .else_(table.d) - .end() - .resolve(table) - ) + expr3 = ibis.cases( + (table.a == 5, _.f), + (table.b == 128, _.b * 2), + (table.c == 1000, _.e), + else_=table.d, + ).resolve(table) # deferred default - expr4 = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(_.d) - .end() - .resolve(table) + expr4 = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=_.d, + ).resolve(table) + + assert ( + repr(deferred) + == "cases(((_.a == 5), ), ((_.b == 128), ), ((_.c == 1000), ), else_=)" ) - - assert repr(deferred) == "" assert expr.equals(expr2) assert expr.equals(expr3) assert expr.equals(expr4) @@ -130,13 +124,11 @@ def test_pickle_multiple_case_node(table): result3 = table.e default = table.d - expr = ( - ibis.case() - .when(case1, result1) - .when(case2, result2) - .when(case3, result3) - .else_(default) - .end() + expr = ibis.cases( + (case1, result1), + (case2, result2), + (case3, result3), + else_=default, ) op = expr.op() @@ -144,7 +136,7 @@ def test_pickle_multiple_case_node(table): def test_simple_case_null_else(table): - expr = table.g.case().when("foo", "bar").end() + expr = table.g.cases(("foo", "bar")) op = expr.op() assert isinstance(expr, ir.StringColumn) @@ -154,8 +146,8 @@ def test_simple_case_null_else(table): def test_multiple_case_null_else(table): - expr = ibis.case().when(table.g == "foo", "bar").end() - expr2 = ibis.case().when(table.g == "foo", _).end().resolve("bar") + expr = ibis.cases((table.g == "foo", "bar")) + expr2 = ibis.cases((table.g == "foo", _)).resolve("bar") assert expr.equals(expr2) @@ -172,8 +164,61 @@ def test_case_mixed_type(): name="my_data", ) - expr = ( - t0.three.case().when(0, "low").when(1, "high").else_("null").end().name("label") - ) + expr = t0.three.cases((0, "low"), (1, "high"), else_="null").name("label") result = t0[expr] assert result["label"].type().equals(dt.string) + + +def test_err_on_nonbool(table): + with pytest.raises(SignatureValidationError): + ibis.cases((table.a, "bar"), else_="baz") + + +@pytest.mark.xfail(reason="Literal('foo', type=bool), should error, but doesn't") +def test_err_on_nonbool2(): + with pytest.raises(SignatureValidationError): + ibis.cases(("foo", "bar"), else_="baz") + + +def test_err_on_noncomparable(table): + table.a.cases((8, "bar")) + table.a.cases((-8, "bar")) + # Can't compare an int to a string + with pytest.raises(TypeError): + table.a.cases(("foo", "bar")) + + +def test_empty_cases(table): + ibis.cases() + ibis.cases(else_=42) + table.a.cases() + table.a.cases(else_=42) + + +def test_dtype(): + assert isinstance(ibis.cases((True, "bar"), (False, "bar")), ir.StringValue) + assert isinstance(ibis.cases((True, None), else_="bar"), ir.StringValue) + with pytest.raises(TypeError): + assert ibis.cases((True, 5), (False, "bar")) + with pytest.raises(TypeError): + assert ibis.cases((True, 5), else_="bar") + + +def test_dshape(table): + assert isinstance(ibis.cases((True, "bar"), (False, "bar")), ir.Scalar) + assert isinstance(ibis.cases((True, None), else_="bar"), ir.Scalar) + assert isinstance(ibis.cases((table.b == 9, None), else_="bar"), ir.Column) + assert isinstance(ibis.cases((True, table.a), else_=42), ir.Column) + assert isinstance(ibis.cases((True, 42), else_=table.a), ir.Column) + assert isinstance(ibis.cases((True, table.a), else_=table.b), ir.Column) + + assert isinstance(ibis.literal(5).cases((9, 42)), ir.Scalar) + assert isinstance(ibis.literal(5).cases((9, 42), else_=43), ir.Scalar) + assert isinstance(ibis.literal(5).cases((table.a, 42)), ir.Column) + assert isinstance(ibis.literal(5).cases((9, table.a)), ir.Column) + assert isinstance(ibis.literal(5).cases((table.a, table.b)), ir.Column) + assert isinstance(ibis.literal(5).cases((9, 42), else_=table.a), ir.Column) + assert isinstance(table.a.cases((9, 42)), ir.Column) + assert isinstance(table.a.cases((table.b, 42)), ir.Column) + assert isinstance(table.a.cases((9, table.b)), ir.Column) + assert isinstance(table.a.cases((table.a, table.b)), ir.Column) diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 5ded2434bfbfc..2e564b833af7e 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -822,23 +822,11 @@ def test_substitute_dict(): subs = {"a": "one", "b": table.bar} result = table.foo.substitute(subs) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(table.foo) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=table.foo) assert_equal(result, expected) result = table.foo.substitute(subs, else_=ibis.NA) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(ibis.NA) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=ibis.NA) assert_equal(result, expected)