diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index d9a3447bda40f..655a5a1c7c811 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -195,6 +195,10 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING + db.session.refresh(conversation) + db.session.refresh(message) + db.session.close() + # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: assistant_cot_runner = AssistantCotApplicationRunner( diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index 83f4f6929a165..d3c91337c8f5c 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -192,6 +192,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, model=app_orchestration_config.model_config.model ) + db.session.close() + invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_orchestration_config.model_config.parameters, diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 5fd635bc3b3b6..1cc56483ad377 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -89,6 +89,10 @@ def process(self, stream: bool) -> Union[dict, Generator]: Process generate task pipeline. :return: """ + db.session.refresh(self._conversation) + db.session.refresh(self._message) + db.session.close() + if stream: return self._process_stream_response() else: @@ -303,6 +307,7 @@ def _process_stream_response(self) -> Generator: .first() ) db.session.refresh(agent_thought) + db.session.close() if agent_thought: response = { @@ -330,6 +335,8 @@ def _process_stream_response(self) -> Generator: .filter(MessageFile.id == event.message_file_id) .first() ) + db.session.close() + # get extension if '.' in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' @@ -413,6 +420,7 @@ def _save_message(self, llm_result: LLMResult) -> None: usage = llm_result.usage self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) self._message.message_tokens = usage.prompt_tokens diff --git a/api/core/application_manager.py b/api/core/application_manager.py index e073eac4b97bb..9aca61c7bb40f 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -201,7 +201,7 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, queue_manager: ApplicationQueueManager, @@ -233,8 +233,6 @@ def _handle_response(self, application_generate_entity: ApplicationGenerateEntit else: logger.exception(e) raise e - finally: - db.session.remove() def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ -> AppOrchestrationConfigEntity: @@ -651,6 +649,7 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat db.session.add(conversation) db.session.commit() + db.session.refresh(conversation) else: conversation = ( db.session.query(Conversation) @@ -689,6 +688,7 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat db.session.add(message) db.session.commit() + db.session.refresh(message) for file in application_generate_entity.files: message_file = MessageFile( diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 0ee6436d1195e..1d9541070f881 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -114,6 +114,7 @@ def __init__(self, tenant_id: str, self.agent_thought_count = db.session.query(MessageAgentThought).filter( MessageAgentThought.message_id == self.message.id, ).count() + db.session.close() # check if model supports stream tool call llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) @@ -341,13 +342,16 @@ def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[ created_by=self.user_id, ) db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + result.append(( message_file, message.save_as )) - - db.session.commit() + db.session.close() + return result def create_agent_thought(self, message_id: str, message: str, @@ -384,6 +388,8 @@ def create_agent_thought(self, message_id: str, message: str, db.session.add(thought) db.session.commit() + db.session.refresh(thought) + db.session.close() self.agent_thought_count += 1 @@ -401,6 +407,10 @@ def save_agent_thought(self, """ Save agent thought """ + agent_thought = db.session.query(MessageAgentThought).filter( + MessageAgentThought.id == agent_thought.id + ).first() + if thought is not None: agent_thought.thought = thought @@ -451,6 +461,7 @@ def save_agent_thought(self, agent_thought.tool_labels_str = json.dumps(labels) db.session.commit() + db.session.close() def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ @@ -523,9 +534,14 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab """ convert tool variables to db variables """ + db_variables = db.session.query(ToolConversationVariables).filter( + ToolConversationVariables.conversation_id == self.message.conversation_id, + ).first() + db_variables.updated_at = datetime.utcnow() db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() + db.session.close() def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ @@ -581,4 +597,6 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P if message.answer: result.append(AssistantPromptMessage(content=message.answer)) + db.session.close() + return result \ No newline at end of file