Skip to content

Commit

Permalink
Support models/gemini-1.5-pro-latest https://ai.google.dev/models/gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Mar 27, 2024
1 parent 5b555a8 commit 2424f45
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
10 changes: 10 additions & 0 deletions src/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,23 @@ class LangChainAgent(Enum):
# gemini-1.0-pro
google_mapping = {
"gemini-pro": 30720,
"gemini-1.0-pro-latest": 30720,
"gemini-pro-vision": 12288,
"gemini-1.0-pro-vision-latest": 12288,
"gemini-1.0-ultra-latest": 30720,
"gemini-ultra": 30720,
"gemini-1.5-pro-latest": 1048576,
}

# FIXME: at least via current API:
google_mapping_outputs = {
"gemini-pro": 2048,
"gemini-1.0-pro-latest": 2048,
"gemini-pro-vision": 4096,
"gemini-1.0-pro-vision-latest": 4096,
"gemini-1.0-ultra-latest": 2048,
"gemini-ultra": 2048,
"gemini-1.5-pro-latest": 8192,
}

mistralai_mapping = {
Expand Down
2 changes: 1 addition & 1 deletion src/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def is_vision_model(base_model):
return is_gradio_vision_model(base_model) or \
base_model.startswith('claude-3-') or \
base_model in ['gpt-4-vision-preview', 'gpt-4-1106-vision-preview'] or \
base_model in ["gemini-pro-vision"]
base_model in ["gemini-pro-vision", "gemini-1.0-pro-vision-latest", "gemini-1.5-pro-latest"]


def get_prompt(prompt_type, prompt_dict, context, reduced, making_context, return_dict=False,
Expand Down
7 changes: 5 additions & 2 deletions tests/test_client_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4834,7 +4834,9 @@ def test_max_new_tokens(max_new_tokens, temperature):


@wrap_test_forked
@pytest.mark.parametrize("base_model", ['gpt-4-vision-preview', 'gemini-pro-vision', 'claude-3-haiku-20240307', 'liuhaotian/llava-v1.6-34b', 'liuhaotian/llava-v1.6-vicuna-13b'])
@pytest.mark.parametrize("base_model", ['gpt-4-vision-preview', 'gemini-pro-vision',
'gemini-1.5-pro-latest', 'claude-3-haiku-20240307', 'liuhaotian/llava-v1.6-34b',
'liuhaotian/llava-v1.6-vicuna-13b'])
@pytest.mark.parametrize("langchain_mode", ['LLM', 'MyData'])
def test_client1_image_qa(langchain_mode, base_model):
inference_server = os.getenv('TEST_SERVER', 'https://gpt.h2o.ai')
Expand Down Expand Up @@ -4888,7 +4890,8 @@ def test_client1_images_qa_proprietary():

from src.gen import get_inf_models
base_models = get_inf_models(inference_server)
base_models_touse = ['gemini-pro-vision', 'gpt-4-vision-preview', 'claude-3-haiku-20240307']
base_models_touse = ['gemini-pro-vision', 'gemini-1.5-pro-latest', 'gpt-4-vision-preview',
'claude-3-haiku-20240307']
assert len(set(base_models_touse).difference(set(base_models))) == 0
h2ogpt_key = os.environ['H2OGPT_H2OGPT_KEY']

Expand Down

0 comments on commit 2424f45

Please sign in to comment.