diff --git a/bigframes/ml/linear_model.py b/bigframes/ml/linear_model.py index 1a1a5e0ca0..eac0fd1fca 100644 --- a/bigframes/ml/linear_model.py +++ b/bigframes/ml/linear_model.py @@ -353,6 +353,34 @@ def predict( return self._bqml_model.predict(X) + def predict_explain( + self, + X: utils.ArrayType, + ) -> bpd.DataFrame: + """ + Explain predictions for a logistic regression model. + + .. note:: + Output matches that of the BigQuery ML.EXPLAIN_PREDICT function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series or + pandas.core.frame.DataFrame or pandas.core.series.Series): + Series or a DataFrame to explain its predictions. + + Returns: + bigframes.pandas.DataFrame: + The predicted DataFrames with explanation columns. + """ + # TODO(b/377366612): Add support for `top_k_features` parameter + if not self._bqml_model: + raise RuntimeError("A model must be fitted before predict") + + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) + + return self._bqml_model.explain_predict(X) + def score( self, X: utils.ArrayType, diff --git a/tests/system/small/ml/test_linear_model.py b/tests/system/small/ml/test_linear_model.py index 0832c559c1..3be1147c1e 100644 --- a/tests/system/small/ml/test_linear_model.py +++ b/tests/system/small/ml/test_linear_model.py @@ -307,6 +307,50 @@ def test_logistic_model_predict(penguins_logistic_model, new_penguins_df): ) +def test_logistic_model_predict_params( + penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df +): + predictions = penguins_logistic_model.predict(new_penguins_df).to_pandas() + assert predictions.shape[0] >= 1 + prediction_columns = set(predictions.columns) + expected_columns = { + "predicted_sex", + "predicted_sex_probs", + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "body_mass_g", + "sex", + } + assert expected_columns <= prediction_columns + + +def test_logistic_model_predict_explain_params( + penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df +): + predictions = penguins_logistic_model.predict_explain(new_penguins_df).to_pandas() + assert predictions.shape[0] >= 1 + prediction_columns = set(predictions.columns) + expected_columns = { + "predicted_sex", + "probability", + "top_feature_attributions", + "baseline_prediction_value", + "prediction_value", + "approximation_error", + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "body_mass_g", + "sex", + } + assert expected_columns <= prediction_columns + + def test_logistic_model_to_gbq_saved_score( penguins_logistic_model, table_id_unique, penguins_df_default_index ):