From 3a8cafa48454735a208e356a5627f1e2119c58e1 Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Wed, 26 Jul 2023 15:54:39 +0300 Subject: [PATCH 1/4] added tests --- tests/test_commands/conftest.py | 36 ++++++++++++++++++++++++++++ tests/test_commands/test_backtest.py | 24 ++++++++++++------- tests/test_commands/test_forecast.py | 32 +++++++++++++++++-------- 3 files changed, 73 insertions(+), 19 deletions(-) diff --git a/tests/test_commands/conftest.py b/tests/test_commands/conftest.py index e1c48f13f..0663ab8a8 100644 --- a/tests/test_commands/conftest.py +++ b/tests/test_commands/conftest.py @@ -50,6 +50,42 @@ def base_pipeline_with_context_size_yaml_path(): tmp.close() +@pytest.fixture +def base_ensemble_yaml_path(): + tmp = NamedTemporaryFile("w") + tmp.write( + """ + _target_: etna.ensembles.VotingEnsemble + pipelines: + - _target_: etna.pipeline.Pipeline + horizon: 4 + model: + _target_: etna.models.SeasonalMovingAverageModel + seasonality: 4 + window: 1 + transforms: [] + - _target_: etna.pipeline.Pipeline + horizon: 4 + model: + _target_: etna.models.SeasonalMovingAverageModel + seasonality: 7 + window: 2 + transforms: [] + - _target_: etna.pipeline.Pipeline + horizon: 4 + model: + _target_: etna.models.SeasonalMovingAverageModel + seasonality: 7 + window: 7 + transforms: [] + context_size: 49 + """ + ) + tmp.flush() + yield Path(tmp.name) + tmp.close() + + @pytest.fixture def elementary_linear_model_pipeline(): tmp = NamedTemporaryFile("w") diff --git a/tests/test_commands/test_backtest.py b/tests/test_commands/test_backtest.py index 0c253bc6e..8ab83e2b3 100644 --- a/tests/test_commands/test_backtest.py +++ b/tests/test_commands/test_backtest.py @@ -67,14 +67,16 @@ def backtest_with_stride_yaml_path(): tmp.close() -def test_dummy_run(base_pipeline_yaml_path, base_backtest_yaml_path, base_timeseries_path): +@pytest.mark.parametrize("pipeline_path_name", ("base_pipeline_yaml_path", "base_ensemble_yaml_path")) +def test_dummy_run(pipeline_path_name, base_backtest_yaml_path, base_timeseries_path, request): tmp_output = TemporaryDirectory() tmp_output_path = Path(tmp_output.name) + pipeline_path = request.getfixturevalue(pipeline_path_name) run( [ "etna", "backtest", - str(base_pipeline_yaml_path), + str(pipeline_path), str(base_backtest_yaml_path), str(base_timeseries_path), "D", @@ -85,16 +87,18 @@ def test_dummy_run(base_pipeline_yaml_path, base_backtest_yaml_path, base_timese assert Path.exists(tmp_output_path / file_name) +@pytest.mark.parametrize("pipeline_path_name", ("base_pipeline_yaml_path", "base_ensemble_yaml_path")) def test_dummy_run_with_exog( - base_pipeline_yaml_path, base_backtest_yaml_path, base_timeseries_path, base_timeseries_exog_path + pipeline_path_name, base_backtest_yaml_path, base_timeseries_path, base_timeseries_exog_path, request ): tmp_output = TemporaryDirectory() tmp_output_path = Path(tmp_output.name) + pipeline_path = request.getfixturevalue(pipeline_path_name) run( [ "etna", "backtest", - str(base_pipeline_yaml_path), + str(pipeline_path), str(base_backtest_yaml_path), str(base_timeseries_path), "D", @@ -126,16 +130,18 @@ def test_forecast_format(base_pipeline_yaml_path, base_backtest_yaml_path, base_ @pytest.mark.parametrize( - "backtest_config_path_name,expected", + "pipeline_path_name,backtest_config_path_name,expected", ( - ("backtest_with_folds_estimation_yaml_path", 24), - ("backtest_with_stride_yaml_path", 1), + ("base_pipeline_with_context_size_yaml_path", "backtest_with_folds_estimation_yaml_path", 24), + ("base_ensemble_yaml_path", "backtest_with_folds_estimation_yaml_path", 12), + ("base_pipeline_with_context_size_yaml_path", "backtest_with_stride_yaml_path", 1), ), ) def test_backtest_estimate_n_folds( - base_pipeline_with_context_size_yaml_path, backtest_config_path_name, base_timeseries_path, expected, request + pipeline_path_name, backtest_config_path_name, base_timeseries_path, expected, request ): backtest_config_path = request.getfixturevalue(backtest_config_path_name) + pipeline_path = request.getfixturevalue(pipeline_path_name) tmp_output = TemporaryDirectory() tmp_output_path = Path(tmp_output.name) @@ -143,7 +149,7 @@ def test_backtest_estimate_n_folds( [ "etna", "backtest", - str(base_pipeline_with_context_size_yaml_path), + str(pipeline_path), str(backtest_config_path), str(base_timeseries_path), "D", diff --git a/tests/test_commands/test_forecast.py b/tests/test_commands/test_forecast.py index cf0ddaa1e..383ff93aa 100644 --- a/tests/test_commands/test_forecast.py +++ b/tests/test_commands/test_forecast.py @@ -59,14 +59,16 @@ def base_forecast_with_folds_estimation_omegaconf_path(): tmp.close() -def test_dummy_run_with_exog(base_pipeline_yaml_path, base_timeseries_path, base_timeseries_exog_path): +@pytest.mark.parametrize("pipeline_path_name", ("base_pipeline_yaml_path", "base_ensemble_yaml_path")) +def test_dummy_run_with_exog(pipeline_path_name, base_timeseries_path, base_timeseries_exog_path, request): tmp_output = NamedTemporaryFile("w") tmp_output_path = Path(tmp_output.name) + pipeline_path = request.getfixturevalue(pipeline_path_name) run( [ "etna", "forecast", - str(base_pipeline_yaml_path), + str(pipeline_path), str(base_timeseries_path), "D", str(tmp_output_path), @@ -103,16 +105,18 @@ def test_dummy_run(base_pipeline_yaml_path, base_timeseries_path): assert len(df_output) == 2 * 4 +@pytest.mark.parametrize("pipeline_path_name", ("base_pipeline_yaml_path", "base_ensemble_yaml_path")) def test_run_with_predictive_intervals( - base_pipeline_yaml_path, base_timeseries_path, base_timeseries_exog_path, base_forecast_omegaconf_path + pipeline_path_name, base_timeseries_path, base_timeseries_exog_path, base_forecast_omegaconf_path, request ): tmp_output = NamedTemporaryFile("w") tmp_output_path = Path(tmp_output.name) + pipeline_path = request.getfixturevalue(pipeline_path_name) run( [ "etna", "forecast", - str(base_pipeline_yaml_path), + str(pipeline_path), str(base_timeseries_path), "D", str(tmp_output_path), @@ -220,17 +224,21 @@ def test_filter_forecast(forecast_params, expected, example_tsds): ], ) def test_forecast_start_timestamp( - model_pipeline, base_timeseries_path, base_timeseries_exog_path, start_timestamp_forecast_omegaconf_path, request + pipeline_path_name, + base_timeseries_path, + base_timeseries_exog_path, + start_timestamp_forecast_omegaconf_path, + request, ): tmp_output = NamedTemporaryFile("w") tmp_output_path = Path(tmp_output.name) - model_pipeline = request.getfixturevalue(model_pipeline) + pipeline_path = request.getfixturevalue(pipeline_path_name) run( [ "etna", "forecast", - str(model_pipeline), + str(pipeline_path), str(base_timeseries_path), "D", str(tmp_output_path), @@ -240,24 +248,28 @@ def test_forecast_start_timestamp( ) df_output = pd.read_csv(tmp_output_path) - assert len(df_output) == 3 * 2 # 3 predictions for 2 segments + assert len(df_output) == 4 * 2 # 4 predictions for 2 segments assert df_output["timestamp"].min() == "2021-09-10" # start_timestamp assert not np.any(df_output.isna().values) +@pytest.mark.parametrize("pipeline_path_name", ("base_pipeline_with_context_size_yaml_path", "base_ensemble_yaml_path")) def test_forecast_estimate_n_folds( - base_pipeline_with_context_size_yaml_path, + pipeline_path_name, base_forecast_with_folds_estimation_omegaconf_path, base_timeseries_path, base_timeseries_exog_path, + request, ): tmp_output = NamedTemporaryFile("w") tmp_output_path = Path(tmp_output.name) + pipeline_path = request.getfixturevalue(pipeline_path_name) + run( [ "etna", "forecast", - str(base_pipeline_with_context_size_yaml_path), + str(pipeline_path), str(base_timeseries_path), "D", str(tmp_output_path), From 08af37e53b4feed151e3fc6658177fe2dc74c524 Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Wed, 26 Jul 2023 15:55:12 +0300 Subject: [PATCH 2/4] handle ensemble horizons --- etna/commands/forecast_command.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/etna/commands/forecast_command.py b/etna/commands/forecast_command.py index d1b863ed3..f22294e4b 100644 --- a/etna/commands/forecast_command.py +++ b/etna/commands/forecast_command.py @@ -44,6 +44,14 @@ def compute_horizon(horizon: int, forecast_params: Dict[str, Any], tsdataset: TS return horizon +def update_horizon(pipeline_configs: Dict[str, Any], forecast_params: Dict[str, Any], tsdataset: TSDataset): + """Update the ``horizon`` parameter in the pipeline config if ``start_timestamp`` is set.""" + for config in pipeline_configs.get("pipelines", [pipeline_configs]): + horizon: int = config["horizon"] # type: ignore + horizon = compute_horizon(horizon=horizon, forecast_params=forecast_params, tsdataset=tsdataset) + config["horizon"] = horizon # type: ignore + + def filter_forecast(forecast_ts: TSDataset, forecast_params: Dict[str, Any]) -> TSDataset: """Filter out forecasts before `start_timestamp` if `start_timestamp` presented in `forecast_params`..""" if "start_timestamp" in forecast_params: @@ -122,9 +130,7 @@ def forecast( tsdataset = TSDataset(df=df_timeseries, freq=freq, df_exog=df_exog, known_future=k_f) - horizon: int = pipeline_configs["horizon"] # type: ignore - horizon = compute_horizon(horizon=horizon, forecast_params=forecast_params, tsdataset=tsdataset) - pipeline_configs["horizon"] = horizon # type: ignore + update_horizon(pipeline_configs=pipeline_configs, forecast_params=forecast_params, tsdataset=tsdataset) pipeline_args = remove_params(params=pipeline_configs, to_remove=ADDITIONAL_PIPELINE_PARAMETERS) pipeline: Pipeline = hydra_slayer.get_from_params(**pipeline_args) From f9ae54d24439b5dfa37693548c80b17585d0896a Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Wed, 26 Jul 2023 15:55:33 +0300 Subject: [PATCH 3/4] added tests --- tests/test_commands/test_forecast.py | 32 +++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/tests/test_commands/test_forecast.py b/tests/test_commands/test_forecast.py index 383ff93aa..c95ec7b6e 100644 --- a/tests/test_commands/test_forecast.py +++ b/tests/test_commands/test_forecast.py @@ -2,12 +2,17 @@ from subprocess import run from tempfile import NamedTemporaryFile +import hydra_slayer import numpy as np import pandas as pd import pytest +from omegaconf import OmegaConf +from etna.commands.forecast_command import ADDITIONAL_PIPELINE_PARAMETERS from etna.commands.forecast_command import compute_horizon from etna.commands.forecast_command import filter_forecast +from etna.commands.forecast_command import update_horizon +from etna.commands.utils import remove_params from etna.datasets import TSDataset @@ -217,11 +222,28 @@ def test_filter_forecast(forecast_params, expected, example_tsds): @pytest.mark.parametrize( - "model_pipeline", - [ - "elementary_linear_model_pipeline", - "elementary_boosting_model_pipeline", - ], + "forecast_params,pipeline_path_name,expected", + ( + ({"start_timestamp": "2020-04-10"}, "base_pipeline_with_context_size_yaml_path", 4), + ({"start_timestamp": "2020-04-12"}, "base_pipeline_with_context_size_yaml_path", 6), + ({"start_timestamp": "2020-04-11"}, "base_ensemble_yaml_path", 5), + ), +) +def test_update_horizon(pipeline_path_name, forecast_params, example_tsds, expected, request): + pipeline_path = request.getfixturevalue(pipeline_path_name) + pipeline_conf = OmegaConf.to_object(OmegaConf.load(pipeline_path)) + + update_horizon(pipeline_configs=pipeline_conf, forecast_params=forecast_params, tsdataset=example_tsds) + + pipeline_conf = remove_params(params=pipeline_conf, to_remove=ADDITIONAL_PIPELINE_PARAMETERS) + pipeline = hydra_slayer.get_from_params(**pipeline_conf) + + assert pipeline.horizon == expected + + +@pytest.mark.parametrize( + "pipeline_path_name", + ("base_pipeline_with_context_size_yaml_path", "base_ensemble_yaml_path"), ) def test_forecast_start_timestamp( pipeline_path_name, From efb4608e8ba8b2f377d56f29f3bb265e0a9017d7 Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Wed, 26 Jul 2023 16:01:49 +0300 Subject: [PATCH 4/4] updated changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1eff1ca78..bc48d187a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Fixed -- +- Pipeline ensembles fail in `etna forecast` CLI ([#1331](https://github.com/tinkoff-ai/etna/pull/1331)) - - - `mrmr` feature selection working with categoricals ([#1311](https://github.com/tinkoff-ai/etna/pull/1311))