Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add transformers save/load #552

Merged
merged 3 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,33 @@ def fit(
return self._fit(X, y)


class Transformer(BaseEstimator):
class BaseTransformer(BaseEstimator):
"""Transformer base class."""

def __init__(self):
self._bqml_model: Optional[core.BqmlModel] = None

_T = TypeVar("_T", bound="BaseTransformer")

def to_gbq(self: _T, model_name: str, replace: bool = False) -> _T:
"""Save the transformer as a BigQuery model.

Args:
model_name (str):
the name of the model.
replace (bool, default False):
whether to replace if the model already exists. Default to False.

Returns:
Saved transformer."""
if not self._bqml_model:
raise RuntimeError("A transformer must be fitted before it can be saved")

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)


class Transformer(BaseTransformer):
"""A BigQuery DataFrames Transformer base class that transforms data.

Also the transformers can be attached to a pipeline with a predictor."""
Expand All @@ -199,7 +225,7 @@ def fit_transform(
return self.fit(X, y).transform(X)


class LabelTransformer(BaseEstimator):
class LabelTransformer(BaseTransformer):
"""A BigQuery DataFrames Label Transformer base class that transforms data.

Also the transformers can be attached to a pipeline with a predictor."""
Expand Down
54 changes: 5 additions & 49 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,11 @@
import bigframes_vendored.sklearn.compose._column_transformer
from google.cloud import bigquery

import bigframes
from bigframes import constants
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, preprocessing, utils
import bigframes.pandas as bpd

_PREPROCESSING_TYPES = Union[
preprocessing.OneHotEncoder,
preprocessing.StandardScaler,
preprocessing.MaxAbsScaler,
preprocessing.MinMaxScaler,
preprocessing.KBinsDiscretizer,
preprocessing.LabelEncoder,
]

_BQML_TRANSFROM_TYPE_MAPPING = types.MappingProxyType(
{
"ML.STANDARD_SCALER": preprocessing.StandardScaler,
Expand All @@ -67,7 +57,7 @@ def __init__(
transformers: List[
Tuple[
str,
_PREPROCESSING_TYPES,
preprocessing.PreprocessingType,
Union[str, List[str]],
]
],
Expand All @@ -82,12 +72,12 @@ def __init__(
@property
def transformers_(
self,
) -> List[Tuple[str, _PREPROCESSING_TYPES, str,]]:
) -> List[Tuple[str, preprocessing.PreprocessingType, str,]]:
"""The collection of transformers as tuples of (name, transformer, column)."""
result: List[
Tuple[
str,
_PREPROCESSING_TYPES,
preprocessing.PreprocessingType,
str,
]
] = []
Expand All @@ -105,15 +95,6 @@ def transformers_(

return result

@classmethod
def _from_bq(
cls, session: bigframes.Session, model: bigquery.Model
) -> ColumnTransformer:
col_transformer = cls._extract_from_bq_model(model)
col_transformer._bqml_model = core.BqmlModel(session, model)

return col_transformer

@classmethod
def _extract_from_bq_model(
cls,
Expand All @@ -125,7 +106,7 @@ def _extract_from_bq_model(
transformers: List[
Tuple[
str,
_PREPROCESSING_TYPES,
preprocessing.PreprocessingType,
Union[str, List[str]],
]
] = []
Expand Down Expand Up @@ -164,15 +145,7 @@ def camel_to_snake(name):

def _merge(
self, bq_model: bigquery.Model
) -> Union[
ColumnTransformer,
preprocessing.StandardScaler,
preprocessing.OneHotEncoder,
preprocessing.MaxAbsScaler,
preprocessing.MinMaxScaler,
preprocessing.KBinsDiscretizer,
preprocessing.LabelEncoder,
]:
) -> Union[ColumnTransformer, preprocessing.PreprocessingType,]:
"""Try to merge the column transformer to a simple transformer. Depends on all the columns in bq_model are transformed with the same transformer."""
transformers = self.transformers_

Expand Down Expand Up @@ -249,20 +222,3 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
bpd.DataFrame,
df[self._output_names],
)

def to_gbq(self, model_name: str, replace: bool = False) -> ColumnTransformer:
"""Save the transformer as a BigQuery model.

Args:
model_name (str):
the name of the model.
replace (bool, default False):
whether to replace if the model already exists. Default to False.

Returns:
ColumnTransformer: saved model."""
if not self._bqml_model:
raise RuntimeError("A transformer must be fitted before it can be saved")

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)
11 changes: 9 additions & 2 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from bigframes.ml import (
cluster,
compose,
core,
decomposition,
ensemble,
forecasting,
imported,
linear_model,
llm,
pipeline,
preprocessing,
utils,
)

Expand Down Expand Up @@ -81,6 +83,7 @@ def from_bq(
llm.PaLM2TextEmbeddingGenerator,
pipeline.Pipeline,
compose.ColumnTransformer,
preprocessing.PreprocessingType,
]:
"""Load a BQML model to BigQuery DataFrames ML.

Expand All @@ -107,8 +110,12 @@ def from_bq(


def _transformer_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
# TODO(garrettwu): add other transformers
return compose.ColumnTransformer._from_bq(session, bq_model)
transformer = compose.ColumnTransformer._extract_from_bq_model(bq_model)._merge(
bq_model
)
transformer._bqml_model = core.BqmlModel(session, bq_model)

return transformer


def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
Expand Down
10 changes: 10 additions & 0 deletions bigframes/ml/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,3 +639,13 @@ def transform(self, y: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
bpd.DataFrame,
df[self._output_names],
)


PreprocessingType = Union[
OneHotEncoder,
StandardScaler,
MaxAbsScaler,
MinMaxScaler,
KBinsDiscretizer,
LabelEncoder,
]
1 change: 1 addition & 0 deletions tests/system/large/ml/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,4 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
("standard_scaler", preprocessing.StandardScaler(), "flipper_length_mm"),
]
assert reloaded_transformer.transformers_ == expected
assert reloaded_transformer._bqml_model is not None
6 changes: 3 additions & 3 deletions tests/system/large/ml/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_pipeline_logistic_regression_fit_score_predict(
)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_pipeline_xgbregressor_fit_score_predict(session, penguins_df_default_index):
"""Test a supervised model with a minimal preprocessing step"""
pl = pipeline.Pipeline(
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_pipeline_xgbregressor_fit_score_predict(session, penguins_df_default_in
)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_pipeline_random_forest_classifier_fit_score_predict(
session, penguins_df_default_index
):
Expand Down Expand Up @@ -445,7 +445,7 @@ def test_pipeline_PCA_fit_score_predict(session, penguins_df_default_index):
)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_pipeline_standard_scaler_kmeans_fit_score_predict(
session, penguins_pandas_df_default_index
):
Expand Down
2 changes: 1 addition & 1 deletion tests/system/small/ml/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_remote_model_predict(
)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_model_generate_text(
bqml_palm2_text_generator_model: core.BqmlModel, llm_text_df
):
Expand Down
24 changes: 12 additions & 12 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_create_text_generator_32k_model(
assert reloaded_model.connection_name == bq_connection


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_create_text_generator_model_default_session(
bq_connection, llm_text_pandas_df, bigquery_client
):
Expand All @@ -76,7 +76,7 @@ def test_create_text_generator_model_default_session(
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_create_text_generator_32k_model_default_session(
bq_connection, llm_text_pandas_df, bigquery_client
):
Expand All @@ -103,7 +103,7 @@ def test_create_text_generator_32k_model_default_session(
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_create_text_generator_model_default_connection(
llm_text_pandas_df, bigquery_client
):
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_create_text_generator_model_default_connection(


# Marked as flaky only because BQML LLM is in preview, the service only has limited capacity, not stable enough.
@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_text_generator_predict_default_params_success(
palm2_text_generator_model, llm_text_df
):
Expand All @@ -142,7 +142,7 @@ def test_text_generator_predict_default_params_success(
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_text_generator_predict_series_default_params_success(
palm2_text_generator_model, llm_text_df
):
Expand All @@ -153,7 +153,7 @@ def test_text_generator_predict_series_default_params_success(
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_text_generator_predict_arbitrary_col_label_success(
palm2_text_generator_model, llm_text_df
):
Expand All @@ -165,7 +165,7 @@ def test_text_generator_predict_arbitrary_col_label_success(
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_text_generator_predict_with_params_success(
palm2_text_generator_model, llm_text_df
):
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_create_text_embedding_generator_multilingual_model_defaults(bq_connecti
assert model._bqml_model is not None


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_embedding_generator_predict_success(
palm2_embedding_generator_model, llm_text_df
):
Expand All @@ -267,7 +267,7 @@ def test_embedding_generator_predict_success(
assert len(value) == 768


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_embedding_generator_multilingual_predict_success(
palm2_embedding_generator_multilingual_model, llm_text_df
):
Expand All @@ -279,7 +279,7 @@ def test_embedding_generator_multilingual_predict_success(
assert len(value) == 768


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_embedding_generator_predict_series_success(
palm2_embedding_generator_model, llm_text_df
):
Expand All @@ -306,7 +306,7 @@ def test_create_gemini_text_generator_model(
assert reloaded_model.connection_name == bq_connection


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_gemini_text_generator_predict_default_params_success(
gemini_text_generator_model, llm_text_df
):
Expand All @@ -317,7 +317,7 @@ def test_gemini_text_generator_predict_default_params_success(
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2, delay=120)
@pytest.mark.flaky(retries=2)
def test_gemini_text_generator_predict_with_params_success(
gemini_text_generator_model, llm_text_df
):
Expand Down
Loading