From bcbc732f321ab31f8fb6b995aeb908ac87750587 Mon Sep 17 00:00:00 2001 From: Arwa Sharif <146148342+arwas11@users.noreply.github.com> Date: Wed, 18 Dec 2024 10:51:07 -0600 Subject: [PATCH] feat: add `LogisticRegression.predict_explain()` to generate `ML.EXPLAIN_PREDICT` columns (#1222) * feat: add LogisticRegression.predict_explain() to generate ML.EXPLAIN_PREDICT columns * update tests --- bigframes/ml/linear_model.py | 28 ++++++++++++++ tests/system/small/ml/test_linear_model.py | 44 ++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/bigframes/ml/linear_model.py b/bigframes/ml/linear_model.py index 1a1a5e0ca0..eac0fd1fca 100644 --- a/bigframes/ml/linear_model.py +++ b/bigframes/ml/linear_model.py @@ -353,6 +353,34 @@ def predict( return self._bqml_model.predict(X) + def predict_explain( + self, + X: utils.ArrayType, + ) -> bpd.DataFrame: + """ + Explain predictions for a logistic regression model. + + .. note:: + Output matches that of the BigQuery ML.EXPLAIN_PREDICT function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series or + pandas.core.frame.DataFrame or pandas.core.series.Series): + Series or a DataFrame to explain its predictions. + + Returns: + bigframes.pandas.DataFrame: + The predicted DataFrames with explanation columns. + """ + # TODO(b/377366612): Add support for `top_k_features` parameter + if not self._bqml_model: + raise RuntimeError("A model must be fitted before predict") + + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) + + return self._bqml_model.explain_predict(X) + def score( self, X: utils.ArrayType, diff --git a/tests/system/small/ml/test_linear_model.py b/tests/system/small/ml/test_linear_model.py index 0832c559c1..3be1147c1e 100644 --- a/tests/system/small/ml/test_linear_model.py +++ b/tests/system/small/ml/test_linear_model.py @@ -307,6 +307,50 @@ def test_logistic_model_predict(penguins_logistic_model, new_penguins_df): ) +def test_logistic_model_predict_params( + penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df +): + predictions = penguins_logistic_model.predict(new_penguins_df).to_pandas() + assert predictions.shape[0] >= 1 + prediction_columns = set(predictions.columns) + expected_columns = { + "predicted_sex", + "predicted_sex_probs", + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "body_mass_g", + "sex", + } + assert expected_columns <= prediction_columns + + +def test_logistic_model_predict_explain_params( + penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df +): + predictions = penguins_logistic_model.predict_explain(new_penguins_df).to_pandas() + assert predictions.shape[0] >= 1 + prediction_columns = set(predictions.columns) + expected_columns = { + "predicted_sex", + "probability", + "top_feature_attributions", + "baseline_prediction_value", + "prediction_value", + "approximation_error", + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "body_mass_g", + "sex", + } + assert expected_columns <= prediction_columns + + def test_logistic_model_to_gbq_saved_score( penguins_logistic_model, table_id_unique, penguins_df_default_index ):