From 26da007d3c58c64c75d23b8da1f27a413f167ef9 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 9 Jan 2024 12:06:17 +0100 Subject: [PATCH] fix(python): Fix interchange protocol data buffer dtype (#10787) --- py-polars/polars/interchange/column.py | 6 +----- py-polars/tests/unit/interchange/test_column.py | 6 +++--- py-polars/tests/unit/interchange/test_roundtrip.py | 8 ++++++++ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/interchange/column.py b/py-polars/polars/interchange/column.py index 72835902fad9..3050d2eef404 100644 --- a/py-polars/polars/interchange/column.py +++ b/py-polars/polars/interchange/column.py @@ -157,11 +157,7 @@ def get_buffers(self) -> ColumnBuffers: def _get_data_buffer(self) -> tuple[PolarsBuffer, Dtype]: s = self._col._get_buffer(0) buffer = PolarsBuffer(s, allow_copy=self._allow_copy) - - dtype = self.dtype - if dtype[0] == DtypeKind.CATEGORICAL: - dtype = (DtypeKind.UINT, 32, "I", Endianness.NATIVE) - + dtype = polars_dtype_to_dtype(s.dtype) return buffer, dtype def _get_validity_buffer(self) -> tuple[PolarsBuffer, Dtype] | None: diff --git a/py-polars/tests/unit/interchange/test_column.py b/py-polars/tests/unit/interchange/test_column.py index 2b5b41c1a016..6303146d504d 100644 --- a/py-polars/tests/unit/interchange/test_column.py +++ b/py-polars/tests/unit/interchange/test_column.py @@ -206,7 +206,7 @@ def test_get_buffers_with_validity_and_offsets() -> None: data_buffer, data_dtype = out["data"] expected = pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8) assert_series_equal(data_buffer._data, expected) - assert data_dtype == (DtypeKind.STRING, 8, "U", "=") + assert data_dtype == (DtypeKind.UINT, 8, "C", "=") validity = out["validity"] assert validity is not None @@ -260,14 +260,14 @@ def test_get_buffers_chunked_zero_copy_fails() -> None: ( pl.Series(["a", "bc", None, "éâç"], dtype=pl.String), pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8), - (DtypeKind.STRING, 8, "U", "="), + (DtypeKind.UINT, 8, "C", "="), ), ( pl.Series( [datetime(1988, 1, 2), None, datetime(2022, 12, 3)], dtype=pl.Datetime ), pl.Series([568080000000000, 0, 1670025600000000], dtype=pl.Int64), - (DtypeKind.DATETIME, 64, "tsu:", "="), + (DtypeKind.INT, 64, "l", "="), ), ( pl.Series(["a", "b", None, "a"], dtype=pl.Categorical), diff --git a/py-polars/tests/unit/interchange/test_roundtrip.py b/py-polars/tests/unit/interchange/test_roundtrip.py index 5191a1e9de91..0c9ecf9c82ea 100644 --- a/py-polars/tests/unit/interchange/test_roundtrip.py +++ b/py-polars/tests/unit/interchange/test_roundtrip.py @@ -57,6 +57,10 @@ def test_to_dataframe_pyarrow_zero_copy_parametric(df: pl.DataFrame) -> None: assert_frame_equal(result, df, categorical_as_str=True) +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="The correct `from_dataframe` implementation for pandas is not available before Python 3.9", +) @pytest.mark.filterwarnings( "ignore:.*PEP3118 format string that does not match its itemsize:RuntimeWarning" ) @@ -68,6 +72,10 @@ def test_to_dataframe_pandas_parametric(df: pl.DataFrame) -> None: assert_frame_equal(result, df, categorical_as_str=True) +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="The correct `from_dataframe` implementation for pandas is not available before Python 3.9", +) @pytest.mark.filterwarnings( "ignore:.*PEP3118 format string that does not match its itemsize:RuntimeWarning" )