From 2f12690ba6fae91168992ddaffc0a228ee49bc79 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 17 Oct 2024 16:21:26 +0200 Subject: [PATCH] fix: make "project-id" parameter optional during initialization (#1141) * Make project-id param optional --- .../generators/google_vertex/chat/gemini.py | 6 +++--- .../generators/google_vertex/gemini.py | 6 +++--- .../google_vertex/tests/chat/test_gemini.py | 19 +++++++++-------- .../google_vertex/tests/test_gemini.py | 21 +++++++++++-------- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index f09692daf..c52f76dc6 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -36,7 +36,7 @@ class VertexAIGeminiChatGenerator: from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator - gemini_chat = VertexAIGeminiChatGenerator(project_id=project_id) + gemini_chat = VertexAIGeminiChatGenerator() messages = [ChatMessage.from_user("Tell me the name of a movie")] res = gemini_chat.run(messages) @@ -50,7 +50,7 @@ def __init__( self, *, model: str = "gemini-1.5-flash", - project_id: str, + project_id: Optional[str] = None, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, @@ -65,7 +65,7 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 2b1c1b477..737f2e668 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -32,7 +32,7 @@ class VertexAIGeminiGenerator: from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator - gemini = VertexAIGeminiGenerator(project_id=project_id) + gemini = VertexAIGeminiGenerator() result = gemini.run(parts = ["What is the most interesting thing you know?"]) for answer in result["replies"]: print(answer) @@ -54,7 +54,7 @@ def __init__( self, *, model: str = "gemini-1.5-flash", - project_id: str, + project_id: Optional[str] = None, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, @@ -69,7 +69,7 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. :param generation_config: The generation config to use. diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 6b1308dab..0d77bd9c6 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -90,14 +90,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiChatGenerator( - project_id="TestID123", - ) + gemini = VertexAIGeminiChatGenerator() assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { "model": "gemini-1.5-flash", - "project_id": "TestID123", + "project_id": None, "location": None, "generation_config": None, "safety_settings": None, @@ -132,6 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiChatGenerator( project_id="TestID123", + location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], @@ -144,7 +143,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "init_parameters": { "model": "gemini-1.5-flash", "project_id": "TestID123", - "location": None, + "location": "TestLocation", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -194,7 +193,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { - "project_id": "TestID123", + "project_id": None, "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, @@ -205,7 +204,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): ) assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" + assert gemini._project_id is None assert gemini._safety_settings is None assert gemini._tools is None assert gemini._tool_config is None @@ -221,6 +220,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { "project_id": "TestID123", + "location": "TestLocation", "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, @@ -272,6 +272,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._tool_config, ToolConfig) @@ -296,7 +297,7 @@ def test_run(mock_generative_model): ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), ] - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) + gemini = VertexAIGeminiChatGenerator() response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() @@ -321,7 +322,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) + gemini = VertexAIGeminiChatGenerator(streaming_callback=streaming_callback) messages = [ ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 9ec3529d7..b3d6dd5f5 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -78,14 +78,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiGenerator( - project_id="TestID123", - ) + gemini = VertexAIGeminiGenerator() assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { "model": "gemini-1.5-flash", - "project_id": "TestID123", + "project_id": None, "location": None, "generation_config": None, "safety_settings": None, @@ -120,6 +118,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiGenerator( project_id="TestID123", + location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], @@ -131,7 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "init_parameters": { "model": "gemini-1.5-flash", "project_id": "TestID123", - "location": None, + "location": "TestLocation", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -181,7 +180,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { - "project_id": "TestID123", + "project_id": None, + "location": None, "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, @@ -194,7 +194,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): ) assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" + assert gemini._project_id is None + assert gemini._location is None assert gemini._safety_settings is None assert gemini._tools is None assert gemini._tool_config is None @@ -210,6 +211,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { "project_id": "TestID123", + "location": "TestLocation", "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, @@ -261,6 +263,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) @@ -277,7 +280,7 @@ def test_run(mock_generative_model): mock_model.generate_content.return_value = MagicMock() mock_generative_model.return_value = mock_model - gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None) + gemini = VertexAIGeminiGenerator() response = gemini.run(["What's the weather like today?"]) @@ -303,7 +306,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True - gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) + gemini = VertexAIGeminiGenerator(model="gemini-pro", streaming_callback=streaming_callback) gemini.run(["Come on, stream!"]) assert streaming_callback_called