Skip to content

Commit

Permalink
fix: make "project-id" parameter optional during initialization (#1141)
Browse files Browse the repository at this point in the history
* Make project-id param optional
  • Loading branch information
Amnah199 authored Oct 17, 2024
1 parent 1e1b178 commit 2f12690
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class VertexAIGeminiChatGenerator:
from haystack.dataclasses import ChatMessage
from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator
gemini_chat = VertexAIGeminiChatGenerator(project_id=project_id)
gemini_chat = VertexAIGeminiChatGenerator()
messages = [ChatMessage.from_user("Tell me the name of a movie")]
res = gemini_chat.run(messages)
Expand All @@ -50,7 +50,7 @@ def __init__(
self,
*,
model: str = "gemini-1.5-flash",
project_id: str,
project_id: Optional[str] = None,
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
Expand All @@ -65,7 +65,7 @@ def __init__(
Authenticates using Google Cloud Application Default Credentials (ADCs).
For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc).
:param project_id: ID of the GCP project to use.
:param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication.
:param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
:param location: The default location to use when making API calls, if not set uses us-central-1.
Defaults to None.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class VertexAIGeminiGenerator:
from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator
gemini = VertexAIGeminiGenerator(project_id=project_id)
gemini = VertexAIGeminiGenerator()
result = gemini.run(parts = ["What is the most interesting thing you know?"])
for answer in result["replies"]:
print(answer)
Expand All @@ -54,7 +54,7 @@ def __init__(
self,
*,
model: str = "gemini-1.5-flash",
project_id: str,
project_id: Optional[str] = None,
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
Expand All @@ -69,7 +69,7 @@ def __init__(
Authenticates using Google Cloud Application Default Credentials (ADCs).
For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc).
:param project_id: ID of the GCP project to use.
:param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication.
:param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
:param location: The default location to use when making API calls, if not set uses us-central-1.
:param generation_config: The generation config to use.
Expand Down
19 changes: 10 additions & 9 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,12 @@ def test_init(mock_vertexai_init, _mock_generative_model):
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_to_dict(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
)
gemini = VertexAIGeminiChatGenerator()
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"project_id": None,
"location": None,
"generation_config": None,
"safety_settings": None,
Expand Down Expand Up @@ -132,6 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
Expand All @@ -144,7 +143,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"location": "TestLocation",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand Down Expand Up @@ -194,7 +193,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
{
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"project_id": "TestID123",
"project_id": None,
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
Expand All @@ -205,7 +204,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
)

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._project_id is None
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._tool_config is None
Expand All @@ -221,6 +220,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"project_id": "TestID123",
"location": "TestLocation",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
Expand Down Expand Up @@ -272,6 +272,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._location == "TestLocation"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._tool_config, ToolConfig)
Expand All @@ -296,7 +297,7 @@ def test_run(mock_generative_model):
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
]
gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None)
gemini = VertexAIGeminiChatGenerator()
response = gemini.run(messages=messages)

mock_model.send_message.assert_called_once()
Expand All @@ -321,7 +322,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback)
gemini = VertexAIGeminiChatGenerator(streaming_callback=streaming_callback)
messages = [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
Expand Down
21 changes: 12 additions & 9 deletions integrations/google_vertex/tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,12 @@ def test_init(mock_vertexai_init, _mock_generative_model):
@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")
def test_to_dict(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiGenerator(
project_id="TestID123",
)
gemini = VertexAIGeminiGenerator()
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"project_id": None,
"location": None,
"generation_config": None,
"safety_settings": None,
Expand Down Expand Up @@ -120,6 +118,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
Expand All @@ -131,7 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"location": "TestLocation",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand Down Expand Up @@ -181,7 +180,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
{
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"project_id": "TestID123",
"project_id": None,
"location": None,
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
Expand All @@ -194,7 +194,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
)

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._project_id is None
assert gemini._location is None
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._tool_config is None
Expand All @@ -210,6 +211,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"project_id": "TestID123",
"location": "TestLocation",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
Expand Down Expand Up @@ -261,6 +263,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._location == "TestLocation"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._generation_config, GenerationConfig)
Expand All @@ -277,7 +280,7 @@ def test_run(mock_generative_model):
mock_model.generate_content.return_value = MagicMock()
mock_generative_model.return_value = mock_model

gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None)
gemini = VertexAIGeminiGenerator()

response = gemini.run(["What's the weather like today?"])

Expand All @@ -303,7 +306,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback)
gemini = VertexAIGeminiGenerator(model="gemini-pro", streaming_callback=streaming_callback)
gemini.run(["Come on, stream!"])
assert streaming_callback_called

Expand Down

0 comments on commit 2f12690

Please sign in to comment.