diff --git a/docs/posts/ci-analysis/index.qmd b/docs/posts/ci-analysis/index.qmd index 5babc2c6d0c61..65159d5bbe9c7 100644 --- a/docs/posts/ci-analysis/index.qmd +++ b/docs/posts/ci-analysis/index.qmd @@ -203,14 +203,12 @@ Let's also give them some names that'll look nice on our plots. stats = stats.mutate( raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int") ).mutate( - improvements=( - _.raw_improvements.case() - .when(0, "None") - .when(1, "Poetry") - .when(2, "Poetry + Team Plan") - .else_("NA") - .end() - ), + improvements=_.raw_improvements.cases( + (0, "None"), + (1, "Poetry"), + (2, "Poetry + Team Plan"), + else_="NA", + ) team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"), ) stats diff --git a/docs/tutorials/ibis-for-sql-users.qmd b/docs/tutorials/ibis-for-sql-users.qmd index 577f7b015111e..537742b876957 100644 --- a/docs/tutorials/ibis-for-sql-users.qmd +++ b/docs/tutorials/ibis-for-sql-users.qmd @@ -473,11 +473,11 @@ semantics: case = ( t.one.cast("timestamp") .year() - .case() - .when(2015, "This year") - .when(2014, "Last year") - .else_("Earlier") - .end() + .cases( + (2015, "This year"), + (2014, "Last year"), + else_="Earlier", + ) ) expr = t.mutate(year_group=case) @@ -496,18 +496,16 @@ CASE END ``` -To do this, use `ibis.case`: +To do this, use `ibis.cases`: ```{python} -case = ( - ibis.case() - .when(t.two < 0, t.three * 2) - .when(t.two > 1, t.three) - .else_(t.two) - .end() +cases = ibis.cases( + (t.two < 0, t.three * 2), + (t.two > 1, t.three), + else_=t.two, ) -expr = t.mutate(cond_value=case) +expr = t.mutate(cond_value=cases) ibis.to_sql(expr) ``` diff --git a/ibis/backends/clickhouse/tests/test_operators.py b/ibis/backends/clickhouse/tests/test_operators.py index 4ca53a3d2b9f3..3ff07ce916a45 100644 --- a/ibis/backends/clickhouse/tests/test_operators.py +++ b/ibis/backends/clickhouse/tests/test_operators.py @@ -201,9 +201,7 @@ def test_ifelse(alltypes, df, op, pandas_op): def test_simple_case(con, alltypes, assert_sql): t = alltypes - expr = ( - t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end() - ) + expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default") assert_sql(expr) assert len(con.execute(expr)) @@ -211,12 +209,10 @@ def test_simple_case(con, alltypes, assert_sql): def test_search_case(con, alltypes, assert_sql): t = alltypes - expr = ( - ibis.case() - .when(t.float_col > 0, t.int_col * 2) - .when(t.float_col < 0, t.int_col) - .else_(0) - .end() + expr = ibis.cases( + (t.float_col > 0, t.int_col * 2), + (t.float_col < 0, t.int_col), + else_=0, ) assert_sql(expr) diff --git a/ibis/backends/dask/tests/test_operations.py b/ibis/backends/dask/tests/test_operations.py index e43e5af454933..cf6bd9a9eb040 100644 --- a/ibis/backends/dask/tests/test_operations.py +++ b/ibis/backends/dask/tests/test_operations.py @@ -773,64 +773,6 @@ def q_fun(x, quantile): tm.assert_series_equal(result, expected, check_index=False) -def test_searched_case_scalar(client): - expr = ibis.case().when(True, 1).when(False, 2).end() - result = client.execute(expr) - expected = np.int8(1) - assert result == expected - - -def test_searched_case_column(batting, batting_pandas_df): - t = batting - df = batting_pandas_df - expr = ( - ibis.case() - .when(t.RBI < 5, "really bad team") - .when(t.teamID == "PH1", "ph1 team") - .else_(t.teamID) - .end() - ) - result = expr.execute() - expected = pd.Series( - np.select( - [df.RBI < 5, df.teamID == "PH1"], - ["really bad team", "ph1 team"], - df.teamID, - ) - ) - tm.assert_series_equal(result, expected, check_names=False) - - -def test_simple_case_scalar(client): - x = ibis.literal(2) - expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end() - result = client.execute(expr) - expected = np.int8(1) - assert result == expected - - -def test_simple_case_column(batting, batting_pandas_df): - t = batting - df = batting_pandas_df - expr = ( - t.RBI.case() - .when(5, "five") - .when(4, "four") - .when(3, "three") - .else_("could be good?") - .end() - ) - result = expr.execute() - expected = pd.Series( - np.select( - [df.RBI == 5, df.RBI == 4, df.RBI == 3], - ["five", "four", "three"], - "could be good?", - ) - ) - tm.assert_series_equal(result, expected, check_names=False) - - def test_table_distinct(t, df): expr = t[["dup_strings"]].distinct() result = expr.compile() diff --git a/ibis/backends/impala/tests/test_case_exprs.py b/ibis/backends/impala/tests/test_case_exprs.py index e23a9436c6fb1..e1a97df344b83 100644 --- a/ibis/backends/impala/tests/test_case_exprs.py +++ b/ibis/backends/impala/tests/test_case_exprs.py @@ -14,13 +14,13 @@ def table(mockcon): @pytest.fixture def simple_case(table): - return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture def search_case(table): t = table - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index 858f49173464b..937b7d3ed9554 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -167,6 +167,8 @@ def visit(cls, op: ops.IsNan, arg): def visit( cls, op: ops.SearchedCase | ops.SimpleCase, cases, results, default, base=None ): + if not cases: + return default if base is not None: cases = tuple(base == case for case in cases) cases, _ = cls.asframe(cases, concat=False) diff --git a/ibis/backends/pandas/tests/test_operations.py b/ibis/backends/pandas/tests/test_operations.py index 293fc008a50e9..8fa1338132083 100644 --- a/ibis/backends/pandas/tests/test_operations.py +++ b/ibis/backends/pandas/tests/test_operations.py @@ -683,73 +683,9 @@ def test_summary_non_numeric(batting, batting_df): assert dict(result.iloc[0]) == expected -def test_searched_case_scalar(client): - expr = ibis.case().when(True, 1).when(False, 2).end() - result = client.execute(expr) - expected = np.int8(1) - assert result == expected - - -def test_searched_case_column(batting, batting_df): - t = batting - df = batting_df - expr = ( - ibis.case() - .when(t.RBI < 5, "really bad team") - .when(t.teamID == "PH1", "ph1 team") - .else_(t.teamID) - .end() - ) - result = expr.execute() - expected = pd.Series( - np.select( - [df.RBI < 5, df.teamID == "PH1"], - ["really bad team", "ph1 team"], - df.teamID, - ) - ) - tm.assert_series_equal(result, expected) - - -def test_simple_case_scalar(client): - x = ibis.literal(2) - expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end() - result = client.execute(expr) - expected = np.int8(1) - assert result == expected - - -def test_simple_case_column(batting, batting_df): - t = batting - df = batting_df - expr = ( - t.RBI.case() - .when(5, "five") - .when(4, "four") - .when(3, "three") - .else_("could be good?") - .end() - ) - result = expr.execute() - expected = pd.Series( - np.select( - [df.RBI == 5, df.RBI == 4, df.RBI == 3], - ["five", "four", "three"], - "could be good?", - ) - ) - tm.assert_series_equal(result, expected) - - def test_non_range_index(): def do_replace(col): - return col.cases( - ( - (1, "one"), - (2, "two"), - ), - default="unk", - ) + return col.cases((1, "one"), (2, "two"), else_="unk") df = pd.DataFrame( { diff --git a/ibis/backends/snowflake/tests/test_udf.py b/ibis/backends/snowflake/tests/test_udf.py index 9d30ae92b6864..0e34b7df2ff2e 100644 --- a/ibis/backends/snowflake/tests/test_udf.py +++ b/ibis/backends/snowflake/tests/test_udf.py @@ -8,7 +8,6 @@ import pytest from pytest import param -import ibis import ibis.expr.datatypes as dt from ibis import udf @@ -115,36 +114,23 @@ def predict_price( df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"] return model.predict(df) - def cases(value, mapping): - """This should really be a top-level function or method.""" - expr = ibis.case() - for k, v in mapping.items(): - expr = expr.when(value == k, v) - return expr.end() - diamonds = con.tables.DIAMONDS expr = diamonds.mutate( predicted_price=predict_price( (_.carat - _.carat.mean()) / _.carat.std(), - cases( - _.cut, - { - c: i - for i, c in enumerate( - ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 - ) - }, + _.cut.cases( + (c, i) + for i, c in enumerate( + ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 + ) ), - cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}), - cases( - _.clarity, - { - c: i - for i, c in enumerate( - ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), - start=1, - ) - }, + _.color.cases((c, i) for i, c in enumerate("DEFGHIJ", start=1)), + _.clarity.cases( + (c, i) + for i, c in enumerate( + ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), + start=1, + ) ), ) ) diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index ced292f2e107f..eb1b8725e8c78 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -935,6 +935,8 @@ def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw): ) def visit_SimpleCase(self, op, *, base=None, cases, results, default): + if not cases: + return default return sge.Case( this=base, ifs=list(map(self.if_, cases, results)), default=default ) diff --git a/ibis/backends/tests/sql/conftest.py b/ibis/backends/tests/sql/conftest.py index 04667e60e033b..06de1c83c8c08 100644 --- a/ibis/backends/tests/sql/conftest.py +++ b/ibis/backends/tests/sql/conftest.py @@ -164,13 +164,13 @@ def difference(con): @pytest.fixture(scope="module") def simple_case(con): t = con.table("alltypes") - return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture(scope="module") def search_case(con): t = con.table("alltypes") - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture(scope="module") diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py index 6058efaa962e6..1b1dcf62dca62 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py @@ -22,18 +22,14 @@ lit2 = ibis.literal("bar") result = alltypes.select( - alltypes.g.case() - .when(lit, lit2) - .when(lit1, ibis.literal("qux")) - .else_(ibis.literal("default")) - .end() - .name("col1"), - ibis.case() - .when(alltypes.g == lit, lit2) - .when(alltypes.g == lit1, alltypes.g) - .else_(ibis.literal(None).cast("string")) - .end() - .name("col2"), + alltypes.g.cases( + (lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default") + ).name("col1"), + ibis.cases( + (alltypes.g == lit, lit2), + (alltypes.g == lit1, alltypes.g), + else_=ibis.literal(None).cast("string"), + ).name("col2"), alltypes.a, alltypes.b, alltypes.c, diff --git a/ibis/backends/tests/sql/test_select_sql.py b/ibis/backends/tests/sql/test_select_sql.py index 94a52017f763f..24893739fb6eb 100644 --- a/ibis/backends/tests/sql/test_select_sql.py +++ b/ibis/backends/tests/sql/test_select_sql.py @@ -397,8 +397,8 @@ def test_bool_bool(snapshot): def test_case_in_projection(alltypes, snapshot): t = alltypes - expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() - expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end() + expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default")) + expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g)) expr = t[expr.name("col1"), expr2.name("col2"), t] snapshot.assert_match(to_sql(expr), "out.sql") diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index c78508d9d0e61..bef4107c0f1ea 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -680,7 +680,7 @@ def test_arbitrary(backend, alltypes, df, filtered): # _something_ we create a column that is a mix of nulls and a single value # (or a single value after filtering is applied). if filtered: - new = alltypes.int_col.cases([(3, 30), (4, 40)]) + new = alltypes.int_col.cases((3, 30), (4, 40)) where = _.int_col == 3 else: new = (alltypes.int_col == 3).ifelse(30, None) @@ -1428,9 +1428,7 @@ def collect_udf(v): def test_binds_are_cast(alltypes): expr = alltypes.aggregate( - high_line_count=( - alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum() - ) + high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum() ) expr.execute() @@ -1476,7 +1474,7 @@ def test_agg_name_in_output_column(alltypes): def test_grouped_case(backend, con): table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]}) - case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end() + case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null()) expr = ( table.group_by(k="key").aggregate(mx=case_expr.max()).dropna("k").order_by("k") diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 23123e4736778..56b7d71786622 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -18,6 +18,7 @@ import ibis.selectors as s from ibis import _ from ibis.backends.conftest import is_newer_than, is_older_than +from ibis.backends.pandas.tests.conftest import TestConf as tm from ibis.backends.tests.errors import ( ClickHouseDatabaseError, ExaQueryError, @@ -356,12 +357,11 @@ def test_case_where(backend, alltypes, df): table = alltypes table = table.mutate( new_col=( - ibis.case() - .when(table["int_col"] == 1, 20) - .when(table["int_col"] == 0, 10) - .else_(0) - .end() - .cast("int64") + ibis.cases( + (table["int_col"] == 1, 20), + (table["int_col"] == 0, 10), + else_=0, + ).cast("int64") ) ) @@ -394,9 +394,7 @@ def test_select_filter_mutate(backend, alltypes, df): # Prepare the float_col so that filter must execute # before the cast to get the correct result. - t = t.mutate( - float_col=ibis.case().when(t["bool_col"], t["float_col"]).else_(np.nan).end() - ) + t = t.mutate(float_col=ibis.cases((t["bool_col"], t["float_col"]), else_=np.nan)) # Actual test t = t[t.columns] @@ -2025,7 +2023,32 @@ def test_sample_with_seed(backend): ), ], ) -def test_value_cases(con, inp, exp): +def test_value_cases_deprecated(con, inp, exp): + with pytest.warns(FutureWarning): + i = inp() + result = con.execute(i) + if exp is None: + assert pd.isna(result) + else: + assert result == exp + + +@pytest.mark.parametrize( + "inp, exp", + [ + pytest.param( + lambda: ibis.literal(1).cases((1, "one"), (2, "two"), else_="other"), + "one", + id="one_kwarg", + ), + pytest.param( + lambda: ibis.literal(5).cases((1, "one"), (2, "two")), + None, + id="fallthrough", + ), + ], +) +def test_value_cases_scalar(con, inp, exp): result = con.execute(inp()) if exp is None: assert pd.isna(result) @@ -2033,6 +2056,46 @@ def test_value_cases(con, inp, exp): assert result == exp +def test_value_cases_column(batting): + df = batting.to_pandas() + expr = batting.RBI.cases( + (5, "five"), (4, "four"), (3, "three"), else_="could be good?" + ) + result = expr.execute() + expected = pd.Series( + np.select( + [df.RBI == 5, df.RBI == 4, df.RBI == 3], + ["five", "four", "three"], + "could be good?", + ) + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.broken( + "sqlite", + reason="the int64 RBI column is .to_pandas()ed to an object column, which is incomparable to 5", + raises=TypeError, +) +def test_ibis_cases_column(batting): + t = batting + df = batting.to_pandas() + expr = ibis.cases( + (t.RBI < 5, "really bad team"), + (t.teamID == "PH1", "ph1 team"), + else_=t.teamID, + ) + result = expr.execute() + expected = pd.Series( + np.select( + [df.RBI < 5, df.teamID == "PH1"], + ["really bad team", "ph1 team"], + df.teamID, + ) + ) + tm.assert_series_equal(result, expected) + + def test_substitute(backend): val = "400" t = backend.functional_alltypes @@ -2046,6 +2109,37 @@ def test_substitute(backend): assert expr["subs_count"].execute()[0] == t.count().execute() // 10 +@pytest.mark.broken("flink", reason="can't handle SELECT NULL AS `SearchedCase(None)`") +def test_ibis_cases_empty(con): + assert pd.isnull(con.execute(ibis.cases())) + assert con.execute(ibis.cases(else_=2)) == 2 + assert pd.isnull(con.execute(ibis.literal(1).cases())) + assert con.execute(ibis.literal(1).cases(else_=2)) == 2 + + +@pytest.mark.broken("clickhouse", reason="special case this and returns 'oops'") +def test_value_cases_null(con): + """CASE x WHEN NULL never gets hit""" + e = ibis.literal(5).nullif(5).cases((None, "oops"), else_="expected") + assert con.execute(e) == "expected" + + +@pytest.mark.broken("pyspark", reason="raises a ResourceWarning that we can't catch") +def test_case(con): + # just to make sure that the deprecated .case() method still works + with pytest.warns(FutureWarning, match=".cases"): + assert con.execute(ibis.case().when(True, "yes").end()) == "yes" + assert pd.isna(con.execute(ibis.case().when(False, "yes").end())) + assert con.execute(ibis.case().when(False, "yes").else_("no").end()) == "no" + + assert con.execute(ibis.literal("a").case().when("a", "yes").end()) == "yes" + assert pd.isna(con.execute(ibis.literal("a").case().when("b", "yes").end())) + assert ( + con.execute(ibis.literal("a").case().when("b", "yes").else_("no").end()) + == "no" + ) + + @pytest.mark.notimpl( ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" ) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 097bbb9cb4577..8f78d3560f899 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -65,16 +65,16 @@ def test_group_by_has_index(backend, snapshot): ) expr = countries.group_by( cont=( - _.continent.case() - .when("NA", "North America") - .when("SA", "South America") - .when("EU", "Europe") - .when("AF", "Africa") - .when("AS", "Asia") - .when("OC", "Oceania") - .when("AN", "Antarctica") - .else_("Unknown continent") - .end() + _.continent.cases( + ("NA", "North America"), + ("SA", "South America"), + ("EU", "Europe"), + ("AF", "Africa"), + ("AS", "Asia"), + ("OC", "Oceania"), + ("AN", "Antarctica"), + else_="Unknown continent", + ) ) ).agg(total_pop=_.population.sum()) sql = str(ibis.to_sql(expr, dialect=backend.name())) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 595129cb9091a..17d6cfae496d3 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -508,9 +508,9 @@ def uses_java_re(t): id="length", ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").startswith( - "abc" - ), + lambda t: t.int_col.cases( + (1, "abcd"), (2, "ABCD"), else_="dabc" + ).startswith("abc"), lambda t: t.int_col == 1, id="startswith", marks=[ @@ -518,7 +518,7 @@ def uses_java_re(t): ], ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").endswith( + lambda t: t.int_col.cases((1, "abcd"), (2, "ABCD"), else_="dabc").endswith( "bcd" ), lambda t: t.int_col == 1, @@ -694,11 +694,9 @@ def test_re_replace_global(con): @pytest.mark.notimpl(["druid"], raises=ValidationError) def test_substr_with_null_values(backend, alltypes, df): table = alltypes.mutate( - substr_col_null=ibis.case() - .when(alltypes["bool_col"], alltypes["string_col"]) - .else_(None) - .end() - .substr(0, 2) + substr_col_null=ibis.cases( + (alltypes["bool_col"], alltypes["string_col"]), else_=None + ).substr(0, 2) ) result = table.execute() @@ -919,7 +917,7 @@ def test_levenshtein(con, right): @pytest.mark.parametrize( "expr", [ - param(ibis.case().when(True, "%").end(), id="case"), + param(ibis.cases((True, "%")), id="case"), param(ibis.ifelse(True, "%", ibis.NA), id="ifelse"), ], ) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index e7f078f7591a1..2101dff274779 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -149,7 +149,7 @@ def test_collect_into_struct(alltypes): @pytest.mark.notimpl(["flink"], raises=Py4JJavaError, reason="not implemented in ibis") def test_field_access_after_case(con): s = ibis.struct({"a": 3}) - x = ibis.case().when(True, s).else_(ibis.struct({"a": 4})).end() + x = ibis.cases((True, s), else_=ibis.struct({"a": 4})) y = x.a assert con.to_pandas(y) == 3 diff --git a/ibis/backends/tests/tpch/test_h08.py b/ibis/backends/tests/tpch/test_h08.py index 651a725b1fc8a..5ffa2f2c53101 100644 --- a/ibis/backends/tests/tpch/test_h08.py +++ b/ibis/backends/tests/tpch/test_h08.py @@ -42,9 +42,7 @@ def test_tpc_h08(part, supplier, region, lineitem, orders, customer, nation): ] ) - q = q.mutate( - nation_volume=ibis.case().when(q.nation == NATION, q.volume).else_(0).end() - ) + q = q.mutate(nation_volume=ibis.cases((q.nation == NATION, q.volume), else_=0)) gq = q.group_by([q.o_year]) q = gq.aggregate(mkt_share=q.nation_volume.sum() / q.volume.sum()) q = q.order_by([q.o_year]) diff --git a/ibis/backends/tests/tpch/test_h12.py b/ibis/backends/tests/tpch/test_h12.py index 0fb092beb8e31..e56efb9507f03 100644 --- a/ibis/backends/tests/tpch/test_h12.py +++ b/ibis/backends/tests/tpch/test_h12.py @@ -32,18 +32,10 @@ def test_tpc_h12(orders, lineitem): gq = q.group_by([q.l_shipmode]) q = gq.aggregate( high_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 1) - .when("2-HIGH", 1) - .else_(0) - .end() + q.o_orderpriority.cases(("1-URGENT", 1), ("2-HIGH", 1), else_=0) ).sum(), low_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 0) - .when("2-HIGH", 0) - .else_(1) - .end() + q.o_orderpriority.cases(("1-URGENT", 0), ("2-HIGH", 0), else_=1) ).sum(), ) q = q.order_by(q.l_shipmode) diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 2dd766ac2bc8b..02a0b4695164e 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -68,6 +68,7 @@ "array", "asc", "case", + "cases", "coalesce", "connect", "cross_join", @@ -1105,56 +1106,76 @@ def interval( return functools.reduce(operator.add, intervals) +@util.deprecated(instead="use ibis.cases() instead", as_of="9.1") def case() -> bl.SearchedCaseBuilder: - """Begin constructing a case expression. + """DEPRECATED: Use `ibis.cases()` instead.""" + return bl.SearchedCaseBuilder() + + +@deferrable +def cases(*branches: tuple[Any, Any], else_: Any | None = None) -> ir.Value: + """Create a multi-branch if-else expression. - Use the `.when` method on the resulting object followed by `.end` to create a - complete case expression. + Goes through each (condition, value) pair in `branches`, finding the + first condition that evaluates to True, and returns the corresponding + value. If no condition is True, returns `else_`. Returns ------- - SearchedCaseBuilder - A builder object to use for constructing a case expression. + Value + A value expression See Also -------- - [`Value.case()`](./expression-generic.qmd#ibis.expr.types.generic.Value.case) + [`Value.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) Examples -------- >>> import ibis - >>> from ibis import _ >>> ibis.options.interactive = True - >>> t = ibis.memtable( - ... { - ... "left": [1, 2, 3, 4], - ... "symbol": ["+", "-", "*", "/"], - ... "right": [5, 6, 7, 8], - ... } - ... ) - >>> t.mutate( - ... result=( - ... ibis.case() - ... .when(_.symbol == "+", _.left + _.right) - ... .when(_.symbol == "-", _.left - _.right) - ... .when(_.symbol == "*", _.left * _.right) - ... .when(_.symbol == "/", _.left / _.right) - ... .end() - ... ) - ... ) - ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ - ┃ left ┃ symbol ┃ right ┃ result ┃ - ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ - │ int64 │ string │ int64 │ float64 │ - ├───────┼────────┼───────┼─────────┤ - │ 1 │ + │ 5 │ 6.0 │ - │ 2 │ - │ 6 │ -4.0 │ - │ 3 │ * │ 7 │ 21.0 │ - │ 4 │ / │ 8 │ 0.5 │ - └───────┴────────┴───────┴─────────┘ - + >>> v = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}).values + >>> ibis.cases((v == 1, "a"), (v > 2, "b"), else_="unk").name("cases") + ┏━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━┩ + │ string │ + ├────────┤ + │ a │ + │ unk │ + │ a │ + │ unk │ + │ b │ + │ unk │ + │ b │ + └────────┘ + >>> ibis.cases( + ... (v % 2 == 0, "divisible by 2"), + ... (v % 3 == 0, "divisible by 3"), + ... (v % 4 == 0, "shadowed by the 2 case"), + ... ).name("cases") + ┏━━━━━━━━━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━━━━━━━━━┩ + │ string │ + ├────────────────┤ + │ NULL │ + │ divisible by 2 │ + │ NULL │ + │ divisible by 2 │ + │ divisible by 3 │ + │ divisible by 2 │ + │ divisible by 2 │ + └────────────────┘ """ - return bl.SearchedCaseBuilder() + for b in branches: + try: + condition, result = b + except (TypeError, ValueError) as e: + raise ValueError( + "Each branch must be a tuple of (condition, result)" + ) from e + cases, results = zip(*branches) if branches else ([], []) + return ops.SearchedCase(cases=cases, results=results, default=else_).to_expr() def now() -> ir.TimestampScalar: diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index c7ea1015975a6..1ed2cd769972a 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -298,16 +298,12 @@ def ifelse(op, bool_expr, true_expr, false_null_expr): @translate.register(ops.SimpleCase) @translate.register(ops.SearchedCase) -def switch_case(op, cases, results, default, base=None): - out = f"{base}.case()" if base else "ibis.case()" - - for case, result in zip(cases, results): - out = f"{out}.when({case}, {result})" - - if default is not None: - out = f"{out}.else_({default})" - - return f"{out}.end()" +def switch_cases(op, cases, results, default, base=None): + namespace = f"{base}" if base else "ibis" + case_strs = [f"({case}, {result})" for case, result in zip(cases, results)] + cases_str = ", ".join(case_strs) + else_str = f", else_={default}" if default is not None else "" + return f"{namespace}.cases({cases_str}{else_str})" _infix_ops = { diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index a33abd28bfae8..226baa6c269d4 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -275,11 +275,24 @@ class SimpleCase(Value): results: VarTuple[Value] default: Value - shape = rlz.shape_like("base") - - def __init__(self, cases, results, **kwargs): + def __init__(self, base, cases, results, default): assert len(cases) == len(results) - super().__init__(cases=cases, results=results, **kwargs) + + for case in cases: + if not rlz.comparable(base, case): + raise TypeError( + f"Base expression {rlz._arg_type_error_format(base)} and " + f"case {rlz._arg_type_error_format(case)} are not comparable" + ) + + if default.dtype.is_null() and results: + default = Cast(default, rlz.highest_precedence_dtype(results)) + super().__init__(base=base, cases=cases, results=results, default=default) + + @attribute + def shape(self): + exprs = [self.base, *self.cases, *self.results, self.default] + return rlz.highest_precedence_shape(exprs) @attribute def dtype(self): @@ -295,14 +308,14 @@ class SearchedCase(Value): def __init__(self, cases, results, default): assert len(cases) == len(results) - if default.dtype.is_null(): + if default.dtype.is_null() and results: default = Cast(default, rlz.highest_precedence_dtype(results)) super().__init__(cases=cases, results=results, default=default) @attribute def shape(self): - # TODO(kszucs): can be removed after making Sequence iterable - return rlz.highest_precedence_shape(self.cases) + exprs = [*self.cases, *self.results, self.default] + return rlz.highest_precedence_shape(exprs) @attribute def dtype(self): diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index 78a33de77e1cc..af99b99668582 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -137,9 +137,9 @@ def shape(self): @public class IfElse(Value): - """Ternary case expression, equivalent to. + """Ternary case expression. - bool_expr.case().when(True, true_expr).else_(false_or_null_expr) + Equivalent to bool_expr.cases((True, true_expr), else_=false_or_null_expr) Many backends implement this as a built-in function. """ diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 4a554e17c2739..f18547d9c90e7 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1,6 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +import warnings +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from public import public @@ -10,6 +11,7 @@ import ibis.expr.builders as bl import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis import util from ibis.common.deferred import Deferred, _, deferrable from ibis.common.grounds import Singleton from ibis.expr.rewrites import rewrite_window_input @@ -717,19 +719,16 @@ def substitute( └────────┴──────────────┘ """ if isinstance(value, dict): - expr = ibis.case() - try: - null_replacement = value.pop(None) - except KeyError: - pass - else: - expr = expr.when(self.isnull(), null_replacement) - for k, v in value.items(): - expr = expr.when(self == k, v) + branches = sorted(value.items()) else: - expr = self.case().when(value, replacement) - - return expr.else_(else_ if else_ is not None else self).end() + branches = [(value, replacement)] + nulls = [(k, v) for k, v in branches if k is None] + nonnulls = [(k, v) for k, v in branches if k is not None] + if nulls: + null_replacement = nulls[0][1] + self = self.fillna(null_replacement) + else_ = else_ if else_ is not None else self + return self.cases(*nonnulls, else_=else_) def over( self, @@ -865,99 +864,80 @@ def notnull(self) -> ir.BooleanValue: """ return ops.NotNull(self).to_expr() + @util.deprecated(instead="use Value.cases() instead", as_of="9.1") def case(self) -> bl.SimpleCaseBuilder: - """Create a SimpleCaseBuilder to chain multiple if-else statements. - - Add new search expressions with the `.when()` method. These must be - comparable with this column expression. Conclude by calling `.end()`. - - Returns - ------- - SimpleCaseBuilder - A case builder - - See Also - -------- - [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) + """DEPRECATED: Use `self.cases()` instead.""" + return bl.SimpleCaseBuilder(self.op()) - Examples - -------- - >>> import ibis - >>> ibis.options.interactive = True - >>> x = ibis.examples.penguins.fetch().head(5)["sex"] - >>> x - ┏━━━━━━━━┓ - ┃ sex ┃ - ┡━━━━━━━━┩ - │ string │ - ├────────┤ - │ male │ - │ female │ - │ female │ - │ NULL │ - │ female │ - └────────┘ - >>> x.case().when("male", "M").when("female", "F").else_("U").end() - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, 'U') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├──────────────────────┤ - │ M │ - │ F │ - │ F │ - │ U │ - │ F │ - └──────────────────────┘ - - Cases not given result in the ELSE case - - >>> x.case().when("male", "M").else_("OTHER").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, 'OTHER') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├──────────────────────────┤ - │ M │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - └──────────────────────────┘ - - If you don't supply an ELSE, then NULL is used - - >>> x.case().when("male", "M").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, Cast(None, string)) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├─────────────────────────────────────┤ - │ M │ - │ NULL │ - │ NULL │ - │ NULL │ - │ NULL │ - └─────────────────────────────────────┘ - """ - import ibis.expr.builders as bl + @staticmethod + def _norm_cases_args(*args, **kwargs): + # TODO: remove in v10.0 once we have a deprecation cycle + # before, the API for Value.cases() was + # def cases( + # self, + # case_result_pairs: Iterable[tuple[Value, Value]], + # default: Value | None = None, + # ) -> Value: + # Now it is + # def cases( + # self, + # *branches: tuple[Value, Value], + # else_: Value | None = None, + # ) -> Value: + # This method normalizes the arguments to the new API. + using_old_api = False + branches = [] + else_ = None + if len(args) >= 1: + first_arg = args[0] + first_arg = util.promote_list(first_arg) + if len(first_arg) > 0 and isinstance(first_arg[0], tuple): + # called as .cases([(test, result), ...], ) + using_old_api = True + branches = first_arg + else_ = args[1] if len(args) == 2 else None + else: + # called as .cases((test, result), ...) + branches = list(args) + + if "case_result_pairs" in kwargs: + using_old_api = True + branches = list(kwargs["case_result_pairs"]) + elif "branches" in kwargs: + branches = list(kwargs["branches"]) + + if "default" in kwargs: + using_old_api = True + else_ = kwargs["default"] + elif "else_" in kwargs: + else_ = kwargs["else_"] + + if using_old_api: + warnings.warn( + "You are using the old API for `cases()`. Please see" + " https://ibis-project.org/reference/expression-generic" + " on how to upgrade to the new API.", + FutureWarning, + ) + return branches, else_ - return bl.SimpleCaseBuilder(self.op()) + def cases(self, *args, **kwargs) -> Value: # noqa: D417 + """Create a multi-branch if-else expression. - def cases( - self, - case_result_pairs: Iterable[tuple[ir.BooleanValue, Value]], - default: Value | None = None, - ) -> Value: - """Create a case expression in one shot. + This is semantically equivalent to + CASE self + WHEN test_val0 THEN result0 + WHEN test_val1 THEN result1 + ELSE else_ + END Parameters ---------- - case_result_pairs - Conditional-result pairs - default + branches + (test_val, result) pairs. We look through the test values in order + and return the result corresponding to the first test value that + matches `self`. If none match, we return `else_`. + else_ Value to return if none of the case conditions are true Returns @@ -968,48 +948,64 @@ def cases( See Also -------- [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) + [`ibis.cases()`](./expression-generic.qmd#ibis.cases) Examples -------- >>> import ibis >>> ibis.options.interactive = True - >>> t = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}) - >>> t - ┏━━━━━━━━┓ - ┃ values ┃ - ┡━━━━━━━━┩ - │ int64 │ - ├────────┤ - │ 1 │ - │ 2 │ - │ 1 │ - │ 2 │ - │ 3 │ - │ 2 │ - │ 4 │ - └────────┘ - >>> number_letter_map = ((1, "a"), (2, "b"), (3, "c")) - >>> t.values.cases(number_letter_map, default="unk").name("replace") - ┏━━━━━━━━━┓ - ┃ replace ┃ - ┡━━━━━━━━━┩ - │ string │ - ├─────────┤ - │ a │ - │ b │ - │ a │ - │ b │ - │ c │ - │ b │ - │ unk │ - └─────────┘ + >>> t = ibis.memtable( + ... { + ... "left": [5, 6, 7, 8, 9, 10], + ... "symbol": ["+", "-", "*", "/", "bogus", None], + ... "right": [1, 2, 3, 4, 5, 6], + ... } + ... ) + + Note we never hit the `None` case, because `x = NULL` is always NULL, + which is not truthy. If you want to replace NULLs, you should use + `.fillna(-999)` prior to `cases()`. + + >>> t.mutate( + ... result=( + ... t.symbol.cases( + ... ("+", t.left + t.right), + ... ("-", t.left - t.right), + ... ("*", t.left * t.right), + ... ("/", t.left / t.right), + ... (None, -999), + ... ) + ... ) + ... ) + ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ + ┃ left ┃ symbol ┃ right ┃ result ┃ + ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ + │ int64 │ string │ int64 │ float64 │ + ├───────┼────────┼───────┼─────────┤ + │ 5 │ + │ 1 │ 6.0 │ + │ 6 │ - │ 2 │ 4.0 │ + │ 7 │ * │ 3 │ 21.0 │ + │ 8 │ / │ 4 │ 2.0 │ + │ 9 │ bogus │ 5 │ NULL │ + │ 10 │ NULL │ 6 │ NULL │ + └───────┴────────┴───────┴─────────┘ """ - builder = self.case() - for case, result in case_result_pairs: - builder = builder.when(case, result) - return builder.else_(default).end() + branches, else_ = self._norm_cases_args(*args, **kwargs) + + branches2 = [] + for b in branches: + b = tuple(b) + try: + test, result = b + except (TypeError, ValueError) as e: + raise ValueError( + f"Each branch must be a tuple of (condition, result), got {b}" + ) from e + branches2.append(b) + cases, results = zip(*branches2) if branches2 else ([], []) + return ops.SimpleCase( + base=self, cases=cases, results=results, default=else_ + ).to_expr() def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 56d5c2dea178c..b7396c576b833 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools from typing import TYPE_CHECKING, Literal from public import public @@ -1143,13 +1142,7 @@ def label(self, labels: Iterable[str], nulls: str | None = None) -> ir.StringVal │ 2 │ c │ └───────┴─────────┘ """ - return ( - functools.reduce( - lambda stmt, inputs: stmt.when(*inputs), enumerate(labels), self.case() - ) - .else_(nulls) - .end() - ) + return self.cases(*enumerate(labels), else_=nulls) @public diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 387a8a60a1d79..7ceb834cae6d6 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -2813,9 +2813,7 @@ def info(self) -> Table: for pos, colname in enumerate(self.columns): col = self[colname] typ = col.type() - agg = self.select( - isna=ibis.case().when(col.isnull(), 1).else_(0).end() - ).agg( + agg = self.select(isna=ibis.cases((col.isnull(), 1), else_=0)).agg( name=lit(colname), type=lit(str(typ)), nullable=lit(typ.nullable), diff --git a/ibis/tests/expr/test_case.py b/ibis/tests/expr/test_case.py index 89e5b3b7df7b4..6444b2aace813 100644 --- a/ibis/tests/expr/test_case.py +++ b/ibis/tests/expr/test_case.py @@ -1,11 +1,14 @@ from __future__ import annotations +import pytest + import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import _ -from ibis.tests.util import assert_equal, assert_pickle_roundtrip +from ibis.common.annotations import SignatureValidationError +from ibis.tests.util import assert_pickle_roundtrip def test_ifelse_method(table): @@ -44,72 +47,63 @@ def test_ifelse_function_deferred(table): assert res.equals(sol) -def test_simple_case_expr(table): - case1, result1 = "foo", table.a - case2, result2 = "bar", table.c - default_result = table.b - - expr1 = table.g.lower().cases( - [(case1, result1), (case2, result2)], default=default_result - ) - - expr2 = ( - table.g.lower() - .case() - .when(case1, result1) - .when(case2, result2) - .else_(default_result) - .end() - ) +def test_err_on_bad_args(table): + with pytest.raises(ValueError): + ibis.cases((True,)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3), 5) - assert_equal(expr1, expr2) - assert isinstance(expr1, ir.IntegerColumn) + with pytest.raises(ValueError): + table.a.cases(("foo",)) + with pytest.raises(ValueError): + table.a.cases(("foo", 3, 4)) + with pytest.raises(ValueError): + table.a.cases(("foo", 3, 4)) + with pytest.raises(TypeError): + table.a.cases(("foo", 3), 5) def test_multiple_case_expr(table): - expr = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(table.d) - .end() + expr = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=table.d, ) # deferred cases - deferred = ( - ibis.case() - .when(_.a == 5, table.f) - .when(_.b == 128, table.b * 2) - .when(_.c == 1000, table.e) - .else_(table.d) - .end() + deferred = ibis.cases( + (_.a == 5, table.f), + (_.b == 128, table.b * 2), + (_.c == 1000, table.e), + else_=table.d, ) expr2 = deferred.resolve(table) # deferred results - expr3 = ( - ibis.case() - .when(table.a == 5, _.f) - .when(table.b == 128, _.b * 2) - .when(table.c == 1000, _.e) - .else_(table.d) - .end() - .resolve(table) - ) + expr3 = ibis.cases( + (table.a == 5, _.f), + (table.b == 128, _.b * 2), + (table.c == 1000, _.e), + else_=table.d, + ).resolve(table) # deferred default - expr4 = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(_.d) - .end() - .resolve(table) + expr4 = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=_.d, + ).resolve(table) + + assert ( + repr(deferred) + == "cases(((_.a == 5), ), ((_.b == 128), ), ((_.c == 1000), ), else_=)" ) - - assert repr(deferred) == "" assert expr.equals(expr2) assert expr.equals(expr3) assert expr.equals(expr4) @@ -130,13 +124,11 @@ def test_pickle_multiple_case_node(table): result3 = table.e default = table.d - expr = ( - ibis.case() - .when(case1, result1) - .when(case2, result2) - .when(case3, result3) - .else_(default) - .end() + expr = ibis.cases( + (case1, result1), + (case2, result2), + (case3, result3), + else_=default, ) op = expr.op() @@ -144,7 +136,7 @@ def test_pickle_multiple_case_node(table): def test_simple_case_null_else(table): - expr = table.g.case().when("foo", "bar").end() + expr = table.g.cases(("foo", "bar")) op = expr.op() assert isinstance(expr, ir.StringColumn) @@ -154,8 +146,8 @@ def test_simple_case_null_else(table): def test_multiple_case_null_else(table): - expr = ibis.case().when(table.g == "foo", "bar").end() - expr2 = ibis.case().when(table.g == "foo", _).end().resolve("bar") + expr = ibis.cases((table.g == "foo", "bar")) + expr2 = ibis.cases((table.g == "foo", _)).resolve("bar") assert expr.equals(expr2) @@ -172,8 +164,61 @@ def test_case_mixed_type(): name="my_data", ) - expr = ( - t0.three.case().when(0, "low").when(1, "high").else_("null").end().name("label") - ) + expr = t0.three.cases((0, "low"), (1, "high"), else_="null").name("label") result = t0[expr] assert result["label"].type().equals(dt.string) + + +def test_err_on_nonbool(table): + with pytest.raises(SignatureValidationError): + ibis.cases((table.a, "bar"), else_="baz") + + +@pytest.mark.xfail(reason="Literal('foo', type=bool), should error, but doesn't") +def test_err_on_nonbool2(): + with pytest.raises(SignatureValidationError): + ibis.cases(("foo", "bar"), else_="baz") + + +def test_err_on_noncomparable(table): + table.a.cases((8, "bar")) + table.a.cases((-8, "bar")) + # Can't compare an int to a string + with pytest.raises(TypeError): + table.a.cases(("foo", "bar")) + + +def test_empty_cases(table): + ibis.cases() + ibis.cases(else_=42) + table.a.cases() + table.a.cases(else_=42) + + +def test_dtype(): + assert isinstance(ibis.cases((True, "bar"), (False, "bar")), ir.StringValue) + assert isinstance(ibis.cases((True, None), else_="bar"), ir.StringValue) + with pytest.raises(TypeError): + assert ibis.cases((True, 5), (False, "bar")) + with pytest.raises(TypeError): + assert ibis.cases((True, 5), else_="bar") + + +def test_dshape(table): + assert isinstance(ibis.cases((True, "bar"), (False, "bar")), ir.Scalar) + assert isinstance(ibis.cases((True, None), else_="bar"), ir.Scalar) + assert isinstance(ibis.cases((table.b == 9, None), else_="bar"), ir.Column) + assert isinstance(ibis.cases((True, table.a), else_=42), ir.Column) + assert isinstance(ibis.cases((True, 42), else_=table.a), ir.Column) + assert isinstance(ibis.cases((True, table.a), else_=table.b), ir.Column) + + assert isinstance(ibis.literal(5).cases((9, 42)), ir.Scalar) + assert isinstance(ibis.literal(5).cases((9, 42), else_=43), ir.Scalar) + assert isinstance(ibis.literal(5).cases((table.a, 42)), ir.Column) + assert isinstance(ibis.literal(5).cases((9, table.a)), ir.Column) + assert isinstance(ibis.literal(5).cases((table.a, table.b)), ir.Column) + assert isinstance(ibis.literal(5).cases((9, 42), else_=table.a), ir.Column) + assert isinstance(table.a.cases((9, 42)), ir.Column) + assert isinstance(table.a.cases((table.b, 42)), ir.Column) + assert isinstance(table.a.cases((9, table.b)), ir.Column) + assert isinstance(table.a.cases((table.a, table.b)), ir.Column) diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 5ded2434bfbfc..2e564b833af7e 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -822,23 +822,11 @@ def test_substitute_dict(): subs = {"a": "one", "b": table.bar} result = table.foo.substitute(subs) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(table.foo) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=table.foo) assert_equal(result, expected) result = table.foo.substitute(subs, else_=ibis.NA) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(ibis.NA) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=ibis.NA) assert_equal(result, expected)