Skip to content

Commit

Permalink
feat: Add chatrole tests and meta for GeminiChatGenerators (#1090)
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 authored Sep 24, 2024
1 parent 81be502 commit bca32be
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,17 +311,25 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
:param response_body: The response from Google AI request.
:returns: The extracted responses.
"""
replies = []
for candidate in response_body.candidates:
replies: List[ChatMessage] = []
metadata = response_body.to_dict()
for idx, candidate in enumerate(response_body.candidates):
candidate_metadata = metadata["candidates"][idx]
candidate_metadata.pop("content", None) # we remove content from the metadata

for part in candidate.content.parts:
if part.text != "":
replies.append(ChatMessage.from_assistant(part.text))
elif part.function_call is not None:
replies.append(
ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata)
)
elif part.function_call:
candidate_metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=candidate_metadata,
)
)
return replies
Expand All @@ -336,11 +344,26 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
replies: List[ChatMessage] = []
for chunk in stream:
content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else ""
streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict()))
responses.append(content)
content: Union[str, Dict[str, Any]] = ""
metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls
for candidate in chunk.candidates:
for part in candidate.content.parts:
if part.text != "":
content = part.text
replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None))
elif part.function_call is not None:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
replies.append(
ChatMessage(
content=content,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
streaming_callback(StreamingChunk(content=content, meta=metadata))
return replies
60 changes: 47 additions & 13 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses.chat_message import ChatMessage
from haystack.dataclasses.chat_message import ChatMessage, ChatRole

from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator

Expand Down Expand Up @@ -207,22 +207,35 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"location": "The city, e.g. San Francisco",
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool])
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

weather = get_current_weather(**res["replies"][0].content)
messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
# check the first response is a function call
chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}

res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
weather = get_current_weather(**chat_message.content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.content, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand All @@ -239,18 +252,37 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"location": "The city, e.g. San Francisco",
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback)
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
assert streaming_callback_called

# check the first response is a function call
chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**response["replies"][0].content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.content, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
def test_past_conversation():
Expand All @@ -261,5 +293,7 @@ def test_past_conversation():
ChatMessage.from_assistant(content="It's an arithmetic operation."),
ChatMessage.from_user(content="Yeah, but what's the result?"),
]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,24 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
:param response_body: The response from Vertex AI request.
:returns: The extracted responses.
"""
replies = []
replies: List[ChatMessage] = []
for candidate in response_body.candidates:
metadata = candidate.to_dict()
for part in candidate.content.parts:
# Remove content from metadata
metadata.pop("content", None)
if part._raw_part.text != "":
replies.append(ChatMessage.from_assistant(part.text))
elif part.function_call is not None:
replies.append(
ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)
)
elif part.function_call:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
return replies
Expand All @@ -254,11 +261,27 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
replies: List[ChatMessage] = []

for chunk in stream:
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict())
streaming_callback(streaming_chunk)
responses.append(streaming_chunk.content)
content: Union[str, Dict[str, Any]] = ""
metadata = chunk.to_dict() # we store whole chunk as metadata for streaming
for candidate in chunk.candidates:
for part in candidate.content.parts:
if part._raw_part.text:
content = chunk.text
replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif part.function_call:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
replies.append(
ChatMessage(
content=content,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
streaming_callback(StreamingChunk(content=content, meta=metadata))

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
return replies
18 changes: 10 additions & 8 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from vertexai.generative_models import (
Content,
FunctionDeclaration,
Expand Down Expand Up @@ -249,9 +249,12 @@ def test_run(mock_generative_model):
ChatMessage.from_user("What's the capital of France?"),
]
gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None)
gemini.run(messages=messages)
response = gemini.run(messages=messages)

mock_model.send_message.assert_called_once()
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
Expand All @@ -260,25 +263,24 @@ def test_run_with_streaming_callback(mock_generative_model):
mock_responses = iter(
[MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")]
)

mock_model.send_message.return_value = mock_responses
mock_model.start_chat.return_value = mock_model
mock_generative_model.return_value = mock_model

streaming_callback_called = []

def streaming_callback(chunk: StreamingChunk) -> None:
streaming_callback_called.append(chunk.content)
def streaming_callback(_chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback)
messages = [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
]
gemini.run(messages=messages)

response = gemini.run(messages=messages)
mock_model.send_message.assert_called_once()
assert streaming_callback_called == ["First part", "Second part"]
assert "replies" in response


def test_serialization_deserialization_pipeline():
Expand Down

0 comments on commit bca32be

Please sign in to comment.