diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index be67396fba..810fb1f7bd 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -172,6 +172,14 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options) return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index() + def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: + sql = self._model_manipulation_sql_generator.ml_explain_forecast( + struct_options=options + ) + return self._session.read_gbq( + sql, index_col="time_series_timestamp" + ).reset_index() + def evaluate(self, input_data: Optional[bpd.DataFrame] = None): sql = self._model_manipulation_sql_generator.ml_evaluate( input_data.sql if (input_data is not None) else None diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 523d306719..6079e0ea22 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -253,6 +253,43 @@ def predict( options={"horizon": horizon, "confidence_level": confidence_level} ) + def predict_explain( + self, X=None, *, horizon: int = 3, confidence_level: float = 0.95 + ) -> bpd.DataFrame: + """Explain Forecast time series at future horizon. + + .. note:: + + Output matches that of the BigQuery ML.EXPLAIN_FORECAST function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-forecast + + Args: + X (default None): + ignored, to be compatible with other APIs. + horizon (int, default: 3): + an int value that specifies the number of time points to forecast. + The default value is 3, and the maximum value is 1000. + confidence_level (float, default 0.95): + A float value that specifies percentage of the future values that fall in the prediction interval. + The valid input range is [0.0, 1.0). + + Returns: + bigframes.dataframe.DataFrame: The predicted DataFrames. + """ + if horizon < 1: + raise ValueError(f"horizon must be at least 1, but is {horizon}.") + if confidence_level < 0.0 or confidence_level >= 1.0: + raise ValueError( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ) + + if not self._bqml_model: + raise RuntimeError("A model must be fitted before predict") + + return self._bqml_model.explain_forecast( + options={"horizon": horizon, "confidence_level": confidence_level} + ) + @property def coef_( self, diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index b7d550ac63..1ef43d9ce5 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -310,6 +310,14 @@ def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()}, {struct_options_sql})""" + def ml_explain_forecast( + self, struct_options: Mapping[str, Union[int, float]] + ) -> str: + """Encode ML.EXPLAIN_FORECAST for BQML""" + struct_options_sql = self.struct_options(**struct_options) + return f"""SELECT * FROM ML.EXPLAIN_FORECAST(MODEL {self._model_ref_sql()}, + {struct_options_sql})""" + def ml_generate_text( self, source_sql: str, struct_options: Mapping[str, Union[int, float]] ) -> str: diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 7fef189550..1b3a650388 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -65,6 +65,42 @@ def test_arima_plus_predict_default( ) +def test_arima_plus_predict_explain_default( + time_series_arima_plus_model: forecasting.ARIMAPlus, +): + utc = pytz.utc + predictions = time_series_arima_plus_model.predict_explain().to_pandas() + assert predictions.shape[0] == 369 + predictions = predictions[ + predictions["time_series_type"] == "forecast" + ].reset_index(drop=True) + assert predictions.shape[0] == 3 + result = predictions[["time_series_timestamp", "time_series_data"]] + expected = pd.DataFrame( + { + "time_series_timestamp": [ + datetime(2017, 8, 2, tzinfo=utc), + datetime(2017, 8, 3, tzinfo=utc), + datetime(2017, 8, 4, tzinfo=utc), + ], + "time_series_data": [2727.693349, 2595.290749, 2370.86767], + } + ) + expected["time_series_data"] = expected["time_series_data"].astype( + pd.Float64Dtype() + ) + expected["time_series_timestamp"] = expected["time_series_timestamp"].astype( + pd.ArrowDtype(pa.timestamp("us", tz="UTC")) + ) + + pd.testing.assert_frame_equal( + result, + expected, + rtol=0.1, + check_index_type=False, + ) + + def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus): utc = pytz.utc predictions = time_series_arima_plus_model.predict( @@ -96,6 +132,33 @@ def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARI ) +def test_arima_plus_predict_explain_params( + time_series_arima_plus_model: forecasting.ARIMAPlus, +): + predictions = time_series_arima_plus_model.predict_explain( + horizon=4, confidence_level=0.9 + ).to_pandas() + assert predictions.shape[0] >= 1 + prediction_columns = set(predictions.columns) + expected_columns = { + "time_series_timestamp", + "time_series_type", + "time_series_data", + "time_series_adjusted_data", + "standard_error", + "confidence_level", + "prediction_interval_lower_bound", + "trend", + "seasonal_period_yearly", + "seasonal_period_quarterly", + "seasonal_period_monthly", + "seasonal_period_weekly", + "seasonal_period_daily", + "holiday_effect", + } + assert expected_columns <= prediction_columns + + def test_arima_plus_detect_anomalies( time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df ): diff --git a/tests/unit/ml/test_forecasting.py b/tests/unit/ml/test_forecasting.py new file mode 100644 index 0000000000..3bbf4c777e --- /dev/null +++ b/tests/unit/ml/test_forecasting.py @@ -0,0 +1,58 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import pytest + +from bigframes.ml import forecasting + + +def test_predict_explain_low_confidence_level(): + confidence_level = -0.5 + + model = forecasting.ARIMAPlus() + + with pytest.raises( + ValueError, + match=re.escape( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ), + ): + model.predict_explain(horizon=4, confidence_level=confidence_level) + + +def test_predict_high_explain_confidence_level(): + confidence_level = 2.1 + + model = forecasting.ARIMAPlus() + + with pytest.raises( + ValueError, + match=re.escape( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ), + ): + model.predict_explain(horizon=4, confidence_level=confidence_level) + + +def test_predict_explain_low_horizon(): + horizon = -1 + + model = forecasting.ARIMAPlus() + + with pytest.raises( + ValueError, match=f"horizon must be at least 1, but is {horizon}." + ): + model.predict_explain(horizon=horizon, confidence_level=0.9)