diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index cf2f9b7176e7a..77781058c7340 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -10,7 +10,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageContentType from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder @@ -434,6 +434,22 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, ) stop = model_config.stop + vision_enabled = node_data.vision.enabled + for prompt_message in prompt_messages: + if not isinstance(prompt_message.content, str): + prompt_message_content = [] + for content_item in prompt_message.content: + if vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + prompt_message_content.append(content_item) + elif content_item.type == PromptMessageContentType.TEXT: + prompt_message_content.append(content_item) + + if len(prompt_message_content) > 1: + prompt_message.content = prompt_message_content + elif (len(prompt_message_content) == 1 + and prompt_message_content[0].type == PromptMessageContentType.TEXT): + prompt_message.content = prompt_message_content[0].data + return prompt_messages, stop @classmethod