Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add polars support to the joiner #885

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ install_requires =
scipy>=1.9.3
pandas>=1.5.3
packaging>=23.1
dataframe-api-compat @ git+https://github.com/data-apis/dataframe-api-compat.git
python_requires = >=3.10

[options.packages.find]
Expand Down
6 changes: 6 additions & 0 deletions skrub/_dataframe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import _dataframe_operations
from ._dataframe_api import asdfapi, asnative, dfapi_ns
from ._dataframe_operations import * # noqa: F403,F401

__all__ = ["asdfapi", "asnative", "dfapi_ns"]
__all__ += _dataframe_operations.__all__
78 changes: 78 additions & 0 deletions skrub/_dataframe/_dataframe_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
def asdfapi(obj):
if hasattr(obj, "__dataframe_namespace__"):
return obj
if hasattr(obj, "__column_namespace__"):
return obj
if hasattr(obj, "__dataframe_consortium_standard__"):
return obj.__dataframe_consortium_standard__()
if hasattr(obj, "__column_consortium_standard__"):
return obj.__column_consortium_standard__()
try:
return _asdfapi_old_pandas(obj)
except (ImportError, TypeError):
pass
try:
return _asdfapi_old_polars(obj)
except (ImportError, TypeError):
pass
raise TypeError(

Check warning on line 18 in skrub/_dataframe/_dataframe_api.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_api.py#L12-L18

Added lines #L12 - L18 were not covered by tests
f"{obj} cannot be converted to DataFrame Consortium Standard object."
)


def _asdfapi_old_pandas(obj):
import pandas as pd

if isinstance(obj, pd.DataFrame):
from dataframe_api_compat.pandas_standard import (
convert_to_standard_compliant_dataframe,
)

return convert_to_standard_compliant_dataframe(obj)
if isinstance(obj, pd.Series):
from dataframe_api_compat.pandas_standard import (
convert_to_standard_compliant_column,
)

return convert_to_standard_compliant_column(obj)
raise TypeError()

Check warning on line 38 in skrub/_dataframe/_dataframe_api.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_api.py#L38

Added line #L38 was not covered by tests


def _asdfapi_old_polars(obj):
import polars as pl

Check warning on line 42 in skrub/_dataframe/_dataframe_api.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_api.py#L42

Added line #L42 was not covered by tests

if isinstance(obj, (pl.DataFrame, pl.LazyFrame)):
from dataframe_api_compat.polars_standard import (

Check warning on line 45 in skrub/_dataframe/_dataframe_api.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_api.py#L44-L45

Added lines #L44 - L45 were not covered by tests
convert_to_standard_compliant_dataframe,
)

return convert_to_standard_compliant_dataframe(obj)
if isinstance(obj, pl.Series):
from dataframe_api_compat.polars_standard import (

Check warning on line 51 in skrub/_dataframe/_dataframe_api.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_api.py#L49-L51

Added lines #L49 - L51 were not covered by tests
convert_to_standard_compliant_column,
)

return convert_to_standard_compliant_column(obj)
raise TypeError()

Check warning on line 56 in skrub/_dataframe/_dataframe_api.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_api.py#L55-L56

Added lines #L55 - L56 were not covered by tests


def asnative(obj):
if hasattr(obj, "__dataframe_namespace__"):
return obj.dataframe
if hasattr(obj, "__column_namespace__"):
return obj.column
if hasattr(obj, "__scalar_namespace__"):
return obj.scalar
return obj

Check warning on line 66 in skrub/_dataframe/_dataframe_api.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_api.py#L64-L66

Added lines #L64 - L66 were not covered by tests


def dfapi_ns(obj):
obj = asdfapi(obj)
for attr in [
"__dataframe_namespace__",
"__column_namespace__",
"__scalar_namespace__",
]:
if hasattr(obj, attr):
return getattr(obj, attr)()
raise TypeError(f"{obj} is not a Dataframe Consortium Standard object.")

Check warning on line 78 in skrub/_dataframe/_dataframe_api.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_api.py#L78

Added line #L78 was not covered by tests
63 changes: 63 additions & 0 deletions skrub/_dataframe/_dataframe_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from ._dataframe_api import asdfapi, asnative, dfapi_ns

__all__ = [
"is_numeric",
"is_temporal",
"numeric_column_names",
"temporal_column_names",
"select",
"set_column_names",
"collect",
]


def is_numeric(column, include_bool=True):
column = asdfapi(column)
ns = dfapi_ns(column)
if ns.is_dtype(column.dtype, "numeric"):
return True
if not include_bool:
return False

Check warning on line 20 in skrub/_dataframe/_dataframe_operations.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/_dataframe_operations.py#L20

