From f187841fda5589e56d62d104cc13577ce252aa29 Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Thu, 28 Mar 2024 18:30:27 +0000 Subject: [PATCH 1/2] feat: support ML.GENERATE_EMBEDDING in PaLM2TextEmbeddingGenerator --- bigframes/ml/core.py | 4 ++-- bigframes/ml/llm.py | 4 ++-- bigframes/ml/sql.py | 6 +++--- tests/system/small/ml/test_llm.py | 8 ++++---- tests/unit/ml/test_sql.py | 6 +++--- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 03d9b806b9..04aaeec1bc 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -152,14 +152,14 @@ def generate_text( ), ) - def generate_text_embedding( + def generate_embedding( self, input_data: bpd.DataFrame, options: Mapping[str, int | float], ) -> bpd.DataFrame: return self._apply_sql( input_data, - lambda source_df: self._model_manipulation_sql_generator.ml_generate_text_embedding( + lambda source_df: self._model_manipulation_sql_generator.ml_generate_embedding( source_df=source_df, struct_options=options, ), diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 6c4ae2ea43..9335a247c0 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -44,7 +44,7 @@ _GEMINI_PRO_ENDPOINT = "gemini-pro" _ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" -_ML_EMBED_TEXT_STATUS = "ml_embed_text_status" +_ML_EMBED_TEXT_STATUS = "ml_generate_embedding_status" @log_adapter.class_logger @@ -389,7 +389,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: "flatten_json_output": True, } - df = self._bqml_model.generate_text_embedding(X, options) + df = self._bqml_model.generate_embedding(X, options) if (df[_ML_EMBED_TEXT_STATUS] != "").any(): warnings.warn( diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 807fadc06a..fab358cce3 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -270,12 +270,12 @@ def ml_generate_text( return f"""SELECT * FROM ML.GENERATE_TEXT(MODEL `{self._model_name}`, ({self._source_sql(source_df)}), {struct_options_sql})""" - def ml_generate_text_embedding( + def ml_generate_embedding( self, source_df: bpd.DataFrame, struct_options: Mapping[str, Union[int, float]] ) -> str: - """Encode ML.GENERATE_TEXT_EMBEDDING for BQML""" + """Encode ML.GENERATE_EMBEDDING for BQML""" struct_options_sql = self.struct_options(**struct_options) - return f"""SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `{self._model_name}`, ({self._source_sql(source_df)}), {struct_options_sql})""" def ml_detect_anomalies( diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 4d2ddfe513..f2c96560c5 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -261,8 +261,8 @@ def test_embedding_generator_predict_success( ): df = palm2_embedding_generator_model.predict(llm_text_df).to_pandas() assert df.shape == (3, 4) - assert "text_embedding" in df.columns - series = df["text_embedding"] + assert "ml_generate_embedding_result" in df.columns + series = df["ml_generate_embedding_result"] value = series[0] assert len(value) == 768 @@ -273,8 +273,8 @@ def test_embedding_generator_multilingual_predict_success( ): df = palm2_embedding_generator_multilingual_model.predict(llm_text_df).to_pandas() assert df.shape == (3, 4) - assert "text_embedding" in df.columns - series = df["text_embedding"] + assert "ml_generate_embedding_result" in df.columns + series = df["ml_generate_embedding_result"] value = series[0] assert len(value) == 768 diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index 913bab0379..5b1ff37775 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -373,17 +373,17 @@ def test_ml_generate_text_correct( ) -def test_ml_generate_text_embedding_correct( +def test_ml_generate_embedding_correct( model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, mock_df: bpd.DataFrame, ): - sql = model_manipulation_sql_generator.ml_generate_text_embedding( + sql = model_manipulation_sql_generator.ml_generate_embedding( source_df=mock_df, struct_options={"option_key1": 1, "option_key2": 2.2}, ) assert ( sql - == """SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project_id.my_dataset_id.my_model_id`, (input_X_sql), STRUCT( 1 AS option_key1, 2.2 AS option_key2))""" From a432b9df81cf6f54662c7e7806368670fc886288 Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Fri, 29 Mar 2024 04:09:39 +0000 Subject: [PATCH 2/2] fix failed test --- tests/system/small/ml/test_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index f2c96560c5..2e135bef7b 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -285,8 +285,8 @@ def test_embedding_generator_predict_series_success( ): df = palm2_embedding_generator_model.predict(llm_text_df["prompt"]).to_pandas() assert df.shape == (3, 4) - assert "text_embedding" in df.columns - series = df["text_embedding"] + assert "ml_generate_embedding_result" in df.columns + series = df["ml_generate_embedding_result"] value = series[0] assert len(value) == 768