Skip to content

[BUG] *Logger fails with StackingEnsemble #458 #460

Merged
merged 3 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Add relevance_params in GaleShapleyFeatureSelectionTransform ([#410](https://github.com/tinkoff-ai/etna/pull/410))
- Docs for statistics transforms ([#441](https://github.com/tinkoff-ai/etna/pull/441))
- Logger fails with StackingEnsemble ([#460](https://github.com/tinkoff-ai/etna/pull/460))

## [1.5.0] - 2021-12-24
### Added
Expand Down
3 changes: 2 additions & 1 deletion etna/ensembles/stacking_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def _fit_pipeline(pipeline: Pipeline, ts: TSDataset) -> Pipeline:

def _backtest_pipeline(self, pipeline: Pipeline, ts: TSDataset) -> TSDataset:
"""Get forecasts from backtest for given pipeline."""
_, forecasts, _ = pipeline.backtest(ts, metrics=[MAE()], n_folds=self.n_folds)
with tslogger.disable():
_, forecasts, _ = pipeline.backtest(ts, metrics=[MAE()], n_folds=self.n_folds)
forecasts = TSDataset(df=forecasts, freq=ts.freq)
return forecasts

Expand Down
9 changes: 9 additions & 0 deletions etna/loggers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
Expand Down Expand Up @@ -173,6 +174,14 @@ def pl_loggers(self):
"""Pytorch lightning loggers."""
return [logger.pl_logger for logger in self.loggers if "_pl_logger" in vars(logger)]

@contextmanager
def disable(self):
"""Context manager for local logging disabling."""
temp_loggers = self.loggers
self.loggers = []
yield
self.loggers = temp_loggers


def percentile(n: int):
"""Percentile for pandas agg."""
Expand Down
40 changes: 40 additions & 0 deletions tests/test_loggers/test_file_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

from etna.datasets import TSDataset
from etna.ensembles import StackingEnsemble
from etna.loggers import LocalFileLogger
from etna.loggers import S3FileLogger
from etna.loggers import tslogger
Expand Down Expand Up @@ -218,6 +219,45 @@ def test_base_file_logger_log_backtest_metrics(example_tsds: TSDataset, aggregat
tslogger.remove(idx)


def test_local_file_logger_with_stacking_ensemble(example_df):
"""Test that LocalFileLogger correctly works in with stacking."""
with tempfile.TemporaryDirectory() as dirname:
cur_dir = pathlib.Path(dirname)
logger = LocalFileLogger(experiments_folder=dirname, gzip=False)

idx = tslogger.add(logger)
example_df = TSDataset.to_dataset(example_df)
example_df = TSDataset(example_df, freq="1H")
ensemble_pipeline = StackingEnsemble(
pipelines=[
Pipeline(
model=NaiveModel(lag=10),
transforms=[],
horizon=5,
),
Pipeline(
model=NaiveModel(lag=10),
transforms=[],
horizon=5,
),
]
)
n_folds = 5

_ = ensemble_pipeline.backtest(example_df, metrics=[MAE()], n_jobs=4, n_folds=n_folds)

assert len(list(cur_dir.iterdir())) == 1, "we've run one experiment"

current_experiment_dir = list(cur_dir.iterdir())[0]
assert len(list(current_experiment_dir.iterdir())) == 2, "crossval and crossval_results folders"

assert (
len(list((current_experiment_dir / "crossval").iterdir())) == n_folds
), "crossval should have `n_folds` runs"

tslogger.remove(idx)


def test_s3_file_logger_fail_init_endpoint_url(monkeypatch):
"""Test that S3FileLogger can't be created without setting 'endpoint_url' environment variable."""
monkeypatch.delenv("endpoint_url", raising=False)
Expand Down