Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add LinearRegression.predict_explain() to generate ML.EXPLAIN_PREDICT columns #1190

Merged
merged 20 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3baa0c5
feat: add LinearRegression.predict_explain to generate predict explai…
arwas11 Nov 27, 2024
00e89ac
add test cases
arwas11 Dec 2, 2024
e737221
add test case
arwas11 Dec 5, 2024
e58a0e7
update predict_explain
arwas11 Dec 5, 2024
a672bef
Merge branch 'main' into b377366612-ml-explain-predict
arwas11 Dec 6, 2024
dba9165
Merge branch 'main' into b377366612-ml-explain-predict
arwas11 Dec 10, 2024
cea6526
update the test
arwas11 Dec 10, 2024
663c911
Merge remote-tracking branch 'origin/main' into b377366612-ml-explain…
arwas11 Dec 10, 2024
7026122
Merge remote-tracking branch 'origin/b377366612-ml-explain-predict' i…
arwas11 Dec 10, 2024
533b989
Merge remote-tracking branch 'origin/main' into b377366612-ml-explain…
arwas11 Dec 12, 2024
e74074c
Merge remote-tracking branch 'origin/main' into b377366612-ml-explain…
arwas11 Dec 13, 2024
4804570
Merge remote-tracking branch 'origin/main' into b377366612-ml-explain…
arwas11 Dec 16, 2024
362dd7b
Add sql and core tests
arwas11 Dec 16, 2024
fb2d02f
fix docs error
arwas11 Dec 16, 2024
3a784bd
add TODO comment to support method paramaters
arwas11 Dec 16, 2024
81c3cdb
update the test parmametr of linear model
arwas11 Dec 16, 2024
9457d79
Merge remote-tracking branch 'origin/main' into b377366612-ml-explain…
arwas11 Dec 16, 2024
c38ed2d
Merge branch 'main' into b377366612-ml-explain-predict
arwas11 Dec 16, 2024
c4a1beb
update test to fix failing checks
arwas11 Dec 16, 2024
73c725d
Merge remote-tracking branch 'origin/b377366612-ml-explain-predict' i…
arwas11 Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
self._model_manipulation_sql_generator.ml_predict,
)

def explain_predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
self._model_manipulation_sql_generator.ml_explain_predict,
)

def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
Expand Down
28 changes: 28 additions & 0 deletions bigframes/ml/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,34 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:

return self._bqml_model.predict(X)

def predict_explain(
self,
X: utils.ArrayType,
) -> bpd.DataFrame:
"""
Explain predictions for a linear regression model.

.. note::
Output matches that of the BigQuery ML.EXPLAIN_PREDICT function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
pandas.core.frame.DataFrame or pandas.core.series.Series):
Series or a DataFrame to explain its predictions.

Returns:
bigframes.pandas.DataFrame:
The predicted DataFrames with explanation columns.
"""
# TODO(b/377366612): Add support for `top_k_features` parameter
if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

return self._bqml_model.explain_predict(X)

def score(
self,
X: utils.ArrayType,
Expand Down
5 changes: 5 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ def ml_predict(self, source_sql: str) -> str:
return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()},
({source_sql}))"""

def ml_explain_predict(self, source_sql: str) -> str:
arwas11 marked this conversation as resolved.
Show resolved Hide resolved
"""Encode ML.EXPLAIN_PREDICT for BQML"""
return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()},
({source_sql}))"""

def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
"""Encode ML.FORECAST for BQML"""
struct_options_sql = self.struct_options(**struct_options)
Expand Down
55 changes: 55 additions & 0 deletions tests/system/small/ml/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,28 @@ def test_model_predict(penguins_bqml_linear_model: core.BqmlModel, new_penguins_
)


def test_model_predict_explain(
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
):
predictions = penguins_bqml_linear_model.explain_predict(
new_penguins_df
).to_pandas()
expected = pd.DataFrame(
{
"predicted_body_mass_g": [4030.1, 3280.8, 3177.9],
"approximation_error": [0.0, 0.0, 0.0],
},
dtype="Float64",
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
)
pd.testing.assert_frame_equal(
predictions[["predicted_body_mass_g", "approximation_error"]].sort_index(),
expected,
check_exact=False,
rtol=0.1,
)


