Skip to content

Commit

Permalink
feat: create arima_plus_predict_attribution method
Browse files Browse the repository at this point in the history
  • Loading branch information
rey-esp committed Nov 11, 2024
1 parent 7ac6639 commit 8bc1cca
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
4 changes: 4 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()

def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_explain_forecast(struct_options=options)
return self._session.read_gbq(sql, index_col="time_series_timestamp").reset_index()

def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
sql = self._model_manipulation_sql_generator.ml_evaluate(
input_data.sql if (input_data is not None) else None
Expand Down
38 changes: 38 additions & 0 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,44 @@ def predict(
return self._bqml_model.forecast(
options={"horizon": horizon, "confidence_level": confidence_level}
)

def predict_attribution(
self, X=None, *, horizon: int = 3, confidence_level: float = 0.95
) -> bpd.DataFrame:
"""Forecast time series at future horizon.
.. note::
Output matches that of the BigQuery ML.FORECAST function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-forecast
Args:
X (default None):
ignored, to be compatible with other APIs.
horizon (int, default: 3):
an int value that specifies the number of time points to forecast.
The default value is 3, and the maximum value is 1000.
confidence_level (float, default 0.95):
A float value that specifies percentage of the future values that fall in the prediction interval.
The valid input range is [0.0, 1.0).
Returns:
bigframes.dataframe.DataFrame: The predicted DataFrames. Which
contains 2 columns: "forecast_timestamp" and "forecast_value".
"""
if horizon < 1 or horizon > 1000:
raise ValueError(f"horizon must be [1, 1000], but is {horizon}.")
if confidence_level < 0.0 or confidence_level >= 1.0:
raise ValueError(
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
)

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

return self._bqml_model.explain_forecast(
options={"horizon": horizon, "confidence_level": confidence_level}
)

@property
def coef_(
Expand Down
6 changes: 6 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_explain_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
"""Encode ML.EXPLAIN_FORECAST for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.EXPLAIN_FORECAST(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_generate_text(
self, source_sql: str, struct_options: Mapping[str, Union[int, float]]
Expand Down
29 changes: 29 additions & 0 deletions tests/system/small/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,35 @@ def test_arima_plus_predict_default(
check_index_type=False,
)

def test_arima_plus_predict_attribution_default(
time_series_arima_plus_model: forecasting.ARIMAPlus,
):
utc = pytz.utc
predictions = time_series_arima_plus_model.predict_attribution().to_pandas()
assert predictions.shape == (3, 8)
result = predictions[["forecast_timestamp", "forecast_value"]]
expected = pd.DataFrame(
{
"forecast_timestamp": [
datetime(2017, 8, 2, tzinfo=utc),
datetime(2017, 8, 3, tzinfo=utc),
datetime(2017, 8, 4, tzinfo=utc),
],
"forecast_value": [2724.472284, 2593.368389, 2353.613034],
}
)
expected["forecast_value"] = expected["forecast_value"].astype(pd.Float64Dtype())
expected["forecast_timestamp"] = expected["forecast_timestamp"].astype(
pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
)

pd.testing.assert_frame_equal(
result,
expected,
rtol=0.1,
check_index_type=False,
)


def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
Expand Down

0 comments on commit 8bc1cca

Please sign in to comment.