From 68d737a4b2d7c1b719eb434f11f3b4635448b6f2 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Wed, 4 Sep 2024 14:49:26 -0500 Subject: [PATCH] fix(polars): support polars `Enum` type --- ibis/backends/polars/tests/test_client.py | 23 +++++++++++++++++++++++ ibis/formats/polars.py | 2 +- ibis/formats/tests/test_polars.py | 3 ++- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/ibis/backends/polars/tests/test_client.py b/ibis/backends/polars/tests/test_client.py index baa4ce99132d..e0d4c698c103 100644 --- a/ibis/backends/polars/tests/test_client.py +++ b/ibis/backends/polars/tests/test_client.py @@ -1,5 +1,7 @@ from __future__ import annotations +import polars as pl +import polars.testing import pytest import ibis @@ -37,3 +39,24 @@ def test_array_flatten(con): {"id": data["id"], "flat": [row[0] for row in data["happy"]]} ) tm.assert_frame_equal(result.to_pandas(), expected) + + +def test_memtable_polars_types(con): + # Check that we can create a memtable with some polars-specific types, + # and that those columns then work in downstream operations + df = pl.DataFrame( + { + "x": ["a", "b", "a"], + "y": ["c", "d", "c"], + "z": ["e", "f", "e"], + }, + schema={ + "x": pl.String, + "y": pl.Categorical, + "z": pl.Enum(["e", "f"]), + }, + ) + t = ibis.memtable(df) + res = con.to_polars((t.x + t.y + t.z).name("test")) + sol = (df["x"] + df["y"] + df["z"]).rename("test") + pl.testing.assert_series_equal(res, sol) diff --git a/ibis/formats/polars.py b/ibis/formats/polars.py index 0b9fcfd98cb1..20cc8998d4a4 100644 --- a/ibis/formats/polars.py +++ b/ibis/formats/polars.py @@ -43,7 +43,7 @@ def to_ibis(cls, typ: pl.DataType, nullable=True) -> dt.DataType: """Convert a polars type to an ibis type.""" base_type = typ.base_type() - if base_type is pl.Categorical: + if base_type in (pl.Categorical, pl.Enum): return dt.String(nullable=nullable) elif base_type is pl.Decimal: return dt.Decimal( diff --git a/ibis/formats/tests/test_polars.py b/ibis/formats/tests/test_polars.py index f46771df2cb2..201e42db1ad0 100644 --- a/ibis/formats/tests/test_polars.py +++ b/ibis/formats/tests/test_polars.py @@ -100,8 +100,9 @@ def test_decimal(): ) -def test_categorical(): +def test_enum_categorical(): assert PolarsType.to_ibis(pl.Categorical()) == dt.string + assert PolarsType.to_ibis(pl.Enum(["a", "b"])) == dt.string def test_interval_unsupported_unit():