From 026f64f6bd85ccddb2afaeecb1fa4bef93007b2c Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Thu, 25 Jan 2024 15:41:55 +0100 Subject: [PATCH 1/3] add polars support to the joiner --- setup.cfg | 1 + skrub/_dataframe/__init__.py | 6 ++ skrub/_dataframe/_dataframe_api.py | 34 ++++++++ skrub/_dataframe/_dataframe_operations.py | 63 +++++++++++++++ skrub/_join_utils.py | 13 ++-- skrub/_joiner.py | 94 +++++++++++------------ skrub/tests/test_joiner.py | 9 ++- 7 files changed, 160 insertions(+), 60 deletions(-) create mode 100644 skrub/_dataframe/_dataframe_api.py create mode 100644 skrub/_dataframe/_dataframe_operations.py diff --git a/setup.cfg b/setup.cfg index 01b29d723..6c13294d4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ install_requires = scipy>=1.9.3 pandas>=1.5.3 packaging>=23.1 + dataframe-api-compat @ git+https://github.com/data-apis/dataframe-api-compat.git python_requires = >=3.10 [options.packages.find] diff --git a/skrub/_dataframe/__init__.py b/skrub/_dataframe/__init__.py index e69de29bb..12742d7e4 100644 --- a/skrub/_dataframe/__init__.py +++ b/skrub/_dataframe/__init__.py @@ -0,0 +1,6 @@ +from . import _dataframe_operations +from ._dataframe_api import asdfapi, asnative, dfapi_ns +from ._dataframe_operations import * # noqa: F403,F401 + +__all__ = ["asdfapi", "asnative", "dfapi_ns"] +__all__ += _dataframe_operations.__all__ diff --git a/skrub/_dataframe/_dataframe_api.py b/skrub/_dataframe/_dataframe_api.py new file mode 100644 index 000000000..6c91a5335 --- /dev/null +++ b/skrub/_dataframe/_dataframe_api.py @@ -0,0 +1,34 @@ +def asdfapi(obj): + if hasattr(obj, "__dataframe_namespace__"): + return obj + if hasattr(obj, "__column_namespace__"): + return obj + if hasattr(obj, "__dataframe_consortium_standard__"): + return obj.__dataframe_consortium_standard__() + if hasattr(obj, "__column_consortium_standard__"): + return obj.__column_consortium_standard__() + raise TypeError( + f"{obj} cannot be converted to DataFrame Consortium Standard object." + ) + + +def asnative(obj): + if hasattr(obj, "__dataframe_namespace__"): + return obj.dataframe + if hasattr(obj, "__column_namespace__"): + return obj.column + if hasattr(obj, "__scalar_namespace__"): + return obj.scalar + return obj + + +def dfapi_ns(obj): + obj = asdfapi(obj) + for attr in [ + "__dataframe_namespace__", + "__column_namespace__", + "__scalar_namespace__", + ]: + if hasattr(obj, attr): + return getattr(obj, attr)() + raise TypeError(f"{obj} is not a Dataframe Consortium Standard object.") diff --git a/skrub/_dataframe/_dataframe_operations.py b/skrub/_dataframe/_dataframe_operations.py new file mode 100644 index 000000000..5fa1e868f --- /dev/null +++ b/skrub/_dataframe/_dataframe_operations.py @@ -0,0 +1,63 @@ +from ._dataframe_api import asdfapi, asnative, dfapi_ns + +__all__ = [ + "is_numeric", + "is_temporal", + "numeric_column_names", + "temporal_column_names", + "select", + "set_column_names", + "collect", +] + + +def is_numeric(column, include_bool=True): + column = asdfapi(column) + ns = dfapi_ns(column) + if ns.is_dtype(column.dtype, "numeric"): + return True + if not include_bool: + return False + return ns.is_dtype(column.dtype, "bool") + + +def is_temporal(column): + column = asdfapi(column) + ns = dfapi_ns(column) + for dtype in [ns.Date, ns.Datetime]: + if isinstance(column.dtype, dtype): + return True + return False + + +def _select_column_names(df, predicate): + df = asdfapi(df) + return [col_name for col_name in df.column_names if predicate(df.col(col_name))] + + +def numeric_column_names(df): + return _select_column_names(df, is_numeric) + + +def temporal_column_names(df): + return _select_column_names(df, is_temporal) + + +def select(df, column_names): + return asnative(asdfapi(df).select(*column_names)) + + +def collect(df): + if hasattr(df, "collect"): + return df.collect() + return df + + +def set_column_names(df, new_column_names): + df = asdfapi(df) + ns = dfapi_ns(df) + new_columns = ( + df.col(col_name).rename(new_name).persist() + for (col_name, new_name) in zip(df.column_names, new_column_names) + ) + return asnative(ns.dataframe_from_columns(*new_columns)) diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index b4b0439c4..a459b497e 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -2,7 +2,8 @@ from collections import Counter from skrub import _utils -from skrub._dataframe._namespace import get_df_namespace + +from ._dataframe import asdfapi, asnative def check_key( @@ -84,12 +85,12 @@ def check_missing_columns(table, key, table_name): table_name : str Name by which to refer to `table` in the error message if necessary. """ - missing_columns = set(key) - set(table.columns) + missing_columns = set(key) - set(asdfapi(table).column_names) if not missing_columns: return raise ValueError( "The following columns cannot be used for joining because they do not exist" - f" in {table_name}:\n{missing_columns}" + f" in {table_name}: \n{missing_columns}" ) @@ -148,5 +149,7 @@ def check_column_name_duplicates( def add_column_name_suffix(dataframe, suffix): - ns, _ = get_df_namespace(dataframe) - return ns.rename_columns(dataframe, f"{{}}{suffix}".format) + api_df = asdfapi(dataframe) + renaming = {name: f"{name}{suffix}" for name in api_df.column_names} + api_df = api_df.rename(renaming) + return asnative(api_df) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 7f322c460..2f9ddc9bd 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -3,7 +3,6 @@ """ import numpy as np -import pandas as pd from sklearn.base import BaseEstimator, TransformerMixin, clone from sklearn.compose import make_column_transformer from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer @@ -12,12 +11,18 @@ from sklearn.utils.validation import check_is_fitted from skrub import _join_utils, _matching, _utils -from skrub._dataframe._namespace import is_pandas, is_polars from skrub._datetime_encoder import DatetimeEncoder +from . import _dataframe as sb +from ._dataframe import asdfapi, asnative, dfapi_ns + def _as_str(column): - return column.fillna("").astype(str) + column = asdfapi(column) + column = column.fill_null("") + ns = dfapi_ns(column) + column = column.cast(ns.String()) + return asnative(column) DEFAULT_STRING_ENCODER = make_pipeline( @@ -45,17 +50,17 @@ def _make_vectorizer(table, string_encoder, rescale): In addition if ``rescale`` is ``True``, a StandardScaler is applied to numeric and datetime columns. """ - transformers = [ - (clone(string_encoder), c) - for c in table.select_dtypes(include=["string", "category", "object"]).columns - ] - num_columns = table.select_dtypes(include="number").columns - if not num_columns.empty: + numeric_columns = sb.numeric_column_names(table) + dt_columns = sb.temporal_column_names(table) + object_columns = list( + set(asdfapi(table).column_names).difference(numeric_columns, dt_columns) + ) + transformers = [(clone(string_encoder), c) for c in object_columns] + if numeric_columns: transformers.append( - (StandardScaler() if rescale else "passthrough", num_columns) + (StandardScaler() if rescale else "passthrough", numeric_columns) ) - dt_columns = table.select_dtypes(["datetime", "datetimetz"]).columns - if not dt_columns.empty: + if dt_columns: transformers.append( ( make_pipeline( @@ -243,17 +248,6 @@ def __init__( ) self.add_match_info = add_match_info - def _check_dataframe(self, dataframe): - # TODO: add support for polars, ATM we just convert to pandas - if is_polars(dataframe): - return dataframe.to_pandas() - if is_pandas(dataframe): - return dataframe - raise TypeError( - f"{self.__class__.__qualname__} only operates on Pandas or Polars" - " dataframes." - ) - def _check_max_dist(self): if ( self.max_dist is None @@ -288,24 +282,24 @@ def fit(self, X, y=None): Fitted Joiner instance (self). """ del y - X = self._check_dataframe(X) - self._aux_table = self._check_dataframe(self.aux_table) self._check_ref_dist() self._check_max_dist() self._main_key, self._aux_key = _join_utils.check_key( self.main_key, self.aux_key, self.key ) _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") - _join_utils.check_missing_columns(self._aux_table, self._aux_key, "'aux_table'") + _join_utils.check_missing_columns(self.aux_table, self._aux_key, "'aux_table'") _join_utils.check_column_name_duplicates( - X, self._aux_table, self.suffix, main_table_name="X" + X, self.aux_table, self.suffix, main_table_name="X" ) self.vectorizer_ = _make_vectorizer( - self._aux_table[self._aux_key], + sb.select(self.aux_table, self._aux_key), self.string_encoder, rescale=self.ref_dist != "no_rescaling", ) - aux = self.vectorizer_.fit_transform(self._aux_table[self._aux_key]) + aux = self.vectorizer_.fit_transform( + sb.collect(sb.select(self.aux_table, self._aux_key)) + ) self._matching.fit(aux) return self @@ -325,41 +319,39 @@ def transform(self, X, y=None): The final joined table. """ del y - input_is_polars = is_polars(X) - X = self._check_dataframe(X) check_is_fitted(self, "vectorizer_") _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") _join_utils.check_column_name_duplicates( - X, self._aux_table, self.suffix, main_table_name="X" - ) - main = self.vectorizer_.transform( - X[self._main_key].set_axis(self._aux_key, axis="columns") + X, self.aux_table, self.suffix, main_table_name="X" ) + main = sb.select(X, self._main_key) + main = sb.set_column_names(main, self._aux_key) + main = self.vectorizer_.transform(sb.collect(main)) match_result = self._matching.match(main, self.max_dist_) - aux_table = _join_utils.add_column_name_suffix( - self._aux_table, self.suffix - ).reset_index(drop=True) + aux_table = _join_utils.add_column_name_suffix(self.aux_table, self.suffix) matching_col = match_result["index"].copy() matching_col[~match_result["match_accepted"]] = -1 token = _utils.random_string() left_key_name = f"skrub_left_key_{token}" right_key_name = f"skrub_right_key_{token}" - left = X.assign(**{left_key_name: matching_col}) - right = aux_table.assign(**{right_key_name: np.arange(aux_table.shape[0])}) - join = pd.merge( - left, + ns = dfapi_ns(X) + left = asdfapi(X).assign( + ns.column_from_1d_array(matching_col, name=left_key_name) + ) + n_rows = asdfapi(aux_table).persist().shape()[0] + right = asdfapi(aux_table).assign( + ns.column_from_1d_array(np.arange(n_rows), name=right_key_name) + ) + join = left.join( right, + how="left", left_on=left_key_name, right_on=right_key_name, - suffixes=("", ""), - how="left", ) - join = join.drop([left_key_name, right_key_name], axis=1) + join = join.drop(left_key_name, right_key_name) if self.add_match_info: for info_key, info_col_name in self._match_info_key_renaming.items(): - join[info_col_name] = match_result[info_key] - if input_is_polars: - import polars as pl - - join = pl.from_pandas(join) - return join + join = join.assign( + ns.column_from_1d_array(match_result[info_key], name=info_col_name) + ) + return asnative(join) diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index c42dd9008..14a643384 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -5,6 +5,7 @@ from pandas.testing import assert_frame_equal from skrub import Joiner +from skrub import _dataframe as sb from skrub._dataframe._polars import POLARS_SETUP MODULES = [pd] @@ -44,7 +45,7 @@ def test_joiner(px): ) joiner.fit(main_table) - big_table = joiner.transform(main_table) + big_table = sb.collect(joiner.transform(main_table)) assert big_table.shape == (main_table.shape[0], 3) assert_array_equal( big_table["Population"].to_numpy(), @@ -81,17 +82,17 @@ def test_multiple_keys(px, assert_frame_equal_): main_key=["Co", "Ca"], add_match_info=False, ) - result = joiner_list.fit_transform(df) + result = sb.collect(joiner_list.fit_transform(df)) try: expected = px.concat([df, df2], axis=1) except TypeError: expected = px.concat([df, df2], how="horizontal") - assert_frame_equal_(result, expected) + assert_frame_equal_(sb.collect(result), expected) joiner_list = Joiner( aux_table=df2, aux_key="CA", main_key="Ca", add_match_info=False ) - result = joiner_list.fit_transform(df) + result = sb.collect(joiner_list.fit_transform(df)) assert_frame_equal_(result, expected) From 25058835f91458065400a322e3a2c3c51daae1d2 Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Thu, 25 Jan 2024 16:50:08 +0100 Subject: [PATCH 2/3] fix for old pandas versions --- skrub/_dataframe/_dataframe_api.py | 44 ++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/skrub/_dataframe/_dataframe_api.py b/skrub/_dataframe/_dataframe_api.py index 6c91a5335..979a540e7 100644 --- a/skrub/_dataframe/_dataframe_api.py +++ b/skrub/_dataframe/_dataframe_api.py @@ -7,11 +7,55 @@ def asdfapi(obj): return obj.__dataframe_consortium_standard__() if hasattr(obj, "__column_consortium_standard__"): return obj.__column_consortium_standard__() + try: + return _asdfapi_old_pandas(obj) + except (ImportError, TypeError): + pass + try: + return _asdfapi_old_polars(obj) + except (ImportError, TypeError): + pass raise TypeError( f"{obj} cannot be converted to DataFrame Consortium Standard object." ) +def _asdfapi_old_pandas(obj): + import pandas as pd + + if isinstance(obj, pd.DataFrame): + from dataframe_api_compat.pandas_standard import ( + convert_to_standard_compliant_dataframe, + ) + + return convert_to_standard_compliant_dataframe(obj) + if isinstance(obj, pd.Series): + from dataframe_api_compat.pandas_standard import ( + convert_to_standard_compliant_column, + ) + + return convert_to_standard_compliant_column(obj) + raise TypeError() + + +def _asdfapi_old_polars(obj): + import polars as pl + + if isinstance(obj, (pl.DataFrame, pl.LazyFrame)): + from dataframe_api_compat.polars_standard import ( + convert_to_standard_compliant_dataframe, + ) + + return convert_to_standard_compliant_dataframe(obj) + if isinstance(obj, pl.Series): + from dataframe_api_compat.polars_standard import ( + convert_to_standard_compliant_column, + ) + + return convert_to_standard_compliant_column(obj) + raise TypeError() + + def asnative(obj): if hasattr(obj, "__dataframe_namespace__"): return obj.dataframe From 2442c454fd3f6f7741b8c492aa18e63c9e431d3c Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Thu, 25 Jan 2024 17:00:48 +0100 Subject: [PATCH 3/3] fix windows dtype mismatch --- skrub/_joiner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 2f9ddc9bd..d3c1cc29b 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -340,7 +340,9 @@ def transform(self, X, y=None): ) n_rows = asdfapi(aux_table).persist().shape()[0] right = asdfapi(aux_table).assign( - ns.column_from_1d_array(np.arange(n_rows), name=right_key_name) + ns.column_from_1d_array( + np.arange(n_rows, dtype="int64"), name=right_key_name + ) ) join = left.join( right,