Skip to content

Commit

Permalink
Update make_future and train_test_split regressors handling (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Jan 31, 2022
1 parent 287989c commit 7aa64f0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
13 changes: 10 additions & 3 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _update_regressors(self, transform: "Transform", columns_before: Set[str], c
else:
raise ValueError("Transform is not FutureMixin and does not have in_column attribute!")

new_regressors = [regressor for regressor in new_regressors if regressor not in self.regressors]
self._regressors.extend(new_regressors)

def __repr__(self):
Expand Down Expand Up @@ -281,7 +282,11 @@ def make_future(self, future_steps: int) -> "TSDataset":

future_dataset = df.tail(future_steps).copy(deep=True)
future_dataset = future_dataset.sort_index(axis=1, level=(0, 1))
future_ts = TSDataset(future_dataset, freq=self.freq)
future_ts = TSDataset(df=future_dataset, freq=self.freq)

# can't put known_future into constructor, _check_known_future fails with df_exog=None
future_ts.known_future = self.known_future
future_ts._regressors = self.regressors
future_ts.transforms = self.transforms
future_ts.df_exog = self.df_exog
return future_ts
Expand Down Expand Up @@ -749,13 +754,15 @@ def train_test_split(

train_df = self.df[train_start_defined:train_end_defined][self.raw_df.columns] # type: ignore
train_raw_df = self.raw_df[train_start_defined:train_end_defined] # type: ignore
train = TSDataset(df=train_df, df_exog=self.df_exog, freq=self.freq)
train = TSDataset(df=train_df, df_exog=self.df_exog, freq=self.freq, known_future=self.known_future)
train.raw_df = train_raw_df
train._regressors = self.regressors

test_df = self.df[test_start_defined:test_end_defined][self.raw_df.columns] # type: ignore
test_raw_df = self.raw_df[train_start_defined:test_end_defined] # type: ignore
test = TSDataset(df=test_df, df_exog=self.df_exog, freq=self.freq)
test = TSDataset(df=test_df, df_exog=self.df_exog, freq=self.freq, known_future=self.known_future)
test.raw_df = test_raw_df
test._regressors = self.regressors

return train, test

Expand Down
15 changes: 15 additions & 0 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,14 @@ def test_train_test_split_failed(test_size, borders, match, tsdf_with_exog):
)


def test_train_test_split_pass_regressors_to_output(df_and_regressors):
df, df_exog, known_future = df_and_regressors
ts = TSDataset(df=df, df_exog=df_exog, freq="D", known_future=known_future)
train, test = ts.train_test_split(test_size=5)
assert train.regressors == ts.regressors
assert test.regressors == ts.regressors


def test_dataset_datetime_conversion():
classic_df = generate_ar_df(periods=30, start_time="2021-06-01", n_segments=2)
classic_df["timestamp"] = classic_df["timestamp"].astype(str)
Expand Down Expand Up @@ -375,6 +383,13 @@ def test_make_future_with_regressors(df_and_regressors):
assert set(ts_future.columns.get_level_values("feature")) == {"target", "regressor_1", "regressor_2"}


def test_make_future_inherits_regressors(df_and_regressors):
df, df_exog, known_future = df_and_regressors
ts = TSDataset(df=df, df_exog=df_exog, freq="D", known_future=known_future)
ts_future = ts.make_future(10)
assert ts_future.regressors == ts.regressors


def test_make_future_warn_not_enough_regressors(df_and_regressors):
"""Check that warning is thrown if regressors don't have enough values for the future."""
df, df_exog, known_future = df_and_regressors
Expand Down

0 comments on commit 7aa64f0

Please sign in to comment.