Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for LinearRegression.predict_explain and LogisticRegression.predict_explain parameter, top_k_features #1228

Merged
merged 18 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
d333fdd
feat: add LogisticRegression.predict_explain() to generate ML.EXPLAIN…
arwas11 Dec 16, 2024
0f88005
update tests
arwas11 Dec 17, 2024
2812053
Merge branch 'main' into b379743612-ml-logistic-regression-explain-pr…
arwas11 Dec 17, 2024
5336780
Merge branch 'main' into b379743612-ml-logistic-regression-explain-pr…
arwas11 Dec 17, 2024
af1f29b
chore: add support for predict_explain paramater, top_k_features
arwas11 Dec 17, 2024
3befe68
Merge branch 'main' into b379743612-support-predict-explain-params
arwas11 Dec 17, 2024
c722f10
update test
arwas11 Dec 17, 2024
7d78018
Merge remote-tracking branch 'origin/b379743612-support-predict-expla…
arwas11 Dec 17, 2024
430ff9f
Merge remote-tracking branch 'origin/b379743612-ml-logistic-regressio…
arwas11 Dec 17, 2024
70108f9
update logistic reg method with the new param
arwas11 Dec 17, 2024
26f16b3
add and test new param's validation
arwas11 Dec 18, 2024
ca7b2d3
resolve merge conflict
arwas11 Dec 18, 2024
576dbb0
Merge branch 'main' into b379743612-support-predict-explain-params
arwas11 Dec 18, 2024
fe7ebbe
Merge remote-tracking branch 'origin/main' into b379743612-support-pr…
arwas11 Dec 19, 2024
1311162
Merge remote-tracking branch 'origin/b379743612-support-predict-expla…
arwas11 Dec 19, 2024
0971bf6
Merge branch 'main' into b379743612-support-predict-explain-params
arwas11 Dec 20, 2024
1d8a1fd
Update bigframes/ml/linear_model.py
tswast Dec 20, 2024
3448571
Merge branch 'main' into b379743612-support-predict-explain-params
arwas11 Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,15 @@ def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
self._model_manipulation_sql_generator.ml_predict,
)

def explain_predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
def explain_predict(
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
self._model_manipulation_sql_generator.ml_explain_predict,
lambda source_sql: self._model_manipulation_sql_generator.ml_explain_predict(
source_sql=source_sql,
struct_options=options,
),
)

def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
Expand Down
44 changes: 38 additions & 6 deletions bigframes/ml/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,16 @@ def _fit(
def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

(X,) = utils.batch_convert_to_dataframe(X)
# add session
tswast marked this conversation as resolved.
Show resolved Hide resolved
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

return self._bqml_model.predict(X)

def predict_explain(
self,
X: utils.ArrayType,
*,
top_k_features: int = 5,
) -> bpd.DataFrame:
"""
Explain predictions for a linear regression model.
Expand All @@ -175,18 +177,32 @@ def predict_explain(
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
pandas.core.frame.DataFrame or pandas.core.series.Series):
Series or a DataFrame to explain its predictions.
top_k_features (int, default 5):
an INT64 value that specifies how many top feature attribution
pairs are generated for each row of input data. The features are
ranked by the absolute values of their attributions.

By default, top_k_features is set to 5. If its value is greater
than the number of features in the training data, the
attributions of all features are returned.

Returns:
bigframes.pandas.DataFrame:
The predicted DataFrames with explanation columns.
"""
# TODO(b/377366612): Add support for `top_k_features` parameter
if top_k_features < 1:
raise ValueError(
f"top_k_features must be at least 1, but is {top_k_features}."
)

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

return self._bqml_model.explain_predict(X)
return self._bqml_model.explain_predict(
X, options={"top_k_features": top_k_features}
)

def score(
self,
Expand Down Expand Up @@ -356,6 +372,8 @@ def predict(
def predict_explain(
self,
X: utils.ArrayType,
*,
top_k_features: int = 5,
) -> bpd.DataFrame:
"""
Explain predictions for a logistic regression model.
Expand All @@ -368,18 +386,32 @@ def predict_explain(
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
pandas.core.frame.DataFrame or pandas.core.series.Series):
Series or a DataFrame to explain its predictions.
top_k_features (int, default 5):
an INT64 value that specifies how many top feature attribution
pairs are generated for each row of input data. The features are
ranked by the absolute values of their attributions.

By default, top_k_features is set to 5. If its value is greater
than the number of features in the training data, the
attributions of all features are returned.

Returns:
bigframes.pandas.DataFrame:
The predicted DataFrames with explanation columns.
"""
# TODO(b/377366612): Add support for `top_k_features` parameter
if top_k_features < 1:
raise ValueError(
f"top_k_features must be at least 1, but is {top_k_features}."
)

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

return self._bqml_model.explain_predict(X)
return self._bqml_model.explain_predict(
X, options={"top_k_features": top_k_features}
)

def score(
self,
Expand Down
7 changes: 5 additions & 2 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,13 @@ def ml_predict(self, source_sql: str) -> str:
return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()},
({source_sql}))"""

def ml_explain_predict(self, source_sql: str) -> str:
def ml_explain_predict(
self, source_sql: str, struct_options: Mapping[str, Union[int, float]]
) -> str:
"""Encode ML.EXPLAIN_PREDICT for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()},
({source_sql}))"""
({source_sql}), {struct_options_sql})"""

