Skip to content

Commit

Permalink
feat: add LinearRegression.predict_explain() to generate `ML.EXPLAI…
Browse files Browse the repository at this point in the history
…N_PREDICT` columns (#1190)

* feat: add LinearRegression.predict_explain to generate predict explain columns

* add test cases

* add test case

* update predict_explain

* update the test

* Add sql and core tests

* fix docs error

* add TODO comment to support method paramaters

* update the test parmametr of linear model

* update test to fix failing checks
  • Loading branch information
arwas11 authored Dec 16, 2024
1 parent 14f24ca commit e13eca2
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 0 deletions.
6 changes: 6 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
self._model_manipulation_sql_generator.ml_predict,
)

def explain_predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
self._model_manipulation_sql_generator.ml_explain_predict,
)

def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
Expand Down
28 changes: 28 additions & 0 deletions bigframes/ml/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,34 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:

return self._bqml_model.predict(X)

def predict_explain(
self,
X: utils.ArrayType,
) -> bpd.DataFrame:
"""
Explain predictions for a linear regression model.
.. note::
Output matches that of the BigQuery ML.EXPLAIN_PREDICT function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict
Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
pandas.core.frame.DataFrame or pandas.core.series.Series):
Series or a DataFrame to explain its predictions.
Returns:
bigframes.pandas.DataFrame:
The predicted DataFrames with explanation columns.
"""
# TODO(b/377366612): Add support for `top_k_features` parameter
if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

return self._bqml_model.explain_predict(X)

def score(
self,
X: utils.ArrayType,
Expand Down
5 changes: 5 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ def ml_predict(self, source_sql: str) -> str:
return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()},
({source_sql}))"""

def ml_explain_predict(self, source_sql: str) -> str:
"""Encode ML.EXPLAIN_PREDICT for BQML"""
return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()},
({source_sql}))"""

def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
"""Encode ML.FORECAST for BQML"""
struct_options_sql = self.struct_options(**struct_options)
Expand Down
55 changes: 55 additions & 0 deletions tests/system/small/ml/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,28 @@ def test_model_predict(penguins_bqml_linear_model: core.BqmlModel, new_penguins_
)


def test_model_predict_explain(
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
):
predictions = penguins_bqml_linear_model.explain_predict(
new_penguins_df
).to_pandas()
expected = pd.DataFrame(
{
"predicted_body_mass_g": [4030.1, 3280.8, 3177.9],
"approximation_error": [0.0, 0.0, 0.0],
},
dtype="Float64",
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
)
pd.testing.assert_frame_equal(
predictions[["predicted_body_mass_g", "approximation_error"]].sort_index(),
expected,
check_exact=False,
rtol=0.1,
)


def test_model_predict_with_unnamed_index(
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
):
Expand Down Expand Up @@ -288,6 +310,39 @@ def test_model_predict_with_unnamed_index(
)


def test_model_predict_explain_with_unnamed_index(
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
):
# This will result in an index that lacks a name, which the ML library will
# need to persist through the call to ML.PREDICT
new_penguins_df = new_penguins_df.reset_index()

# remove the middle tag number to ensure we're really keeping the unnamed index
new_penguins_df = typing.cast(
bigframes.dataframe.DataFrame,
new_penguins_df[new_penguins_df.tag_number != 1672],
)

predictions = penguins_bqml_linear_model.explain_predict(
new_penguins_df
).to_pandas()

expected = pd.DataFrame(
{
"predicted_body_mass_g": [4030.1, 3177.9],
"approximation_error": [0.0, 0.0],
},
dtype="Float64",
index=pd.Index([0, 2], dtype="Int64"),
)
pd.testing.assert_frame_equal(
predictions[["predicted_body_mass_g", "approximation_error"]].sort_index(),
expected,
check_exact=False,
rtol=0.1,
)


def test_model_detect_anomalies(
penguins_bqml_pca_model: core.BqmlModel, new_penguins_df
):
Expand Down
68 changes: 68 additions & 0 deletions tests/system/small/ml/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import pandas
import pytest

from bigframes.ml import linear_model


def test_linear_reg_model_score(penguins_linear_model, penguins_df_default_index):
df = penguins_df_default_index.dropna()
Expand Down Expand Up @@ -106,6 +108,72 @@ def test_linear_reg_model_predict(penguins_linear_model, new_penguins_df):
)


def test_linear_reg_model_predict_explain(penguins_linear_model, new_penguins_df):
predictions = penguins_linear_model.predict_explain(new_penguins_df).to_pandas()
assert predictions.shape == (3, 12)
result = predictions[["predicted_body_mass_g", "approximation_error"]]
expected = pandas.DataFrame(
{
"predicted_body_mass_g": [4030.1, 3280.8, 3177.9],
"approximation_error": [
0.0,
0.0,
0.0,
],
},
dtype="Float64",
index=pandas.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
)
pandas.testing.assert_frame_equal(
result.sort_index(),
expected,
check_exact=False,
rtol=0.1,
)


def test_linear_reg_model_predict_params(
penguins_linear_model: linear_model.LinearRegression, new_penguins_df
):
predictions = penguins_linear_model.predict(new_penguins_df).to_pandas()
assert predictions.shape[0] >= 1
prediction_columns = set(predictions.columns)
expected_columns = {
"predicted_body_mass_g",
"species",
"island",
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"body_mass_g",
"sex",
}
assert expected_columns <= prediction_columns


def test_linear_reg_model_predict_explain_params(
penguins_linear_model: linear_model.LinearRegression, new_penguins_df
):
predictions = penguins_linear_model.predict_explain(new_penguins_df).to_pandas()
assert predictions.shape[0] >= 1
prediction_columns = set(predictions.columns)
expected_columns = {
"predicted_body_mass_g",
"top_feature_attributions",
"baseline_prediction_value",
"prediction_value",
"approximation_error",
"species",
"island",
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"body_mass_g",
"sex",
}
assert expected_columns <= prediction_columns


def test_to_gbq_saved_linear_reg_model_scores(
penguins_linear_model, table_id_unique, penguins_df_default_index
):
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,18 @@ def test_ml_predict_correct(
)


def test_ml_explain_predict_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
):
sql = model_manipulation_sql_generator.ml_explain_predict(source_sql=mock_df.sql)
assert (
sql
== """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
(input_X_y_sql))"""
)


def test_ml_llm_evaluate_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
Expand Down

0 comments on commit e13eca2

Please sign in to comment.