-
Notifications
You must be signed in to change notification settings - Fork 80
Fix DeepARModel
and TFTModel
to work with changed prediction_size
#1251
Conversation
DeepARModel
and TFTModel
DeepARModel
and TFTModel
to work with changed prediction_size
🚀 Deployed on https://deploy-preview-1251--etna-docs.netlify.app |
Codecov Report
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more @@ Coverage Diff @@
## master #1251 +/- ##
==========================================
+ Coverage 87.31% 87.67% +0.35%
==========================================
Files 175 175
Lines 10330 10330
==========================================
+ Hits 9020 9057 +37
+ Misses 1310 1273 -37
... and 5 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
@@ -139,19 +140,17 @@ def test_forecast_model_equals_pipeline(example_tsds): | |||
horizon = 10 | |||
pfdb = _get_default_dataset_builder(horizon) | |||
|
|||
import torch # TODO: remove after fix at issue-802 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't make fix, do we?
We use seed as in the past anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought, that it isn't really the problem that is isn't deterministic. If someone needs it to be deterministic he can fix the seeds.
@@ -2,6 +2,7 @@ | |||
|
|||
import pandas as pd | |||
import pytest | |||
from lightning_fabric.utilities.seed import seed_everything |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we don't have that package in pyproject.toml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understand, it is a part of pytorch_lightning package: source. I first tried to use pytorch_lightning.utilities.seed, but it is deprecated in favor of lightning_fabric.utilities.seed
.
# `TimeSeriesDataSet.from_parameters` in predict mode ignores `min_prediction_length`, | ||
# and we can change prediction size only by changing `max_prediction_length` | ||
dataset_params = deepcopy(self.pf_dataset_params) | ||
dataset_params["max_prediction_length"] = horizon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems it could change the behaviour.
Have you checked both results - before changing and after?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll explain the core of the problem. The problem was that max_prediction_length
is set during training. You can set min_predicition_length
, but it is ignored during forecasting and set equal to max_prediction_length
. It is how pf works.
It leads to the situation when you can't make a forecast on dataset with smaller horizon that was used during training. It expects to forecast max_prediction_length
points.
About the identity of the results I'll write report below in this discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a script
from pytorch_forecasting.data import GroupNormalizer
from lightning_fabric.utilities.seed import seed_everything
from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
from etna.models.nn import DeepARModel, PytorchForecastingDatasetBuilder
from etna.pipeline import Pipeline
def main():
# load data
df = generate_ar_df(periods=100, n_segments=3, start_time="2020-01-01", freq="D", random_seed=0)
ts = TSDataset(df=TSDataset.to_dataset(df), freq="D")
# fit pipeline
builder = PytorchForecastingDatasetBuilder(
max_encoder_length=5,
max_prediction_length=5,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
target_normalizer=GroupNormalizer(groups=["segment"]),
)
model = DeepARModel(dataset_builder=builder, trainer_params=dict(max_epochs=5), lr=0.01)
pipeline = Pipeline(model=model, horizon=5)
seed_everything(0)
pipeline.fit(ts)
# forecast
ts_forecast_1 = pipeline.forecast()
print(ts_forecast_1.to_pandas(flatten=True))
if __name__ == "__main__":
main()
On this branch the result is:
timestamp segment target
0 2020-04-10 segment_0 5.302200
1 2020-04-11 segment_0 5.245512
2 2020-04-12 segment_0 5.072897
3 2020-04-13 segment_0 4.956703
4 2020-04-14 segment_0 4.994010
5 2020-04-10 segment_1 8.561027
6 2020-04-11 segment_1 8.674180
7 2020-04-12 segment_1 8.962281
8 2020-04-13 segment_1 8.377350
9 2020-04-14 segment_1 8.384933
10 2020-04-10 segment_2 -6.026866
11 2020-04-11 segment_2 -6.008684
12 2020-04-12 segment_2 -5.824715
13 2020-04-13 segment_2 -6.108476
14 2020-04-14 segment_2 -5.785951
The same result is the same on the current master
branch.
Code behavior isn't the same for all cases, because never version fixes inference tests. But I think this example proves that for normal scenario we haven't change logic of DeepARModel
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this example both max_prediction_lenght and horizon are the same.
I guess difference could arise in case of horzion smaller than max_prediction_lenght in case of transformers ( maybe if there is a difference it's a bug of source library - we shouldn't have bug in case of causal transformers )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another experiment:
from pytorch_forecasting.data import GroupNormalizer
from lightning_fabric.utilities.seed import seed_everything
from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
from etna.models.nn import DeepARModel, PytorchForecastingDatasetBuilder
def main():
# load data
df = generate_ar_df(periods=100, n_segments=3, start_time="2020-01-01", freq="D", random_seed=0)
ts = TSDataset(df=TSDataset.to_dataset(df), freq="D")
# fit pipeline
builder = PytorchForecastingDatasetBuilder(
max_encoder_length=5,
max_prediction_length=5,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
target_normalizer=GroupNormalizer(groups=["segment"]),
)
model = DeepARModel(dataset_builder=builder, trainer_params=dict(max_epochs=5), lr=0.01)
seed_everything(0)
model.fit(ts)
# forecast
future = ts.make_future(future_steps=3, tail_steps=model.context_size)
result = model.forecast(future, prediction_size=3)
print(result.to_pandas(flatten=True))
if __name__ == "__main__":
main()
Current master branch: fails with error AssertionError: filters should not remove entries all entries - check encoder/decoder lengths and lags
.
This branch: works fine with result:
timestamp segment target
0 2020-04-10 segment_0 5.302200
1 2020-04-11 segment_0 5.245512
2 2020-04-12 segment_0 5.072897
3 2020-04-10 segment_1 8.561027
4 2020-04-11 segment_1 8.674180
5 2020-04-12 segment_1 8.962281
6 2020-04-10 segment_2 -6.026866
7 2020-04-11 segment_2 -6.008684
8 2020-04-12 segment_2 -5.824716
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal of this change was exactly to make it possible to make a forecast with smaller horizon. We have some inference tests that worked in a similar scenario before tsdataset-2.0 and stopped working after. So, I thought that we should make it work as before.
Before submitting (must do checklist)
Proposed Changes
DeepARModel
. If someone wants to make it deterministic he can useseed_everything
.Closing issues
Closes #802.