From 9599662513a64a0bfe194250f883c33bbf17ecff Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 9 Nov 2024 17:32:31 +0000 Subject: [PATCH] fix: use PyCapsule Interface instead of Dataframe Interchange Protocol --- pyproject.toml | 1 + seaborn/_core/data.py | 66 +++++++++++++++++++--------------------- seaborn/_core/plot.py | 2 +- seaborn/_core/typing.py | 4 +-- tests/_core/test_data.py | 14 ++++----- tests/_core/test_plot.py | 4 +-- tests/conftest.py | 7 +++-- tests/test_axisgrid.py | 8 ++--- 8 files changed, 52 insertions(+), 54 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0a4e497d0d..ae5b401f12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "mypy", "pandas-stubs", "pre-commit", + "pyarrow", "flit", ] docs = [ diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index c17bfe95c5..2c2a0cf0d1 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -5,7 +5,6 @@ from collections.abc import Mapping, Sized from typing import cast -import warnings import pandas as pd from pandas import DataFrame @@ -269,9 +268,9 @@ def _assign_variables( def handle_data_source(data: object) -> pd.DataFrame | Mapping | None: """Convert the data source object to a common union representation.""" - if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"): + if isinstance(data, pd.DataFrame) or hasattr(data, "__arrow_c_stream__"): # Check for pd.DataFrame inheritance could be removed once - # minimal pandas version supports dataframe interchange (1.5.0). + # minimal pandas version supports PyCapsule Interface (2.2). data = convert_dataframe_to_pandas(data) elif data is not None and not isinstance(data, Mapping): err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}." @@ -285,35 +284,32 @@ def convert_dataframe_to_pandas(data: object) -> pd.DataFrame: if isinstance(data, pd.DataFrame): return data - if not hasattr(pd.api, "interchange"): - msg = ( - "Support for non-pandas DataFrame objects requires a version of pandas " - "that implements the DataFrame interchange protocol. Please upgrade " - "your pandas version or coerce your data to pandas before passing " - "it to seaborn." - ) - raise TypeError(msg) - - if _version_predates(pd, "2.0.2"): - msg = ( - "DataFrame interchange with pandas<2.0.2 has some known issues. " - f"You are using pandas {pd.__version__}. " - "Continuing, but it is recommended to carefully inspect the results and to " - "consider upgrading." - ) - warnings.warn(msg, stacklevel=2) - - try: - # This is going to convert all columns in the input dataframe, even though - # we may only need one or two of them. It would be more efficient to select - # the columns that are going to be used in the plot prior to interchange. - # Solving that in general is a hard problem, especially with the objects - # interface where variables passed in Plot() may only be referenced later - # in Plot.add(). But noting here in case this seems to be a bottleneck. - return pd.api.interchange.from_dataframe(data) - except Exception as err: - msg = ( - "Encountered an exception when converting data source " - "to a pandas DataFrame. See traceback above for details." - ) - raise RuntimeError(msg) from err + if hasattr(data, '__arrow_c_stream__'): + try: + import pyarrow + except ImportError as err: + msg = "PyArrow is required for non-pandas Dataframe support." + raise RuntimeError(msg) from err + if _version_predates(pyarrow, '14.0.0'): + msg = "PyArrow>=14.0.0 is required for non-pandas Dataframe support." + raise RuntimeError(msg) + try: + # This is going to convert all columns in the input dataframe, even though + # we may only need one or two of them. It would be more efficient to select + # the columns that are going to be used in the plot prior to interchange. + # Solving that in general is a hard problem, especially with the objects + # interface where variables passed in Plot() may only be referenced later + # in Plot.add(). But noting here in case this seems to be a bottleneck. + return pyarrow.table(data).to_pandas() + except Exception as err: + msg = ( + "Encountered an exception when converting data source " + "to a pandas DataFrame. See traceback above for details." + ) + raise RuntimeError(msg) from err + + msg = ( + "Expected object which implements '__arrow_c_stream__' from the " + f"PyCapsule Interface, got: {type(data)}" + ) + raise TypeError(msg) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index c9dc61c8a7..b021e7b959 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -349,7 +349,7 @@ def _resolve_positionals( if ( isinstance(args[0], (abc.Mapping, pd.DataFrame)) - or hasattr(args[0], "__dataframe__") + or hasattr(args[0], "__arrow_c_stream__") ): if data is not None: raise TypeError("`data` given by both name and position.") diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index 9bdf8a6ef8..e5529b766a 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -17,9 +17,9 @@ VariableSpec = Union[ColumnName, Vector, None] VariableSpecList = Union[List[VariableSpec], Index, None] -# A DataSource can be an object implementing __dataframe__, or a Mapping +# A DataSource can be an object implementing __arrow_c_stream__, or a Mapping # (and is optional in all contexts where it is used). -# I don't think there's an abc for "has __dataframe__", so we type as object +# I don't think there's an abc for "has __arrow_c_stream__", so we type as object # but keep the (slightly odd) Union alias for better user-facing annotations. DataSource = Union[object, Mapping, None] diff --git a/tests/_core/test_data.py b/tests/_core/test_data.py index 0e67ed37b4..72b9220a2a 100644 --- a/tests/_core/test_data.py +++ b/tests/_core/test_data.py @@ -405,8 +405,8 @@ def test_bad_type(self, flat_list): PlotData(flat_list, {}) @pytest.mark.skipif( - condition=not hasattr(pd.api, "interchange"), - reason="Tests behavior assuming support for dataframe interchange" + condition=not hasattr(pd.DataFrame, "__arrow_c_stream__"), + reason="Tests behavior assuming support for PyCapsule Interface" ) def test_data_interchange(self, mock_long_df, long_df): @@ -420,18 +420,18 @@ def test_data_interchange(self, mock_long_df, long_df): assert_vector_equal(p.frame[var], long_df[col]) @pytest.mark.skipif( - condition=not hasattr(pd.api, "interchange"), - reason="Tests behavior assuming support for dataframe interchange" + condition=not hasattr(pd.DataFrame, "__arrow_c_stream__"), + reason="Tests behavior assuming support for PyCapsule Interface" ) def test_data_interchange_failure(self, mock_long_df): - mock_long_df._data = None # Break __dataframe__() + mock_long_df.__arrow_c_stream__ = lambda _x: 1 / 0 # Break __arrow_c_stream__() with pytest.raises(RuntimeError, match="Encountered an exception"): PlotData(mock_long_df, {"x": "x"}) @pytest.mark.skipif( - condition=hasattr(pd.api, "interchange"), - reason="Tests graceful failure without support for dataframe interchange" + condition=hasattr(pd.DataFrame, "__arrow_c_stream__"), + reason="Tests graceful failure without support for PyCapsule Interface" ) def test_data_interchange_support_test(self, mock_long_df): diff --git a/tests/_core/test_plot.py b/tests/_core/test_plot.py index 5554ea650f..e0bb96ddd2 100644 --- a/tests/_core/test_plot.py +++ b/tests/_core/test_plot.py @@ -171,8 +171,8 @@ def test_positional_x(self, long_df): assert list(p._data.source_vars) == ["x"] @pytest.mark.skipif( - condition=not hasattr(pd.api, "interchange"), - reason="Tests behavior assuming support for dataframe interchange" + condition=not hasattr(pd.DataFrame, "__arrow_c_stream__"), + reason="Tests behavior assuming support for PyCapsule Interface" ) def test_positional_interchangeable_dataframe(self, mock_long_df, long_df): diff --git a/tests/conftest.py b/tests/conftest.py index 6ee53e7ee4..620b417cf2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -188,11 +188,12 @@ class MockInterchangeableDataFrame: def __init__(self, data): self._data = data - def __dataframe__(self, *args, **kwargs): - return self._data.__dataframe__(*args, **kwargs) + def __arrow_c_stream__(self, *args, **kwargs): + return self._data.__arrow_c_stream__() @pytest.fixture def mock_long_df(long_df): + import pyarrow - return MockInterchangeableDataFrame(long_df) + return MockInterchangeableDataFrame(pyarrow.Table.from_pandas(long_df)) diff --git a/tests/test_axisgrid.py b/tests/test_axisgrid.py index 6470edfa4f..b582e2f371 100644 --- a/tests/test_axisgrid.py +++ b/tests/test_axisgrid.py @@ -709,8 +709,8 @@ def test_tick_params(self): assert tick.get_pad() == pad @pytest.mark.skipif( - condition=not hasattr(pd.api, "interchange"), - reason="Tests behavior assuming support for dataframe interchange" + condition=not hasattr(pd.DataFrame, "__arrow_c_stream__"), + reason="Tests behavior assuming support for PyCapsule Interface" ) def test_data_interchange(self, mock_long_df, long_df): @@ -1478,8 +1478,8 @@ def test_tick_params(self): assert tick.get_pad() == pad @pytest.mark.skipif( - condition=not hasattr(pd.api, "interchange"), - reason="Tests behavior assuming support for dataframe interchange" + condition=not hasattr(pd.DataFrame, "__arrow_c_stream__"), + reason="Tests behavior assuming support for PyCapsule Interface" ) def test_data_interchange(self, mock_long_df, long_df):