From 352cb850d23e41a2278edf0df584b89ee9619aab Mon Sep 17 00:00:00 2001 From: Garrett Wu <6505921+GarrettWu@users.noreply.github.com> Date: Thu, 21 Mar 2024 23:18:26 -0700 Subject: [PATCH] feat: add ml ARIMAPlus model params (#488) --- bigframes/ml/forecasting.py | 157 +++++++++++++++++++++- tests/system/large/ml/test_forecasting.py | 41 ++++++ 2 files changed, 192 insertions(+), 6 deletions(-) diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 18380328c7..292389dcbb 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union from google.cloud import bigquery @@ -25,12 +25,108 @@ from bigframes.ml import base, core, globals, utils import bigframes.pandas as bpd +_BQML_PARAMS_MAPPING = { + "horizon": "horizon", + "auto_arima": "autoArima", + "auto_arima_max_order": "autoArimaMaxOrder", + "auto_arima_min_order": "autoArimaMinOrder", + "order": "nonSeasonalOrder", + "data_frequency": "dataFrequency", + "holiday_region": "holidayRegion", + "clean_spikes_and_dips": "cleanSpikesAndDips", + "adjust_step_changes": "adjustStepChanges", + "time_series_length_fraction": "timeSeriesLengthFraction", + "min_time_series_length": "minTimeSeriesLength", + "max_time_series_length": "maxTimeSeriesLength", + "decompose_time_series": "decomposeTimeSeries", + "trend_smoothing_window_size": "trendSmoothingWindowSize", +} + @log_adapter.class_logger class ARIMAPlus(base.SupervisedTrainablePredictor): - """Time Series ARIMA Plus model.""" + """Time Series ARIMA Plus model. + + Args: + horizon (int, default 1,000): + The number of time points to forecast. Default to 1,000, max value 10,000. + + auto_arima (bool, default True): + Determines whether the training process uses auto.ARIMA or not. If True, training automatically finds the best non-seasonal order (that is, the p, d, q tuple) and decides whether or not to include a linear drift term when d is 1. + + auto_arima_max_order (int or None, default None): + The maximum value for the sum of non-seasonal p and q. + + auto_arima_min_order (int or None, default None): + The minimum value for the sum of non-seasonal p and q. + + data_frequency (str, default "auto_frequency"): + The data frequency of the input time series. + Possible values are "auto_frequency", "per_minute", "hourly", "daily", "weekly", "monthly", "quarterly", "yearly" + + include_drift (bool, defalut False): + Determines whether the model should include a linear drift term or not. The drift term is applicable when non-seasonal d is 1. + + holiday_region (str or None, default None): + The geographical region based on which the holiday effect is applied in modeling. By default, holiday effect modeling isn't used. + Possible values see https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-time-series#holiday_region. + + clean_spikes_and_dips (bool, default True): + Determines whether or not to perform automatic spikes and dips detection and cleanup in the model training pipeline. The spikes and dips are replaced with local linear interpolated values when they're detected. + + adjust_step_changes (bool, default True): + Determines whether or not to perform automatic step change detection and adjustment in the model training pipeline. + + time_series_length_fraction (float or None, default None): + The fraction of the interpolated length of the time series that's used to model the time series trend component. All of the time points of the time series are used to model the non-trend component. + + min_time_series_length (int or None, default None): + The minimum number of time points that are used in modeling the trend component of the time series. + + max_time_series_length (int or None, default None): + The maximum number of time points in a time series that can be used in modeling the trend component of the time series. + + trend_smoothing_window_size (int or None, default None): + The smoothing window size for the trend component. + + decompose_time_series (bool, default True): + Determines whether the separate components of both the history and forecast parts of the time series (such as holiday effect and seasonal components) are saved in the model. + """ + + def __init__( + self, + *, + horizon: int = 1000, + auto_arima: bool = True, + auto_arima_max_order: Optional[int] = None, + auto_arima_min_order: Optional[int] = None, + data_frequency: str = "auto_frequency", + include_drift: bool = False, + holiday_region: Optional[str] = None, + clean_spikes_and_dips: bool = True, + adjust_step_changes: bool = True, + time_series_length_fraction: Optional[float] = None, + min_time_series_length: Optional[int] = None, + max_time_series_length: Optional[int] = None, + trend_smoothing_window_size: Optional[int] = None, + decompose_time_series: bool = True, + ): + self.horizon = horizon + self.auto_arima = auto_arima + self.auto_arima_max_order = auto_arima_max_order + self.auto_arima_min_order = auto_arima_min_order + self.data_frequency = data_frequency + self.include_drift = include_drift + self.holiday_region = holiday_region + self.clean_spikes_and_dips = clean_spikes_and_dips + self.adjust_step_changes = adjust_step_changes + self.time_series_length_fraction = time_series_length_fraction + self.min_time_series_length = min_time_series_length + self.max_time_series_length = max_time_series_length + self.trend_smoothing_window_size = trend_smoothing_window_size + self.decompose_time_series = decompose_time_series + # TODO(garrettwu) add order and seasonalities params, which need struct/array - def __init__(self): self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() @@ -38,16 +134,65 @@ def __init__(self): def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> ARIMAPlus: assert model.model_type == "ARIMA_PLUS" - kwargs: Dict[str, str | int | bool | float | List[str]] = {} + kwargs: dict = {} + last_fitting = model.training_runs[-1]["trainingOptions"] + + dummy_arima = cls() + for bf_param, bf_value in dummy_arima.__dict__.items(): + bqml_param = _BQML_PARAMS_MAPPING.get(bf_param) + if bqml_param in last_fitting: + # Convert types + if bf_param in ["time_series_length_fraction"]: + kwargs[bf_param] = float(last_fitting[bqml_param]) + elif bf_param in [ + "auto_arima_max_order", + "auto_arima_min_order", + "min_time_series_length", + "max_time_series_length", + "trend_smoothing_window_size", + ]: + kwargs[bf_param] = int(last_fitting[bqml_param]) + elif bf_param in ["holiday_region"]: + kwargs[bf_param] = str(last_fitting[bqml_param]) + else: + kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param]) new_arima_plus = cls(**kwargs) new_arima_plus._bqml_model = core.BqmlModel(session, model) return new_arima_plus @property - def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: + def _bqml_options(self) -> dict: """The model options as they will be set for BQML.""" - return {"model_type": "ARIMA_PLUS"} + options = { + "model_type": "ARIMA_PLUS", + "horizon": self.horizon, + "auto_arima": self.auto_arima, + "data_frequency": self.data_frequency, + "clean_spikes_and_dips": self.clean_spikes_and_dips, + "adjust_step_changes": self.adjust_step_changes, + "decompose_time_series": self.decompose_time_series, + } + + if self.auto_arima_max_order is not None: + options["auto_arima_max_order"] = self.auto_arima_max_order + if self.auto_arima_min_order is not None: + options["auto_arima_min_order"] = self.auto_arima_min_order + if self.holiday_region is not None: + options["holiday_region"] = self.holiday_region + if self.time_series_length_fraction is not None: + options["time_series_length_fraction"] = self.time_series_length_fraction + if self.min_time_series_length is not None: + options["min_time_series_length"] = self.min_time_series_length + if self.max_time_series_length is not None: + options["max_time_series_length"] = self.max_time_series_length + if self.trend_smoothing_window_size is not None: + options["trend_smoothing_window_size"] = self.trend_smoothing_window_size + + if self.include_drift: + options["include_drift"] = True + + return options def _fit( self, diff --git a/tests/system/large/ml/test_forecasting.py b/tests/system/large/ml/test_forecasting.py index 2bb136b0f2..b333839e2e 100644 --- a/tests/system/large/ml/test_forecasting.py +++ b/tests/system/large/ml/test_forecasting.py @@ -77,3 +77,44 @@ def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id): assert ( f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name ) + + +def test_arima_plus_model_fit_params(time_series_df_default_index, dataset_id): + model = forecasting.ARIMAPlus( + horizon=100, + auto_arima=True, + auto_arima_max_order=4, + auto_arima_min_order=1, + data_frequency="daily", + holiday_region="US", + clean_spikes_and_dips=False, + adjust_step_changes=False, + time_series_length_fraction=0.5, + min_time_series_length=10, + trend_smoothing_window_size=5, + decompose_time_series=False, + ) + + X_train = time_series_df_default_index[["parsed_date"]] + y_train = time_series_df_default_index[["total_visits"]] + model.fit(X_train, y_train) + + # save, load to ensure configuration was kept + reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True) + assert ( + f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name + ) + + assert reloaded_model.horizon == 100 + assert reloaded_model.auto_arima is True + assert reloaded_model.auto_arima_max_order == 4 + # TODO(garrettwu): now BQML doesn't populate auto_arima_min_order + # assert reloaded_model.auto_arima_min_order == 1 + assert reloaded_model.data_frequency == "DAILY" + assert reloaded_model.holiday_region == "US" + assert reloaded_model.clean_spikes_and_dips is False + assert reloaded_model.adjust_step_changes is False + assert reloaded_model.time_series_length_fraction == 0.5 + assert reloaded_model.min_time_series_length == 10 + assert reloaded_model.trend_smoothing_window_size == 5 + assert reloaded_model.decompose_time_series is False