diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index dca7e555f6..9b7228fe83 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -79,6 +79,11 @@ _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, _GEMINI_2_FLASH_EXP_ENDPOINT, ) +_GEMINI_FINE_TUNE_SCORE_ENDPOINTS = ( + _GEMINI_PRO_ENDPOINT, + _GEMINI_1P5_PRO_002_ENDPOINT, + _GEMINI_1P5_FLASH_002_ENDPOINT, +) _CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet" _CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku" @@ -890,7 +895,8 @@ def fit( X: utils.ArrayType, y: utils.ArrayType, ) -> GeminiTextGenerator: - """Fine tune GeminiTextGenerator model. Only support "gemini-pro" model for now. + """Fine tune GeminiTextGenerator model. Only support "gemini-pro", "gemini-1.5-pro-002", + "gemini-1.5-flash-002" models for now. .. note:: @@ -908,13 +914,18 @@ def fit( Returns: GeminiTextGenerator: Fitted estimator. """ - if self._bqml_model.model_name.startswith("gemini-1.5"): - raise NotImplementedError("Fit is not supported for gemini-1.5 model.") + if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: + raise NotImplementedError( + "fit() only supports gemini-pro, \ + gemini-1.5-pro-002, or gemini-1.5-flash-002 model." + ) X, y = utils.batch_convert_to_dataframe(X, y) options = self._bqml_options - options["endpoint"] = "gemini-1.0-pro-002" + options["endpoint"] = ( + "gemini-1.0-pro-002" if self.model_name == "gemini-pro" else self.model_name + ) options["prompt_col"] = X.columns.tolist()[0] self._bqml_model = self._bqml_model_factory.create_llm_remote_model( @@ -1025,7 +1036,7 @@ def score( "text_generation", "classification", "summarization", "question_answering" ] = "text_generation", ) -> bpd.DataFrame: - """Calculate evaluation metrics of the model. Only "gemini-pro" model is supported for now. + """Calculate evaluation metrics of the model. Only support "gemini-pro" and "gemini-1.5-pro-002", and "gemini-1.5-flash-002". .. note:: @@ -1057,9 +1068,11 @@ def score( if not self._bqml_model: raise RuntimeError("A model must be fitted before score") - # TODO(ashleyxu): Support gemini-1.5 when the rollout is ready. b/344891364. - if self._bqml_model.model_name.startswith("gemini-1.5"): - raise NotImplementedError("Score is not supported for gemini-1.5 model.") + if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: + raise NotImplementedError( + "score() only supports gemini-pro \ + , gemini-1.5-pro-002, and gemini-1.5-flash-2 model." + ) X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session) diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index 9ef60bae0b..45dd1667a6 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -38,12 +38,19 @@ def llm_remote_text_df(session, llm_remote_text_pandas_df): return session.read_pandas(llm_remote_text_pandas_df) -@pytest.mark.flaky(retries=2) +@pytest.mark.parametrize( + "model_name", + ( + "gemini-pro", + "gemini-1.5-pro-002", + "gemini-1.5-flash-002", + ), +) def test_llm_gemini_configure_fit( - session, llm_fine_tune_df_default_index, llm_remote_text_df + session, model_name, llm_fine_tune_df_default_index, llm_remote_text_df ): model = llm.GeminiTextGenerator( - session=session, model_name="gemini-pro", max_iterations=1 + session=session, model_name=model_name, max_iterations=1 ) X_train = llm_fine_tune_df_default_index[["prompt"]] @@ -69,7 +76,6 @@ def test_llm_gemini_configure_fit( ], index=3, ) - # TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept @pytest.mark.flaky(retries=2) diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 1690e8ab4c..a82be38017 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -417,9 +417,16 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): ) -@pytest.mark.flaky(retries=2) -def test_llm_gemini_pro_score(llm_fine_tune_df_default_index): - model = llm.GeminiTextGenerator(model_name="gemini-pro") +@pytest.mark.parametrize( + "model_name", + ( + "gemini-pro", + "gemini-1.5-pro-002", + "gemini-1.5-flash-002", + ), +) +def test_llm_gemini_score(llm_fine_tune_df_default_index, model_name): + model = llm.GeminiTextGenerator(model_name=model_name) # Check score to ensure the model was fitted score_result = model.score( @@ -439,9 +446,16 @@ def test_llm_gemini_pro_score(llm_fine_tune_df_default_index): ) -@pytest.mark.flaky(retries=2) -def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index): - model = llm.GeminiTextGenerator(model_name="gemini-pro") +@pytest.mark.parametrize( + "model_name", + ( + "gemini-pro", + "gemini-1.5-pro-002", + "gemini-1.5-flash-002", + ), +) +def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index, model_name): + model = llm.GeminiTextGenerator(model_name=model_name) # Check score to ensure the model was fitted score_result = model.score(