-
Notifications
You must be signed in to change notification settings - Fork 3.2k
/
Copy pathautoets_model.py
134 lines (115 loc) · 4.03 KB
/
autoets_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# pylint: disable=too-many-arguments
"""Automatic ETS (Error, Trend, and Seasonality) Model"""
__docformat__ = "numpy"
import logging
from typing import Any, Union, Optional, List, Tuple
import warnings
import numpy as np
import pandas as pd
from statsforecast.models import ETS
from statsforecast.core import StatsForecast
from openbb_terminal.decorators import log_start_end
from openbb_terminal.rich_config import console
from openbb_terminal.forecast import helpers
warnings.simplefilter("ignore")
logger = logging.getLogger(__name__)
@log_start_end(log=logger)
def get_autoets_data(
data: Union[pd.Series, pd.DataFrame],
target_column: str = "close",
seasonal_periods: int = 7,
n_predict: int = 30,
start_window: float = 0.85,
forecast_horizon: int = 5,
) -> Tuple[list[np.ndarray], List[np.ndarray], List[np.ndarray], Optional[float], Any]:
"""Performs Automatic ETS forecasting
This is a wrapper around StatsForecast ETS;
we refer to this link for the original and more complete documentation of the parameters.
https://nixtla.github.io/statsforecast/models.html#ets
Parameters
----------
data : Union[pd.Series, np.ndarray]
Input data.
target_column (str, optional):
Target column to forecast. Defaults to "close".
seasonal_periods: int
Number of seasonal periods in a year (7 for daily data)
If not set, inferred from frequency of the series.
n_predict: int
Number of days to forecast
start_window: float
Size of sliding window from start of timeseries and onwards
forecast_horizon: int
Number of days to forecast when backtesting and retraining historical
Returns
-------
list[float]
Adjusted Data series
list[float]
List of historical fcast values
list[float]
List of predicted fcast values
Optional[float]
precision
Any
Fit ETS model object.
"""
use_scalers = False
# statsforecast preprocessing
# when including more time series
# the preprocessing is similar
_, ticker_series = helpers.get_series(data, target_column, is_scaler=use_scalers)
freq = ticker_series.freq_str
ticker_series = ticker_series.pd_dataframe().reset_index()
ticker_series.columns = ["ds", "y"]
ticker_series.insert(0, "unique_id", target_column)
# Model Init
model_ets = ETS(
season_length=int(seasonal_periods),
)
fcst = StatsForecast(df=ticker_series, models=[model_ets], freq=freq, verbose=True)
# Historical backtesting
last_training_point = int((len(ticker_series) - 1) * start_window)
historical_fcast_ets = fcst.cross_validation(
h=int(forecast_horizon),
test_size=len(ticker_series) - last_training_point,
n_windows=None,
input_size=min(10 * forecast_horizon, len(ticker_series)),
)
# train new model on entire timeseries to provide best current forecast
# we have the historical fcast, now lets predict.
forecast = fcst.forecast(int(n_predict))
y_true = historical_fcast_ets["y"].values
y_hat = historical_fcast_ets["ETS"].values
precision = helpers.mean_absolute_percentage_error(y_true, y_hat)
console.print(f"AutoETS obtains MAPE: {precision:.2f}% \n")
# transform outputs to make them compatible with
# plots
use_scalers = False
_, ticker_series = helpers.get_series(
ticker_series.rename(columns={"y": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)
_, forecast = helpers.get_series(
forecast.rename(columns={"ETS": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)
_, historical_fcast_ets = helpers.get_series(
historical_fcast_ets.groupby("ds")
.head(1)
.rename(columns={"ETS": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)
return (
ticker_series,
historical_fcast_ets,
forecast,
precision,
fcst,
)