Skip to content

Commit

Permalink
fix!: exclude remote models for .register() (#465)
Browse files Browse the repository at this point in the history
* fix: exclude remote models for .register()

* fix mypy
  • Loading branch information
GarrettWu authored Mar 19, 2024
1 parent 3971bd2 commit 73fe0f8
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 16 deletions.
1 change: 1 addition & 0 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __repr__(self):
return prettyprinter.pformat(self)


# TODO(garrettwu): refactor to reflect the actual property. Now the class contains .register() method.
class Predictor(BaseEstimator):
"""A BigQuery DataFrames ML Model base class that can be used to predict outputs."""

Expand Down
6 changes: 3 additions & 3 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


@log_adapter.class_logger
class PaLM2TextGenerator(base.Predictor):
class PaLM2TextGenerator(base.BaseEstimator):
"""PaLM2 text generator LLM model.
Args:
Expand Down Expand Up @@ -258,7 +258,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:


@log_adapter.class_logger
class PaLM2TextEmbeddingGenerator(base.Predictor):
class PaLM2TextEmbeddingGenerator(base.BaseEstimator):
"""PaLM2 text embedding generator LLM model.
Args:
Expand Down Expand Up @@ -418,7 +418,7 @@ def to_gbq(


@log_adapter.class_logger
class GeminiTextGenerator(base.Predictor):
class GeminiTextGenerator(base.BaseEstimator):
"""Gemini text generator LLM model.
Args:
Expand Down
17 changes: 4 additions & 13 deletions tests/system/small/ml/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing import cast

import pytest

from bigframes.ml import core, imported, linear_model, llm


Expand Down Expand Up @@ -54,19 +56,8 @@ def test_linear_reg_register_with_params(
def test_palm2_text_generator_register(
ephemera_palm2_text_generator_model: llm.PaLM2TextGenerator,
):
model = ephemera_palm2_text_generator_model
model.register()

model_name = "bigframes_" + cast(
str, cast(core.BqmlModel, model._bqml_model).model.model_id
)
# Only registered model contains the field, and the field includes project/dataset. Here only check model_id.
assert (
model_name[:63] # truncated
in cast(core.BqmlModel, model._bqml_model).model.training_runs[-1][
"vertexAiModelId"
]
)
with pytest.raises(AttributeError):
ephemera_palm2_text_generator_model.register() # type: ignore


def test_imported_tensorflow_register(
Expand Down

0 comments on commit 73fe0f8

Please sign in to comment.