def test_model_predict_with_unnamed_index(
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
):
Expand Down Expand Up @@ -289,6 +311,39 @@ def test_model_predict_with_unnamed_index(
)


def test_model_predict_explain_with_unnamed_index(
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
):
# This will result in an index that lacks a name, which the ML library will
# need to persist through the call to ML.PREDICT
new_penguins_df = new_penguins_df.reset_index()

# remove the middle tag number to ensure we're really keeping the unnamed index
new_penguins_df = typing.cast(
bigframes.dataframe.DataFrame,
new_penguins_df[new_penguins_df.tag_number != 1672],
)

predictions = penguins_bqml_linear_model.explain_predict(
new_penguins_df
).to_pandas()

expected = pd.DataFrame(
{
"predicted_body_mass_g": [4030.1, 3177.9],
"approximation_error": [0.0, 0.0],
},
dtype="Float64",
index=pd.Index([0, 2], dtype="Int64"),
)
pd.testing.assert_frame_equal(
predictions[["predicted_body_mass_g", "approximation_error"]].sort_index(),
expected,
check_exact=False,
rtol=0.1,
)


def test_model_detect_anomalies(
penguins_bqml_pca_model: core.BqmlModel, new_penguins_df
):
Expand Down
68 changes: 68 additions & 0 deletions tests/system/small/ml/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import pandas
import pytest

from bigframes.ml import linear_model


def test_linear_reg_model_score(penguins_linear_model, penguins_df_default_index):
df = penguins_df_default_index.dropna()
Expand Down Expand Up @@ -106,6 +108,72 @@ def test_linear_reg_model_predict(penguins_linear_model, new_penguins_df):
)


def test_linear_reg_model_predict_explain(penguins_linear_model, new_penguins_df):
predictions = penguins_linear_model.predict_explain(new_penguins_df).to_pandas()
assert predictions.shape == (3, 12)
result = predictions[["predicted_body_mass_g", "approximation_error"]]
arwas11 marked this conversation as resolved.
Show resolved Hide resolved
expected = pandas.DataFrame(
{
"predicted_body_mass_g": [4030.1, 3280.8, 3177.9],
"approximation_error": [
0.0,
0.0,
0.0,
],
},
dtype="Float64",
index=pandas.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
)
pandas.testing.assert_frame_equal(
result.sort_index(),
expected,
check_exact=False,
rtol=0.1,
)


def test_linear_reg_model_predict_params(
penguins_linear_model: linear_model.LinearRegression, new_penguins_df
):
predictions = penguins_linear_model.predict(new_penguins_df).to_pandas()
assert predictions.shape[0] >= 1
prediction_columns = set(predictions.columns)
expected_columns = {
"predicted_body_mass_g",
"species",
"island",
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"body_mass_g",
"sex",
}
assert expected_columns <= prediction_columns


def test_linear_reg_model_predict_explain_params(
penguins_linear_model: linear_model.LinearRegression, new_penguins_df
):
predictions = penguins_linear_model.predict_explain(new_penguins_df).to_pandas()
assert predictions.shape[0] >= 1
prediction_columns = set(predictions.columns)
expected_columns = {
"predicted_body_mass_g",
"top_feature_attributions",
"baseline_prediction_value",
"prediction_value",
"approximation_error",
"species",
"island",
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"body_mass_g",
"sex",
}
assert expected_columns <= prediction_columns


def test_to_gbq_saved_linear_reg_model_scores(
penguins_linear_model, table_id_unique, penguins_df_default_index
):
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,18 @@ def test_ml_predict_correct(
)


def test_ml_explain_predict_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
):
sql = model_manipulation_sql_generator.ml_explain_predict(source_sql=mock_df.sql)
assert (
sql
== """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
(input_X_y_sql))"""
)


def test_ml_llm_evaluate_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
Expand Down
Loading