Skip to content

Change ProphetModel and SARIMAXModel according to latest architecture #549

Merged
merged 3 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Rename `_CatBoostModel`, `_HoltWintersModel`, `_SklearnModel` ([#543](https://github.com/tinkoff-ai/etna/pull/543))
-
- Rename `_SARIMAXModel` and `_ProphetModel`, make `SARIMAXModel` and `ProphetModel` inherit from `PerSegmentPredictionIntervalModel` ([#549](https://github.com/tinkoff-ai/etna/pull/549))
-
### Fixed
- Fix `TSDataset._update_regressors` logic removing the regressors ([#489](https://github.com/tinkoff-ai/etna/pull/489))
- Fix `TSDataset.info`, `TSDataset.describe` methods ([#519](https://github.com/tinkoff-ai/etna/pull/519))
Expand Down
16 changes: 5 additions & 11 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Sequence
from typing import Union

import numpy as np
import pandas as pd

from etna.core.mixins import BaseMixin
Expand Down Expand Up @@ -181,7 +182,6 @@ def __init__(self, base_model: Any):
Internal model which will be used to forecast segments, expected to have fit/predict interface
"""
self._base_model = base_model
self._segments: Optional[List[str]] = None
self._models: Optional[Dict[str, Any]] = None

@log_decorator
Expand All @@ -198,7 +198,6 @@ def fit(self, ts: TSDataset) -> "PerSegmentBaseModel":
self:
Model after fit
"""
self._segments = ts.segments
self._models = {}
for segment in ts.segments:
self._models[segment] = deepcopy(self._base_model)
Expand All @@ -223,7 +222,7 @@ def get_model(self) -> Dict[str, Any]:
dictionary where key is segment and value is internal model
"""
if self._models is None:
raise ValueError("Can not get the dict with base models from not fitted model!")
raise ValueError("Can not get the dict with base models, the model is not fitted!")
return self._models

@staticmethod
Expand All @@ -235,17 +234,12 @@ def _forecast_segment(model: Any, segment: str, ts: TSDataset, *args, **kwargs)
dates = segment_features["timestamp"]
dates.reset_index(drop=True, inplace=True)
segment_predict = model.predict(df=segment_features, *args, **kwargs)
segment_predict = pd.DataFrame({"target": segment_predict})
if isinstance(segment_predict, np.ndarray):
segment_predict = pd.DataFrame({"target": segment_predict})
Comment on lines +237 to +238
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

segment_predict["segment"] = segment
segment_predict["timestamp"] = dates
return segment_predict

def _build_models(self):
"""Create a dict with models for each segment (if required)."""
self._models = {}
for segment in self._segments: # type: ignore
self._models[segment] = deepcopy(self._base_model)


class PerSegmentModel(PerSegmentBaseModel, ForecastAbstractModel):
"""Class for holding specific models for per-segment prediction."""
Expand Down Expand Up @@ -305,7 +299,7 @@ def __init__(self, base_model: Any):
"""
super().__init__(base_model=base_model)

@abstractmethod
@log_decorator
def forecast(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
) -> TSDataset:
Expand Down
27 changes: 27 additions & 0 deletions etna/models/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def __init__(
self._categorical = None

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_CatBoostAdapter":
"""
Fit Catboost model.

Parameters
----------
df:
Features dataframe
regressors:
List of the columns with regressors(ignored in this model)
Returns
-------
self:
Fitted model
"""
features = df.drop(columns=["timestamp", "target"])
target = df["target"]
self._categorical = features.select_dtypes(include=["category"]).columns.to_list()
Expand All @@ -44,6 +58,19 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_CatBoostAdapter":
return self

def predict(self, df: pd.DataFrame) -> np.ndarray:
"""
Compute predictions from a Catboost model.

Parameters
----------
df:
Features dataframe

Returns
-------
y_pred:
Array with predictions
"""
features = df.drop(columns=["timestamp", "target"])
predict_pool = Pool(features, cat_features=self._categorical)
pred = self.model.predict(predict_pool)
Expand Down
18 changes: 10 additions & 8 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Tuple
from typing import Union

import numpy as np
import pandas as pd
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.holtwinters import HoltWintersResults
Expand Down Expand Up @@ -171,17 +172,18 @@ def __init__(

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_HoltWintersAdapter":
"""
Fits a Holt-Winters' model.
Fit Holt-Winters' model.

Parameters
----------
df:
Features dataframe

regressors:
List of the columns with regressors(ignored in this model)
Returns
-------
self: _HoltWintersAdapter
fitted model
self:
Fitted model
"""
self._check_df(df)

Expand Down Expand Up @@ -213,7 +215,7 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_HoltWintersAdapter":
)
return self

def predict(self, df: pd.DataFrame) -> pd.Series:
def predict(self, df: pd.DataFrame) -> np.ndarray:
"""
Compute predictions from a Holt-Winters' model.

Expand All @@ -224,15 +226,15 @@ def predict(self, df: pd.DataFrame) -> pd.Series:

Returns
-------
y_pred: pd.Series
Series with predictions
y_pred:
Array with predictions
"""
if self._result is None or self._model is None:
raise ValueError("This model is not fitted! Fit the model before calling predict method!")
self._check_df(df)

forecast = self._result.predict(start=df["timestamp"].min(), end=df["timestamp"].max())
y_pred = pd.Series(data=forecast.values, name="target")
y_pred = forecast.values
return y_pred

def _check_df(self, df: pd.DataFrame):
Expand Down
103 changes: 12 additions & 91 deletions etna/models/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
import pandas as pd

from etna import SETTINGS
from etna.datasets import TSDataset
from etna.models.base import PerSegmentModel
from etna.models.base import log_decorator
from etna.models.base import PerSegmentPredictionIntervalModel

if SETTINGS.prophet_required:
from prophet import Prophet


class _ProphetModel:
class _ProphetAdapter:
"""Class for holding Prophet model."""

def __init__(
Expand Down Expand Up @@ -83,7 +81,7 @@ def __init__(

self.regressor_columns: Optional[List[str]] = None

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetAdapter":
"""
Fits a Prophet model.

Expand All @@ -104,9 +102,9 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetModel":
self.model.fit(prophet_df)
return self

def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]):
def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame:
"""
Compute Prophet predictions.
Compute predictions from a Prophet model.

Parameters
----------
Expand All @@ -119,7 +117,7 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen

Returns
-------
y_pred: pd.DataFrame
y_pred:
DataFrame with predictions
"""
df = df.reset_index()
Expand All @@ -134,10 +132,14 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
for quantile in quantiles:
percentile = quantile * 100
y_pred[f"yhat_{quantile:.4g}"] = self.model.percentile(sim_values["yhat"], percentile, axis=1)
rename_dict = {
column: column.replace("yhat", "target") for column in y_pred.columns if column.startswith("yhat")
}
y_pred = y_pred.rename(rename_dict, axis=1)
return y_pred


class ProphetModel(PerSegmentModel):
class ProphetModel(PerSegmentPredictionIntervalModel):
"""Class for holding Prophet model.

Examples
Expand Down Expand Up @@ -296,7 +298,7 @@ def __init__(
self.additional_seasonality_params = additional_seasonality_params

super(ProphetModel, self).__init__(
base_model=_ProphetModel(
base_model=_ProphetAdapter(
growth=self.growth,
n_changepoints=self.n_changepoints,
changepoints=self.changepoints,
Expand All @@ -316,84 +318,3 @@ def __init__(
additional_seasonality_params=self.additional_seasonality_params,
)
)

@log_decorator
def fit(self, ts: TSDataset) -> "ProphetModel":
"""Fit model."""
self._segments = ts.segments
self._build_models()

for segment in self._segments:
model = self._models[segment] # type: ignore
segment_features = ts[:, segment, :]
segment_features = segment_features.dropna()
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
model.fit(df=segment_features, regressors=ts.regressors)
return self

@staticmethod
def _forecast_one_segment(
model,
segment: Union[str, List[str]],
ts: TSDataset,
prediction_interval: bool,
quantiles: Sequence[float],
) -> pd.DataFrame:
segment_features = ts[:, segment, :]
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
dates = segment_features["timestamp"]
dates.reset_index(drop=True, inplace=True)
segment_predict = model.predict(
df=segment_features, prediction_interval=prediction_interval, quantiles=quantiles
)
rename_dict = {
column: column.replace("yhat", "target") for column in segment_predict.columns if column.startswith("yhat")
}
segment_predict = segment_predict.rename(rename_dict, axis=1)
segment_predict["segment"] = segment
segment_predict["timestamp"] = dates
return segment_predict

@log_decorator
def forecast(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
) -> TSDataset:
"""Make predictions.

Parameters
----------
ts:
Dataframe with features
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval

Returns
-------
TSDataset
Models result
"""
if self._segments is None:
raise ValueError("The model is not fitted yet, use fit() to train it")

result_list = list()
for segment in self._segments:
model = self._models[segment] # type: ignore

segment_predict = self._forecast_one_segment(model, segment, ts, prediction_interval, quantiles)
result_list.append(segment_predict)

# need real case to test
result_df = pd.concat(result_list, ignore_index=True)
result_df = result_df.set_index(["timestamp", "segment"])
df = ts.to_pandas(flatten=True)
df = df.set_index(["timestamp", "segment"])
df = df.combine_first(result_df).reset_index()

df = TSDataset.to_dataset(df)
ts.df = df
ts.inverse_transform()
return ts
Loading