Skip to content

Commit

Permalink
Improve RetrieveChat (microsoft#6)
Browse files Browse the repository at this point in the history
* Upsert in batch

* Improve update context, support customized answer prefix

* Update tests

* Update intermediate answer

* Fix duplicate intermediate answer, add example 6 to notebook

* Add notebook results

* Works better without intermediate answers in the context

* Bump version to 0.1.2

* Remove commented code and add descriptions to _generate_retrieve_user_reply

---------

Co-authored-by: Qingyun Wu <[email protected]>
  • Loading branch information
thinkall and qingyun-wu authored Sep 27, 2023
1 parent 0e6e5db commit f2e17e5
Show file tree
Hide file tree
Showing 5 changed files with 538 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
pip install -e.[mathchat] datasets pytest
pip install -e.[mathchat,retrievechat] datasets pytest
pip uninstall -y openai
- name: Test with pytest
if: matrix.python-version != '3.10'
Expand Down
89 changes: 68 additions & 21 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import chromadb
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
Expand Down Expand Up @@ -122,6 +123,9 @@ def __init__(
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- no_update_context (Optional, bool): if True, will not apply `Update Context` for interactive retrieval. Default is False.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
"""
super().__init__(
Expand All @@ -143,11 +147,16 @@ def __init__(
self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True)
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.no_update_context = self._retrieve_config.get("no_update_context", False)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = False # the collection is not created
self._ipython = get_ipython()
self._doc_idx = -1 # the index of the current used doc
self._results = {} # the results of the current query
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply)

@staticmethod
Expand All @@ -161,17 +170,24 @@ def get_max_tokens(model="gpt-3.5-turbo"):
else:
return 4000

def _reset(self):
def _reset(self, intermediate=False):
self._doc_idx = -1 # the index of the current used doc
self._results = {} # the results of the current query
if not intermediate:
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc

def _get_context(self, results):
doc_contents = ""
current_tokens = 0
_doc_idx = self._doc_idx
_tmp_retrieve_count = 0
for idx, doc in enumerate(results["documents"][0]):
if idx <= _doc_idx:
continue
if results["ids"][0][idx] in self._doc_ids:
continue
_doc_tokens = num_tokens_from_text(doc)
if _doc_tokens > self._context_max_tokens:
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
Expand All @@ -185,14 +201,19 @@ def _get_context(self, results):
current_tokens += _doc_tokens
doc_contents += doc + "\n"
self._doc_idx = idx
self._doc_ids.append(results["ids"][0][idx])
self._doc_contents.append(doc)
_tmp_retrieve_count += 1
if _tmp_retrieve_count >= self.n_results:
break
return doc_contents

def _generate_message(self, doc_contents, task="default"):
if not doc_contents:
print(colored("No more context, will terminate.", "green"), flush=True)
return "TERMINATE"
if self.customized_prompt:
message = self.customized_prompt + "\nUser's question is: " + self.problem + "\nContext is: " + doc_contents
message = self.customized_prompt.format(input_question=self.problem, input_context=doc_contents)
elif task.upper() == "CODE":
message = PROMPT_CODE.format(input_question=self.problem, input_context=doc_contents)
elif task.upper() == "QA":
Expand All @@ -209,24 +230,64 @@ def _generate_retrieve_user_reply(
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""In this function, we will update the context and reset the conversation based on different conditions.
We'll update the context and reset the conversation if no_update_context is False and either of the following:
(1) the last message contains "UPDATE CONTEXT",
(2) the last message doesn't contain "UPDATE CONTEXT" and the customized_answer_prefix is not in the message.
"""
if config is None:
config = self
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
if (
update_context_case1 = (
"UPDATE CONTEXT" in message.get("content", "")[-20:].upper()
or "UPDATE CONTEXT" in message.get("content", "")[:20].upper()
):
)
update_context_case2 = (
self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper()
)
if (update_context_case1 or update_context_case2) and not self.no_update_context:
print(colored("Updating context and resetting conversation.", "green"), flush=True)
# extract the first sentence in the response as the intermediate answer
_message = message.get("content", "").split("\n")[0].strip()
_intermediate_info = re.split(r"(?<=[.!?])\s+", _message)
self._intermediate_answers.add(_intermediate_info[0])

if update_context_case1:
# try to get more context from the current retrieved doc results because the results may be too long to fit
# in the LLM context.
doc_contents = self._get_context(self._results)

# Always use self.problem as the query text to retrieve docs, but each time we replace the context with the
# next similar docs in the retrieved doc results.
if not doc_contents:
for _tmp_retrieve_count in range(1, 5):
self._reset(intermediate=True)
self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1))
doc_contents = self._get_context(self._results)
if doc_contents:
break
elif update_context_case2:
# Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar
# docs in the retrieved doc results to the context.
for _tmp_retrieve_count in range(5):
self._reset(intermediate=True)
self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1))
self._get_context(self._results)
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers)
if doc_contents:
break

self.clear_history()
sender.clear_history()
doc_contents = self._get_context(self._results)
return True, self._generate_message(doc_contents, task=self._task)
return False, None
else:
return False, None

def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
if not self._collection:
print("Trying to create collection.")
create_vector_db_from_dir(
dir_path=self._docs_path,
max_tokens=self._chunk_token_size,
Expand Down Expand Up @@ -263,6 +324,7 @@ def generate_init_message(self, problem: str, n_results: int = 20, search_string
self._reset()
self.retrieve_docs(problem, n_results, search_string)
self.problem = problem
self.n_results = n_results
doc_contents = self._get_context(self._results)
message = self._generate_message(doc_contents, self._task)
return message
Expand All @@ -278,21 +340,6 @@ def run_code(self, code, **kwargs):
if self._ipython is None or lang != "python":
return super().run_code(code, **kwargs)
else:
# # capture may not work as expected
# result = self._ipython.run_cell("%%capture --no-display cap\n" + code)
# log = self._ipython.ev("cap.stdout")
# log += self._ipython.ev("cap.stderr")
# if result.result is not None:
# log += str(result.result)
# exitcode = 0 if result.success else 1
# if result.error_before_exec is not None:
# log += f"\n{result.error_before_exec}"
# exitcode = 1
# if result.error_in_exec is not None:
# log += f"\n{result.error_in_exec}"
# exitcode = 1
# return exitcode, log, None

result = self._ipython.run_cell(code)
log = str(result.result)
exitcode = 0 if result.success else 1
Expand Down
14 changes: 11 additions & 3 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,18 @@ def create_vector_db_from_dir(
)

chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
# updates existing items, or adds them if they don't yet exist.
print(f"Found {len(chunks)} chunks.")
# upsert in batch of 40000
for i in range(0, len(chunks), 40000):
collection.upsert(
documents=chunks[
i : i + 40000
], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
ids=[f"doc_{i}" for i in range(i, i + 40000)], # unique for each doc
)
collection.upsert(
documents=chunks, # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
ids=[f"doc_{i}" for i in range(len(chunks))], # unique for each doc
documents=chunks[i : len(chunks)],
ids=[f"doc_{i}" for i in range(i, len(chunks))], # unique for each doc
)
except ValueError as e:
logger.warning(f"{e}")
Expand Down
Loading

0 comments on commit f2e17e5

Please sign in to comment.