diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 37b8c5c33a7e..0fc83bdb7593 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -62,26 +62,10 @@ def colored(x, *args, **kwargs): """ -def _is_termination_msg_retrievechat(message): - """Check if a message is a termination message.""" - if isinstance(message, dict): - message = message.get("content") - if message is None: - return False - cb = extract_code(message) - contain_code = False - for c in cb: - if c[0] == "python": - contain_code = True - break - return not contain_code - - class RetrieveUserProxyAgent(UserProxyAgent): def __init__( self, name="RetrieveChatAgent", # default set to RetrieveChatAgent - is_termination_msg: Optional[Callable[[Dict], bool]] = _is_termination_msg_retrievechat, human_input_mode: Optional[str] = "ALWAYS", retrieve_config: Optional[Dict] = None, # config for the retrieve agent **kwargs, @@ -135,7 +119,6 @@ def __init__( """ super().__init__( name=name, - is_termination_msg=is_termination_msg, human_input_mode=human_input_mode, **kwargs, ) @@ -164,7 +147,27 @@ def __init__( self._intermediate_answers = set() # the intermediate answers self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc - self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply) + self._is_termination_msg = self._is_termination_msg_retrievechat # update the termination message function + self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=1) + + def _is_termination_msg_retrievechat(self, message): + """Check if a message is a termination message. + For code generation, terminate when no code block is detected. Currently only detect python code blocks. + For question answering, terminate when don't update context, i.e., answer is given. + """ + if isinstance(message, dict): + message = message.get("content") + if message is None: + return False + cb = extract_code(message) + contain_code = False + for c in cb: + # todo: support more languages + if c[0] == "python": + contain_code = True + break + update_context_case1, update_context_case2 = self._check_update_context(message) + return not (contain_code or update_context_case1 or update_context_case2) @staticmethod def get_max_tokens(model="gpt-3.5-turbo"): @@ -231,6 +234,13 @@ def _generate_message(self, doc_contents, task="default"): raise NotImplementedError(f"task {task} is not implemented.") return message + def _check_update_context(self, message): + if isinstance(message, dict): + message = message.get("content", "") + update_context_case1 = "UPDATE CONTEXT" in message[-20:].upper() or "UPDATE CONTEXT" in message[:20].upper() + update_context_case2 = self.customized_answer_prefix and self.customized_answer_prefix not in message.upper() + return update_context_case1, update_context_case2 + def _generate_retrieve_user_reply( self, messages: Optional[List[Dict]] = None, @@ -247,13 +257,7 @@ def _generate_retrieve_user_reply( if messages is None: messages = self._oai_messages[sender] message = messages[-1] - update_context_case1 = ( - "UPDATE CONTEXT" in message.get("content", "")[-20:].upper() - or "UPDATE CONTEXT" in message.get("content", "")[:20].upper() - ) - update_context_case2 = ( - self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper() - ) + update_context_case1, update_context_case2 = self._check_update_context(message) if (update_context_case1 or update_context_case2) and self.update_context: print(colored("Updating context and resetting conversation.", "green"), flush=True) # extract the first sentence in the response as the intermediate answer