Skip to content

Commit

Permalink
feat: add LogisticRegression.predict_explain() to generate `ML.EXPL…
Browse files Browse the repository at this point in the history
…AIN_PREDICT` columns (#1222)

* feat: add LogisticRegression.predict_explain() to generate ML.EXPLAIN_PREDICT columns

* update tests
  • Loading branch information
arwas11 authored Dec 18, 2024
1 parent 684b2a6 commit bcbc732
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
28 changes: 28 additions & 0 deletions bigframes/ml/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,34 @@ def predict(

return self._bqml_model.predict(X)

def predict_explain(
self,
X: utils.ArrayType,
) -> bpd.DataFrame:
"""
Explain predictions for a logistic regression model.
.. note::
Output matches that of the BigQuery ML.EXPLAIN_PREDICT function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict
Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
pandas.core.frame.DataFrame or pandas.core.series.Series):
Series or a DataFrame to explain its predictions.
Returns:
bigframes.pandas.DataFrame:
The predicted DataFrames with explanation columns.
"""
# TODO(b/377366612): Add support for `top_k_features` parameter
if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

return self._bqml_model.explain_predict(X)

def score(
self,
X: utils.ArrayType,
Expand Down
44 changes: 44 additions & 0 deletions tests/system/small/ml/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,50 @@ def test_logistic_model_predict(penguins_logistic_model, new_penguins_df):
)


def test_logistic_model_predict_params(
penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df
):
predictions = penguins_logistic_model.predict(new_penguins_df).to_pandas()
assert predictions.shape[0] >= 1
prediction_columns = set(predictions.columns)
expected_columns = {
"predicted_sex",
"predicted_sex_probs",
"species",
"island",
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"body_mass_g",
"sex",
}
assert expected_columns <= prediction_columns


def test_logistic_model_predict_explain_params(
penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df
):
predictions = penguins_logistic_model.predict_explain(new_penguins_df).to_pandas()
assert predictions.shape[0] >= 1
prediction_columns = set(predictions.columns)
expected_columns = {
"predicted_sex",
"probability",
"top_feature_attributions",
"baseline_prediction_value",
"prediction_value",
"approximation_error",
"species",
"island",
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"body_mass_g",
"sex",
}
assert expected_columns <= prediction_columns


def test_logistic_model_to_gbq_saved_score(
penguins_logistic_model, table_id_unique, penguins_df_default_index
):
Expand Down

0 comments on commit bcbc732

Please sign in to comment.