Skip to content

Commit

Permalink
feat: add ARIMAPlus.predict_explain() to generate forecasts with ex…
Browse files Browse the repository at this point in the history
…planation columns (#1177)

* feat: create arima_plus_predict_attribution method

* tmp: debug notes for time_series_arima_plus_model.predict_attribution

* update test_arima_plus_predict_explain_default test and create test_arima_plus_predict_explain_params test

* Merge branch 'ml-predict-explain' of github.com:googleapis/python-bigquery-dataframes into ml-predict-explain

* update  test_arima_plus_predict_explain_params test

* Revert "tmp: debug notes for time_series_arima_plus_model.predict_attribution"

This reverts commit f6dd455.

* format and lint

* Update bigframes/ml/forecasting.py

Co-authored-by: Tim Sweña (Swast) <[email protected]>

* update predict explain params test

* update test

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* add unit test file - bare bones

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* fixed tests

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* lint

* lint

* fix test: float -> int

---------

Co-authored-by: Chelsea Lin <[email protected]>
Co-authored-by: Tim Sweña (Swast) <[email protected]>
Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Dec 9, 2024
1 parent 0d8a16b commit 05f8b4d
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 0 deletions.
8 changes: 8 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()

def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_explain_forecast(
struct_options=options
)
return self._session.read_gbq(
sql, index_col="time_series_timestamp"
).reset_index()

def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
sql = self._model_manipulation_sql_generator.ml_evaluate(
input_data.sql if (input_data is not None) else None
Expand Down
37 changes: 37 additions & 0 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,43 @@ def predict(
options={"horizon": horizon, "confidence_level": confidence_level}
)

def predict_explain(
self, X=None, *, horizon: int = 3, confidence_level: float = 0.95
) -> bpd.DataFrame:
"""Explain Forecast time series at future horizon.
.. note::
Output matches that of the BigQuery ML.EXPLAIN_FORECAST function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-forecast
Args:
X (default None):
ignored, to be compatible with other APIs.
horizon (int, default: 3):
an int value that specifies the number of time points to forecast.
The default value is 3, and the maximum value is 1000.
confidence_level (float, default 0.95):
A float value that specifies percentage of the future values that fall in the prediction interval.
The valid input range is [0.0, 1.0).
Returns:
bigframes.dataframe.DataFrame: The predicted DataFrames.
"""
if horizon < 1:
raise ValueError(f"horizon must be at least 1, but is {horizon}.")
if confidence_level < 0.0 or confidence_level >= 1.0:
raise ValueError(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
)

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

return self._bqml_model.explain_forecast(
options={"horizon": horizon, "confidence_level": confidence_level}
)

@property
def coef_(
self,
Expand Down
8 changes: 8 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_explain_forecast(
self, struct_options: Mapping[str, Union[int, float]]
) -> str:
"""Encode ML.EXPLAIN_FORECAST for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.EXPLAIN_FORECAST(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_generate_text(
self, source_sql: str, struct_options: Mapping[str, Union[int, float]]
) -> str:
Expand Down
63 changes: 63 additions & 0 deletions tests/system/small/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,42 @@ def test_arima_plus_predict_default(
)


def test_arima_plus_predict_explain_default(
time_series_arima_plus_model: forecasting.ARIMAPlus,
):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict_explain().to_pandas()
assert predictions.shape[0] == 369
predictions = predictions[
predictions["time_series_type"] == "forecast"
].reset_index(drop=True)
assert predictions.shape[0] == 3
result = predictions[["time_series_timestamp", "time_series_data"]]
expected = pd.DataFrame(
{
"time_series_timestamp": [
datetime(2017, 8, 2, tzinfo=utc),
datetime(2017, 8, 3, tzinfo=utc),
datetime(2017, 8, 4, tzinfo=utc),
],
"time_series_data": [2727.693349, 2595.290749, 2370.86767],
}
)
expected["time_series_data"] = expected["time_series_data"].astype(
pd.Float64Dtype()
)
expected["time_series_timestamp"] = expected["time_series_timestamp"].astype(
pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
)

pd.testing.assert_frame_equal(
result,
expected,
rtol=0.1,
check_index_type=False,
)


def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict(
Expand Down Expand Up @@ -96,6 +132,33 @@ def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARI
)


def test_arima_plus_predict_explain_params(
time_series_arima_plus_model: forecasting.ARIMAPlus,
):
predictions = time_series_arima_plus_model.predict_explain(
horizon=4, confidence_level=0.9
).to_pandas()
assert predictions.shape[0] >= 1
prediction_columns = set(predictions.columns)
expected_columns = {
"time_series_timestamp",
"time_series_type",
"time_series_data",
"time_series_adjusted_data",
"standard_error",
"confidence_level",
"prediction_interval_lower_bound",
"trend",
"seasonal_period_yearly",
"seasonal_period_quarterly",
"seasonal_period_monthly",
"seasonal_period_weekly",
"seasonal_period_daily",
"holiday_effect",
}
assert expected_columns <= prediction_columns


def test_arima_plus_detect_anomalies(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/ml/test_forecasting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import pytest

from bigframes.ml import forecasting


def test_predict_explain_low_confidence_level():
confidence_level = -0.5

model = forecasting.ARIMAPlus()

with pytest.raises(
ValueError,
match=re.escape(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
),
):
model.predict_explain(horizon=4, confidence_level=confidence_level)


def test_predict_high_explain_confidence_level():
confidence_level = 2.1

model = forecasting.ARIMAPlus()

with pytest.raises(
ValueError,
match=re.escape(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
),
):
model.predict_explain(horizon=4, confidence_level=confidence_level)


def test_predict_explain_low_horizon():
horizon = -1

model = forecasting.ARIMAPlus()

with pytest.raises(
ValueError, match=f"horizon must be at least 1, but is {horizon}."
):
model.predict_explain(horizon=horizon, confidence_level=0.9)

0 comments on commit 05f8b4d

Please sign in to comment.