Skip to content

Commit

Permalink
Extend testing
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Mar 29, 2024
1 parent 54b9fc5 commit e675c6c
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/test_client_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4844,7 +4844,8 @@ def test_max_new_tokens(max_new_tokens, temperature):
@wrap_test_forked
@pytest.mark.parametrize("base_model", vision_models)
@pytest.mark.parametrize("langchain_mode", ['LLM', 'MyData'])
def test_client1_image_qa(langchain_mode, base_model):
@pytest.mark.parametrize("langchain_action", [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value])
def test_client1_image_qa(langchain_action, langchain_mode, base_model):
inference_server = os.getenv('TEST_SERVER', 'https://gpt.h2o.ai')
if inference_server == 'https://gpt.h2o.ai':
auth_kwargs = dict(auth=('guest', 'guest'))
Expand Down Expand Up @@ -4872,6 +4873,7 @@ def test_client1_image_qa(langchain_mode, base_model):
visible_models=base_model,
stream_output=False,
langchain_mode=langchain_mode,
langchain_action=langchain_action,
h2ogpt_key=h2ogpt_key)
try:
res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
Expand All @@ -4897,7 +4899,8 @@ def test_client1_image_qa(langchain_mode, base_model):
@wrap_test_forked
@pytest.mark.parametrize("base_model", vision_models)
@pytest.mark.parametrize("langchain_mode", ['LLM', 'MyData'])
def test_client1_images_qa(langchain_mode, base_model):
@pytest.mark.parametrize("langchain_action", [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value])
def test_client1_images_qa(langchain_action, langchain_mode, base_model):
image_dir = 'pdf_images'
makedirs(image_dir)
os.system('pdftoppm tests/2403.09629.pdf %s/outputname -jpeg' % image_dir)
Expand Down Expand Up @@ -4930,6 +4933,7 @@ def test_client1_images_qa(langchain_mode, base_model):
visible_models=base_model,
stream_output=False,
langchain_mode=langchain_mode,
langchain_action=langchain_action,
h2ogpt_key=h2ogpt_key)
res_dict = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
response = ast.literal_eval(res_dict)['response']
Expand Down

0 comments on commit e675c6c

Please sign in to comment.