Skip to content

Commit

Permalink
feat: add ml LogisticRegression model params (#481)
Browse files Browse the repository at this point in the history
* feat: add ml LogisticRegression model params

* fix tests

* fix tests
  • Loading branch information
GarrettWu authored Mar 22, 2024
1 parent 352cb85 commit f959b65
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 21 deletions.
64 changes: 58 additions & 6 deletions bigframes/ml/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
"learn_rate_strategy": "learnRateStrategy",
"learn_rate": "learnRate",
"early_stop": "earlyStop",
# To rename to tol.
"min_rel_progress": "minRelativeProgress",
"tol": "minRelativeProgress",
"ls_init_learn_rate": "initialLearnRate",
"warm_start": "warmStart",
"calculate_p_values": "calculatePValues",
Expand All @@ -59,7 +61,7 @@ def __init__(
*,
optimize_strategy: Literal[
"auto_strategy", "batch_gradient_descent", "normal_equation"
] = "normal_equation",
] = "auto_strategy",
fit_intercept: bool = True,
l1_reg: Optional[float] = None,
l2_reg: float = 0.0,
Expand Down Expand Up @@ -139,7 +141,7 @@ def _bqml_options(self) -> dict:
if self.ls_init_learn_rate is not None:
options["ls_init_learn_rate"] = self.ls_init_learn_rate
# Even presenting warm_start returns error for NORMAL_EQUATION optimizer
if self.warm_start is True:
if self.warm_start:
options["warm_start"] = self.warm_start

return options
Expand Down Expand Up @@ -212,10 +214,34 @@ class LogisticRegression(
def __init__(
self,
*,
optimize_strategy: Literal[
"auto_strategy", "batch_gradient_descent", "normal_equation"
] = "auto_strategy",
fit_intercept: bool = True,
l1_reg: Optional[float] = None,
l2_reg: float = 0.0,
max_iterations: int = 20,
warm_start: bool = False,
learn_rate: Optional[float] = None,
learn_rate_strategy: Literal["line_search", "constant"] = "line_search",
tol: float = 0.01,
ls_init_learn_rate: Optional[float] = None,
calculate_p_values: bool = False,
enable_global_explain: bool = False,
class_weights: Optional[Union[Literal["balanced"], Dict[str, float]]] = None,
):
self.optimize_strategy = optimize_strategy
self.fit_intercept = fit_intercept
self.l1_reg = l1_reg
self.l2_reg = l2_reg
self.max_iterations = max_iterations
self.warm_start = warm_start
self.learn_rate = learn_rate
self.learn_rate_strategy = learn_rate_strategy
self.tol = tol
self.ls_init_learn_rate = ls_init_learn_rate
self.calculate_p_values = calculate_p_values
self.enable_global_explain = enable_global_explain
self.class_weights = class_weights
self._auto_class_weight = class_weights == "balanced"
self._bqml_model: Optional[core.BqmlModel] = None
Expand All @@ -231,8 +257,16 @@ def _from_bq(

# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
last_fitting = model.training_runs[-1]["trainingOptions"]
if "fitIntercept" in last_fitting:
kwargs["fit_intercept"] = last_fitting["fitIntercept"]
dummy_logistic = cls()
for bf_param, bf_value in dummy_logistic.__dict__.items():
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
if bqml_param in last_fitting:
# Convert types
kwargs[bf_param] = (
float(last_fitting[bqml_param])
if bf_param in ["l1_reg", "learn_rate", "ls_init_learn_rate"]
else type(bf_value)(last_fitting[bqml_param])
)
if last_fitting["autoClassWeights"]:
kwargs["class_weights"] = "balanced"
# TODO(ashleyxu) support class_weights in the constructor.
Expand All @@ -244,16 +278,34 @@ def _from_bq(
return new_logistic_regression

@property
def _bqml_options(self) -> Dict[str, str | int | float | List[str]]:
def _bqml_options(self) -> dict:
"""The model options as they will be set for BQML"""
return {
options = {
"model_type": "LOGISTIC_REG",
"data_split_method": "NO_SPLIT",
"fit_intercept": self.fit_intercept,
"auto_class_weights": self._auto_class_weight,
"optimize_strategy": self.optimize_strategy,
"l2_reg": self.l2_reg,
"max_iterations": self.max_iterations,
"learn_rate_strategy": self.learn_rate_strategy,
"min_rel_progress": self.tol,
"calculate_p_values": self.calculate_p_values,
"enable_global_explain": self.enable_global_explain,
# TODO(ashleyxu): support class_weights (struct array as dict in our API)
# "class_weights": self.class_weights,
}
if self.l1_reg is not None:
options["l1_reg"] = self.l1_reg
if self.learn_rate is not None:
options["learn_rate"] = self.learn_rate
if self.ls_init_learn_rate is not None:
options["ls_init_learn_rate"] = self.ls_init_learn_rate
# Even presenting warm_start returns error for NORMAL_EQUATION optimizer
if self.warm_start:
options["warm_start"] = self.warm_start

return options

def _fit(
self,
Expand Down
34 changes: 26 additions & 8 deletions tests/system/large/ml/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,15 @@ def test_logistic_regression_customized_params_fit_score(
penguins_df_default_index, dataset_id
):
model = bigframes.ml.linear_model.LogisticRegression(
fit_intercept=False, class_weights="balanced"
fit_intercept=False,
class_weights="balanced",
l2_reg=0.2,
tol=0.02,
l1_reg=0.2,
max_iterations=30,
optimize_strategy="batch_gradient_descent",
learn_rate_strategy="constant",
learn_rate=0.2,
)
df = penguins_df_default_index.dropna()
X_train = df[
Expand All @@ -203,12 +211,12 @@ def test_logistic_regression_customized_params_fit_score(
result = model.score(X_train, y_train).to_pandas()
expected = pd.DataFrame(
{
"precision": [0.58483],
"recall": [0.586616],
"accuracy": [0.877246],
"f1_score": [0.58571],
"log_loss": [1.032699],
"roc_auc": [0.924132],
"precision": [0.487],
"recall": [0.602],
"accuracy": [0.464],
"f1_score": [0.379],
"log_loss": [0.972],
"roc_auc": [0.700],
},
dtype="Float64",
)
Expand All @@ -223,5 +231,15 @@ def test_logistic_regression_customized_params_fit_score(
f"{dataset_id}.temp_configured_logistic_reg_model"
in reloaded_model._bqml_model.model_name
)
# TODO(garrettwu) optimize_strategy isn't logged in BQML
# assert reloaded_model.optimize_strategy == "BATCH_GRADIENT_DESCENT"
assert reloaded_model.fit_intercept is False
assert reloaded_model.class_weights == "balanced"
assert reloaded_model.calculate_p_values is False
assert reloaded_model.enable_global_explain is False
assert reloaded_model.l1_reg == 0.2
assert reloaded_model.l2_reg == 0.2
assert reloaded_model.ls_init_learn_rate is None
assert reloaded_model.max_iterations == 30
assert reloaded_model.tol == 0.02
assert reloaded_model.learn_rate_strategy == "CONSTANT"
assert reloaded_model.learn_rate == 0.2
18 changes: 13 additions & 5 deletions tests/unit/ml/test_golden_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_linear_regression_default_fit(
model.fit(mock_X, mock_y)

mock_session._start_query_ml_ddl.assert_called_once_with(
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="auto_strategy",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
)


Expand All @@ -115,7 +115,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X,
model.fit(mock_X, mock_y)

mock_session._start_query_ml_ddl.assert_called_once_with(
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="auto_strategy",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
)


Expand Down Expand Up @@ -148,21 +148,29 @@ def test_logistic_regression_default_fit(
model.fit(mock_X, mock_y)

mock_session._start_query_ml_ddl.assert_called_once_with(
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy="auto_strategy",\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
)


def test_logistic_regression_params_fit(
bqml_model_factory, mock_session, mock_X, mock_y
):
model = linear_model.LogisticRegression(
fit_intercept=False, class_weights="balanced"
fit_intercept=False,
class_weights="balanced",
l2_reg=0.2,
tol=0.02,
l1_reg=0.2,
max_iterations=30,
optimize_strategy="batch_gradient_descent",
learn_rate_strategy="constant",
learn_rate=0.2,
)
model._bqml_model_factory = bqml_model_factory
model.fit(mock_X, mock_y)

mock_session._start_query_ml_ddl.assert_called_once_with(
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy="batch_gradient_descent",\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy="constant",\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
)


Expand Down
4 changes: 2 additions & 2 deletions third_party/bigframes_vendored/sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ class LinearRegression(RegressorMixin, LinearModel):
the dataset, and the targets predicted by the linear approximation.
Args:
optimize_strategy (str, default "normal_equation"):
optimize_strategy (str, default "auto_strategy"):
The strategy to train linear regression models. Possible values are
"auto_strategy", "batch_gradient_descent", "normal_equation". Default
to "normal_equation".
to "auto_strategy".
fit_intercept (bool, default True):
Default ``True``. Whether to calculate the intercept for this
model. If set to False, no intercept will be used in calculations
Expand Down
24 changes: 24 additions & 0 deletions third_party/bigframes_vendored/sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class LogisticRegression(LinearClassifierMixin, BaseEstimator):
"""Logistic Regression (aka logit, MaxEnt) classifier.
Args:
optimize_strategy (str, default "auto_strategy"):
The strategy to train logistic regression models. Possible values are
"auto_strategy", "batch_gradient_descent", "normal_equation". Default
to "auto_strategy".
fit_intercept (default True):
Default True. Specifies if a constant (a.k.a. bias or intercept)
should be added to the decision function.
Expand All @@ -35,6 +39,26 @@ class LogisticRegression(LinearClassifierMixin, BaseEstimator):
frequencies in the input data as
``n_samples / (n_classes * np.bincount(y))``. Dict isn't
supported now.
l1_reg (float or None, default None):
The amount of L1 regularization applied. Default to None. Can't be set in "normal_equation" mode. If unset, value 0 is used.
l2_reg (float, default 0.0):
The amount of L2 regularization applied. Default to 0.
max_iterations (int, default 20):
The maximum number of training iterations or steps. Default to 20.
warm_start (bool, default False):
Determines whether to train a model with new training data, new model options, or both. Unless you explicitly override them, the initial options used to train the model are used for the warm start run. Default to False.
learn_rate (float or None, default None):
The learn rate for gradient descent when learn_rate_strategy='constant'. If unset, value 0.1 is used. If learn_rate_strategy='line_search', an error is returned.
learn_rate_strategy (str, default "line_search"):
The strategy for specifying the learning rate during training. Default to "line_search".
tol (float, default 0.01):
The minimum relative loss improvement that is necessary to continue training when EARLY_STOP is set to true. For example, a value of 0.01 specifies that each iteration must reduce the loss by 1% for training to continue. Default to 0.01.
ls_init_learn_rate (float or None, default None):
Sets the initial learning rate that learn_rate_strategy='line_search' uses. This option can only be used if line_search is specified. If unset, value 0.1 is used.
calculate_p_values (bool, default False):
Specifies whether to compute p-values and standard errors during training. Default to False.
enable_global_explain (bool, default False):
Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False.
"""

def fit(
Expand Down

0 comments on commit f959b65

Please sign in to comment.