Skip to content

Commit

Permalink
feat: add ml ARIMAPlus model params (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
GarrettWu authored Mar 22, 2024
1 parent 60d4a7b commit 352cb85
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 6 deletions.
157 changes: 151 additions & 6 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import Dict, List, Optional, Union
from typing import List, Optional, Union

from google.cloud import bigquery

Expand All @@ -25,29 +25,174 @@
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd

_BQML_PARAMS_MAPPING = {
"horizon": "horizon",
"auto_arima": "autoArima",
"auto_arima_max_order": "autoArimaMaxOrder",
"auto_arima_min_order": "autoArimaMinOrder",
"order": "nonSeasonalOrder",
"data_frequency": "dataFrequency",
"holiday_region": "holidayRegion",
"clean_spikes_and_dips": "cleanSpikesAndDips",
"adjust_step_changes": "adjustStepChanges",
"time_series_length_fraction": "timeSeriesLengthFraction",
"min_time_series_length": "minTimeSeriesLength",
"max_time_series_length": "maxTimeSeriesLength",
"decompose_time_series": "decomposeTimeSeries",
"trend_smoothing_window_size": "trendSmoothingWindowSize",
}


@log_adapter.class_logger
class ARIMAPlus(base.SupervisedTrainablePredictor):
"""Time Series ARIMA Plus model."""
"""Time Series ARIMA Plus model.
Args:
horizon (int, default 1,000):
The number of time points to forecast. Default to 1,000, max value 10,000.
auto_arima (bool, default True):
Determines whether the training process uses auto.ARIMA or not. If True, training automatically finds the best non-seasonal order (that is, the p, d, q tuple) and decides whether or not to include a linear drift term when d is 1.
auto_arima_max_order (int or None, default None):
The maximum value for the sum of non-seasonal p and q.
auto_arima_min_order (int or None, default None):
The minimum value for the sum of non-seasonal p and q.
data_frequency (str, default "auto_frequency"):
The data frequency of the input time series.
Possible values are "auto_frequency", "per_minute", "hourly", "daily", "weekly", "monthly", "quarterly", "yearly"
include_drift (bool, defalut False):
Determines whether the model should include a linear drift term or not. The drift term is applicable when non-seasonal d is 1.
holiday_region (str or None, default None):
The geographical region based on which the holiday effect is applied in modeling. By default, holiday effect modeling isn't used.
Possible values see https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-time-series#holiday_region.
clean_spikes_and_dips (bool, default True):
Determines whether or not to perform automatic spikes and dips detection and cleanup in the model training pipeline. The spikes and dips are replaced with local linear interpolated values when they're detected.
adjust_step_changes (bool, default True):
Determines whether or not to perform automatic step change detection and adjustment in the model training pipeline.
time_series_length_fraction (float or None, default None):
The fraction of the interpolated length of the time series that's used to model the time series trend component. All of the time points of the time series are used to model the non-trend component.
min_time_series_length (int or None, default None):
The minimum number of time points that are used in modeling the trend component of the time series.
max_time_series_length (int or None, default None):
The maximum number of time points in a time series that can be used in modeling the trend component of the time series.
trend_smoothing_window_size (int or None, default None):
The smoothing window size for the trend component.
decompose_time_series (bool, default True):
Determines whether the separate components of both the history and forecast parts of the time series (such as holiday effect and seasonal components) are saved in the model.
"""

def __init__(
self,
*,
horizon: int = 1000,
auto_arima: bool = True,
auto_arima_max_order: Optional[int] = None,
auto_arima_min_order: Optional[int] = None,
data_frequency: str = "auto_frequency",
include_drift: bool = False,
holiday_region: Optional[str] = None,
clean_spikes_and_dips: bool = True,
adjust_step_changes: bool = True,
time_series_length_fraction: Optional[float] = None,
min_time_series_length: Optional[int] = None,
max_time_series_length: Optional[int] = None,
trend_smoothing_window_size: Optional[int] = None,
decompose_time_series: bool = True,
):
self.horizon = horizon
self.auto_arima = auto_arima
self.auto_arima_max_order = auto_arima_max_order
self.auto_arima_min_order = auto_arima_min_order
self.data_frequency = data_frequency
self.include_drift = include_drift
self.holiday_region = holiday_region
self.clean_spikes_and_dips = clean_spikes_and_dips
self.adjust_step_changes = adjust_step_changes
self.time_series_length_fraction = time_series_length_fraction
self.min_time_series_length = min_time_series_length
self.max_time_series_length = max_time_series_length
self.trend_smoothing_window_size = trend_smoothing_window_size
self.decompose_time_series = decompose_time_series
# TODO(garrettwu) add order and seasonalities params, which need struct/array

def __init__(self):
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()

@classmethod
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> ARIMAPlus:
assert model.model_type == "ARIMA_PLUS"

kwargs: Dict[str, str | int | bool | float | List[str]] = {}
kwargs: dict = {}
last_fitting = model.training_runs[-1]["trainingOptions"]

dummy_arima = cls()
for bf_param, bf_value in dummy_arima.__dict__.items():
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
if bqml_param in last_fitting:
# Convert types
if bf_param in ["time_series_length_fraction"]:
kwargs[bf_param] = float(last_fitting[bqml_param])
elif bf_param in [
"auto_arima_max_order",
"auto_arima_min_order",
"min_time_series_length",
"max_time_series_length",
"trend_smoothing_window_size",
]:
kwargs[bf_param] = int(last_fitting[bqml_param])
elif bf_param in ["holiday_region"]:
kwargs[bf_param] = str(last_fitting[bqml_param])
else:
kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param])

