Skip to content

Commit

Permalink
Change ProphetModel and SARIMAXModel according to latest architec…
Browse files Browse the repository at this point in the history
…ture (#549)
  • Loading branch information
alex-hse-repository authored Feb 18, 2022
1 parent 2242d40 commit 8f594cf
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 215 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Rename `_CatBoostModel`, `_HoltWintersModel`, `_SklearnModel` ([#543](https://github.com/tinkoff-ai/etna/pull/543))
-
- Rename `_SARIMAXModel` and `_ProphetModel`, make `SARIMAXModel` and `ProphetModel` inherit from `PerSegmentPredictionIntervalModel` ([#549](https://github.com/tinkoff-ai/etna/pull/549))
-
### Fixed
- Fix `TSDataset._update_regressors` logic removing the regressors ([#489](https://github.com/tinkoff-ai/etna/pull/489))
- Fix `TSDataset.info`, `TSDataset.describe` methods ([#519](https://github.com/tinkoff-ai/etna/pull/519))
Expand Down
16 changes: 5 additions & 11 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Sequence
from typing import Union

import numpy as np
import pandas as pd

from etna.core.mixins import BaseMixin
Expand Down Expand Up @@ -181,7 +182,6 @@ def __init__(self, base_model: Any):
Internal model which will be used to forecast segments, expected to have fit/predict interface
"""
self._base_model = base_model
self._segments: Optional[List[str]] = None
self._models: Optional[Dict[str, Any]] = None

@log_decorator
Expand All @@ -198,7 +198,6 @@ def fit(self, ts: TSDataset) -> "PerSegmentBaseModel":
self:
Model after fit
"""
self._segments = ts.segments
self._models = {}
for segment in ts.segments:
self._models[segment] = deepcopy(self._base_model)
Expand All @@ -223,7 +222,7 @@ def get_model(self) -> Dict[str, Any]:
dictionary where key is segment and value is internal model
"""
if self._models is None:
raise ValueError("Can not get the dict with base models from not fitted model!")
raise ValueError("Can not get the dict with base models, the model is not fitted!")
return self._models

@staticmethod
Expand All @@ -235,17 +234,12 @@ def _forecast_segment(model: Any, segment: str, ts: TSDataset, *args, **kwargs)
dates = segment_features["timestamp"]
dates.reset_index(drop=True, inplace=True)
segment_predict = model.predict(df=segment_features, *args, **kwargs)
segment_predict = pd.DataFrame({"target": segment_predict})
if isinstance(segment_predict, np.ndarray):
segment_predict = pd.DataFrame({"target": segment_predict})
segment_predict["segment"] = segment
segment_predict["timestamp"] = dates
return segment_predict

def _build_models(self):
"""Create a dict with models for each segment (if required)."""
self._models = {}
for segment in self._segments: # type: ignore
self._models[segment] = deepcopy(self._base_model)


class PerSegmentModel(PerSegmentBaseModel, ForecastAbstractModel):
"""Class for holding specific models for per-segment prediction."""
Expand Down Expand Up @@ -305,7 +299,7 @@ def __init__(self, base_model: Any):
"""
super().__init__(base_model=base_model)

@abstractmethod
@log_decorator
def forecast(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
) -> TSDataset:
Expand Down
27 changes: 27 additions & 0 deletions etna/models/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def __init__(
self._categorical = None

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_CatBoostAdapter":
"""
Fit Catboost model.
Parameters
----------
df:
Features dataframe
regressors:
List of the columns with regressors(ignored in this model)
Returns
-------
self:
Fitted model
"""
features = df.drop(columns=["timestamp", "target"])
target = df["target"]
self._categorical = features.select_dtypes(include=["category"]).columns.to_list()
Expand All @@ -44,6 +58,19 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_CatBoostAdapter":
return self

def predict(self, df: pd.DataFrame) -> np.ndarray:
"""
Compute predictions from a Catboost model.
Parameters
----------
df:
Features dataframe
Returns
-------
y_pred:
Array with predictions
"""
features = df.drop(columns=["timestamp", "target"])
predict_pool = Pool(features, cat_features=self._categorical)
pred = self.model.predict(predict_pool)
Expand Down
18 changes: 10 additions & 8 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Tuple
from typing import Union

import numpy as np
import pandas as pd
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.holtwinters import HoltWintersResults
Expand Down Expand Up @@ -171,17 +172,18 @@ def __init__(

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_HoltWintersAdapter":
"""
Fits a Holt-Winters' model.
Fit Holt-Winters' model.
Parameters
----------
df:
Features dataframe
regressors:
List of the columns with regressors(ignored in this model)
Returns
-------
self: _HoltWintersAdapter
fitted model
self:
Fitted model
"""
self._check_df(df)

Expand Down Expand Up @@ -213,7 +215,7 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_HoltWintersAdapter":
)
return self

def predict(self, df: pd.DataFrame) -> pd.Series:
def predict(self, df: pd.DataFrame) -> np.ndarray:
"""
Compute predictions from a Holt-Winters' model.
Expand All @@ -224,15 +226,15 @@ def predict(self, df: pd.DataFrame) -> pd.Series:
Returns
-------
y_pred: pd.Series
Series with predictions
y_pred:
Array with predictions
"""
if self._result is None or self._model is None:
raise ValueError("This model is not fitted! Fit the model before calling predict method!")
self._check_df(df)

forecast = self._result.predict(start=df["timestamp"].min(), end=df["timestamp"].max())
y_pred = pd.Series(data=forecast.values, name="target")
y_pred = forecast.values
return y_pred

def _check_df(self, df: pd.DataFrame):
Expand Down
103 changes: 12 additions & 91 deletions etna/models/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
import pandas as pd

from etna import SETTINGS
from etna.datasets import TSDataset
from etna.models.base import PerSegmentModel
from etna.models.base import log_decorator
from etna.models.base import PerSegmentPredictionIntervalModel

if SETTINGS.prophet_required:
from prophet import Prophet


class _ProphetModel:
class _ProphetAdapter:
"""Class for holding Prophet model."""

def __init__(
Expand Down Expand Up @@ -83,7 +81,7 @@ def __init__(

self.regressor_columns: Optional[List[str]] = None

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetModel":
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetAdapter":
"""
Fits a Prophet model.
Expand All @@ -104,9 +102,9 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetModel":
self.model.fit(prophet_df)
return self

def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]):
def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequence[float]) -> pd.DataFrame:
"""
Compute Prophet predictions.
Compute predictions from a Prophet model.
Parameters
----------
Expand All @@ -119,7 +117,7 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
Returns
-------
y_pred: pd.DataFrame
y_pred:
DataFrame with predictions
"""
df = df.reset_index()
Expand All @@ -134,10 +132,14 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
for quantile in quantiles:
percentile = quantile * 100
y_pred[f"yhat_{quantile:.4g}"] = self.model.percentile(sim_values["yhat"], percentile, axis=1)
rename_dict = {
column: column.replace("yhat", "target") for column in y_pred.columns if column.startswith("yhat")
}
y_pred = y_pred.rename(rename_dict, axis=1)
return y_pred


class ProphetModel(PerSegmentModel):
class ProphetModel(PerSegmentPredictionIntervalModel):
"""Class for holding Prophet model.
Examples
Expand Down Expand Up @@ -296,7 +298,7 @@ def __init__(
self.additional_seasonality_params = additional_seasonality_params

super(ProphetModel, self).__init__(
base_model=_ProphetModel(
base_model=_ProphetAdapter(
growth=self.growth,
n_changepoints=self.n_changepoints,
changepoints=self.changepoints,
Expand All @@ -316,84 +318,3 @@ def __init__(
additional_seasonality_params=self.additional_seasonality_params,
)
)

@log_decorator
def fit(self, ts: TSDataset) -> "ProphetModel":
"""Fit model."""
self._segments = ts.segments
self._build_models()

for segment in self._segments:
model = self._models[segment] # type: ignore
segment_features = ts[:, segment, :]
segment_features = segment_features.dropna()
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
model.fit(df=segment_features, regressors=ts.regressors)
return self

@staticmethod
def _forecast_one_segment(
model,
segment: Union[str, List[str]],
ts: TSDataset,
prediction_interval: bool,
quantiles: Sequence[float],
) -> pd.DataFrame:
segment_features = ts[:, segment, :]
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = segment_features.reset_index()
dates = segment_features["timestamp"]
dates.reset_index(drop=True, inplace=True)
segment_predict = model.predict(
df=segment_features, prediction_interval=prediction_interval, quantiles=quantiles
)
rename_dict = {
column: column.replace("yhat", "target") for column in segment_predict.columns if column.startswith("yhat")
}
segment_predict = segment_predict.rename(rename_dict, axis=1)
segment_predict["segment"] = segment
segment_predict["timestamp"] = dates
return segment_predict

@log_decorator
def forecast(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
) -> TSDataset:
"""Make predictions.
Parameters
----------
ts:
Dataframe with features
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval
Returns
-------
TSDataset
Models result
"""
if self._segments is None:
raise ValueError("The model is not fitted yet, use fit() to train it")

result_list = list()
for segment in self._segments:
model = self._models[segment] # type: ignore

segment_predict = self._forecast_one_segment(model, segment, ts, prediction_interval, quantiles)
result_list.append(segment_predict)

# need real case to test
result_df = pd.concat(result_list, ignore_index=True)
result_df = result_df.set_index(["timestamp", "segment"])
df = ts.to_pandas(flatten=True)
df = df.set_index(["timestamp", "segment"])
df = df.combine_first(result_df).reset_index()

df = TSDataset.to_dataset(df)
ts.df = df
ts.inverse_transform()
return ts
Loading

0 comments on commit 8f594cf

Please sign in to comment.