From 73fe0f89a96557afc4225521654978b96a2291b3 Mon Sep 17 00:00:00 2001 From: Garrett Wu <6505921+GarrettWu@users.noreply.github.com> Date: Tue, 19 Mar 2024 11:11:16 -0700 Subject: [PATCH] fix!: exclude remote models for .register() (#465) * fix: exclude remote models for .register() * fix mypy --- bigframes/ml/base.py | 1 + bigframes/ml/llm.py | 6 +++--- tests/system/small/ml/test_register.py | 17 ++++------------- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 9001987e9a..e58ed4feef 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -90,6 +90,7 @@ def __repr__(self): return prettyprinter.pformat(self) +# TODO(garrettwu): refactor to reflect the actual property. Now the class contains .register() method. class Predictor(BaseEstimator): """A BigQuery DataFrames ML Model base class that can be used to predict outputs.""" diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 79f6b90bfd..10c3cc51b2 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -48,7 +48,7 @@ @log_adapter.class_logger -class PaLM2TextGenerator(base.Predictor): +class PaLM2TextGenerator(base.BaseEstimator): """PaLM2 text generator LLM model. Args: @@ -258,7 +258,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator: @log_adapter.class_logger -class PaLM2TextEmbeddingGenerator(base.Predictor): +class PaLM2TextEmbeddingGenerator(base.BaseEstimator): """PaLM2 text embedding generator LLM model. Args: @@ -418,7 +418,7 @@ def to_gbq( @log_adapter.class_logger -class GeminiTextGenerator(base.Predictor): +class GeminiTextGenerator(base.BaseEstimator): """Gemini text generator LLM model. Args: diff --git a/tests/system/small/ml/test_register.py b/tests/system/small/ml/test_register.py index bcf1f4a5b0..6d8ff0a712 100644 --- a/tests/system/small/ml/test_register.py +++ b/tests/system/small/ml/test_register.py @@ -14,6 +14,8 @@ from typing import cast +import pytest + from bigframes.ml import core, imported, linear_model, llm @@ -54,19 +56,8 @@ def test_linear_reg_register_with_params( def test_palm2_text_generator_register( ephemera_palm2_text_generator_model: llm.PaLM2TextGenerator, ): - model = ephemera_palm2_text_generator_model - model.register() - - model_name = "bigframes_" + cast( - str, cast(core.BqmlModel, model._bqml_model).model.model_id - ) - # Only registered model contains the field, and the field includes project/dataset. Here only check model_id. - assert ( - model_name[:63] # truncated - in cast(core.BqmlModel, model._bqml_model).model.training_runs[-1][ - "vertexAiModelId" - ] - ) + with pytest.raises(AttributeError): + ephemera_palm2_text_generator_model.register() # type: ignore def test_imported_tensorflow_register(