Skip to content

Commit

Permalink
feat: add ml PCA model params (#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
GarrettWu authored and TrevorBergeron committed Mar 20, 2024
1 parent 55547a8 commit 7d09d43
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
15 changes: 12 additions & 3 deletions bigframes/ml/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import List, Optional, Union
from typing import List, Literal, Optional, Union

import bigframes_vendored.sklearn.decomposition._pca
from google.cloud import bigquery
Expand All @@ -35,21 +35,29 @@ class PCA(
):
__doc__ = bigframes_vendored.sklearn.decomposition._pca.PCA.__doc__

def __init__(self, n_components: int = 3):
def __init__(
self,
n_components: int = 3,
*,
svd_solver: Literal["full", "randomized", "auto"] = "auto",
):
self.n_components = n_components
self.svd_solver = svd_solver
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()

@classmethod
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> PCA:
assert model.model_type == "PCA"

kwargs = {}
kwargs: dict = {}

# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
last_fitting = model.training_runs[-1]["trainingOptions"]
if "numPrincipalComponents" in last_fitting:
kwargs["n_components"] = int(last_fitting["numPrincipalComponents"])
if "pcaSolver" in last_fitting:
kwargs["svd_solver"] = str(last_fitting["pcaSolver"])

new_pca = cls(**kwargs)
new_pca._bqml_model = core.BqmlModel(session, model)
Expand All @@ -69,6 +77,7 @@ def _fit(
options={
"model_type": "PCA",
"num_principal_components": self.n_components,
"pca_solver": self.svd_solver,
},
)
return self
Expand Down
71 changes: 71 additions & 0 deletions tests/system/large/ml/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,74 @@ def test_decomposition_configure_fit_score_predict(
in reloaded_model._bqml_model.model_name
)
assert reloaded_model.n_components == 3


def test_decomposition_configure_fit_score_predict_params(
session, penguins_df_default_index, dataset_id
):
model = decomposition.PCA(n_components=5, svd_solver="randomized")
model.fit(penguins_df_default_index)

new_penguins = session.read_pandas(
pd.DataFrame(
{
"tag_number": [1633, 1672, 1690],
"species": [
"Adelie Penguin (Pygoscelis adeliae)",
"Gentoo penguin (Pygoscelis papua)",
"Adelie Penguin (Pygoscelis adeliae)",
],
"island": ["Dream", "Biscoe", "Torgersen"],
"culmen_length_mm": [37.8, 46.5, 41.1],
"culmen_depth_mm": [18.1, 14.8, 18.6],
"flipper_length_mm": [193.0, 217.0, 189.0],
"body_mass_g": [3750.0, 5200.0, 3325.0],
"sex": ["MALE", "FEMALE", "MALE"],
}
).set_index("tag_number")
)

# Check score to ensure the model was fitted
score_result = model.score(new_penguins).to_pandas()
score_expected = pd.DataFrame(
{
"total_explained_variance_ratio": [0.932897],
},
dtype="Float64",
)
score_expected = score_expected.reindex(index=score_expected.index.astype("Int64"))

pd.testing.assert_frame_equal(
score_result, score_expected, check_exact=False, rtol=0.1
)

result = model.predict(new_penguins).to_pandas()
expected = pd.DataFrame(
{
"principal_component_1": [-1.459, 2.258, -1.685],
"principal_component_2": [-1.120, -1.351, -0.874],
"principal_component_3": [-0.646, 0.443, -0.704],
"principal_component_4": [-0.539, 0.234, -0.571],
"principal_component_5": [-0.876, 0.122, 0.609],
},
dtype="Float64",
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
)

tests.system.utils.assert_pandas_df_equal_pca(
result,
expected,
check_exact=False,
rtol=0.1,
)

# save, load, check n_components to ensure configuration was kept
reloaded_model = model.to_gbq(
f"{dataset_id}.temp_configured_pca_model", replace=True
)
assert (
f"{dataset_id}.temp_configured_pca_model"
in reloaded_model._bqml_model.model_name
)
assert reloaded_model.n_components == 5
assert reloaded_model.svd_solver == "RANDOMIZED"
4 changes: 3 additions & 1 deletion third_party/bigframes_vendored/sklearn/decomposition/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ class PCA(BaseEstimator, metaclass=ABCMeta):
truncated SVD.
Args:
n_components (Optional[int], default 3):
n_components (Optional[int], default 3):
Number of components to keep. if n_components is not set all components
are kept.
svd_solver ("full", "randomized" or "auto", default "auto"):
The solver to use to calculate the principal components. Details: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-pca#pca_solver.
"""

Expand Down

0 comments on commit 7d09d43

Please sign in to comment.