diff --git a/mentat/conversation.py b/mentat/conversation.py index 55a6c9721..397cf71c9 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -108,21 +108,20 @@ def add_transcript_message(self, transcript_message: TranscriptMessage): def add_user_message(self, message: str, image: Optional[str] = None): """Used for actual user input messages""" - content: List[ChatCompletionContentPartParam] = [ - { - "type": "text", - "text": message, - }, - ] + content: List[ChatCompletionContentPartParam] | str = message if image: - content.append( + content = [ + { + "type": "text", + "text": message, + }, { "type": "image_url", "image_url": { "url": image, }, }, - ) + ] self.add_transcript_message(UserMessage(message=content, prior_messages=None)) self.add_message(ChatCompletionUserMessageParam(role="user", content=content)) diff --git a/mentat/transcripts.py b/mentat/transcripts.py index 4400ff934..7975faece 100644 --- a/mentat/transcripts.py +++ b/mentat/transcripts.py @@ -9,7 +9,7 @@ class UserMessage(TypedDict): - message: list[ChatCompletionContentPartParam] + message: list[ChatCompletionContentPartParam] | str # We need this field so that it is included when we convert to JSON prior_messages: None diff --git a/tests/conversation_test.py b/tests/conversation_test.py index e3d027ce9..6106c44b8 100644 --- a/tests/conversation_test.py +++ b/tests/conversation_test.py @@ -28,3 +28,29 @@ def test_no_parser_prompt(mock_call_llm_api): assert len(conversation.get_messages()) == 1 config.no_parser_prompt = True assert len(conversation.get_messages()) == 0 + + +def test_add_user_message_with_and_without_image(mock_call_llm_api): + session_context = SESSION_CONTEXT.get() + conversation = session_context.conversation + + # Test with image + test_message = "Hello, World!" + test_image_url = "http://example.com/image.png" + conversation.add_user_message(test_message, test_image_url) + messages_with_image = conversation.get_messages() + assert len(messages_with_image) == 2 # System prompt + user message + user_message_content_with_image = messages_with_image[-1]["content"] + assert len(user_message_content_with_image) == 2 # Text + image + assert user_message_content_with_image[0]["type"] == "text" + assert user_message_content_with_image[0]["text"] == test_message + assert user_message_content_with_image[1]["type"] == "image_url" + assert user_message_content_with_image[1]["image_url"]["url"] == test_image_url + + # Test without image + conversation.clear_messages() + conversation.add_user_message(test_message) + messages_without_image = conversation.get_messages() + assert len(messages_without_image) == 2 # System prompt + user message + user_message_content_without_image = messages_without_image[-1]["content"] + assert user_message_content_without_image == test_message