Skip to content

Commit

Permalink
Fix metrics' representation in backtest (#192)
Browse files Browse the repository at this point in the history
* Backtest start

* Backtest fix tests

* Add backtest utils

* Add deprecated decorator

* Add backtest to pipeline

* Revert "Add deprecated decorator"

This reverts commit e507ae7.

* Add docstring

* Upd CHANGELOG

* Make backtest work with pipelines

* Apply suggestions from code review

Co-authored-by: Martin Gabdushev <[email protected]>

* Fix metric.class.name -> metric.repr in backtest metrics computation

* Add conftests for pipelines, ensembles

* Upd CHANGELOG

* Fix metrics name in pipeline

* Upd CHANGELOG

* Format tests

Co-authored-by: Martin Gabdushev <[email protected]>
  • Loading branch information
julia-shenshina and martins0n authored Oct 13, 2021
1 parent 9a9f2c3 commit fd34ec1
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 21 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- TrendTransform ([#139](https://github.com/tinkoff-ai/etna-ts/pull/139))
- Running notebooks in ci ([#134](https://github.com/tinkoff-ai/etna-ts/issues/134))
- Cluster plotter to EDA ([#169](https://github.com/tinkoff-ai/etna-ts/pull/169))
- Pipeline.backtest method ([#161](https://github.com/tinkoff-ai/etna-ts/pull/161))
- Pipeline.backtest method ([#161](https://github.com/tinkoff-ai/etna-ts/pull/161), [#192](https://github.com/tinkoff-ai/etna-ts/pull/192))
- STLTransform class ([#158](https://github.com/tinkoff-ai/etna-ts/pull/158))
- NN_examples notebook ([#159](https://github.com/tinkoff-ai/etna-ts/pull/159))
- Example for ProphetModel ([#178](https://github.com/tinkoff-ai/etna-ts/pull/178))
Expand All @@ -41,7 +41,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add Correlation heatmap in EDA notebook ([#144](https://github.com/tinkoff-ai/etna-ts/pull/144))
- Add `__repr__` for Pipeline ([#151](https://github.com/tinkoff-ai/etna-ts/pull/151))
- Defined random state for every test cases ([#155](https://github.com/tinkoff-ai/etna-ts/pull/155))
- TimeSeriesCrossValidation returns `Metric.__repr__` as a key in `backtest`'s return values ([#161](https://github.com/tinkoff-ai/etna-ts/pull/161))
- Add confidence intervals to Prophet ([#153](https://github.com/tinkoff-ai/etna-ts/pull/153))
- Add confidence intervals to SARIMA ([#172](https://github.com/tinkoff-ai/etna-ts/pull/172))

Expand Down
2 changes: 1 addition & 1 deletion etna/model_selection/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _compute_metrics(self, y_true: TSDataset, y_pred: TSDataset) -> Dict[str, fl
"""
metrics = {}
for metric in self.metrics:
metrics[metric.__repr__()] = metric(y_true=y_true, y_pred=y_pred)
metrics[metric.__class__.__name__] = metric(y_true=y_true, y_pred=y_pred)
return metrics

def get_forecasts(self) -> pd.DataFrame:
Expand Down
11 changes: 9 additions & 2 deletions etna/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from etna.loggers import tslogger
from etna.metrics import Metric
from etna.metrics import MetricAggregationMode
from etna.metrics.utils import compute_metrics
from etna.models.base import Model
from etna.transforms.base import Transform

Expand Down Expand Up @@ -150,6 +149,14 @@ def _generate_folds_datasets(

yield train, test

@staticmethod
def _compute_metrics(metrics: List[Metric], y_true: TSDataset, y_pred: TSDataset) -> Dict[str, float]:
"""Compute metrics for given y_true, y_pred."""
metrics_values = {}
for metric in metrics:
metrics_values[metric.__class__.__name__] = metric(y_true=y_true, y_pred=y_pred)
return metrics_values

def _run_fold(
self,
train: TSDataset,
Expand All @@ -170,7 +177,7 @@ def _run_fold(
fold[f"{stage_name}_timerange"]["start"] = stage_df.index.min()
fold[f"{stage_name}_timerange"]["end"] = stage_df.index.max()
fold["forecast"] = forecast
fold["metrics"] = deepcopy(compute_metrics(metrics=metrics, y_true=test, y_pred=forecast))
fold["metrics"] = deepcopy(self._compute_metrics(metrics=metrics, y_true=test, y_pred=forecast))

tslogger.log_backtest_run(pd.DataFrame(fold["metrics"]), forecast.to_pandas(), test.to_pandas())
tslogger.finish_experiment()
Expand Down
10 changes: 2 additions & 8 deletions tests/test_model_selection/test_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,11 @@ def _fit_backtest_pipeline(
(
(
False,
[
"MAE(mode = 'per-segment', )",
"MSE(mode = 'per-segment', )",
"SMAPE(mode = 'per-segment', )",
"fold_number",
"segment",
],
["fold_number", "MAE", "MSE", "segment", "SMAPE"],
),
(
True,
["MAE(mode = 'per-segment', )", "MSE(mode = 'per-segment', )", "SMAPE(mode = 'per-segment', )", "segment"],
["MAE", "MSE", "segment", "SMAPE"],
),
),
)
Expand Down
10 changes: 2 additions & 8 deletions tests/test_pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,11 @@ def test_generate_constant_timeranges_hours():
(
(
False,
[
"fold_number",
"MAE(mode = 'per-segment', )",
"MSE(mode = 'per-segment', )",
"segment",
"SMAPE(mode = 'per-segment', )",
],
["fold_number", "MAE", "MSE", "segment", "SMAPE"],
),
(
True,
["MAE(mode = 'per-segment', )", "MSE(mode = 'per-segment', )", "segment", "SMAPE(mode = 'per-segment', )"],
["MAE", "MSE", "segment", "SMAPE"],
),
),
)
Expand Down

0 comments on commit fd34ec1

Please sign in to comment.