Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ARIMAPlus.predict_explain() to generate forecasts with explanation columns #1177

Merged
merged 28 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8bc1cca
feat: create arima_plus_predict_attribution method
rey-esp Nov 11, 2024
f6dd455
tmp: debug notes for time_series_arima_plus_model.predict_attribution
chelsea-lin Nov 12, 2024
b8ec20d
update test_arima_plus_predict_explain_default test and create test_a…
rey-esp Nov 18, 2024
8056c92
Merge branch 'ml-predict-explain' of github.com:googleapis/python-big…
rey-esp Nov 18, 2024
722181b
Merge branch 'ml-predict-explain' of github.com:googleapis/python-big…
rey-esp Nov 18, 2024
a161b33
update test_arima_plus_predict_explain_params test
rey-esp Nov 18, 2024
dd38aaf
Merge branch 'main' into ml-predict-explain
rey-esp Nov 26, 2024
347c3c4
Revert "tmp: debug notes for time_series_arima_plus_model.predict_att…
chelsea-lin Nov 26, 2024
54175c0
format and lint
rey-esp Nov 26, 2024
75d9f91
Merge branch 'main' into ml-predict-explain
rey-esp Nov 26, 2024
2634a91
Merge branch 'main' into ml-predict-explain
rey-esp Dec 2, 2024
448e63a
Update bigframes/ml/forecasting.py
rey-esp Dec 2, 2024
8347c4b
update predict explain params test
rey-esp Dec 2, 2024
1fe2d37
update test
rey-esp Dec 3, 2024
48c81ed
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Dec 3, 2024
c22eec8
Merge branch 'main' into ml-predict-explain
rey-esp Dec 3, 2024
706a1ae
add unit test file - bare bones
rey-esp Dec 4, 2024
79a5359
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Dec 4, 2024
7922d64
Merge branch 'main' into ml-predict-explain
rey-esp Dec 5, 2024
aba6aca
Merge branch 'main' into ml-predict-explain
rey-esp Dec 9, 2024
b9343cf
Merge branch 'main' into ml-predict-explain
rey-esp Dec 9, 2024
ac271ff
Merge branch 'main' into ml-predict-explain
rey-esp Dec 9, 2024
3befd2e
fixed tests
rey-esp Dec 9, 2024
3fbcb64
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Dec 9, 2024
04d1fb4
lint
rey-esp Dec 9, 2024
75e2994
Merge branch 'ml-predict-explain' of github.com:googleapis/python-big…
rey-esp Dec 9, 2024
6bfb1d3
lint
rey-esp Dec 9, 2024
e2eb29d
fix test: float -> int
rey-esp Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()

def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_explain_forecast(
struct_options=options
)
return self._session.read_gbq(
sql, index_col="time_series_timestamp"
).reset_index()

def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
sql = self._model_manipulation_sql_generator.ml_evaluate(
input_data.sql if (input_data is not None) else None
Expand Down
37 changes: 37 additions & 0 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,43 @@ def predict(
options={"horizon": horizon, "confidence_level": confidence_level}
)

def predict_explain(
self, X=None, *, horizon: int = 3, confidence_level: float = 0.95
) -> bpd.DataFrame:
"""Explain Forecast time series at future horizon.

.. note::

Output matches that of the BigQuery ML.EXPLAIN_FORECAST function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-forecast

Args:
X (default None):
ignored, to be compatible with other APIs.
horizon (int, default: 3):
an int value that specifies the number of time points to forecast.
The default value is 3, and the maximum value is 1000.
confidence_level (float, default 0.95):
A float value that specifies percentage of the future values that fall in the prediction interval.
The valid input range is [0.0, 1.0).

Returns:
bigframes.dataframe.DataFrame: The predicted DataFrames.
"""
if horizon < 1 or horizon > 1000:
raise ValueError(f"horizon must be [1, 1000], but is {horizon}.")
rey-esp marked this conversation as resolved.
Show resolved Hide resolved
if confidence_level < 0.0 or confidence_level >= 1.0:
tswast marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
)

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

return self._bqml_model.explain_forecast(
options={"horizon": horizon, "confidence_level": confidence_level}
)

@property
def coef_(
self,
Expand Down
8 changes: 8 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_explain_forecast(
self, struct_options: Mapping[str, Union[int, float]]
) -> str:
"""Encode ML.EXPLAIN_FORECAST for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.EXPLAIN_FORECAST(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_generate_text(
self, source_sql: str, struct_options: Mapping[str, Union[int, float]]
) -> str:
Expand Down
71 changes: 71 additions & 0 deletions tests/system/small/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,42 @@ def test_arima_plus_predict_default(
)


def test_arima_plus_predict_explain_default(
time_series_arima_plus_model: forecasting.ARIMAPlus,
):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict_explain().to_pandas()
assert predictions.shape[0] == 369
predictions = predictions[
predictions["time_series_type"] == "forecast"
].reset_index(drop=True)
assert predictions.shape[0] == 3
result = predictions[["time_series_timestamp", "time_series_data"]]
expected = pd.DataFrame(
{
"time_series_timestamp": [
datetime(2017, 8, 2, tzinfo=utc),
datetime(2017, 8, 3, tzinfo=utc),
datetime(2017, 8, 4, tzinfo=utc),
],
"time_series_data": [2727.693349, 2595.290749, 2370.86767],
}
)
expected["time_series_data"] = expected["time_series_data"].astype(
pd.Float64Dtype()
)
expected["time_series_timestamp"] = expected["time_series_timestamp"].astype(
pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
)

pd.testing.assert_frame_equal(
result,
expected,
rtol=0.1,
check_index_type=False,
)


def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict(
Expand Down Expand Up @@ -96,6 +132,41 @@ def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARI
)


def test_arima_plus_predict_explain_params(
time_series_arima_plus_model: forecasting.ARIMAPlus,
):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict_explain(
horizon=4, confidence_level=0.9
).to_pandas()
assert predictions.shape == (4, 8)
rey-esp marked this conversation as resolved.
Show resolved Hide resolved
result = predictions[["time_series_timestamp", "time_series_data"]]
expected = pd.DataFrame(
{
"time_series_timestamp": [
datetime(2017, 8, 2, tzinfo=utc),
datetime(2017, 8, 3, tzinfo=utc),
datetime(2017, 8, 4, tzinfo=utc),
datetime(2017, 8, 5, tzinfo=utc),
],
"time_series_data": [2724.472284, 2593.368389, 2353.613034, 1781.623071],
}
)
expected["time_series_data"] = expected["time_series_data"].astype(
pd.Float64Dtype()
)
expected["time_series_timestamp"] = expected["time_series_timestamp"].astype(
pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
)

pd.testing.assert_frame_equal(
result,
expected,
rtol=0.1,
check_index_type=False,
)


def test_arima_plus_detect_anomalies(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
Expand Down