From 25b04f9820194161cc19756cb1dbcebb6a6f2f71 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 12 May 2024 12:01:22 +0100 Subject: [PATCH] make timegapsplit dataframe-agnostic --- pyproject.toml | 2 +- sklego/model_selection.py | 111 +++++++++------- .../test_model_selection/test_timegapsplit.py | 119 +++++++++++++++++- 3 files changed, 186 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2ec1424f4..b3f2354e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ maintainers = [ ] dependencies = [ - "narwhals>=0.8.9", + "narwhals>=0.8.12", "pandas>=1.1.5", "scikit-learn>=1.0", "importlib-metadata >= 1.0; python_version < '3.8'", diff --git a/sklego/model_selection.py b/sklego/model_selection.py index 4d52d817a..52ac577c4 100644 --- a/sklego/model_selection.py +++ b/sklego/model_selection.py @@ -3,6 +3,7 @@ from itertools import combinations from warnings import warn +import narwhals as nw import numpy as np import pandas as pd from sklearn.exceptions import NotFittedError @@ -44,8 +45,10 @@ class TimeGapSplit: Parameters ---------- - date_serie : pd.Series + date_serie : Series Series with the date, that should have all the indices of X used in the split() method. + If the Series is not pandas-like (for example, if it's a Polars Series, which does not have + an index) then it must the same same length as the `X` and `y` objects passed to `split`. valid_duration : datetime.timedelta Retraining period. train_duration : datetime.timedelta | None, default=None @@ -65,6 +68,20 @@ class TimeGapSplit: - `"rolling"` window has fixed size and is shifted entirely. - `"expanding"` left side of window is fixed, right border increases each fold. + + Notes + ----- + Native cross-dataframe support is achieved using + [Narwhals](https://narwhals-dev.github.io/narwhals/){:target="_blank"}. + Supported dataframes are: + + - pandas + - Polars (eager) + - Modin + + See [Narwhals docs](https://narwhals-dev.github.io/narwhals/extending/){:target="_blank"} for an up-to-date list + (and to learn how you can add your dataframe library to it!), though note that only those + convertible to `numpy` arrays will work with this class. """ def __init__( @@ -82,11 +99,7 @@ def __init__( if (train_duration is not None) and (train_duration <= gap_duration): raise ValueError("gap_duration is longer than train_duration, it should be shorter.") - if not date_serie.index.is_unique: - raise ValueError("date_serie doesn't have a unique index") - - self.date_serie = date_serie.copy() - self.date_serie = self.date_serie.rename("__date__") + self.date_serie = nw.from_native(date_serie, series_only=True).alias("__date__") self.train_duration = train_duration self.valid_duration = valid_duration self.gap_duration = gap_duration @@ -98,13 +111,15 @@ def _join_date_and_x(self, X): index and with the 'numpy index' column (i.e. just a range) that is required for the output and the rest of sklearn. + If the user is working with index-less dataframes (e.g. Polars), then `self.date_series` needs to be the same + length as `X`. + Parameters ---------- - X : pd.DataFrame + X : DataFrame Dataframe with the data to split """ - X_index_df = pd.DataFrame(range(len(X)), columns=["np_index"], index=X.index) - X_index_df = X_index_df.join(self.date_serie) + X_index_df = nw.maybe_align_index(self.date_serie, X).to_frame().with_row_index("np_index") return X_index_df @@ -113,7 +128,7 @@ def split(self, X, y=None, groups=None): Parameters ---------- - X : pd.DataFrame + X : DataFrame Dataframe with the data to split. y : array-like | None, default=None Ignored, present for compatibility. @@ -126,8 +141,9 @@ def split(self, X, y=None, groups=None): Train and test indices of the same fold. """ + X = nw.from_native(X, eager_only=True) X_index_df = self._join_date_and_x(X) - X_index_df = X_index_df.sort_values("__date__", ascending=True) + X_index_df = X_index_df.sort("__date__", descending=False) if len(X) != len(X_index_df): raise AssertionError( @@ -167,23 +183,20 @@ def split(self, X, y=None, groups=None): if current_date + self.train_duration + time_shift + self.gap_duration > date_max: break - X_train_df = X_index_df[ - (X_index_df["__date__"] >= start_date) & (X_index_df["__date__"] < current_date + self.train_duration) - ] - X_valid_df = X_index_df[ - (X_index_df["__date__"] >= current_date + self.train_duration + self.gap_duration) - & ( - X_index_df["__date__"] - < current_date + self.train_duration + self.valid_duration + self.gap_duration - ) - ] + X_train_df = X_index_df.filter( + nw.col("__date__") >= start_date, nw.col("__date__") < current_date + self.train_duration + ) + X_valid_df = X_index_df.filter( + nw.col("__date__") >= current_date + self.train_duration + self.gap_duration, + nw.col("__date__") < current_date + self.train_duration + self.valid_duration + self.gap_duration, + ) current_date = current_date + time_shift if self.window == "rolling": start_date = current_date yield ( - X_train_df["np_index"].values, - X_valid_df["np_index"].values, + X_train_df["np_index"].to_numpy(), + X_valid_df["np_index"].to_numpy(), ) def get_n_splits(self, X=None, y=None, groups=None): @@ -191,7 +204,7 @@ def get_n_splits(self, X=None, y=None, groups=None): Parameters ---------- - X : pd.DataFrame + X : DataFrame Dataframe with the data to split. y : array-like | None, default=None Ignored, present for compatibility. @@ -210,42 +223,52 @@ def summary(self, X): Parameters ---------- - X : pd.DataFrame + X : DataFrame Dataframe with the data to split. Returns ------- - pd.DataFrame + DataFrame Summary of all folds. """ summary = [] + X = nw.from_native(X, eager_only=True) X_index_df = self._join_date_and_x(X) - def get_split_info(X, indices, j, part, summary): - dates = X_index_df.iloc[indices]["__date__"] + summary = { + "Start date": [], + "End date": [], + "Period": [], + "Unique days": [], + "nbr samples": [], + "part": [], + "fold": [], + } + native_namespace = nw.get_native_namespace(X) + + def update_split_info(indices, j, part, summary): + dates = X_index_df["__date__"][indices] mindate = dates.min() maxdate = dates.max() + n_unique = dates.n_unique() - s = pd.Series( - { - "Start date": mindate, - "End date": maxdate, - "Period": pd.to_datetime(maxdate, format="%Y%m%d") - pd.to_datetime(mindate, format="%Y%m%d"), - "Unique days": len(dates.unique()), - "nbr samples": len(indices), - }, - name=(j, part), - ) - summary.append(s) - return summary + summary["Start date"].append(mindate) + summary["End date"].append(maxdate) + summary["Period"].append(maxdate - mindate) + summary["Unique days"].append(n_unique) + summary["nbr samples"].append(len(indices)) + summary["part"].append(part) + summary["fold"].append(j) j = 0 - for i in self.split(X): - summary = get_split_info(X, i[0], j, "train", summary) - summary = get_split_info(X, i[1], j, "valid", summary) + for i in self.split(nw.to_native(X)): + update_split_info(native_namespace.Series(i[0]), j, "train", summary) + update_split_info(native_namespace.Series(i[1]), j, "valid", summary) j = j + 1 - return pd.DataFrame(summary) + result = nw.from_native(native_namespace.DataFrame(summary)) + result = nw.maybe_set_index(result, ["fold", "part"]) + return nw.to_native(result) def KlusterFoldValidation(**kwargs): diff --git a/tests/test_model_selection/test_timegapsplit.py b/tests/test_model_selection/test_timegapsplit.py index 48b41b29a..47c5cb45e 100644 --- a/tests/test_model_selection/test_timegapsplit.py +++ b/tests/test_model_selection/test_timegapsplit.py @@ -3,7 +3,10 @@ import numpy as np import pandas as pd +import polars as pl import pytest +from pandas.testing import assert_frame_equal as pandas_assert_frame_equal +from polars.testing import assert_frame_equal as polars_assert_frame_equal from sklearn.linear_model import Lasso from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline @@ -43,6 +46,42 @@ def test_timegapsplit(): assert valid_mindate == datetime.datetime.strptime("2018-01-21", "%Y-%m-%d") assert valid_maxdate == datetime.datetime.strptime("2018-01-23", "%Y-%m-%d") + expected = [ + (np.array([0, 1, 2, 3, 4]), np.array([5, 6, 7])), + (np.array([3, 4, 5, 6, 7]), np.array([8, 9, 10])), + (np.array([6, 7, 8, 9, 10]), np.array([11, 12, 13])), + (np.array([9, 10, 11, 12, 13]), np.array([14, 15, 16])), + (np.array([12, 13, 14, 15, 16]), np.array([17, 18, 19])), + (np.array([15, 16, 17, 18, 19]), np.array([20, 21, 22])), + ] + for result_indices, expected_indices in zip(list(cv.split(X_train, y_train)), expected): + np.testing.assert_array_equal(result_indices[0], expected_indices[0]) + np.testing.assert_array_equal(result_indices[1], expected_indices[1]) + + # Polars doesn't have an index, so this class behaves a bit differenly for + # index-less objects. We need to first ensure that `date_serie`, `X_train`, + # and `y_train` all have the same length. + date_serie = df["date"].loc[X_train.index] + cv = TimeGapSplit( + date_serie=pl.from_pandas(date_serie), + train_duration=timedelta(days=5), + valid_duration=timedelta(days=3), + gap_duration=timedelta(days=0), + ) + expected = [ + (np.array([0, 1, 2, 3, 4]), np.array([5, 6, 7])), + (np.array([3, 4, 5, 6, 7]), np.array([8, 9, 10])), + (np.array([6, 7, 8, 9, 10]), np.array([11, 12, 13])), + (np.array([9, 10, 11, 12, 13]), np.array([14, 15, 16])), + (np.array([12, 13, 14, 15, 16]), np.array([17, 18, 19])), + (np.array([15, 16, 17, 18, 19]), np.array([20, 21, 22])), + ] + for result_indices, expected_indices in zip( + list(cv.split(pl.from_pandas(X_train), pl.from_pandas(y_train))), expected + ): + np.testing.assert_array_equal(result_indices[0], expected_indices[0]) + np.testing.assert_array_equal(result_indices[1], expected_indices[1]) + def test_timegapsplit_too_big_gap(): try: @@ -151,5 +190,83 @@ def test_timegapsplit_summary(): ) summary = cv.summary(X_train) - assert summary.shape == (12, 5) + + expected_data = { + "Start date": [ + datetime.datetime(2018, 1, 1, 0, 0), + datetime.datetime(2018, 1, 6, 0, 0), + datetime.datetime(2018, 1, 4, 0, 0), + datetime.datetime(2018, 1, 9, 0, 0), + datetime.datetime(2018, 1, 7, 0, 0), + datetime.datetime(2018, 1, 12, 0, 0), + datetime.datetime(2018, 1, 10, 0, 0), + datetime.datetime(2018, 1, 15, 0, 0), + datetime.datetime(2018, 1, 13, 0, 0), + datetime.datetime(2018, 1, 18, 0, 0), + datetime.datetime(2018, 1, 16, 0, 0), + datetime.datetime(2018, 1, 21, 0, 0), + ], + "End date": [ + datetime.datetime(2018, 1, 5, 0, 0), + datetime.datetime(2018, 1, 8, 0, 0), + datetime.datetime(2018, 1, 8, 0, 0), + datetime.datetime(2018, 1, 11, 0, 0), + datetime.datetime(2018, 1, 11, 0, 0), + datetime.datetime(2018, 1, 14, 0, 0), + datetime.datetime(2018, 1, 14, 0, 0), + datetime.datetime(2018, 1, 17, 0, 0), + datetime.datetime(2018, 1, 17, 0, 0), + datetime.datetime(2018, 1, 20, 0, 0), + datetime.datetime(2018, 1, 20, 0, 0), + datetime.datetime(2018, 1, 23, 0, 0), + ], + "Period": [ + datetime.timedelta(days=4), + datetime.timedelta(days=2), + datetime.timedelta(days=4), + datetime.timedelta(days=2), + datetime.timedelta(days=4), + datetime.timedelta(days=2), + datetime.timedelta(days=4), + datetime.timedelta(days=2), + datetime.timedelta(days=4), + datetime.timedelta(days=2), + datetime.timedelta(days=4), + datetime.timedelta(days=2), + ], + "Unique days": [5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3], + "nbr samples": [5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3], + "part": [ + "train", + "valid", + "train", + "valid", + "train", + "valid", + "train", + "valid", + "train", + "valid", + "train", + "valid", + ], + "fold": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], + } + expected = pd.DataFrame(expected_data).set_index(["fold", "part"]) + pandas_assert_frame_equal(summary, expected) + + # Polars doesn't have an index, so this class behaves a bit differenly for + # index-less objects. We need to ensure that `date_serie` and `X_train` have + # the same length. + date_serie = df["date"].loc[X_train.index] + cv = TimeGapSplit( + date_serie=pl.from_pandas(date_serie), + train_duration=timedelta(days=5), + valid_duration=timedelta(days=3), + gap_duration=timedelta(days=0), + ) + summary = cv.summary(pl.from_pandas(X_train)) + + expected = pl.DataFrame(expected_data) + polars_assert_frame_equal(summary, expected)