Skip to content

Commit

Permalink
feat: support ML.GENERATE_EMBEDDING in PaLM2TextEmbeddingGenerator (
Browse files Browse the repository at this point in the history
#539)

* feat: support ML.GENERATE_EMBEDDING in PaLM2TextEmbeddingGenerator
  • Loading branch information
ashleyxuu authored and Genesis929 committed Apr 9, 2024
1 parent 7a426d8 commit 68607ad
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ def generate_text(
),
)

def generate_text_embedding(
def generate_embedding(
self,
input_data: bpd.DataFrame,
options: Mapping[str, int | float],
) -> bpd.DataFrame:
return self._apply_sql(
input_data,
lambda source_df: self._model_manipulation_sql_generator.ml_generate_text_embedding(
lambda source_df: self._model_manipulation_sql_generator.ml_generate_embedding(
source_df=source_df,
struct_options=options,
),
Expand Down
4 changes: 2 additions & 2 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
_GEMINI_PRO_ENDPOINT = "gemini-pro"

_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
_ML_EMBED_TEXT_STATUS = "ml_generate_embedding_status"


@log_adapter.class_logger
Expand Down Expand Up @@ -389,7 +389,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
"flatten_json_output": True,
}

df = self._bqml_model.generate_text_embedding(X, options)
df = self._bqml_model.generate_embedding(X, options)

if (df[_ML_EMBED_TEXT_STATUS] != "").any():
warnings.warn(
Expand Down
6 changes: 3 additions & 3 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,12 @@ def ml_generate_text(
return f"""SELECT * FROM ML.GENERATE_TEXT(MODEL `{self._model_name}`,
({self._source_sql(source_df)}), {struct_options_sql})"""

def ml_generate_text_embedding(
def ml_generate_embedding(
self, source_df: bpd.DataFrame, struct_options: Mapping[str, Union[int, float]]
) -> str:
"""Encode ML.GENERATE_TEXT_EMBEDDING for BQML"""
"""Encode ML.GENERATE_EMBEDDING for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `{self._model_name}`,
return f"""SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `{self._model_name}`,
({self._source_sql(source_df)}), {struct_options_sql})"""

def ml_detect_anomalies(
Expand Down
12 changes: 6 additions & 6 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def test_embedding_generator_predict_success(
):
df = palm2_embedding_generator_model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 4)
assert "text_embedding" in df.columns
series = df["text_embedding"]
assert "ml_generate_embedding_result" in df.columns
series = df["ml_generate_embedding_result"]
value = series[0]
assert len(value) == 768

Expand All @@ -273,8 +273,8 @@ def test_embedding_generator_multilingual_predict_success(
):
df = palm2_embedding_generator_multilingual_model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 4)
assert "text_embedding" in df.columns
series = df["text_embedding"]
assert "ml_generate_embedding_result" in df.columns
series = df["ml_generate_embedding_result"]
value = series[0]
assert len(value) == 768

Expand All @@ -285,8 +285,8 @@ def test_embedding_generator_predict_series_success(
):
df = palm2_embedding_generator_model.predict(llm_text_df["prompt"]).to_pandas()
assert df.shape == (3, 4)
assert "text_embedding" in df.columns
series = df["text_embedding"]
assert "ml_generate_embedding_result" in df.columns
series = df["ml_generate_embedding_result"]
value = series[0]
assert len(value) == 768

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,17 +373,17 @@ def test_ml_generate_text_correct(
)


def test_ml_generate_text_embedding_correct(
def test_ml_generate_embedding_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
):
sql = model_manipulation_sql_generator.ml_generate_text_embedding(
sql = model_manipulation_sql_generator.ml_generate_embedding(
source_df=mock_df,
struct_options={"option_key1": 1, "option_key2": 2.2},
)
assert (
sql
== """SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `my_project_id.my_dataset_id.my_model_id`,
== """SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project_id.my_dataset_id.my_model_id`,
(input_X_sql), STRUCT(
1 AS option_key1,
2.2 AS option_key2))"""
Expand Down

0 comments on commit 68607ad

Please sign in to comment.