Skip to content

Commit

Permalink
feat: add Gemini 1.5 stable models support (#945)
Browse files Browse the repository at this point in the history
* feat: add Gemini 1.5 stable models support

* add to loader
  • Loading branch information
GarrettWu authored and arwas11 committed Sep 9, 2024
1 parent a7985d2 commit 6701dd7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
12 changes: 10 additions & 2 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,14 @@
_GEMINI_PRO_ENDPOINT = "gemini-pro"
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
_GEMINI_1P5_PRO_001_ENDPOINT = "gemini-1.5-pro-001"
_GEMINI_1P5_FLASH_001_ENDPOINT = "gemini-1.5-flash-001"
_GEMINI_ENDPOINTS = (
_GEMINI_PRO_ENDPOINT,
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT,
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT,
_GEMINI_1P5_PRO_001_ENDPOINT,
_GEMINI_1P5_FLASH_001_ENDPOINT,
)

_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet"
Expand Down Expand Up @@ -728,7 +732,7 @@ class GeminiTextGenerator(base.BaseEstimator):
Args:
model_name (str, Default to "gemini-pro"):
The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514". Default to "gemini-pro".
The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514", "gemini-1.5-pro-001" and "gemini-1.5-flash-001". Default to "gemini-pro".
.. note::
"gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514" is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
Expand All @@ -750,7 +754,11 @@ def __init__(
self,
*,
model_name: Literal[
"gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"
"gemini-pro",
"gemini-1.5-pro-preview-0514",
"gemini-1.5-flash-preview-0514",
"gemini-1.5-pro-001",
"gemini-1.5-flash-001",
] = "gemini-pro",
session: Optional[bigframes.Session] = None,
connection_name: Optional[str] = None,
Expand Down
2 changes: 2 additions & 0 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator,
llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
llm._GEMINI_1P5_PRO_001_ENDPOINT: llm.GeminiTextGenerator,
llm._GEMINI_1P5_FLASH_001_ENDPOINT: llm.GeminiTextGenerator,
llm._CLAUDE_3_HAIKU_ENDPOINT: llm.Claude3TextGenerator,
llm._CLAUDE_3_SONNET_ENDPOINT: llm.Claude3TextGenerator,
llm._CLAUDE_3_5_SONNET_ENDPOINT: llm.Claude3TextGenerator,
Expand Down
26 changes: 22 additions & 4 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_create_load_text_embedding_generator_model(
("text-embedding-004", "text-multilingual-embedding-002"),
)
@pytest.mark.flaky(retries=2)
def test_gemini_text_embedding_generator_predict_default_params_success(
def test_text_embedding_generator_predict_default_params_success(
llm_text_df, model_name, session, bq_connection
):
text_embedding_model = llm.TextEmbeddingGenerator(
Expand All @@ -340,7 +340,13 @@ def test_gemini_text_embedding_generator_predict_default_params_success(

@pytest.mark.parametrize(
"model_name",
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
(
"gemini-pro",
"gemini-1.5-pro-preview-0514",
"gemini-1.5-flash-preview-0514",
"gemini-1.5-pro-001",
"gemini-1.5-flash-001",
),
)
def test_create_load_gemini_text_generator_model(
dataset_id, model_name, session, bq_connection
Expand All @@ -362,7 +368,13 @@ def test_create_load_gemini_text_generator_model(

@pytest.mark.parametrize(
"model_name",
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
(
"gemini-pro",
"gemini-1.5-pro-preview-0514",
"gemini-1.5-flash-preview-0514",
"gemini-1.5-pro-001",
"gemini-1.5-flash-001",
),
)
@pytest.mark.flaky(retries=2)
def test_gemini_text_generator_predict_default_params_success(
Expand All @@ -379,7 +391,13 @@ def test_gemini_text_generator_predict_default_params_success(

@pytest.mark.parametrize(
"model_name",
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
(
"gemini-pro",
"gemini-1.5-pro-preview-0514",
"gemini-1.5-flash-preview-0514",
"gemini-1.5-pro-001",
"gemini-1.5-flash-001",
),
)
@pytest.mark.flaky(retries=2)
def test_gemini_text_generator_predict_with_params_success(
Expand Down

0 comments on commit 6701dd7

Please sign in to comment.