Skip to content

Commit

Permalink
Merge pull request #69 from 20001LastOrder/repeated_action
Browse files Browse the repository at this point in the history
reformulate query if the query is duplicate
  • Loading branch information
20001LastOrder authored Aug 16, 2023
2 parents 99e0363 + 368db99 commit 5ad07e8
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions apps/slackbot/task_agent.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import json
from os import environ
from typing import List, Optional

from pydantic import ValidationError
import openai

from langchain.chains.llm import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.schema import AIMessage, BaseMessage, Document, HumanMessage
from langchain.tools.base import BaseTool
from langchain.tools.human.tool import HumanInputRun
from langchain.vectorstores.base import VectorStoreRetriever
from pydantic import ValidationError

from output_parser import BaseTaskOutputParser, TaskOutputParser
from post_processors import md_link_to_slack
from prompt import SlackBotPrompt
from pydantic import ValidationError


class TaskAgent:
Expand Down Expand Up @@ -93,6 +91,7 @@ def run(self, task: str) -> str:

# Interaction Loop

previous_action = ""
while True:
# Discontinue if continuous limit is reached
loop_count = self.loop_count
Expand All @@ -117,7 +116,7 @@ def run(self, task: str) -> str:
user_input=user_input,
)
except openai.error.APIError as e:
return f"OpenAI API returned an API Error: {e}"
return f"OpenAI API returned an API Error: {e}"
except openai.error.APIConnectionError as e:
return f"Failed to connect to OpenAI API: {e}"
except openai.error.RateLimitError as e:
Expand All @@ -130,7 +129,6 @@ def run(self, task: str) -> str:
return f"OpenAI API Service unavailable: {e}"
except openai.error.InvalidRequestError as e:
return f"OpenAI API invalid request error: {e}"


assistant_reply = self.chain.run(
task=task,
Expand Down Expand Up @@ -177,6 +175,33 @@ def run(self, task: str) -> str:
action = self.output_parser.parse(assistant_reply)
print("action:", action)
tools = {t.name: t for t in self.tools}
if action == previous_action:
if action.name == "Search" or action.name == "Context Search":
print(
"Action name: ", action.name, "\nStart reformulating the query"
)
instruction = (
f"You want to search for useful information to answer the query: {task}."
f"The original query is: {action.args['query']}"
f"Reformulate the query so that it can be used to search for relevant information."
f"Only return one query instead of multiple queries."
f"Reformulated query:\n\n"
)
openai.api_key = environ.get("OPENAI_KEY")
response = openai.Completion.create(
engine="text-davinci-003",
prompt=" ".join(str(i) for i in self.previous_message)
+ "\n"
+ instruction,
temperature=0.7,
max_tokens=1024,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
)
reformulated_query = response["choices"][0]["text"]
action.args["query"] = reformulated_query

if action.name == "finish":
self.loop_count = self.max_iterations
result = "Finished task. "
Expand Down Expand Up @@ -220,6 +245,7 @@ def run(self, task: str) -> str:

# self.memory.add_documents([Document(page_content=memory_to_add)])
self.previous_message.append(HumanMessage(content=memory_to_add))
previous_action = action

def set_user_input(self, user_input: str):
result = f"Command UserInput returned: {user_input}"
Expand Down

0 comments on commit 5ad07e8

Please sign in to comment.