diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 35bcf0a33c..a3cd065a55 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -55,10 +55,14 @@ _GEMINI_PRO_ENDPOINT = "gemini-pro" _GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514" _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514" +_GEMINI_1P5_PRO_001_ENDPOINT = "gemini-1.5-pro-001" +_GEMINI_1P5_FLASH_001_ENDPOINT = "gemini-1.5-flash-001" _GEMINI_ENDPOINTS = ( _GEMINI_PRO_ENDPOINT, _GEMINI_1P5_PRO_PREVIEW_ENDPOINT, _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, + _GEMINI_1P5_PRO_001_ENDPOINT, + _GEMINI_1P5_FLASH_001_ENDPOINT, ) _CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet" @@ -728,7 +732,7 @@ class GeminiTextGenerator(base.BaseEstimator): Args: model_name (str, Default to "gemini-pro"): - The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514". Default to "gemini-pro". + The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514", "gemini-1.5-pro-001" and "gemini-1.5-flash-001". Default to "gemini-pro". .. note:: "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514" is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the @@ -750,7 +754,11 @@ def __init__( self, *, model_name: Literal[ - "gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514" + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", ] = "gemini-pro", session: Optional[bigframes.Session] = None, connection_name: Optional[str] = None, diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index 7d75f4c65a..4e7e808260 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -63,6 +63,8 @@ llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, + llm._GEMINI_1P5_PRO_001_ENDPOINT: llm.GeminiTextGenerator, + llm._GEMINI_1P5_FLASH_001_ENDPOINT: llm.GeminiTextGenerator, llm._CLAUDE_3_HAIKU_ENDPOINT: llm.Claude3TextGenerator, llm._CLAUDE_3_SONNET_ENDPOINT: llm.Claude3TextGenerator, llm._CLAUDE_3_5_SONNET_ENDPOINT: llm.Claude3TextGenerator, diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 43e756019d..e3d2b51081 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -324,7 +324,7 @@ def test_create_load_text_embedding_generator_model( ("text-embedding-004", "text-multilingual-embedding-002"), ) @pytest.mark.flaky(retries=2) -def test_gemini_text_embedding_generator_predict_default_params_success( +def test_text_embedding_generator_predict_default_params_success( llm_text_df, model_name, session, bq_connection ): text_embedding_model = llm.TextEmbeddingGenerator( @@ -340,7 +340,13 @@ def test_gemini_text_embedding_generator_predict_default_params_success( @pytest.mark.parametrize( "model_name", - ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), + ( + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", + ), ) def test_create_load_gemini_text_generator_model( dataset_id, model_name, session, bq_connection @@ -362,7 +368,13 @@ def test_create_load_gemini_text_generator_model( @pytest.mark.parametrize( "model_name", - ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), + ( + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", + ), ) @pytest.mark.flaky(retries=2) def test_gemini_text_generator_predict_default_params_success( @@ -379,7 +391,13 @@ def test_gemini_text_generator_predict_default_params_success( @pytest.mark.parametrize( "model_name", - ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), + ( + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", + ), ) @pytest.mark.flaky(retries=2) def test_gemini_text_generator_predict_with_params_success(