From 59088faa28b21ad314bb8bbf87101750904e8260 Mon Sep 17 00:00:00 2001 From: bryant-nn Date: Thu, 14 Nov 2024 15:49:49 +0800 Subject: [PATCH] Enable embedding_function in RetrieveUserProxyAgent to create initial vector_db --- .../contrib/retrieve_user_proxy_agent.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index ee8f74bb9a6..812651aaec5 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -379,11 +379,24 @@ def _init_db(self): ) chunk_ids_set = set(chunk_ids) chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set] - docs = [ - Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx]) - for idx in chunk_ids_set_idx - if chunk_ids[idx] not in all_docs_ids - ] + + if self._embedding_function is None: + docs = [ + Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx]) + for idx in chunk_ids_set_idx + if chunk_ids[idx] not in all_docs_ids + ] + else: + docs = [ + Document( + id=chunk_ids[idx], + content=chunks[idx], + metadata=sources[idx], + embedding=self._embedding_function([chunks[idx]])[0] + ) + for idx in chunk_ids_set_idx + if chunk_ids[idx] not in all_docs_ids + ] self._vector_db.insert_docs(docs=docs, collection_name=self._collection_name, upsert=True)