diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index b6b764a09cec..6a3f5b39f814 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -9,7 +9,7 @@ import polars as pl from polars import StringCache -from polars.testing import assert_series_equal +from polars.testing import assert_frame_equal, assert_series_equal def test_enum_creation() -> None: @@ -402,3 +402,22 @@ def test_enum_cast_from_other_integer_dtype_oob() -> None: pl.ComputeError, match="conversion from `u64` to `u32` failed in column" ): series.cast(enum_dtype) + + +def test_enum_creating_col_expr() -> None: + df = pl.DataFrame( + { + "col1": ["a", "b", "c"], + "col2": ["d", "e", "f"], + "col3": ["g", "h", "i"], + }, + schema={ + "col1": pl.Enum(["a", "b", "c"]), + "col2": pl.Categorical(), + "col3": pl.Enum(["g", "h", "i"]), + }, + ) + + out = df.select(pl.col(pl.Enum)) + expected = df.select("col1", "col3") + assert_frame_equal(out, expected)