Skip to content

Commit

Permalink
fix vertex (#1242)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Dec 11, 2024
1 parent adba166 commit 0811b3b
Showing 1 changed file with 20 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

Expand Down Expand Up @@ -41,7 +42,7 @@ class VertexAIGeminiChatGenerator:
messages = [ChatMessage.from_user("Tell me the name of a movie")]
res = gemini_chat.run(messages)
print(res["replies"][0].content)
print(res["replies"][0].text)
>>> The Shawshank Redemption
```
"""
Expand Down Expand Up @@ -209,31 +210,31 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
def _message_to_part(self, message: ChatMessage) -> Part:
if message.role == ChatRole.ASSISTANT and message.name:
p = Part.from_dict({"function_call": {"name": message.name, "args": {}}})
for k, v in message.content.items():
for k, v in json.loads(message.text).items():
p.function_call.args[k] = v
return p
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
return Part.from_text(message.content)
elif message.role == ChatRole.FUNCTION:
return Part.from_function_response(name=message.name, response=message.content)
elif message.role == ChatRole.USER:
return self._convert_part(message.content)
elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT):
return Part.from_text(message.text)
elif message.is_from(ChatRole.FUNCTION):
return Part.from_function_response(name=message.name, response=message.text)
elif message.is_from(ChatRole.USER):
return self._convert_part(message.text)

def _message_to_content(self, message: ChatMessage) -> Content:
if message.role == ChatRole.ASSISTANT and message.name:
if message.is_from(ChatRole.ASSISTANT) and message.name:
part = Part.from_dict({"function_call": {"name": message.name, "args": {}}})
for k, v in message.content.items():
for k, v in json.loads(message.text).items():
part.function_call.args[k] = v
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
part = Part.from_text(message.content)
elif message.role == ChatRole.FUNCTION:
part = Part.from_function_response(name=message.name, response=message.content)
elif message.role == ChatRole.USER:
part = self._convert_part(message.content)
elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT):
part = Part.from_text(message.text)
elif message.is_from(ChatRole.FUNCTION):
part = Part.from_function_response(name=message.name, response=message.text)
elif message.is_from(ChatRole.USER):
part = self._convert_part(message.text)
else:
msg = f"Unsupported message role {message.role}"
raise ValueError(msg)
role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model"
role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model"
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
Expand Down Expand Up @@ -283,7 +284,7 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
elif part.function_call:
metadata["function_call"] = part.function_call
new_message = ChatMessage.from_assistant(
content=dict(part.function_call.args.items()), meta=metadata
content=json.dumps(dict(part.function_call.args)), meta=metadata
)
new_message.name = part.function_call.name
replies.append(new_message)
Expand Down Expand Up @@ -311,7 +312,7 @@ def _get_stream_response(
replies.append(ChatMessage.from_assistant(content, meta=metadata))
elif part.function_call:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
content = json.dumps(dict(part.function_call.args))
new_message = ChatMessage.from_assistant(content, meta=metadata)
new_message.name = part.function_call.name
replies.append(new_message)
Expand Down

0 comments on commit 0811b3b

Please sign in to comment.