From 1565b180224ff8f2098ab647cdb6cd06fba40877 Mon Sep 17 00:00:00 2001 From: Mr-Geekman <36005824+Mr-Geekman@users.noreply.github.com> Date: Thu, 11 Aug 2022 10:58:00 +0300 Subject: [PATCH] Teach `AutoARIMAModel` to work with out-sample predictions (#830) --- CHANGELOG.md | 2 +- etna/models/autoarima.py | 133 +-------- etna/models/sarimax.py | 325 +++++++++++----------- tests/test_models/test_autoarima_model.py | 4 +- tests/test_models/test_inference.py | 14 +- tests/test_models/test_sarimax_model.py | 4 +- 6 files changed, 181 insertions(+), 301 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48f62f687..be6f435e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - -- +- Teach AutoARIMAModel to work with out-sample predictions ([#830](https://github.com/tinkoff-ai/etna/pull/830)) - - - diff --git a/etna/models/autoarima.py b/etna/models/autoarima.py index f3eb0f6d2..bb6641de4 100644 --- a/etna/models/autoarima.py +++ b/etna/models/autoarima.py @@ -1,16 +1,12 @@ import warnings -from typing import List -from typing import Optional -from typing import Sequence -import numpy as np import pandas as pd import pmdarima as pm -from pmdarima.arima import ARIMA from statsmodels.tools.sm_exceptions import ValueWarning +from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper -from etna.models.base import BaseAdapter from etna.models.base import PerSegmentPredictionIntervalModel +from etna.models.sarimax import _SARIMAXBaseAdapter warnings.filterwarnings( message="No frequency information was provided, so inferred frequency .* will be used", @@ -20,7 +16,7 @@ ) -class _AutoARIMAAdapter(BaseAdapter): +class _AutoARIMAAdapter(_SARIMAXBaseAdapter): """ Class for holding auto arima model. @@ -45,126 +41,11 @@ def __init__( Training parameters for auto_arima from pmdarima package. """ self.kwargs = kwargs - self._model: Optional[ARIMA] = None - self.regressor_columns: List[str] = [] + super().__init__() - def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_AutoARIMAAdapter": - """ - Fits auto ARIMA model. - - Parameters - ---------- - df: - Features dataframe - regressors: - List of the columns with regressors - - Returns - ------- - : - Fitted model - """ - self.regressor_columns = regressors - categorical_cols = df.select_dtypes(include=["category"]).columns.tolist() - try: - df.loc[:, categorical_cols] = df[categorical_cols].astype(int) - except ValueError: - raise ValueError( - f"Categorical columns {categorical_cols} can not been converted to int.\n " - "Try to encode this columns manually." - ) - - self._check_df(df) - - targets = df["target"] - targets.index = df["timestamp"] - - exog_train = self._select_regressors(df) - - self._model = pm.auto_arima(df["target"], X=exog_train, **self.kwargs) - return self - - def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame: - """ - Compute predictions from auto ARIMA model. - - Parameters - ---------- - df: - Features dataframe - prediction_interval: - If True returns prediction interval for forecast - quantiles: - Levels of prediction distribution - - Returns - ------- - : - DataFrame with predictions - """ - if self._model is None: - raise ValueError("AutoARIMA model is not fitted! Fit the model before calling predict method!") - horizon = len(df) - self._check_df(df, horizon) - - categorical_cols = df.select_dtypes(include=["category"]).columns.tolist() - try: - df.loc[:, categorical_cols] = df[categorical_cols].astype(int) - except ValueError: - raise ValueError( - f"Categorical columns {categorical_cols} can not been converted to int.\n " - "Try to encode this columns manually." - ) - - exog_future = self._select_regressors(df) - if prediction_interval: - confints = np.unique([2 * i if i < 0.5 else 2 * (1 - i) for i in quantiles]) - - y_pred = pd.DataFrame({"target": self._model.predict(len(df), X=exog_future), "timestamp": df["timestamp"]}) - - for confint in confints: - forecast = self._model.predict(len(df), X=exog_future, return_conf_int=True, alpha=confint) - if confint / 2 in quantiles: - y_pred[f"target_{confint/2:.4g}"] = forecast[1][:, :1] - if 1 - confint / 2 in quantiles: - y_pred[f"target_{1 - confint/2:.4g}"] = forecast[1][:, 1:] - else: - y_pred = pd.DataFrame({"target": self._model.predict(len(df), X=exog_future), "timestamp": df["timestamp"]}) - y_pred = y_pred.reset_index(drop=True, inplace=False) - return y_pred - - def _check_df(self, df: pd.DataFrame, horizon: Optional[int] = None): - column_to_drop = [col for col in df.columns if col not in ["target", "timestamp"] + self.regressor_columns] - if column_to_drop: - warnings.warn( - message=f"AutoARIMA model does not work with exogenous features (features unknown in future).\n " - f"{column_to_drop} will be dropped" - ) - if horizon: - short_regressors = [regressor for regressor in self.regressor_columns if df[regressor].count() < horizon] - if short_regressors: - raise ValueError( - f"Regressors {short_regressors} are too short for chosen horizon value.\n " - "Try lower horizon value, or drop this regressors." - ) - - def _select_regressors(self, df: pd.DataFrame) -> Optional[pd.DataFrame]: - if self.regressor_columns: - exog_future = df[self.regressor_columns] - exog_future.index = df["timestamp"] - else: - exog_future = None - return exog_future - - def get_model(self) -> ARIMA: - """Get internal pmdarima.arima.arima.ARIMA model that is used inside etna class. - - Returns - ------- - : - Internal model - """ - return self._model + def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResultsWrapper: + model = pm.auto_arima(endog, X=exog, **self.kwargs) + return model.arima_res_ class AutoARIMAModel(PerSegmentPredictionIntervalModel): diff --git a/etna/models/sarimax.py b/etna/models/sarimax.py index de1d19b8f..1b5a9d880 100644 --- a/etna/models/sarimax.py +++ b/etna/models/sarimax.py @@ -1,4 +1,5 @@ import warnings +from abc import abstractmethod from datetime import datetime from typing import List from typing import Optional @@ -8,6 +9,7 @@ import pandas as pd from statsmodels.tools.sm_exceptions import ValueWarning from statsmodels.tsa.statespace.sarimax import SARIMAX +from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper from etna.libs.pmdarima_utils import seasonal_prediction_with_confidence from etna.models.base import BaseAdapter @@ -22,7 +24,163 @@ ) -class _SARIMAXAdapter(BaseAdapter): +class _SARIMAXBaseAdapter(BaseAdapter): + """Base class for adapters based on :py:class:`statsmodels.tsa.statespace.sarimax.SARIMAX`.""" + + def __init__(self): + self.regressor_columns = None + self._fit_results = None + self._freq = None + self._first_train_timestamp = None + + def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SARIMAXBaseAdapter": + """ + Fits a SARIMAX model. + + Parameters + ---------- + df: + Features dataframe + regressors: + List of the columns with regressors + + Returns + ------- + : + Fitted model + """ + self.regressor_columns = regressors + + self._encode_categoricals(df) + self._check_df(df) + + exog_train = self._select_regressors(df) + self._fit_results = self._get_fit_results(endog=df["target"], exog=exog_train) + + freq = pd.infer_freq(df["timestamp"], warn=False) + if freq is None: + raise ValueError("Can't determine frequency of a given dataframe") + self._freq = freq + self._first_train_timestamp = df["timestamp"].min() + + return self + + def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame: + """ + Compute predictions from a SARIMAX model. + + Parameters + ---------- + df: + Features dataframe + prediction_interval: + If True returns prediction interval for forecast + quantiles: + Levels of prediction distribution + + Returns + ------- + : + DataFrame with predictions + """ + if self._fit_results is None: + raise ValueError("Model is not fitted! Fit the model before calling predict method!") + + horizon = len(df) + self._encode_categoricals(df) + self._check_df(df, horizon) + + exog_future = self._select_regressors(df) + start_timestamp = df["timestamp"].min() + end_timestamp = df["timestamp"].max() + # determine index of start_timestamp if counting from first timestamp of train + start_idx = determine_num_steps( + start_timestamp=self._first_train_timestamp, end_timestamp=start_timestamp, freq=self._freq # type: ignore + ) + # determine index of end_timestamp if counting from first timestamp of train + end_idx = determine_num_steps( + start_timestamp=self._first_train_timestamp, end_timestamp=end_timestamp, freq=self._freq # type: ignore + ) + + if prediction_interval: + forecast, _ = seasonal_prediction_with_confidence( + arima_res=self._fit_results, start=start_idx, end=end_idx, X=exog_future, alpha=0.05 + ) + y_pred = pd.DataFrame({"mean": forecast}) + for quantile in quantiles: + # set alpha in the way to get a desirable quantile + alpha = min(quantile * 2, (1 - quantile) * 2) + _, borders = seasonal_prediction_with_confidence( + arima_res=self._fit_results, start=start_idx, end=end_idx, X=exog_future, alpha=alpha + ) + if quantile < 1 / 2: + series = borders[:, 0] + else: + series = borders[:, 1] + y_pred[f"mean_{quantile:.4g}"] = series + else: + forecast, _ = seasonal_prediction_with_confidence( + arima_res=self._fit_results, start=start_idx, end=end_idx, X=exog_future, alpha=0.05 + ) + y_pred = pd.DataFrame({"mean": forecast}) + + rename_dict = { + column: column.replace("mean", "target") for column in y_pred.columns if column.startswith("mean") + } + y_pred = y_pred.rename(rename_dict, axis=1) + return y_pred + + @abstractmethod + def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame) -> SARIMAXResultsWrapper: + pass + + def _check_df(self, df: pd.DataFrame, horizon: Optional[int] = None): + if self.regressor_columns is None: + raise ValueError("Something went wrong, regressor_columns is None!") + column_to_drop = [col for col in df.columns if col not in ["target", "timestamp"] + self.regressor_columns] + if column_to_drop: + warnings.warn( + message=f"SARIMAX model does not work with exogenous features (features unknown in future).\n " + f"{column_to_drop} will be dropped" + ) + if horizon: + short_regressors = [regressor for regressor in self.regressor_columns if df[regressor].count() < horizon] + if short_regressors: + raise ValueError( + f"Regressors {short_regressors} are too short for chosen horizon value.\n " + "Try lower horizon value, or drop this regressors." + ) + + def _select_regressors(self, df: pd.DataFrame) -> Optional[pd.DataFrame]: + if self.regressor_columns: + exog_future = df[self.regressor_columns] + exog_future.index = df["timestamp"] + else: + exog_future = None + return exog_future + + def _encode_categoricals(self, df: pd.DataFrame) -> None: + categorical_cols = df.select_dtypes(include=["category"]).columns.tolist() + try: + df.loc[:, categorical_cols] = df[categorical_cols].astype(int) + except ValueError: + raise ValueError( + f"Categorical columns {categorical_cols} can not been converted to int.\n " + "Try to encode this columns manually." + ) + + def get_model(self) -> SARIMAXResultsWrapper: + """Get :py:class:`statsmodels.tsa.statespace.sarimax.SARIMAXResultsWrapper` that is used inside etna class. + + Returns + ------- + : + Internal model + """ + return self._fit_results + + +class _SARIMAXAdapter(_SARIMAXBaseAdapter): """ Class for holding Sarimax model. @@ -163,48 +321,14 @@ def __init__( self.missing = missing self.validate_specification = validate_specification self.kwargs = kwargs - self._model: Optional[SARIMAX] = None - self._result: Optional[SARIMAX] = None - self.regressor_columns: Optional[List[str]] = None - self._freq = None - self._first_train_timestamp = None - - def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SARIMAXAdapter": - """ - Fits a SARIMAX model. - - Parameters - ---------- - df: - Features dataframe - regressors: - List of the columns with regressors - - Returns - ------- - : - Fitted model - """ - self.regressor_columns = regressors - categorical_cols = df.select_dtypes(include=["category"]).columns.tolist() - try: - df.loc[:, categorical_cols] = df[categorical_cols].astype(int) - except ValueError: - raise ValueError( - f"Categorical columns {categorical_cols} can not been converted to int.\n " - "Try to encode this columns manually." - ) - - self._check_df(df) + super().__init__() + def _get_fit_results(self, endog: pd.Series, exog: pd.DataFrame): # make it a numpy array for forgetting about indices, it is necessary for _seasonal_prediction_with_confidence - targets = df["target"].values - - exog_train = self._select_regressors(df) - - self._model = SARIMAX( - endog=targets, - exog=exog_train, + endog_np = endog.values + model = SARIMAX( + endog=endog_np, + exog=exog, order=self.order, seasonal_order=self.seasonal_order, trend=self.trend, @@ -224,123 +348,8 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SARIMAXAdapter": validate_specification=self.validate_specification, **self.kwargs, ) - self._result = self._model.fit() - - freq = pd.infer_freq(df["timestamp"], warn=False) - if freq is None: - raise ValueError("Can't determine frequency of a given dataframe") - self._freq = freq - - self._first_train_timestamp = df["timestamp"].min() - - return self - - def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame: - """ - Compute predictions from a SARIMAX model. - - Parameters - ---------- - df: - Features dataframe - prediction_interval: - If True returns prediction interval for forecast - quantiles: - Levels of prediction distribution - - Returns - ------- - : - DataFrame with predictions - """ - if self._result is None or self._model is None: - raise ValueError("SARIMAX model is not fitted! Fit the model before calling predict method!") - horizon = len(df) - self._check_df(df, horizon) - - categorical_cols = df.select_dtypes(include=["category"]).columns.tolist() - try: - df.loc[:, categorical_cols] = df[categorical_cols].astype(int) - except ValueError: - raise ValueError( - f"Categorical columns {categorical_cols} can not been converted to int.\n " - "Try to encode this columns manually." - ) - - exog_future = self._select_regressors(df) - start_timestamp = df["timestamp"].min() - end_timestamp = df["timestamp"].max() - # determine index of start_timestamp if counting from first timestamp of train - start_idx = determine_num_steps( - start_timestamp=self._first_train_timestamp, end_timestamp=start_timestamp, freq=self._freq # type: ignore - ) - # determine index of end_timestamp if counting from first timestamp of train - end_idx = determine_num_steps( - start_timestamp=self._first_train_timestamp, end_timestamp=end_timestamp, freq=self._freq # type: ignore - ) - - if prediction_interval: - forecast, _ = seasonal_prediction_with_confidence( - arima_res=self._result, start=start_idx, end=end_idx, X=exog_future, alpha=0.05 - ) - y_pred = pd.DataFrame({"mean": forecast}) - for quantile in quantiles: - # set alpha in the way to get a desirable quantile - alpha = min(quantile * 2, (1 - quantile) * 2) - _, borders = seasonal_prediction_with_confidence( - arima_res=self._result, start=start_idx, end=end_idx, X=exog_future, alpha=alpha - ) - if quantile < 1 / 2: - series = borders[:, 0] - else: - series = borders[:, 1] - y_pred[f"mean_{quantile:.4g}"] = series - else: - forecast, _ = seasonal_prediction_with_confidence( - arima_res=self._result, start=start_idx, end=end_idx, X=exog_future, alpha=0.05 - ) - y_pred = pd.DataFrame({"mean": forecast}) - - rename_dict = { - column: column.replace("mean", "target") for column in y_pred.columns if column.startswith("mean") - } - y_pred = y_pred.rename(rename_dict, axis=1) - return y_pred - - def _check_df(self, df: pd.DataFrame, horizon: Optional[int] = None): - if self.regressor_columns is None: - raise ValueError("Something went wrong, regressor_columns is None!") - column_to_drop = [col for col in df.columns if col not in ["target", "timestamp"] + self.regressor_columns] - if column_to_drop: - warnings.warn( - message=f"SARIMAX model does not work with exogenous features (features unknown in future).\n " - f"{column_to_drop} will be dropped" - ) - if horizon: - short_regressors = [regressor for regressor in self.regressor_columns if df[regressor].count() < horizon] - if short_regressors: - raise ValueError( - f"Regressors {short_regressors} are too short for chosen horizon value.\n " - "Try lower horizon value, or drop this regressors." - ) - - def _select_regressors(self, df: pd.DataFrame) -> Optional[pd.DataFrame]: - if self.regressor_columns: - exog_future = df[self.regressor_columns] - exog_future.index = df["timestamp"] - else: - exog_future = None - return exog_future - - def get_model(self) -> SARIMAX: - """Get internal statsmodels.tsa.statespace.sarimax.SARIMAX model that is used inside etna class. - - Returns - ------- - : - Internal model - """ - return self._model + result = model.fit() + return result class SARIMAXModel(PerSegmentPredictionIntervalModel): diff --git a/tests/test_models/test_autoarima_model.py b/tests/test_models/test_autoarima_model.py index 376a32efc..503391482 100644 --- a/tests/test_models/test_autoarima_model.py +++ b/tests/test_models/test_autoarima_model.py @@ -1,5 +1,5 @@ import pytest -from pmdarima.arima import ARIMA +from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper from etna.models import AutoARIMAModel from etna.pipeline import Pipeline @@ -135,7 +135,7 @@ def test_get_model_after_training(example_tsds): models_dict = pipeline.model.get_model() assert isinstance(models_dict, dict) for segment in example_tsds.segments: - assert isinstance(models_dict[segment], ARIMA) + assert isinstance(models_dict[segment], SARIMAXResultsWrapper) def test_autoarima_forecast_1_point(example_tsds): diff --git a/tests/test_models/test_inference.py b/tests/test_models/test_inference.py index b140df326..7f43a484d 100644 --- a/tests/test_models/test_inference.py +++ b/tests/test_models/test_inference.py @@ -341,6 +341,7 @@ def test_forecast_out_sample_prefix(model, transforms, example_tsds): (LinearMultiSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]), (ElasticPerSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]), (ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]), + (AutoARIMAModel(), []), (ProphetModel(), []), (SARIMAXModel(), []), (HoltModel(), []), @@ -394,7 +395,6 @@ def test_forecast_out_sample_suffix_not_implemented(model, transforms, example_t @pytest.mark.parametrize( "model, transforms", [ - (AutoARIMAModel(), []), (MovingAverageModel(window=3), []), (SeasonalMovingAverageModel(), []), (NaiveModel(lag=3), []), @@ -413,6 +413,7 @@ def test_forecast_out_sample_suffix_failed(model, transforms, example_tsds): (LinearMultiSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]), (ElasticPerSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]), (ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]), + (AutoARIMAModel(), []), (ProphetModel(), []), (SARIMAXModel(), []), (HoltModel(), []), @@ -424,17 +425,6 @@ def test_forecast_mixed_in_out_sample(model, transforms, example_tsds): _test_forecast_mixed_in_out_sample(example_tsds, model, transforms) -@pytest.mark.xfail(strict=True) -@pytest.mark.parametrize( - "model, transforms", - [ - (AutoARIMAModel(), []), - ], -) -def test_forecast_mixed_in_out_sample_failed(model, transforms, example_tsds): - _test_forecast_mixed_in_out_sample(example_tsds, model, transforms) - - @pytest.mark.parametrize( "model, transforms", [ diff --git a/tests/test_models/test_sarimax_model.py b/tests/test_models/test_sarimax_model.py index 4a6e581c0..b156f9456 100644 --- a/tests/test_models/test_sarimax_model.py +++ b/tests/test_models/test_sarimax_model.py @@ -1,5 +1,5 @@ import pytest -from statsmodels.tsa.statespace.sarimax import SARIMAX +from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper from etna.models import SARIMAXModel from etna.pipeline import Pipeline @@ -115,7 +115,7 @@ def test_get_model_after_training(example_tsds): models_dict = pipeline.model.get_model() assert isinstance(models_dict, dict) for segment in example_tsds.segments: - assert isinstance(models_dict[segment], SARIMAX) + assert isinstance(models_dict[segment], SARIMAXResultsWrapper) def test_sarimax_forecast_1_point(example_tsds):