new_arima_plus = cls(**kwargs)
new_arima_plus._bqml_model = core.BqmlModel(session, model)
return new_arima_plus

@property
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
def _bqml_options(self) -> dict:
"""The model options as they will be set for BQML."""
return {"model_type": "ARIMA_PLUS"}
options = {
"model_type": "ARIMA_PLUS",
"horizon": self.horizon,
"auto_arima": self.auto_arima,
"data_frequency": self.data_frequency,
"clean_spikes_and_dips": self.clean_spikes_and_dips,
"adjust_step_changes": self.adjust_step_changes,
"decompose_time_series": self.decompose_time_series,
}

if self.auto_arima_max_order is not None:
options["auto_arima_max_order"] = self.auto_arima_max_order
if self.auto_arima_min_order is not None:
options["auto_arima_min_order"] = self.auto_arima_min_order
if self.holiday_region is not None:
options["holiday_region"] = self.holiday_region
if self.time_series_length_fraction is not None:
options["time_series_length_fraction"] = self.time_series_length_fraction
if self.min_time_series_length is not None:
options["min_time_series_length"] = self.min_time_series_length
if self.max_time_series_length is not None:
options["max_time_series_length"] = self.max_time_series_length
if self.trend_smoothing_window_size is not None:
options["trend_smoothing_window_size"] = self.trend_smoothing_window_size

if self.include_drift:
options["include_drift"] = True

return options

def _fit(
self,
Expand Down
41 changes: 41 additions & 0 deletions tests/system/large/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,44 @@ def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id):
assert (
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
)


def test_arima_plus_model_fit_params(time_series_df_default_index, dataset_id):
model = forecasting.ARIMAPlus(
horizon=100,
auto_arima=True,
auto_arima_max_order=4,
auto_arima_min_order=1,
data_frequency="daily",
holiday_region="US",
clean_spikes_and_dips=False,
adjust_step_changes=False,
time_series_length_fraction=0.5,
min_time_series_length=10,
trend_smoothing_window_size=5,
decompose_time_series=False,
)

X_train = time_series_df_default_index[["parsed_date"]]
y_train = time_series_df_default_index[["total_visits"]]
model.fit(X_train, y_train)

# save, load to ensure configuration was kept
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
assert (
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
)

assert reloaded_model.horizon == 100
assert reloaded_model.auto_arima is True
assert reloaded_model.auto_arima_max_order == 4
# TODO(garrettwu): now BQML doesn't populate auto_arima_min_order
# assert reloaded_model.auto_arima_min_order == 1
assert reloaded_model.data_frequency == "DAILY"
assert reloaded_model.holiday_region == "US"
assert reloaded_model.clean_spikes_and_dips is False
assert reloaded_model.adjust_step_changes is False
assert reloaded_model.time_series_length_fraction == 0.5
assert reloaded_model.min_time_series_length == 10
assert reloaded_model.trend_smoothing_window_size == 5
assert reloaded_model.decompose_time_series is False

0 comments on commit 352cb85

Please sign in to comment.