Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] StatsForecast AutoETS forecasting model #2988

Merged
merged 23 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
06bde2e
feat: add autoets model
AzulGarza Oct 21, 2022
80e05e3
feat: update autoets index
AzulGarza Oct 21, 2022
70e44a9
Merge branch 'main' into feat/autoets-model
AzulGarza Oct 21, 2022
0a5e40f
Merge branch 'main' into feat/autoets-model
jmaslek Oct 22, 2022
42d0984
Merge branch 'main' into feat/autoets-model
martinb-ai Oct 24, 2022
00584dd
Merge branch 'main' into feat/autoets-model
martinb-ai Oct 24, 2022
80fe018
Merge branch 'main' into feat/autoets-model
martinb-ai Oct 24, 2022
e9ca311
Merge branch 'main' into feat/autoets-model
martinb-ai Oct 24, 2022
9e9a7ae
Merge branch 'main' into feat/autoets-model
martinb-ai Oct 24, 2022
374d078
Merge branch 'main' into feat/autoets-model
martinb-ai Oct 25, 2022
023d9c2
Merge branch 'main' into feat/autoets-model
martinb-ai Oct 25, 2022
5d1ea53
feat: add verbose option
AzulGarza Oct 25, 2022
8bfa047
Merge branch 'main' into feat/autoets-model
AzulGarza Oct 25, 2022
7519b80
feat: freq preprocessing
AzulGarza Oct 26, 2022
15af868
feat: improve test size call
AzulGarza Oct 26, 2022
9125832
Merge branch 'feat/autoets-model' of https://github.com/FedericoGarza…
AzulGarza Oct 26, 2022
ba75a1f
Merge branch 'main' into feat/autoets-model
AzulGarza Oct 26, 2022
76e2599
feat: add statsforecast dep to conda env full
AzulGarza Oct 26, 2022
87824b8
Merge branch 'feat/autoets-model' of https://github.com/FedericoGarza…
AzulGarza Oct 26, 2022
ff89dd8
Merge branch 'main' into feat/autoets-model
AzulGarza Oct 26, 2022
69df784
Merge branch 'main' into feat/autoets-model
AzulGarza Oct 28, 2022
b50d6d0
Merge branch 'main' into feat/autoets-model
AzulGarza Oct 30, 2022
f4e3d5a
Merge branch 'main' into feat/autoets-model
AzulGarza Oct 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build/conda/conda-3-9-env-full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ dependencies:
- u8darts[torch]=0.22.0
- poetry=1.1.13
- cvxpy=1.2.1
- statsforecast==1.1.3
134 changes: 134 additions & 0 deletions openbb_terminal/forecast/autoets_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# pylint: disable=too-many-arguments
"""Automatic ETS (Error, Trend, and Seasonality) Model"""
__docformat__ = "numpy"

import logging
from typing import Any, Union, Optional, List, Tuple

import warnings
import numpy as np
import pandas as pd
from statsforecast.models import ETS
from statsforecast.core import StatsForecast

from openbb_terminal.decorators import log_start_end
from openbb_terminal.rich_config import console
from openbb_terminal.forecast import helpers


warnings.simplefilter("ignore")

logger = logging.getLogger(__name__)


@log_start_end(log=logger)
def get_autoets_data(
data: Union[pd.Series, pd.DataFrame],
target_column: str = "close",
seasonal_periods: int = 7,
n_predict: int = 30,
start_window: float = 0.85,
forecast_horizon: int = 5,
) -> Tuple[list[np.ndarray], List[np.ndarray], List[np.ndarray], Optional[float], Any]:

"""Performs Automatic ETS forecasting
This is a wrapper around StatsForecast ETS;
we refer to this link for the original and more complete documentation of the parameters.


https://nixtla.github.io/statsforecast/models.html#ets

Parameters
----------
data : Union[pd.Series, np.ndarray]
Input data.
target_column (str, optional):
Target column to forecast. Defaults to "close".
seasonal_periods: int
Number of seasonal periods in a year (7 for daily data)
If not set, inferred from frequency of the series.
n_predict: int
Number of days to forecast
start_window: float
Size of sliding window from start of timeseries and onwards
forecast_horizon: int
Number of days to forecast when backtesting and retraining historical

Returns
-------
list[float]
Adjusted Data series
list[float]
List of historical fcast values
list[float]
List of predicted fcast values
Optional[float]
precision
Any
Fit ETS model object.
"""

use_scalers = False
# statsforecast preprocessing
# when including more time series
# the preprocessing is similar
_, ticker_series = helpers.get_series(data, target_column, is_scaler=use_scalers)
freq = ticker_series.freq_str
ticker_series = ticker_series.pd_dataframe().reset_index()
ticker_series.columns = ["ds", "y"]
ticker_series.insert(0, "unique_id", target_column)

