From 3afa5b208f6444a5c00aa3a97df5344fd2cd56cc Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 10 Sep 2024 22:15:50 +0800 Subject: [PATCH] feat: add from_variable_selector for stream chunk / message event (#8228) --- .../app/apps/advanced_chat/generate_task_pipeline.py | 4 +++- api/core/app/apps/workflow/generate_task_pipeline.py | 11 ++++++++--- api/core/app/entities/task_entities.py | 2 ++ api/core/app/task_pipeline/message_cycle_manage.py | 11 +++++++++-- .../workflow/nodes/answer/answer_stream_processor.py | 1 + 5 files changed, 23 insertions(+), 6 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f3e1a49cc2243a..8f65a670c35949 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -451,7 +451,9 @@ def _process_stream_response( tts_publisher.publish(message=queue_message) self._task_state.answer += delta_text - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response( + answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation yield self._message_replace_to_stream_response(answer=event.text) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 904b6493811c07..215d02bdddfe6e 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -376,7 +376,9 @@ def _process_stream_response( tts_publisher.publish(message=queue_message) self._task_state.answer += delta_text - yield self._text_chunk_to_stream_response(delta_text) + yield self._text_chunk_to_stream_response( + delta_text, from_variable_selector=event.from_variable_selector + ) else: continue @@ -412,14 +414,17 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: db.session.commit() db.session.close() - def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: + def _text_chunk_to_stream_response( + self, text: str, from_variable_selector: Optional[list[str]] = None + ) -> TextChunkStreamResponse: """ Handle completed event. :param text: text :return: """ response = TextChunkStreamResponse( - task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text) + task_id=self._application_generate_entity.task_id, + data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), ) return response diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 0135c97172b486..49e5f55ebc2bde 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -90,6 +90,7 @@ class MessageStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE id: str answer: str + from_variable_selector: Optional[list[str]] = None class MessageAudioStreamResponse(StreamResponse): @@ -479,6 +480,7 @@ class Data(BaseModel): """ text: str + from_variable_selector: Optional[list[str]] = None event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 011daba6878ffa..5872e00740b61c 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -153,14 +153,21 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti return None - def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse: + def _message_to_stream_response( + self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None + ) -> MessageStreamResponse: """ Message to stream response. :param answer: answer :param message_id: message id :return: """ - return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer) + return MessageStreamResponse( + task_id=self._application_generate_entity.task_id, + id=message_id, + answer=answer, + from_variable_selector=from_variable_selector, + ) def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: """ diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 9776ce5810dbaa..32dbf436ec32da 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -108,6 +108,7 @@ def _generate_stream_outputs_when_node_finished( route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + from_variable_selector=[answer_node_id, "answer"], ) else: route_chunk = cast(VarGenerateRouteChunk, route_chunk)