def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
"""Encode ML.FORECAST for BQML"""
Expand Down
6 changes: 4 additions & 2 deletions tests/system/small/ml/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,9 @@ def test_model_predict(penguins_bqml_linear_model: core.BqmlModel, new_penguins_
def test_model_predict_explain(
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
):
options = {"top_k_features": 3}
predictions = penguins_bqml_linear_model.explain_predict(
new_penguins_df
new_penguins_df, options
).to_pandas()
expected = pd.DataFrame(
{
Expand Down Expand Up @@ -317,14 +318,15 @@ def test_model_predict_explain_with_unnamed_index(
# need to persist through the call to ML.PREDICT
new_penguins_df = new_penguins_df.reset_index()

options = {"top_k_features": 3}
# remove the middle tag number to ensure we're really keeping the unnamed index
new_penguins_df = typing.cast(
bigframes.dataframe.DataFrame,
new_penguins_df[new_penguins_df.tag_number != 1672],
)

predictions = penguins_bqml_linear_model.explain_predict(
new_penguins_df
new_penguins_df, options
).to_pandas()

expected = pd.DataFrame(
Expand Down
30 changes: 30 additions & 0 deletions tests/system/small/ml/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import google.api_core.exceptions
import pandas
import pytest
Expand Down Expand Up @@ -132,6 +134,20 @@ def test_linear_reg_model_predict_explain(penguins_linear_model, new_penguins_df
)


def test_linear_model_predict_explain_top_k_features(
penguins_logistic_model: linear_model.LinearRegression, new_penguins_df
):
top_k_features = 0

with pytest.raises(
ValueError,
match=re.escape(f"top_k_features must be at least 1, but is {top_k_features}."),
):
penguins_logistic_model.predict_explain(
new_penguins_df, top_k_features=top_k_features
).to_pandas()


def test_linear_reg_model_predict_params(
penguins_linear_model: linear_model.LinearRegression, new_penguins_df
):
Expand Down Expand Up @@ -307,6 +323,20 @@ def test_logistic_model_predict(penguins_logistic_model, new_penguins_df):
)


def test_logistic_model_predict_explain_top_k_features(
penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df
):
top_k_features = 0

with pytest.raises(
ValueError,
match=re.escape(f"top_k_features must be at least 1, but is {top_k_features}."),
):
penguins_logistic_model.predict_explain(
new_penguins_df, top_k_features=top_k_features
).to_pandas()


def test_logistic_model_predict_params(
penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df
):
Expand Down
29 changes: 17 additions & 12 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,18 +342,6 @@ def test_ml_predict_correct(
)


def test_ml_explain_predict_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
):
sql = model_manipulation_sql_generator.ml_explain_predict(source_sql=mock_df.sql)
assert (
sql
== """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
(input_X_y_sql))"""
)


def test_ml_llm_evaluate_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
Expand Down Expand Up @@ -462,6 +450,23 @@ def test_ml_generate_embedding_correct(
)


def test_ml_explain_predict_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
):
sql = model_manipulation_sql_generator.ml_explain_predict(
source_sql=mock_df.sql,
struct_options={"option_key1": 1, "option_key2": 2.25},
)
assert (
sql
== """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
(input_X_y_sql), STRUCT(
1 AS `option_key1`,
2.25 AS `option_key2`))"""
)


def test_ml_detect_anomalies_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
Expand Down
Loading