Skip to content

Commit

Permalink
feat: optimize db connection when llm invoking (langgenius#2774)
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost authored Mar 10, 2024
1 parent d8b64c4 commit 4dd156c
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 5 deletions.
4 changes: 4 additions & 0 deletions api/core/app_runner/assistant_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def run(self, application_generate_entity: ApplicationGenerateEntity,
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING

db.session.refresh(conversation)
db.session.refresh(message)
db.session.close()

# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner(
Expand Down
2 changes: 2 additions & 0 deletions api/core/app_runner/basic_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity,
model=app_orchestration_config.model_config.model
)

db.session.close()

invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
Expand Down
8 changes: 8 additions & 0 deletions api/core/app_runner/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def process(self, stream: bool) -> Union[dict, Generator]:
Process generate task pipeline.
:return:
"""
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()

if stream:
return self._process_stream_response()
else:
Expand Down Expand Up @@ -303,6 +307,7 @@ def _process_stream_response(self) -> Generator:
.first()
)
db.session.refresh(agent_thought)
db.session.close()

if agent_thought:
response = {
Expand Down Expand Up @@ -330,6 +335,8 @@ def _process_stream_response(self) -> Generator:
.filter(MessageFile.id == event.message_file_id)
.first()
)
db.session.close()

# get extension
if '.' in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
Expand Down Expand Up @@ -413,6 +420,7 @@ def _save_message(self, llm_result: LLMResult) -> None:
usage = llm_result.usage

self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()

self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens
Expand Down
6 changes: 3 additions & 3 deletions api/core/application_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _generate_worker(self, flask_app: Flask,
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()

def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
Expand Down Expand Up @@ -233,8 +233,6 @@ def _handle_response(self, application_generate_entity: ApplicationGenerateEntit
else:
logger.exception(e)
raise e
finally:
db.session.remove()

def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity:
Expand Down Expand Up @@ -651,6 +649,7 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat

db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation = (
db.session.query(Conversation)
Expand Down Expand Up @@ -689,6 +688,7 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat

db.session.add(message)
db.session.commit()
db.session.refresh(message)

for file in application_generate_entity.files:
message_file = MessageFile(
Expand Down
22 changes: 20 additions & 2 deletions api/core/features/assistant_base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, tenant_id: str,
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
db.session.close()

# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
Expand Down Expand Up @@ -341,13 +342,16 @@ def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[
created_by=self.user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)

result.append((
message_file,
message.save_as
))

db.session.commit()

db.session.close()

return result

def create_agent_thought(self, message_id: str, message: str,
Expand Down Expand Up @@ -384,6 +388,8 @@ def create_agent_thought(self, message_id: str, message: str,

db.session.add(thought)
db.session.commit()
db.session.refresh(thought)
db.session.close()

self.agent_thought_count += 1

Expand All @@ -401,6 +407,10 @@ def save_agent_thought(self,
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()

if thought is not None:
agent_thought.thought = thought

Expand Down Expand Up @@ -451,6 +461,7 @@ def save_agent_thought(self,
agent_thought.tool_labels_str = json.dumps(labels)

db.session.commit()
db.session.close()

def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
"""
Expand Down Expand Up @@ -523,9 +534,14 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab
"""
convert tool variables to db variables
"""
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()

db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()

def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Expand Down Expand Up @@ -581,4 +597,6 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))

db.session.close()

return result

0 comments on commit 4dd156c

Please sign in to comment.