# Model Init
model_ets = ETS(
season_length=int(seasonal_periods),
)
fcst = StatsForecast(df=ticker_series, models=[model_ets], freq=freq, verbose=True)

# Historical backtesting
last_training_point = int((len(ticker_series) - 1) * start_window)
historical_fcast_ets = fcst.cross_validation(
h=int(forecast_horizon),
test_size=len(ticker_series) - last_training_point,
n_windows=None,
input_size=min(10 * forecast_horizon, len(ticker_series)),
)

# train new model on entire timeseries to provide best current forecast
# we have the historical fcast, now lets predict.
forecast = fcst.forecast(int(n_predict))
y_true = historical_fcast_ets["y"].values
y_hat = historical_fcast_ets["ETS"].values
precision = helpers.mean_absolute_percentage_error(y_true, y_hat)
console.print(f"AutoETS obtains MAPE: {precision:.2f}% \n")

# transform outputs to make them compatible with
# plots
use_scalers = False
_, ticker_series = helpers.get_series(
ticker_series.rename(columns={"y": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)
_, forecast = helpers.get_series(
forecast.rename(columns={"ETS": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)
_, historical_fcast_ets = helpers.get_series(
historical_fcast_ets.groupby("ds")
.head(1)
.rename(columns={"ETS": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)

return (
ticker_series,
historical_fcast_ets,
forecast,
precision,
fcst,
)
113 changes: 113 additions & 0 deletions openbb_terminal/forecast/autoets_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Automatic ETS (Error, Trend, Sesonality) View"""
__docformat__ = "numpy"

import logging
from typing import Union, Optional, List
from datetime import datetime

import pandas as pd
import matplotlib.pyplot as plt

from openbb_terminal.forecast import autoets_model
from openbb_terminal.decorators import log_start_end
from openbb_terminal.forecast import helpers

logger = logging.getLogger(__name__)
# pylint: disable=too-many-arguments


@log_start_end(log=logger)
def display_autoets_forecast(
data: Union[pd.DataFrame, pd.Series],
target_column: str = "close",
dataset_name: str = "",
seasonal_periods: int = 7,
n_predict: int = 30,
start_window: float = 0.85,
forecast_horizon: int = 5,
export: str = "",
residuals: bool = False,
forecast_only: bool = False,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
naive: bool = False,
export_pred_raw: bool = False,
external_axes: Optional[List[plt.axes]] = None,
):
"""Display Automatic ETS (Error, Trend, Sesonality) Model

Parameters
----------
data : Union[pd.Series, np.array]
Data to forecast
dataset_name str
The name of the ticker to be predicted
target_column (str, optional):
Target column to forecast. Defaults to "close".
seasonal_periods: int
Number of seasonal periods in a year
If not set, inferred from frequency of the series.
n_predict: int
Number of days to forecast
start_window: float
Size of sliding window from start of timeseries and onwards
forecast_horizon: int
Number of days to forecast when backtesting and retraining historical
export: str
Format to export data
residuals: bool
Whether to show residuals for the model. Defaults to False.
forecast_only: bool
Whether to only show dates in the forecasting range. Defaults to False.
start_date: Optional[datetime]
The starting date to perform analysis, data before this is trimmed. Defaults to None.
end_date: Optional[datetime]
The ending date to perform analysis, data after this is trimmed. Defaults to None.
naive: bool
Whether to show the naive baseline. This just assumes the closing price will be the same
as the previous day's closing price. Defaults to False.
external_axes:Optional[List[plt.axes]]
External axes to plot on
"""
data = helpers.clean_data(data, start_date, end_date, target_column, None)
if not helpers.check_data(data, target_column, None):
return

(
ticker_series,
historical_fcast,
predicted_values,
precision,
_model,
) = autoets_model.get_autoets_data(
data=data,
target_column=target_column,
seasonal_periods=seasonal_periods,
n_predict=n_predict,
start_window=start_window,
forecast_horizon=forecast_horizon,
)
probabilistic = False
helpers.plot_forecast(
name="AutoETS",
target_col=target_column,
historical_fcast=historical_fcast,
predicted_values=predicted_values,
ticker_series=ticker_series,
ticker_name=dataset_name,
data=data,
n_predict=n_predict,
forecast_horizon=forecast_horizon,
past_covariates=None,
precision=precision,
probabilistic=probabilistic,
export=export,
forecast_only=forecast_only,
naive=naive,
export_pred_raw=export_pred_raw,
external_axes=external_axes,
)
if residuals:
helpers.plot_residuals(
_model, None, ticker_series, forecast_horizon=forecast_horizon
)
58 changes: 58 additions & 0 deletions openbb_terminal/forecast/forecast_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from openbb_terminal.forecast import (
forecast_model,
forecast_view,
autoets_view,
expo_model,
expo_view,
linregr_view,
Expand Down Expand Up @@ -108,6 +109,7 @@ class ForecastController(BaseController):
"delta",
"atr",
"signal",
"autoets",
"expo",
"theta",
"rnn",
Expand Down Expand Up @@ -246,6 +248,7 @@ def update_runtime_choices(self):
"signal",
"combine",
"rename",
"autoets",
"expo",
"theta",
"rnn",
Expand Down Expand Up @@ -324,6 +327,7 @@ def print_help(self):
mt.add_cmd("signal", self.files)
mt.add_raw("\n")
mt.add_info("_tsforecasting_")
mt.add_cmd("autoets", self.files)
mt.add_cmd("expo", self.files)
mt.add_cmd("theta", self.files)
mt.add_cmd("linregr", self.files)
Expand Down Expand Up @@ -1633,6 +1637,60 @@ def call_export(self, other_args: List[str]):
ns_parser.target_dataset,
)

# AutoETS Model
@log_start_end(log=logger)
def call_autoets(self, other_args: List[str]):
"""Process autoets command"""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
add_help=False,
prog="autoets",
description="""
Perform Automatic ETS (Error, Trend, Seasonality) forecast
""",
)
if other_args and "-" not in other_args[0][0]:
other_args.insert(0, "--target-dataset")

ns_parser = self.parse_known_args_and_warn(
parser,
other_args,
export_allowed=EXPORT_ONLY_FIGURES_ALLOWED,
target_dataset=True,
target_column=True,
n_days=True,
seasonal="A",
periods=True,
window=True,
residuals=True,
forecast_only=True,
start=True,
end=True,
naive=True,
export_pred_raw=True,
)
# TODO Convert this to multi series
if ns_parser:
if not helpers.check_parser_input(ns_parser, self.datasets):
return

autoets_view.display_autoets_forecast(
data=self.datasets[ns_parser.target_dataset],
dataset_name=ns_parser.target_dataset,
n_predict=ns_parser.n_days,
target_column=ns_parser.target_column,
seasonal_periods=ns_parser.seasonal_periods,
start_window=ns_parser.start_window,
forecast_horizon=ns_parser.n_days,
export=ns_parser.export,
residuals=ns_parser.residuals,
forecast_only=ns_parser.forecast_only,
start_date=ns_parser.s_start_date,
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
)

# EXPO Model
@log_start_end(log=logger)
def call_expo(self, other_args: List[str]):
Expand Down
7 changes: 5 additions & 2 deletions openbb_terminal/forecast/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,15 @@ def dt_format(x) -> str:


def get_series(
data: pd.DataFrame, target_column: str = None, is_scaler: bool = True
data: pd.DataFrame,
target_column: str = None,
is_scaler: bool = True,
time_col: str = "date",
) -> Tuple[Optional[Scaler], TimeSeries]:
filler = MissingValuesFiller()
filler_kwargs = dict(
df=data,
time_col="date",
time_col=time_col,
value_cols=[target_column],
freq="B",
fill_missing_dates=True,
Expand Down
1 change: 1 addition & 0 deletions openbb_terminal/miscellaneous/i18n/en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ en:
forecast/atr: Add Average True Range
forecast/signal: Add Price Signal (short vs. long term)
forecast/_tsforecasting_: TimeSeries Forecasting
forecast/autoets: Automatic ETS (Error, Trend, Seasonality) Model
forecast/arima: Arima (Non-darts)
forecast/expo: Probabilistic Exponential Smoothing
forecast/theta: Theta Method
Expand Down
4 changes: 4 additions & 0 deletions openbb_terminal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2004,6 +2004,10 @@
"forecast.roc": {"model": "openbb_terminal.forecast.forecast_model.add_roc"},
"forecast.mom": {"model": "openbb_terminal.forecast.forecast_model.add_momentum"},
"forecast.delta": {"model": "openbb_terminal.forecast.forecast_model.add_delta"},
"forecast.autoets": {
"model": "openbb_terminal.forecast.autoets_model.get_autoets_data",
"view": "openbb_terminal.forecast.autoets_view.display_autoets_forecast",
},
"forecast.expo": {
"model": "openbb_terminal.forecast.expo_model.get_expo_data",
"view": "openbb_terminal.forecast.expo_view.display_expo_forecast",
Expand Down
11 changes: 11 additions & 0 deletions tests/openbb_terminal/forecast/test_autoets_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from tests.openbb_terminal.forecast import conftest

try:
from openbb_terminal.forecast import autoets_model
except ImportError:
pytest.skip(allow_module_level=True)


def test_get_autoets_model(tsla_csv):
conftest.test_model(autoets_model.get_autoets_data, tsla_csv)
Loading