From 7d09d43599915af0b34b04d4aff9fa3a5a89f73d Mon Sep 17 00:00:00 2001 From: Garrett Wu <6505921+GarrettWu@users.noreply.github.com> Date: Tue, 19 Mar 2024 21:37:50 -0700 Subject: [PATCH] feat: add ml PCA model params (#474) --- bigframes/ml/decomposition.py | 15 +++- tests/system/large/ml/test_decomposition.py | 71 +++++++++++++++++++ .../sklearn/decomposition/_pca.py | 4 +- 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/bigframes/ml/decomposition.py b/bigframes/ml/decomposition.py index 9dc60be78f..36fa28e141 100644 --- a/bigframes/ml/decomposition.py +++ b/bigframes/ml/decomposition.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union import bigframes_vendored.sklearn.decomposition._pca from google.cloud import bigquery @@ -35,8 +35,14 @@ class PCA( ): __doc__ = bigframes_vendored.sklearn.decomposition._pca.PCA.__doc__ - def __init__(self, n_components: int = 3): + def __init__( + self, + n_components: int = 3, + *, + svd_solver: Literal["full", "randomized", "auto"] = "auto", + ): self.n_components = n_components + self.svd_solver = svd_solver self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() @@ -44,12 +50,14 @@ def __init__(self, n_components: int = 3): def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> PCA: assert model.model_type == "PCA" - kwargs = {} + kwargs: dict = {} # See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun last_fitting = model.training_runs[-1]["trainingOptions"] if "numPrincipalComponents" in last_fitting: kwargs["n_components"] = int(last_fitting["numPrincipalComponents"]) + if "pcaSolver" in last_fitting: + kwargs["svd_solver"] = str(last_fitting["pcaSolver"]) new_pca = cls(**kwargs) new_pca._bqml_model = core.BqmlModel(session, model) @@ -69,6 +77,7 @@ def _fit( options={ "model_type": "PCA", "num_principal_components": self.n_components, + "pca_solver": self.svd_solver, }, ) return self diff --git a/tests/system/large/ml/test_decomposition.py b/tests/system/large/ml/test_decomposition.py index 953287def2..7932536e0c 100644 --- a/tests/system/large/ml/test_decomposition.py +++ b/tests/system/large/ml/test_decomposition.py @@ -84,3 +84,74 @@ def test_decomposition_configure_fit_score_predict( in reloaded_model._bqml_model.model_name ) assert reloaded_model.n_components == 3 + + +def test_decomposition_configure_fit_score_predict_params( + session, penguins_df_default_index, dataset_id +): + model = decomposition.PCA(n_components=5, svd_solver="randomized") + model.fit(penguins_df_default_index) + + new_penguins = session.read_pandas( + pd.DataFrame( + { + "tag_number": [1633, 1672, 1690], + "species": [ + "Adelie Penguin (Pygoscelis adeliae)", + "Gentoo penguin (Pygoscelis papua)", + "Adelie Penguin (Pygoscelis adeliae)", + ], + "island": ["Dream", "Biscoe", "Torgersen"], + "culmen_length_mm": [37.8, 46.5, 41.1], + "culmen_depth_mm": [18.1, 14.8, 18.6], + "flipper_length_mm": [193.0, 217.0, 189.0], + "body_mass_g": [3750.0, 5200.0, 3325.0], + "sex": ["MALE", "FEMALE", "MALE"], + } + ).set_index("tag_number") + ) + + # Check score to ensure the model was fitted + score_result = model.score(new_penguins).to_pandas() + score_expected = pd.DataFrame( + { + "total_explained_variance_ratio": [0.932897], + }, + dtype="Float64", + ) + score_expected = score_expected.reindex(index=score_expected.index.astype("Int64")) + + pd.testing.assert_frame_equal( + score_result, score_expected, check_exact=False, rtol=0.1 + ) + + result = model.predict(new_penguins).to_pandas() + expected = pd.DataFrame( + { + "principal_component_1": [-1.459, 2.258, -1.685], + "principal_component_2": [-1.120, -1.351, -0.874], + "principal_component_3": [-0.646, 0.443, -0.704], + "principal_component_4": [-0.539, 0.234, -0.571], + "principal_component_5": [-0.876, 0.122, 0.609], + }, + dtype="Float64", + index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"), + ) + + tests.system.utils.assert_pandas_df_equal_pca( + result, + expected, + check_exact=False, + rtol=0.1, + ) + + # save, load, check n_components to ensure configuration was kept + reloaded_model = model.to_gbq( + f"{dataset_id}.temp_configured_pca_model", replace=True + ) + assert ( + f"{dataset_id}.temp_configured_pca_model" + in reloaded_model._bqml_model.model_name + ) + assert reloaded_model.n_components == 5 + assert reloaded_model.svd_solver == "RANDOMIZED" diff --git a/third_party/bigframes_vendored/sklearn/decomposition/_pca.py b/third_party/bigframes_vendored/sklearn/decomposition/_pca.py index 30c9c3b0b6..25d67f64c4 100644 --- a/third_party/bigframes_vendored/sklearn/decomposition/_pca.py +++ b/third_party/bigframes_vendored/sklearn/decomposition/_pca.py @@ -32,9 +32,11 @@ class PCA(BaseEstimator, metaclass=ABCMeta): truncated SVD. Args: - n_components (Optional[int], default 3): + n_components (Optional[int], default 3): Number of components to keep. if n_components is not set all components are kept. + svd_solver ("full", "randomized" or "auto", default "auto"): + The solver to use to calculate the principal components. Details: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-pca#pca_solver. """