diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 1e2224c9bc..7c156b4cb7 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -136,6 +136,13 @@ def evaluate(self, input_data: Optional[bpd.DataFrame] = None): return self._session.read_gbq(sql) + def arima_evaluate(self, show_all_candidate_models: bool = False): + sql = self._model_manipulation_sql_generator.ml_arima_evaluate( + show_all_candidate_models + ) + + return self._session.read_gbq(sql) + def centroids(self) -> bpd.DataFrame: assert self._model.model_type == "KMEANS" diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 03b9857cc5..8d448fbace 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -151,6 +151,31 @@ def score( input_data = X.join(y, how="outer") return self._bqml_model.evaluate(input_data) + def summary( + self, + show_all_candidate_models: bool = False, + ) -> bpd.DataFrame: + """Summary of the evaluation metrics of the time series model. + + .. note:: + + Output matches that of the BigQuery ML.ARIMA_EVALUATE function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-arima-evaluate + for the outputs relevant to this model type. + + Args: + show_all_candidate_models (bool, default to False): + Whether to show evaluation metrics or an error message for either + all candidate models or for only the best model with the lowest + AIC. Default to False. + + Returns: + bigframes.dataframe.DataFrame: A DataFrame as evaluation result. + """ + if not self._bqml_model: + raise RuntimeError("A model must be fitted before score") + return self._bqml_model.arima_evaluate(show_all_candidate_models) + def to_gbq(self, model_name: str, replace: bool = False) -> ARIMAPlus: """Save the model to BigQuery. diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 25caaf1ac6..152f881ec0 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -260,6 +260,12 @@ def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str: return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`, ({source_sql}))""" + # ML evaluation TVFs + def ml_arima_evaluate(self, show_all_candidate_models: bool = False) -> str: + """Encode ML.ARMIA_EVALUATE for BQML""" + return f"""SELECT * FROM ML.ARIMA_EVALUATE(MODEL `{self._model_name}`, + STRUCT({show_all_candidate_models} AS show_all_candidate_models))""" + def ml_centroids(self) -> str: """Encode ML.CENTROIDS for BQML""" return f"""SELECT * FROM ML.CENTROIDS(MODEL `{self._model_name}`)""" diff --git a/tests/system/large/ml/test_forecasting.py b/tests/system/large/ml/test_forecasting.py index 33b835e852..2bb136b0f2 100644 --- a/tests/system/large/ml/test_forecasting.py +++ b/tests/system/large/ml/test_forecasting.py @@ -16,6 +16,20 @@ from bigframes.ml import forecasting +ARIMA_EVALUATE_OUTPUT_COL = [ + "non_seasonal_p", + "non_seasonal_d", + "non_seasonal_q", + "log_likelihood", + "AIC", + "variance", + "seasonal_periods", + "has_holiday_effect", + "has_spikes_and_dips", + "has_step_changes", + "error_message", +] + def test_arima_plus_model_fit_score( time_series_df_default_index, dataset_id, new_time_series_df @@ -42,7 +56,24 @@ def test_arima_plus_model_fit_score( pd.testing.assert_frame_equal(result, expected, check_exact=False, rtol=0.1) # save, load to ensure configuration was kept - reloaded_model = model.to_gbq(f"{dataset_id}.temp_configured_model", replace=True) + reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True) + assert ( + f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name + ) + + +def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id): + model = forecasting.ARIMAPlus() + X_train = time_series_df_default_index[["parsed_date"]] + y_train = time_series_df_default_index[["total_visits"]] + model.fit(X_train, y_train) + + result = model.summary() + assert result.shape == (1, 12) + assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL) + + # save, load to ensure configuration was kept + reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True) assert ( - f"{dataset_id}.temp_configured_model" in reloaded_model._bqml_model.model_name + f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name ) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index be8d9c2bac..4726d5ab21 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -20,6 +20,20 @@ from bigframes.ml import forecasting +ARIMA_EVALUATE_OUTPUT_COL = [ + "non_seasonal_p", + "non_seasonal_d", + "non_seasonal_q", + "log_likelihood", + "AIC", + "variance", + "seasonal_periods", + "has_holiday_effect", + "has_spikes_and_dips", + "has_step_changes", + "error_message", +] + def test_model_predict_default(time_series_arima_plus_model: forecasting.ARIMAPlus): utc = pytz.utc @@ -104,6 +118,24 @@ def test_model_score( ) +def test_model_summary( + time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df +): + result = time_series_arima_plus_model.summary() + assert result.shape == (1, 12) + assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL) + + +def test_model_summary_show_all_candidates( + time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df +): + result = time_series_arima_plus_model.summary( + show_all_candidate_models=True, + ) + assert result.shape[0] > 1 + assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL) + + def test_model_score_series( time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df ): @@ -126,3 +158,11 @@ def test_model_score_series( rtol=0.1, check_index_type=False, ) + + +def test_model_summary_series( + time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df +): + result = time_series_arima_plus_model.summary() + assert result.shape == (1, 12) + assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL) diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index 73d19cc0bb..37cc33d33e 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -273,6 +273,19 @@ def test_ml_evaluate_produces_correct_sql( ) +def test_ml_arima_evaluate_produces_correct_sql( + model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, +): + sql = model_manipulation_sql_generator.ml_arima_evaluate( + show_all_candidate_models=True + ) + assert ( + sql + == """SELECT * FROM ML.ARIMA_EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`, + STRUCT(True AS show_all_candidate_models))""" + ) + + def test_ml_evaluate_no_source_produces_correct_sql( model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, ):