Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ARIMA_EVAULATE options in forecasting models #336

Merged
merged 9 commits into from
Jan 24, 2024
8 changes: 8 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def evaluate(self, input_data: Optional[bpd.DataFrame] = None):

return self._session.read_gbq(sql)

def arima_evaluate(self, show_all_candidate_models: bool = False):
# TODO: validate input data schema
ashleyxuu marked this conversation as resolved.
Show resolved Hide resolved
sql = self._model_manipulation_sql_generator.ml_arima_evaluate(
show_all_candidate_models
)

return self._session.read_gbq(sql)

def centroids(self) -> bpd.DataFrame:
assert self._model.model_type == "KMEANS"

Expand Down
25 changes: 25 additions & 0 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,31 @@ def score(
input_data = X.join(y, how="outer")
return self._bqml_model.evaluate(input_data)

def summary(
self,
show_all_candidate_models: bool = False,
) -> bpd.DataFrame:
"""Summary of the evaluation metrics of the time series model.

.. note::

Output matches that of the BigQuery ML.ARIMA_EVALUATE function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-arima-evaluate
for the outputs relevant to this model type.

Args:
show_all_candidate_models (bool, default to False):
Whether to show evaluation metrics or an error message for either
all candidate models or for only the best model with the lowest
AIC. Default to False.

Returns:
bigframes.dataframe.DataFrame: A DataFrame as evaluation result.
"""
if not self._bqml_model:
raise RuntimeError("A model must be fitted before score")
return self._bqml_model.arima_evaluate(show_all_candidate_models)

def to_gbq(self, model_name: str, replace: bool = False) -> ARIMAPlus:
"""Save the model to BigQuery.

Expand Down
6 changes: 6 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str:
return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`,
({source_sql}))"""

# ML evaluation TVFs
def ml_arima_evaluate(self, show_all_candidate_models: bool = False) -> str:
"""Encode ML.ARMIA_EVALUATE for BQML"""
return f"""SELECT * FROM ML.ARIMA_EVALUATE(MODEL `{self._model_name}`,
STRUCT({show_all_candidate_models} AS show_all_candidate_models))"""

def ml_centroids(self) -> str:
"""Encode ML.CENTROIDS for BQML"""
return f"""SELECT * FROM ML.CENTROIDS(MODEL `{self._model_name}`)"""
Expand Down
35 changes: 33 additions & 2 deletions tests/system/large/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@

from bigframes.ml import forecasting

ARIMA_EVALUATE_OUTPUT_COL = [
"non_seasonal_p",
"non_seasonal_d",
"non_seasonal_q",
"log_likelihood",
"AIC",
"variance",
"seasonal_periods",
"has_holiday_effect",
"has_spikes_and_dips",
"has_step_changes",
"error_message",
]


def test_arima_plus_model_fit_score(
time_series_df_default_index, dataset_id, new_time_series_df
Expand All @@ -42,7 +56,24 @@ def test_arima_plus_model_fit_score(
pd.testing.assert_frame_equal(result, expected, check_exact=False, rtol=0.1)

# save, load to ensure configuration was kept
reloaded_model = model.to_gbq(f"{dataset_id}.temp_configured_model", replace=True)
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
assert (
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
)


def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id):
model = forecasting.ARIMAPlus()
X_train = time_series_df_default_index[["parsed_date"]]
y_train = time_series_df_default_index[["total_visits"]]
model.fit(X_train, y_train)

result = model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)

# save, load to ensure configuration was kept
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
assert (
f"{dataset_id}.temp_configured_model" in reloaded_model._bqml_model.model_name
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
)
40 changes: 40 additions & 0 deletions tests/system/small/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@

from bigframes.ml import forecasting

ARIMA_EVALUATE_OUTPUT_COL = [
"non_seasonal_p",
"non_seasonal_d",
"non_seasonal_q",
"log_likelihood",
"AIC",
"variance",
"seasonal_periods",
"has_holiday_effect",
"has_spikes_and_dips",
"has_step_changes",
"error_message",
]


def test_model_predict_default(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
Expand Down Expand Up @@ -104,6 +118,24 @@ def test_model_score(
)


def test_model_summary(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)


def test_model_summary_show_all_candidates(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary(
show_all_candidate_models=True,
)
assert result.shape[0] > 1
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)


def test_model_score_series(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
Expand All @@ -126,3 +158,11 @@ def test_model_score_series(
rtol=0.1,
check_index_type=False,
)


def test_model_summary_series(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)
13 changes: 13 additions & 0 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,19 @@ def test_ml_evaluate_produces_correct_sql(
)


def test_ml_arima_evaluate_produces_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
):
sql = model_manipulation_sql_generator.ml_arima_evaluate(
show_all_candidate_models=True
)
assert (
sql
== """SELECT * FROM ML.ARIMA_EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`,
STRUCT(True AS show_all_candidate_models))"""
)


def test_ml_evaluate_no_source_produces_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
):
Expand Down