Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ARIMAPlus.predict_explain() to generate forecasts with explanation columns #1177

Merged
merged 28 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8bc1cca
feat: create arima_plus_predict_attribution method
rey-esp Nov 11, 2024
f6dd455
tmp: debug notes for time_series_arima_plus_model.predict_attribution
chelsea-lin Nov 12, 2024
b8ec20d
update test_arima_plus_predict_explain_default test and create test_a…
rey-esp Nov 18, 2024
8056c92
Merge branch 'ml-predict-explain' of github.com:googleapis/python-big…
rey-esp Nov 18, 2024
722181b
Merge branch 'ml-predict-explain' of github.com:googleapis/python-big…
rey-esp Nov 18, 2024
a161b33
update test_arima_plus_predict_explain_params test
rey-esp Nov 18, 2024
dd38aaf
Merge branch 'main' into ml-predict-explain
rey-esp Nov 26, 2024
347c3c4
Revert "tmp: debug notes for time_series_arima_plus_model.predict_att…
chelsea-lin Nov 26, 2024
54175c0
format and lint
rey-esp Nov 26, 2024
75d9f91
Merge branch 'main' into ml-predict-explain
rey-esp Nov 26, 2024
2634a91
Merge branch 'main' into ml-predict-explain
rey-esp Dec 2, 2024
448e63a
Update bigframes/ml/forecasting.py
rey-esp Dec 2, 2024
8347c4b
update predict explain params test
rey-esp Dec 2, 2024
1fe2d37
update test
rey-esp Dec 3, 2024
48c81ed
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Dec 3, 2024
c22eec8
Merge branch 'main' into ml-predict-explain
rey-esp Dec 3, 2024
706a1ae
add unit test file - bare bones
rey-esp Dec 4, 2024
79a5359
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Dec 4, 2024
7922d64
Merge branch 'main' into ml-predict-explain
rey-esp Dec 5, 2024
aba6aca
Merge branch 'main' into ml-predict-explain
rey-esp Dec 9, 2024
b9343cf
Merge branch 'main' into ml-predict-explain
rey-esp Dec 9, 2024
ac271ff
Merge branch 'main' into ml-predict-explain
rey-esp Dec 9, 2024
3befd2e
fixed tests
rey-esp Dec 9, 2024
3fbcb64
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Dec 9, 2024
04d1fb4
lint
rey-esp Dec 9, 2024
75e2994
Merge branch 'ml-predict-explain' of github.com:googleapis/python-big…
rey-esp Dec 9, 2024
6bfb1d3
lint
rey-esp Dec 9, 2024
e2eb29d
fix test: float -> int
rey-esp Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()

def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_explain_forecast(
struct_options=options
)
return self._session.read_gbq(
sql, index_col="time_series_timestamp"
).reset_index()

def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
sql = self._model_manipulation_sql_generator.ml_evaluate(
input_data.sql if (input_data is not None) else None
Expand Down
37 changes: 37 additions & 0 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,43 @@ def predict(
options={"horizon": horizon, "confidence_level": confidence_level}
)

def predict_explain(
self, X=None, *, horizon: int = 3, confidence_level: float = 0.95
) -> bpd.DataFrame:
"""Explain Forecast time series at future horizon.

.. note::

Output matches that of the BigQuery ML.EXPLAIN_FORECAST function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-forecast

Args:
X (default None):
ignored, to be compatible with other APIs.
horizon (int, default: 3):
an int value that specifies the number of time points to forecast.
The default value is 3, and the maximum value is 1000.
confidence_level (float, default 0.95):
A float value that specifies percentage of the future values that fall in the prediction interval.
The valid input range is [0.0, 1.0).

Returns:
bigframes.dataframe.DataFrame: The predicted DataFrames.
"""
if horizon < 1:
raise ValueError(f"horizon must be at least 1, but is {horizon}.")
if confidence_level < 0.0 or confidence_level >= 1.0:
tswast marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
)

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

return self._bqml_model.explain_forecast(
options={"horizon": horizon, "confidence_level": confidence_level}
)

@property
def coef_(
self,
Expand Down
8 changes: 8 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_explain_forecast(
self, struct_options: Mapping[str, Union[int, float]]
) -> str:
"""Encode ML.EXPLAIN_FORECAST for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.EXPLAIN_FORECAST(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_generate_text(
self, source_sql: str, struct_options: Mapping[str, Union[int, float]]
) -> str:
Expand Down
63 changes: 63 additions & 0 deletions tests/system/small/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,42 @@ def test_arima_plus_predict_default(
)


def test_arima_plus_predict_explain_default(
time_series_arima_plus_model: forecasting.ARIMAPlus,
):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict_explain().to_pandas()
assert predictions.shape[0] == 369
predictions = predictions[
predictions["time_series_type"] == "forecast"
].reset_index(drop=True)
assert predictions.shape[0] == 3
result = predictions[["time_series_timestamp", "time_series_data"]]
expected = pd.DataFrame(
{
"time_series_timestamp": [
datetime(2017, 8, 2, tzinfo=utc),
datetime(2017, 8, 3, tzinfo=utc),
datetime(2017, 8, 4, tzinfo=utc),
],
"time_series_data": [2727.693349, 2595.290749, 2370.86767],
}
)
expected["time_series_data"] = expected["time_series_data"].astype(
pd.Float64Dtype()
)
expected["time_series_timestamp"] = expected["time_series_timestamp"].astype(
pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
)

pd.testing.assert_frame_equal(
result,
expected,
rtol=0.1,
check_index_type=False,
)


def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict(
Expand Down Expand Up @@ -96,6 +132,33 @@ def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARI
)


def test_arima_plus_predict_explain_params(
time_series_arima_plus_model: forecasting.ARIMAPlus,
):
predictions = time_series_arima_plus_model.predict_explain(
horizon=4, confidence_level=0.9
).to_pandas()
assert predictions.shape[0] >= 1
prediction_columns = set(predictions.columns)
expected_columns = {
"time_series_timestamp",
"time_series_type",
"time_series_data",
"time_series_adjusted_data",
"standard_error",
"confidence_level",
"prediction_interval_lower_bound",
"trend",
"seasonal_period_yearly",
"seasonal_period_quarterly",
"seasonal_period_monthly",
"seasonal_period_weekly",
"seasonal_period_daily",
"holiday_effect",
}
assert expected_columns <= prediction_columns
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!



def test_arima_plus_detect_anomalies(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/ml/test_forecasting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import pytest

from bigframes.ml import forecasting


def test_predict_explain_low_confidence_level():
confidence_level = -0.5

model = forecasting.ARIMAPlus()

with pytest.raises(
ValueError,
match=re.escape(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
),
):
model.predict_explain(horizon=4, confidence_level=confidence_level)


def test_predict_high_explain_confidence_level():
confidence_level = 2.1

model = forecasting.ARIMAPlus()

with pytest.raises(
ValueError,
match=re.escape(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
),
):
model.predict_explain(horizon=4, confidence_level=confidence_level)


def test_predict_explain_low_horizon():
horizon = -1

model = forecasting.ARIMAPlus()

with pytest.raises(
ValueError, match=f"horizon must be at least 1, but is {horizon}."
):
model.predict_explain(horizon=horizon, confidence_level=0.9)
Loading