Added line #L20 was not covered by tests
return ns.is_dtype(column.dtype, "bool")


def is_temporal(column):
column = asdfapi(column)
ns = dfapi_ns(column)
for dtype in [ns.Date, ns.Datetime]:
if isinstance(column.dtype, dtype):
return True
return False


def _select_column_names(df, predicate):
df = asdfapi(df)
return [col_name for col_name in df.column_names if predicate(df.col(col_name))]


def numeric_column_names(df):
return _select_column_names(df, is_numeric)


def temporal_column_names(df):
return _select_column_names(df, is_temporal)


def select(df, column_names):
return asnative(asdfapi(df).select(*column_names))


def collect(df):
if hasattr(df, "collect"):
return df.collect()
return df


def set_column_names(df, new_column_names):
df = asdfapi(df)
ns = dfapi_ns(df)
new_columns = (
df.col(col_name).rename(new_name).persist()
for (col_name, new_name) in zip(df.column_names, new_column_names)
)
return asnative(ns.dataframe_from_columns(*new_columns))
13 changes: 8 additions & 5 deletions skrub/_join_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from collections import Counter

from skrub import _utils
from skrub._dataframe._namespace import get_df_namespace

from ._dataframe import asdfapi, asnative


def check_key(
Expand Down Expand Up @@ -84,12 +85,12 @@ def check_missing_columns(table, key, table_name):
table_name : str
Name by which to refer to `table` in the error message if necessary.
"""
missing_columns = set(key) - set(table.columns)
missing_columns = set(key) - set(asdfapi(table).column_names)
if not missing_columns:
return
raise ValueError(
"The following columns cannot be used for joining because they do not exist"
f" in {table_name}:\n{missing_columns}"
f" in {table_name}: \n{missing_columns}"
)


Expand Down Expand Up @@ -148,5 +149,7 @@ def check_column_name_duplicates(


def add_column_name_suffix(dataframe, suffix):
ns, _ = get_df_namespace(dataframe)
return ns.rename_columns(dataframe, f"{{}}{suffix}".format)
api_df = asdfapi(dataframe)
renaming = {name: f"{name}{suffix}" for name in api_df.column_names}
api_df = api_df.rename(renaming)
return asnative(api_df)
96 changes: 45 additions & 51 deletions skrub/_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.compose import make_column_transformer
from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer
Expand All @@ -12,12 +11,18 @@
from sklearn.utils.validation import check_is_fitted

from skrub import _join_utils, _matching, _utils
from skrub._dataframe._namespace import is_pandas, is_polars
from skrub._datetime_encoder import DatetimeEncoder

from . import _dataframe as sb
from ._dataframe import asdfapi, asnative, dfapi_ns


def _as_str(column):
return column.fillna("").astype(str)
column = asdfapi(column)
column = column.fill_null("")
ns = dfapi_ns(column)
column = column.cast(ns.String())
return asnative(column)


DEFAULT_STRING_ENCODER = make_pipeline(
Expand Down Expand Up @@ -45,17 +50,17 @@ def _make_vectorizer(table, string_encoder, rescale):
In addition if ``rescale`` is ``True``, a StandardScaler is applied to
numeric and datetime columns.
"""
transformers = [
(clone(string_encoder), c)
for c in table.select_dtypes(include=["string", "category", "object"]).columns
]
num_columns = table.select_dtypes(include="number").columns
if not num_columns.empty:
numeric_columns = sb.numeric_column_names(table)
dt_columns = sb.temporal_column_names(table)
object_columns = list(
set(asdfapi(table).column_names).difference(numeric_columns, dt_columns)
)
transformers = [(clone(string_encoder), c) for c in object_columns]
if numeric_columns:
transformers.append(
(StandardScaler() if rescale else "passthrough", num_columns)
(StandardScaler() if rescale else "passthrough", numeric_columns)
)
dt_columns = table.select_dtypes(["datetime", "datetimetz"]).columns
if not dt_columns.empty:
if dt_columns:
transformers.append(
(
make_pipeline(
Expand Down Expand Up @@ -243,17 +248,6 @@ def __init__(
)
self.add_match_info = add_match_info

def _check_dataframe(self, dataframe):
# TODO: add support for polars, ATM we just convert to pandas
if is_polars(dataframe):
return dataframe.to_pandas()
if is_pandas(dataframe):
return dataframe
raise TypeError(
f"{self.__class__.__qualname__} only operates on Pandas or Polars"
" dataframes."
)

def _check_max_dist(self):
if (
self.max_dist is None
Expand Down Expand Up @@ -288,24 +282,24 @@ def fit(self, X, y=None):
Fitted Joiner instance (self).
"""
del y
X = self._check_dataframe(X)
self._aux_table = self._check_dataframe(self.aux_table)
self._check_ref_dist()
self._check_max_dist()
self._main_key, self._aux_key = _join_utils.check_key(
self.main_key, self.aux_key, self.key
)
_join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)")
_join_utils.check_missing_columns(self._aux_table, self._aux_key, "'aux_table'")
_join_utils.check_missing_columns(self.aux_table, self._aux_key, "'aux_table'")
_join_utils.check_column_name_duplicates(
X, self._aux_table, self.suffix, main_table_name="X"
X, self.aux_table, self.suffix, main_table_name="X"
)
self.vectorizer_ = _make_vectorizer(
self._aux_table[self._aux_key],
sb.select(self.aux_table, self._aux_key),
self.string_encoder,
rescale=self.ref_dist != "no_rescaling",
)
aux = self.vectorizer_.fit_transform(self._aux_table[self._aux_key])
aux = self.vectorizer_.fit_transform(
sb.collect(sb.select(self.aux_table, self._aux_key))
)
self._matching.fit(aux)
return self

Expand All @@ -325,41 +319,41 @@ def transform(self, X, y=None):
The final joined table.
"""
del y
input_is_polars = is_polars(X)
X = self._check_dataframe(X)
check_is_fitted(self, "vectorizer_")
_join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)")
_join_utils.check_column_name_duplicates(
X, self._aux_table, self.suffix, main_table_name="X"
)
main = self.vectorizer_.transform(
X[self._main_key].set_axis(self._aux_key, axis="columns")
X, self.aux_table, self.suffix, main_table_name="X"
)
main = sb.select(X, self._main_key)
main = sb.set_column_names(main, self._aux_key)
main = self.vectorizer_.transform(sb.collect(main))
match_result = self._matching.match(main, self.max_dist_)
aux_table = _join_utils.add_column_name_suffix(
self._aux_table, self.suffix
).reset_index(drop=True)
aux_table = _join_utils.add_column_name_suffix(self.aux_table, self.suffix)
matching_col = match_result["index"].copy()
matching_col[~match_result["match_accepted"]] = -1
token = _utils.random_string()
left_key_name = f"skrub_left_key_{token}"
right_key_name = f"skrub_right_key_{token}"
left = X.assign(**{left_key_name: matching_col})
right = aux_table.assign(**{right_key_name: np.arange(aux_table.shape[0])})
join = pd.merge(
left,
ns = dfapi_ns(X)
left = asdfapi(X).assign(
ns.column_from_1d_array(matching_col, name=left_key_name)
)
n_rows = asdfapi(aux_table).persist().shape()[0]
right = asdfapi(aux_table).assign(
ns.column_from_1d_array(
np.arange(n_rows, dtype="int64"), name=right_key_name
)
)
join = left.join(
right,
how="left",
left_on=left_key_name,
right_on=right_key_name,
suffixes=("", ""),
how="left",
)
join = join.drop([left_key_name, right_key_name], axis=1)
join = join.drop(left_key_name, right_key_name)
if self.add_match_info:
for info_key, info_col_name in self._match_info_key_renaming.items():
join[info_col_name] = match_result[info_key]
if input_is_polars:
import polars as pl

join = pl.from_pandas(join)
return join
join = join.assign(
ns.column_from_1d_array(match_result[info_key], name=info_col_name)
)
return asnative(join)
9 changes: 5 additions & 4 deletions skrub/tests/test_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pandas.testing import assert_frame_equal

from skrub import Joiner
from skrub import _dataframe as sb
from skrub._dataframe._polars import POLARS_SETUP

MODULES = [pd]
Expand Down Expand Up @@ -44,7 +45,7 @@ def test_joiner(px):
)

joiner.fit(main_table)
big_table = joiner.transform(main_table)
big_table = sb.collect(joiner.transform(main_table))
assert big_table.shape == (main_table.shape[0], 3)
assert_array_equal(
big_table["Population"].to_numpy(),
Expand Down Expand Up @@ -81,17 +82,17 @@ def test_multiple_keys(px, assert_frame_equal_):
main_key=["Co", "Ca"],
add_match_info=False,
)
result = joiner_list.fit_transform(df)
result = sb.collect(joiner_list.fit_transform(df))
try:
expected = px.concat([df, df2], axis=1)
except TypeError:
expected = px.concat([df, df2], how="horizontal")
assert_frame_equal_(result, expected)
assert_frame_equal_(sb.collect(result), expected)

joiner_list = Joiner(
aux_table=df2, aux_key="CA", main_key="Ca", add_match_info=False
)
result = joiner_list.fit_transform(df)
result = sb.collect(joiner_list.fit_transform(df))
assert_frame_equal_(result, expected)


Expand Down
Loading