Skip to content

Commit

Permalink
make timegapsplit dataframe-agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed May 12, 2024
1 parent 0773db9 commit 25b04f9
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ maintainers = [
]

dependencies = [
"narwhals>=0.8.9",
"narwhals>=0.8.12",
"pandas>=1.1.5",
"scikit-learn>=1.0",
"importlib-metadata >= 1.0; python_version < '3.8'",
Expand Down
111 changes: 67 additions & 44 deletions sklego/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import combinations
from warnings import warn

import narwhals as nw
import numpy as np
import pandas as pd
from sklearn.exceptions import NotFittedError
Expand Down Expand Up @@ -44,8 +45,10 @@ class TimeGapSplit:
Parameters
----------
date_serie : pd.Series
date_serie : Series
Series with the date, that should have all the indices of X used in the split() method.
If the Series is not pandas-like (for example, if it's a Polars Series, which does not have
an index) then it must the same same length as the `X` and `y` objects passed to `split`.
valid_duration : datetime.timedelta
Retraining period.
train_duration : datetime.timedelta | None, default=None
Expand All @@ -65,6 +68,20 @@ class TimeGapSplit:
- `"rolling"` window has fixed size and is shifted entirely.
- `"expanding"` left side of window is fixed, right border increases each fold.
Notes
-----
Native cross-dataframe support is achieved using
[Narwhals](https://narwhals-dev.github.io/narwhals/){:target="_blank"}.
Supported dataframes are:
- pandas
- Polars (eager)
- Modin
See [Narwhals docs](https://narwhals-dev.github.io/narwhals/extending/){:target="_blank"} for an up-to-date list
(and to learn how you can add your dataframe library to it!), though note that only those
convertible to `numpy` arrays will work with this class.
"""

def __init__(
Expand All @@ -82,11 +99,7 @@ def __init__(
if (train_duration is not None) and (train_duration <= gap_duration):
raise ValueError("gap_duration is longer than train_duration, it should be shorter.")

if not date_serie.index.is_unique:
raise ValueError("date_serie doesn't have a unique index")

self.date_serie = date_serie.copy()
self.date_serie = self.date_serie.rename("__date__")
self.date_serie = nw.from_native(date_serie, series_only=True).alias("__date__")
self.train_duration = train_duration
self.valid_duration = valid_duration
self.gap_duration = gap_duration
Expand All @@ -98,13 +111,15 @@ def _join_date_and_x(self, X):
index and with the 'numpy index' column (i.e. just a range) that is required for the output and the rest of
sklearn.
If the user is working with index-less dataframes (e.g. Polars), then `self.date_series` needs to be the same
length as `X`.
Parameters
----------
X : pd.DataFrame
X : DataFrame
Dataframe with the data to split
"""
X_index_df = pd.DataFrame(range(len(X)), columns=["np_index"], index=X.index)
X_index_df = X_index_df.join(self.date_serie)
X_index_df = nw.maybe_align_index(self.date_serie, X).to_frame().with_row_index("np_index")

return X_index_df

Expand All @@ -113,7 +128,7 @@ def split(self, X, y=None, groups=None):
Parameters
----------
X : pd.DataFrame
X : DataFrame
Dataframe with the data to split.
y : array-like | None, default=None
Ignored, present for compatibility.
Expand All @@ -126,8 +141,9 @@ def split(self, X, y=None, groups=None):
Train and test indices of the same fold.
"""

X = nw.from_native(X, eager_only=True)
X_index_df = self._join_date_and_x(X)
X_index_df = X_index_df.sort_values("__date__", ascending=True)
X_index_df = X_index_df.sort("__date__", descending=False)

if len(X) != len(X_index_df):
raise AssertionError(
Expand Down Expand Up @@ -167,31 +183,28 @@ def split(self, X, y=None, groups=None):
if current_date + self.train_duration + time_shift + self.gap_duration > date_max:
break

X_train_df = X_index_df[
(X_index_df["__date__"] >= start_date) & (X_index_df["__date__"] < current_date + self.train_duration)
]
X_valid_df = X_index_df[
(X_index_df["__date__"] >= current_date + self.train_duration + self.gap_duration)
& (
X_index_df["__date__"]
< current_date + self.train_duration + self.valid_duration + self.gap_duration
)
]
X_train_df = X_index_df.filter(
nw.col("__date__") >= start_date, nw.col("__date__") < current_date + self.train_duration
)
X_valid_df = X_index_df.filter(
nw.col("__date__") >= current_date + self.train_duration + self.gap_duration,
nw.col("__date__") < current_date + self.train_duration + self.valid_duration + self.gap_duration,
)

current_date = current_date + time_shift
if self.window == "rolling":
start_date = current_date
yield (
X_train_df["np_index"].values,
X_valid_df["np_index"].values,
X_train_df["np_index"].to_numpy(),
X_valid_df["np_index"].to_numpy(),
)

def get_n_splits(self, X=None, y=None, groups=None):
"""Get the number of splits
Parameters
----------
X : pd.DataFrame
X : DataFrame
Dataframe with the data to split.
y : array-like | None, default=None
Ignored, present for compatibility.
Expand All @@ -210,42 +223,52 @@ def summary(self, X):
Parameters
----------
X : pd.DataFrame
X : DataFrame
Dataframe with the data to split.
Returns
-------
pd.DataFrame
DataFrame
Summary of all folds.
"""
summary = []
X = nw.from_native(X, eager_only=True)
X_index_df = self._join_date_and_x(X)

def get_split_info(X, indices, j, part, summary):
dates = X_index_df.iloc[indices]["__date__"]
summary = {
"Start date": [],
"End date": [],
"Period": [],
"Unique days": [],
"nbr samples": [],
"part": [],
"fold": [],
}
native_namespace = nw.get_native_namespace(X)

def update_split_info(indices, j, part, summary):
dates = X_index_df["__date__"][indices]
mindate = dates.min()
maxdate = dates.max()
n_unique = dates.n_unique()

s = pd.Series(
{
"Start date": mindate,
"End date": maxdate,
"Period": pd.to_datetime(maxdate, format="%Y%m%d") - pd.to_datetime(mindate, format="%Y%m%d"),
"Unique days": len(dates.unique()),
"nbr samples": len(indices),
},
name=(j, part),
)
summary.append(s)
return summary
summary["Start date"].append(mindate)
summary["End date"].append(maxdate)
summary["Period"].append(maxdate - mindate)
summary["Unique days"].append(n_unique)
summary["nbr samples"].append(len(indices))
summary["part"].append(part)
summary["fold"].append(j)

j = 0
for i in self.split(X):
summary = get_split_info(X, i[0], j, "train", summary)
summary = get_split_info(X, i[1], j, "valid", summary)
for i in self.split(nw.to_native(X)):
update_split_info(native_namespace.Series(i[0]), j, "train", summary)
update_split_info(native_namespace.Series(i[1]), j, "valid", summary)
j = j + 1

return pd.DataFrame(summary)
result = nw.from_native(native_namespace.DataFrame(summary))
result = nw.maybe_set_index(result, ["fold", "part"])
return nw.to_native(result)


def KlusterFoldValidation(**kwargs):
Expand Down
119 changes: 118 additions & 1 deletion tests/test_model_selection/test_timegapsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import numpy as np
import pandas as pd
import polars as pl
import pytest
from pandas.testing import assert_frame_equal as pandas_assert_frame_equal
from polars.testing import assert_frame_equal as polars_assert_frame_equal
from sklearn.linear_model import Lasso
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
Expand Down Expand Up @@ -43,6 +46,42 @@ def test_timegapsplit():
assert valid_mindate == datetime.datetime.strptime("2018-01-21", "%Y-%m-%d")
assert valid_maxdate == datetime.datetime.strptime("2018-01-23", "%Y-%m-%d")

expected = [
(np.array([0, 1, 2, 3, 4]), np.array([5, 6, 7])),
(np.array([3, 4, 5, 6, 7]), np.array([8, 9, 10])),
(np.array([6, 7, 8, 9, 10]), np.array([11, 12, 13])),
(np.array([9, 10, 11, 12, 13]), np.array([14, 15, 16])),
(np.array([12, 13, 14, 15, 16]), np.array([17, 18, 19])),
(np.array([15, 16, 17, 18, 19]), np.array([20, 21, 22])),
]
for result_indices, expected_indices in zip(list(cv.split(X_train, y_train)), expected):
np.testing.assert_array_equal(result_indices[0], expected_indices[0])
np.testing.assert_array_equal(result_indices[1], expected_indices[1])

# Polars doesn't have an index, so this class behaves a bit differenly for
# index-less objects. We need to first ensure that `date_serie`, `X_train`,
# and `y_train` all have the same length.
date_serie = df["date"].loc[X_train.index]
cv = TimeGapSplit(
date_serie=pl.from_pandas(date_serie),
train_duration=timedelta(days=5),
valid_duration=timedelta(days=3),
gap_duration=timedelta(days=0),
)
expected = [
(np.array([0, 1, 2, 3, 4]), np.array([5, 6, 7])),
(np.array([3, 4, 5, 6, 7]), np.array([8, 9, 10])),
(np.array([6, 7, 8, 9, 10]), np.array([11, 12, 13])),
(np.array([9, 10, 11, 12, 13]), np.array([14, 15, 16])),
(np.array([12, 13, 14, 15, 16]), np.array([17, 18, 19])),
(np.array([15, 16, 17, 18, 19]), np.array([20, 21, 22])),
]
for result_indices, expected_indices in zip(
list(cv.split(pl.from_pandas(X_train), pl.from_pandas(y_train))), expected
):
np.testing.assert_array_equal(result_indices[0], expected_indices[0])
np.testing.assert_array_equal(result_indices[1], expected_indices[1])


def test_timegapsplit_too_big_gap():
try:
Expand Down Expand Up @@ -151,5 +190,83 @@ def test_timegapsplit_summary():
)

summary = cv.summary(X_train)

assert summary.shape == (12, 5)

expected_data = {
"Start date": [
datetime.datetime(2018, 1, 1, 0, 0),
datetime.datetime(2018, 1, 6, 0, 0),
datetime.datetime(2018, 1, 4, 0, 0),
datetime.datetime(2018, 1, 9, 0, 0),
datetime.datetime(2018, 1, 7, 0, 0),
datetime.datetime(2018, 1, 12, 0, 0),
datetime.datetime(2018, 1, 10, 0, 0),
datetime.datetime(2018, 1, 15, 0, 0),
datetime.datetime(2018, 1, 13, 0, 0),
datetime.datetime(2018, 1, 18, 0, 0),
datetime.datetime(2018, 1, 16, 0, 0),
datetime.datetime(2018, 1, 21, 0, 0),
],
"End date": [
datetime.datetime(2018, 1, 5, 0, 0),
datetime.datetime(2018, 1, 8, 0, 0),
datetime.datetime(2018, 1, 8, 0, 0),
datetime.datetime(2018, 1, 11, 0, 0),
datetime.datetime(2018, 1, 11, 0, 0),
datetime.datetime(2018, 1, 14, 0, 0),
datetime.datetime(2018, 1, 14, 0, 0),
datetime.datetime(2018, 1, 17, 0, 0),
datetime.datetime(2018, 1, 17, 0, 0),
datetime.datetime(2018, 1, 20, 0, 0),
datetime.datetime(2018, 1, 20, 0, 0),
datetime.datetime(2018, 1, 23, 0, 0),
],
"Period": [
datetime.timedelta(days=4),
datetime.timedelta(days=2),
datetime.timedelta(days=4),
datetime.timedelta(days=2),
datetime.timedelta(days=4),
datetime.timedelta(days=2),
datetime.timedelta(days=4),
datetime.timedelta(days=2),
datetime.timedelta(days=4),
datetime.timedelta(days=2),
datetime.timedelta(days=4),
datetime.timedelta(days=2),
],
"Unique days": [5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3],
"nbr samples": [5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3],
"part": [
"train",
"valid",
"train",
"valid",
"train",
"valid",
"train",
"valid",
"train",
"valid",
"train",
"valid",
],
"fold": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
}
expected = pd.DataFrame(expected_data).set_index(["fold", "part"])
pandas_assert_frame_equal(summary, expected)

# Polars doesn't have an index, so this class behaves a bit differenly for
# index-less objects. We need to ensure that `date_serie` and `X_train` have
# the same length.
date_serie = df["date"].loc[X_train.index]
cv = TimeGapSplit(
date_serie=pl.from_pandas(date_serie),
train_duration=timedelta(days=5),
valid_duration=timedelta(days=3),
gap_duration=timedelta(days=0),
)
summary = cv.summary(pl.from_pandas(X_train))

expected = pl.DataFrame(expected_data)
polars_assert_frame_equal(summary, expected)

0 comments on commit 25b04f9

Please sign in to comment.