Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RetrieveChat #6

Merged
merged 11 commits into from
Sep 27, 2023
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
pip install -e.[mathchat] datasets pytest
pip install -e.[mathchat,retrievechat] datasets pytest
pip uninstall -y openai
- name: Test with pytest
if: matrix.python-version != '3.10'
Expand Down
69 changes: 63 additions & 6 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
sonichi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import chromadb
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
Expand Down Expand Up @@ -122,6 +123,9 @@ def __init__(
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- no_update_context (Optional, bool): if True, will not apply `Update Context` for interactive retrieval. Default is False.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
"""
super().__init__(
Expand All @@ -143,11 +147,16 @@ def __init__(
self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True)
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.no_update_context = self._retrieve_config.get("no_update_context", False)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = False # the collection is not created
self._ipython = get_ipython()
self._doc_idx = -1 # the index of the current used doc
self._results = {} # the results of the current query
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply)

@staticmethod
Expand All @@ -161,17 +170,24 @@ def get_max_tokens(model="gpt-3.5-turbo"):
else:
return 4000

def _reset(self):
def _reset(self, intermediate=False):
self._doc_idx = -1 # the index of the current used doc
self._results = {} # the results of the current query
if not intermediate:
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc

def _get_context(self, results):
doc_contents = ""
current_tokens = 0
_doc_idx = self._doc_idx
_tmp_retrieve_count = 0
for idx, doc in enumerate(results["documents"][0]):
if idx <= _doc_idx:
continue
if results["ids"][0][idx] in self._doc_ids:
continue
_doc_tokens = num_tokens_from_text(doc)
if _doc_tokens > self._context_max_tokens:
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
Expand All @@ -185,14 +201,19 @@ def _get_context(self, results):
current_tokens += _doc_tokens
doc_contents += doc + "\n"
self._doc_idx = idx
self._doc_ids.append(results["ids"][0][idx])
self._doc_contents.append(doc)
_tmp_retrieve_count += 1
if _tmp_retrieve_count >= self.n_results:
break
return doc_contents

def _generate_message(self, doc_contents, task="default"):
if not doc_contents:
print(colored("No more context, will terminate.", "green"), flush=True)
return "TERMINATE"
if self.customized_prompt:
message = self.customized_prompt + "\nUser's question is: " + self.problem + "\nContext is: " + doc_contents
message = self.customized_prompt.format(input_question=self.problem, input_context=doc_contents)
elif task.upper() == "CODE":
message = PROMPT_CODE.format(input_question=self.problem, input_context=doc_contents)
elif task.upper() == "QA":
Expand All @@ -214,19 +235,54 @@ def _generate_retrieve_user_reply(
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
if (
update_context_case1 = (
"UPDATE CONTEXT" in message.get("content", "")[-20:].upper()
or "UPDATE CONTEXT" in message.get("content", "")[:20].upper()
):
)
update_context_case2 = (
self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper()
)
if (update_context_case1 or update_context_case2) and not self.no_update_context:
print(colored("Updating context and resetting conversation.", "green"), flush=True)
# extract the first sentence in the response as the intermediate answer
_message = message.get("content", "").split("\n")[0].strip()
_intermediate_info = re.split(r"(?<=[.!?])\s+", _message)
self._intermediate_answers.add(_intermediate_info[0])

if update_context_case1:
# try to get more context from the current retrieved doc results because the results may be too long to fit
# in the LLM context.
doc_contents = self._get_context(self._results)

# Always use self.problem as the query text to retrieve docs, but each time we replace the context with the
# next similar docs in the retrieved doc results.
if not doc_contents:
for _tmp_retrieve_count in range(1, 5):
self._reset(intermediate=True)
self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1))
doc_contents = self._get_context(self._results)
if doc_contents:
break
elif update_context_case2:
# Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar
# docs in the retrieved doc results to the context.
for _tmp_retrieve_count in range(5):
self._reset(intermediate=True)
self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1))
self._get_context(self._results)
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers)
if doc_contents:
break

self.clear_history()
sender.clear_history()
doc_contents = self._get_context(self._results)
return True, self._generate_message(doc_contents, task=self._task)
return False, None
else:
return False, None

def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
if not self._collection:
print("Trying to create collection.")
create_vector_db_from_dir(
dir_path=self._docs_path,
max_tokens=self._chunk_token_size,
Expand Down Expand Up @@ -263,6 +319,7 @@ def generate_init_message(self, problem: str, n_results: int = 20, search_string
self._reset()
self.retrieve_docs(problem, n_results, search_string)
self.problem = problem
self.n_results = n_results
doc_contents = self._get_context(self._results)
message = self._generate_message(doc_contents, self._task)
return message
Expand Down
14 changes: 11 additions & 3 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,18 @@ def create_vector_db_from_dir(
)

chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
# updates existing items, or adds them if they don't yet exist.
print(f"Found {len(chunks)} chunks.")
# upsert in batch of 40000
for i in range(0, len(chunks), 40000):
collection.upsert(
documents=chunks[
i : i + 40000
], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
ids=[f"doc_{i}" for i in range(i, i + 40000)], # unique for each doc
)
collection.upsert(
documents=chunks, # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
ids=[f"doc_{i}" for i in range(len(chunks))], # unique for each doc
documents=chunks[i : len(chunks)],
ids=[f"doc_{i}" for i in range(i, len(chunks))], # unique for each doc
)
except ValueError as e:
logger.warning(f"{e}")
Expand Down
Loading
Loading