From 7b967ee323b393c6c0b6f117b5e4c21dd47dd12e Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Fri, 10 Nov 2023 15:33:40 +0530 Subject: [PATCH 01/52] update --- .../contrib/retrieve_user_proxy_agent.py | 59 ++++++-------- autogen/retriever/__init__.py | 10 +++ autogen/retriever/base.py | 65 +++++++++++++++ autogen/retriever/chromadb.py | 79 +++++++++++++++++++ autogen/retriever/lancedb.py | 9 +++ 5 files changed, 187 insertions(+), 35 deletions(-) create mode 100644 autogen/retriever/__init__.py create mode 100644 autogen/retriever/base.py create mode 100644 autogen/retriever/chromadb.py create mode 100644 autogen/retriever/lancedb.py diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index b24249bbe96..e430e31f8ae 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -1,14 +1,11 @@ import re -try: - import chromadb -except ImportError: - raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`") from autogen.agentchat.agent import Agent from autogen.agentchat import UserProxyAgent from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db from autogen.token_count_utils import count_token from autogen.code_utils import extract_code +from autogen.retriever import get_retriever from typing import Callable, Dict, Optional, Union, List, Tuple, Any from IPython import get_ipython @@ -95,8 +92,7 @@ def __init__( To use default config, set to None. Otherwise, set to a dictionary with the following keys: - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System prompt will be different for different tasks. The default value is `default`, which supports both code and qa. - - client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()` - will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. + - client (Optional, Any): the vectordb client/connection. If key not provided, the Retreiver class should handle it. - docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file, or the url to a single file. Default is None, which works only if the collection is already created. - collection_name (Optional, str): the name of the collection. @@ -123,7 +119,7 @@ def __init__( If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. - get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat. - This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None. + This is the same as that used in retriever. Default is False. Will be set to False if docs_path is None. - custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string. The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function. Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models. @@ -132,7 +128,7 @@ def __init__( **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). Example of overriding retrieve_docs: - If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code. + If you have set up a customized vector db, and it's not compatible with retriever, you can easily plug in it with below code. ```python class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): def query_vector_db( @@ -164,8 +160,9 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ) self._retrieve_config = {} if retrieve_config is None else retrieve_config + self._retriever_type = self._retrieve_config.get("retriever_type", "chromadb") self._task = self._retrieve_config.get("task", "default") - self._client = self._retrieve_config.get("client", chromadb.Client()) + self._client = self._retrieve_config.get("client", None) self._docs_path = self._retrieve_config.get("docs_path", None) self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs") self._model = self._retrieve_config.get("model", "gpt-4") @@ -345,13 +342,9 @@ def _generate_retrieve_user_reply( def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): """Retrieve docs based on the given problem and assign the results to the class property `_results`. - In case you want to customize the retrieval process, such as using a different vector db whose APIs are not - compatible with chromadb or filter results with metadata, you can override this function. Just keep the current - parameters and add your own parameters with default values, and keep the results in below type. Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of - the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer - to `chromadb.api.types.QueryResult` as an example. + the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. ids: List[string] documents: List[List[string]] @@ -362,29 +355,25 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = """ if not self._collection or self._get_or_create: print("Trying to create collection.") - self._client = create_vector_db_from_dir( - dir_path=self._docs_path, - max_tokens=self._chunk_token_size, - client=self._client, - collection_name=self._collection_name, - chunk_mode=self._chunk_mode, - must_break_at_empty_line=self._must_break_at_empty_line, - embedding_model=self._embedding_model, - get_or_create=self._get_or_create, - embedding_function=self._embedding_function, - custom_text_split_function=self.custom_text_split_function, - ) + retriever_class = get_retriever(self._retriever_type) + self.retriever = retriever_class( + name=self._collection_name, + embedding_model_name=self._embedding_model, + embedding_function=self._embedding_function, + max_tokens= self._chunk_token_size, + chunk_mode = self._chunk_mode, + must_break_at_empty_line = self._must_break_at_empty_line, + custom_text_split_function = self.custom_text_split_function, + use_existing=not self._get_or_create, + client=self._client + ) self._collection = True self._get_or_create = False - - results = query_vector_db( - query_texts=[problem], - n_results=n_results, - search_string=search_string, - client=self._client, - collection_name=self._collection_name, - embedding_model=self._embedding_model, - embedding_function=self._embedding_function, + self.retriever.ingest_data(self._docs_path) + results = self.retriever.query( + texts=[problem], + top_k=n_results, + filter=search_string, ) self._results = results print("doc_ids: ", results["ids"]) diff --git a/autogen/retriever/__init__.py b/autogen/retriever/__init__.py new file mode 100644 index 00000000000..9f74bee10ab --- /dev/null +++ b/autogen/retriever/__init__.py @@ -0,0 +1,10 @@ +from .chromadb import ChromaDB +from .lancedb import LanceDB + +def get_retriever(type:str): + if type == "chromadb": + return ChromaDB + elif type == "lancedb": + return LanceDB + else: + raise ValueError(f"Unknown retriever type {type}") \ No newline at end of file diff --git a/autogen/retriever/base.py b/autogen/retriever/base.py new file mode 100644 index 00000000000..20133f5ff58 --- /dev/null +++ b/autogen/retriever/base.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod +from typing import List, Union, Callable, Any + +class Retriever(ABC): + def __init__(self, path="./db", + name="vectorstore", + embedding_model_name="all-MiniLM-L6-v2", + embedding_function=None, + max_tokens: int = 4000, + chunk_mode: str = "multi_lines", + must_break_at_empty_line: bool = True, + custom_text_split_function: Callable = None, + use_existing=True, + client=None + ): + """ + Args: + path: path to the folder where the database is stored + name: name of the database + embedding_model_name: name of the embedding model to use + embedding_function: function to use to embed the text + max_tokens: maximum number of tokens to embed + chunk_mode: mode to chunk the text. Can be "multi_lines" or "single_line" + must_break_at_empty_line: whether to break the text at empty lines when chunking + custom_text_split_function: custom function to split the text into chunks + """ + self.path = path + self.name = name + self.embedding_model_name = embedding_model_name + self.embedding_function = embedding_function + self.max_tokens = max_tokens + self.chunk_mode = chunk_mode + self.must_break_at_empty_line = must_break_at_empty_line + self.custom_text_split_function = custom_text_split_function + self.use_existing = use_existing + self.client = client + + self.init_db() + + + @abstractmethod + def ingest_data(self, data_dir): + """ + Create a vector database from a directory of files. + Args: + data_dir: path to the directory containing the text files + """ + pass + + @abstractmethod + def query(self, texts: List[str], top_k: int = 10, filter: Any=None): + """ + Query the database. + Args: + query: query string or list of query strings + top_k: number of results to return + """ + pass + + @abstractmethod + def init_db(self): + """ + Initialize the database. + """ + pass \ No newline at end of file diff --git a/autogen/retriever/chromadb.py b/autogen/retriever/chromadb.py new file mode 100644 index 00000000000..8020cc3c0bb --- /dev/null +++ b/autogen/retriever/chromadb.py @@ -0,0 +1,79 @@ +from typing import Callable, List +from .base import Retriever +from autogen.retrieve_utils import ( + split_text_to_chunks, + extract_text_from_pdf, + split_files_to_chunks, + get_files_from_dir +) +try: + import chromadb + if chromadb.__version__ < "0.4.15": + from chromadb.api import API + else: + from chromadb.api import ClientAPI as API + from chromadb.api.types import QueryResult + import chromadb.utils.embedding_functions as ef +except ImportError: + raise ImportError("Please install chromadb: pip install chromadb") + +class ChromaDB(Retriever): + def init_db(self): + if self.client is None: + self.client = chromadb.PersistentClient(path=self.path) + embedding_function = ( + ef.SentenceTransformerEmbeddingFunction(self.embedding_model_name) + if self.embedding_function is None + else embedding_function + ) + self.collection = self.client.create_collection( + self.name, + get_or_create=not self.use_existing, + embedding_function=embedding_function, + # https://github.com/nmslib/hnswlib#supported-distances + # https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184 + # https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md + metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine + ) + + def ingest_data(self, data_dir): + """ + Create a vector database from a directory of files. + Args: + data_dir: path to the directory containing the text files + """ + if self.client is None: + self.init_db() + if self.custom_text_split_function is not None: + chunks = split_files_to_chunks( + get_files_from_dir(data_dir), custom_text_split_function=self.custom_text_split_function + ) + else: + chunks = split_files_to_chunks( + get_files_from_dir(data_dir), self.max_tokens, self.chunk_mode, self.must_break_at_empty_line + ) + print(f"Found {len(chunks)} chunks.") # + # Upsert in batch of 40000 or less if the total number of chunks is less than 40000 + for i in range(0, len(chunks), min(40000, len(chunks))): + end_idx = i + min(40000, len(chunks) - i) + self.collection.upsert( + documents=chunks[i:end_idx], + ids=[f"doc_{j}" for j in range(i, end_idx)], # unique for each doc + ) + + def query(self, texts: List[str], top_k: int = 10, filter: str = None): + if self.client is None: + self.init_db() + # the collection's embedding function is always the default one, but we want to use the one we used to create the + # collection. So we compute the embeddings ourselves and pass it to the query function. + embedding_function = ( + ef.SentenceTransformerEmbeddingFunction(self.embedding_model_name) if self.embedding_function is None else self.embedding_function + ) + query_embeddings = embedding_function(texts) + # Query/search n most similar results. You can also .get by id + results = self.collection.query( + query_embeddings=query_embeddings, + n_results=top_k, + where_document={"$contains": filter} if filter else None, # optional filter + ) + return results \ No newline at end of file diff --git a/autogen/retriever/lancedb.py b/autogen/retriever/lancedb.py new file mode 100644 index 00000000000..596c54189de --- /dev/null +++ b/autogen/retriever/lancedb.py @@ -0,0 +1,9 @@ +from typing import Callable, List +from .base import Retriever +try: + import lancedb +except ImportError: + raise ImportError("Please install lancedb: pip install lancedb") + +class LanceDB(Retriever): + pass \ No newline at end of file From 144ed87c0e015e0820f41b9e0808d2bdd6435568 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Fri, 10 Nov 2023 15:39:23 +0530 Subject: [PATCH 02/52] update --- .../contrib/retrieve_user_proxy_agent.py | 2 +- autogen/retrieve_utils.py | 136 ------------------ 2 files changed, 1 insertion(+), 137 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index e430e31f8ae..fc7f8c234c9 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -160,7 +160,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ) self._retrieve_config = {} if retrieve_config is None else retrieve_config - self._retriever_type = self._retrieve_config.get("retriever_type", "chromadb") + self._retriever_type = self._retrieve_config.get("retriever_type") self._task = self._retrieve_config.get("task", "default") self._client = self._retrieve_config.get("client", None) self._docs_path = self._retrieve_config.get("docs_path", None) diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index b98ba862d1a..675341411e3 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -3,14 +3,6 @@ import requests from urllib.parse import urlparse import glob -import chromadb - -if chromadb.__version__ < "0.4.15": - from chromadb.api import API -else: - from chromadb.api import ClientAPI as API -from chromadb.api.types import QueryResult -import chromadb.utils.embedding_functions as ef import logging import pypdf from autogen.token_count_utils import count_token @@ -216,131 +208,3 @@ def is_url(string: str): return all([result.scheme, result.netloc]) except ValueError: return False - - -def create_vector_db_from_dir( - dir_path: str, - max_tokens: int = 4000, - client: API = None, - db_path: str = "/tmp/chromadb.db", - collection_name: str = "all-my-documents", - get_or_create: bool = False, - chunk_mode: str = "multi_lines", - must_break_at_empty_line: bool = True, - embedding_model: str = "all-MiniLM-L6-v2", - embedding_function: Callable = None, - custom_text_split_function: Callable = None, -) -> API: - """Create a vector db from all the files in a given directory, the directory can also be a single file or a url to - a single file. We support chromadb compatible APIs to create the vector db, this function is not required if - you prepared your own vector db. - - Args: - dir_path (str): the path to the directory, file or url. - max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000. - client (Optional, API): the chromadb client. Default is None. - db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db". - collection_name (Optional, str): the name of the collection. Default is "all-my-documents". - get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection - will be recreated if it already exists. - chunk_mode (Optional, str): the chunk mode. Default is "multi_lines". - must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True. - embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if - embedding_function is not None. - embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with - the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding - functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. - - Returns: - API: the chromadb client. - """ - if client is None: - client = chromadb.PersistentClient(path=db_path) - try: - embedding_function = ( - ef.SentenceTransformerEmbeddingFunction(embedding_model) - if embedding_function is None - else embedding_function - ) - collection = client.create_collection( - collection_name, - get_or_create=get_or_create, - embedding_function=embedding_function, - # https://github.com/nmslib/hnswlib#supported-distances - # https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184 - # https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md - metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine - ) - - if custom_text_split_function is not None: - chunks = split_files_to_chunks( - get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function - ) - else: - chunks = split_files_to_chunks( - get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line - ) - logger.info(f"Found {len(chunks)} chunks.") - # Upsert in batch of 40000 or less if the total number of chunks is less than 40000 - for i in range(0, len(chunks), min(40000, len(chunks))): - end_idx = i + min(40000, len(chunks) - i) - collection.upsert( - documents=chunks[i:end_idx], - ids=[f"doc_{j}" for j in range(i, end_idx)], # unique for each doc - ) - except ValueError as e: - logger.warning(f"{e}") - return client - - -def query_vector_db( - query_texts: List[str], - n_results: int = 10, - client: API = None, - db_path: str = "/tmp/chromadb.db", - collection_name: str = "all-my-documents", - search_string: str = "", - embedding_model: str = "all-MiniLM-L6-v2", - embedding_function: Callable = None, -) -> QueryResult: - """Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db - and query function. - - Args: - query_texts (List[str]): the query texts. - n_results (Optional, int): the number of results to return. Default is 10. - client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used. - db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db". - collection_name (Optional, str): the name of the collection. Default is "all-my-documents". - search_string (Optional, str): the search string. Default is "". - embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if - embedding_function is not None. - embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with - the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding - functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. - - Returns: - QueryResult: the query result. The format is: - class QueryResult(TypedDict): - ids: List[IDs] - embeddings: Optional[List[List[Embedding]]] - documents: Optional[List[List[Document]]] - metadatas: Optional[List[List[Metadata]]] - distances: Optional[List[List[float]]] - """ - if client is None: - client = chromadb.PersistentClient(path=db_path) - # the collection's embedding function is always the default one, but we want to use the one we used to create the - # collection. So we compute the embeddings ourselves and pass it to the query function. - collection = client.get_collection(collection_name) - embedding_function = ( - ef.SentenceTransformerEmbeddingFunction(embedding_model) if embedding_function is None else embedding_function - ) - query_embeddings = embedding_function(query_texts) - # Query/search n most similar results. You can also .get by id - results = collection.query( - query_embeddings=query_embeddings, - n_results=n_results, - where_document={"$contains": search_string} if search_string else None, # optional filter - ) - return results From 4baf0ae20bf8cb464e2af0e50f10ff872e905e58 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Thu, 16 Nov 2023 16:44:50 +0530 Subject: [PATCH 03/52] update --- .../qdrant_retrieve_user_proxy_agent.py | 2 +- .../contrib/retrieve_user_proxy_agent.py | 1 - autogen/retriever/__init__.py | 7 +- autogen/retriever/chromadb.py | 6 +- autogen/retriever/lancedb.py | 94 ++++++++++++++++++- autogen/{ => retriever}/retrieve_utils.py | 0 6 files changed, 103 insertions(+), 7 deletions(-) rename autogen/{ => retriever}/retrieve_utils.py (100%) diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index e0bb8d8216f..502f7ea711f 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -1,7 +1,7 @@ from typing import Callable, Dict, List, Optional from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent -from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks +from autogen.autogen.retriever.retrieve_utils import get_files_from_dir, split_files_to_chunks import logging logger = logging.getLogger(__name__) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 07753dc10fa..728dc7161b6 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -2,7 +2,6 @@ from autogen.agentchat.agent import Agent from autogen.agentchat import UserProxyAgent -from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db from autogen.token_count_utils import count_token from autogen.code_utils import extract_code from autogen.retriever import get_retriever diff --git a/autogen/retriever/__init__.py b/autogen/retriever/__init__.py index 9f74bee10ab..4dadba1cd18 100644 --- a/autogen/retriever/__init__.py +++ b/autogen/retriever/__init__.py @@ -1,7 +1,12 @@ +from typing import Optional from .chromadb import ChromaDB from .lancedb import LanceDB -def get_retriever(type:str): +DEFAULT_RETRIEVER = "lancedb" + +def get_retriever(type: Optional[str]=None): + """Return a retriever instance.""" + type = type or DEFAULT_RETRIEVER if type == "chromadb": return ChromaDB elif type == "lancedb": diff --git a/autogen/retriever/chromadb.py b/autogen/retriever/chromadb.py index 8020cc3c0bb..83c7283a05d 100644 --- a/autogen/retriever/chromadb.py +++ b/autogen/retriever/chromadb.py @@ -1,6 +1,6 @@ -from typing import Callable, List +from typing import List from .base import Retriever -from autogen.retrieve_utils import ( +from autogen.retriever.retrieve_utils import ( split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, @@ -76,4 +76,4 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): n_results=top_k, where_document={"$contains": filter} if filter else None, # optional filter ) - return results \ No newline at end of file + return results diff --git a/autogen/retriever/lancedb.py b/autogen/retriever/lancedb.py index 596c54189de..374b57b84ad 100644 --- a/autogen/retriever/lancedb.py +++ b/autogen/retriever/lancedb.py @@ -1,9 +1,101 @@ from typing import Callable, List +from collections import defaultdict from .base import Retriever try: import lancedb + from lancedb.embeddings import get_registry, EmbeddingFunction, with_embeddings + from lancedb.pydantic import LanceModel, Vector + import pyarrow as pa except ImportError: raise ImportError("Please install lancedb: pip install lancedb") +from typing import List +from .base import Retriever +from autogen.retriever.retrieve_utils import ( + split_text_to_chunks, + extract_text_from_pdf, + split_files_to_chunks, + get_files_from_dir +) + + class LanceDB(Retriever): - pass \ No newline at end of file + db = None + def init_db(self): + if self.db is None: + self.db = lancedb.connect(self.path) + self.embedding_function = ( + get_registry().get("sentence-transformers").create(name=self.embedding_model_name) + if self.embedding_function is None + else self.embedding_function + ) + if self.use_existing and self.name in self.db.table_names(): + self.table = self.db.open_table(self.name) + else: + schema = self._get_schema(self.embedding_function) + self.table = self.db.create_table(self.name, schema=schema) + + def ingest_data(self, data_dir): + """ + Create a vector database from a directory of files. + Args: + data_dir: path to the directory containing the text files + """ + if self.client is None: + self.init_db() + if self.custom_text_split_function is not None: + chunks = split_files_to_chunks( + get_files_from_dir(data_dir), custom_text_split_function=self.custom_text_split_function + ) + else: + chunks = split_files_to_chunks( + get_files_from_dir(data_dir), self.max_tokens, self.chunk_mode, self.must_break_at_empty_line + ) + print(f"Found {len(chunks)} chunks.") # + data = [ {"documents": docs, "ids": idx } for idx, docs in enumerate(chunks) ] + if isinstance(self.embedding_function, EmbeddingFunction): # this means we are using embedding API + self.table.add(data) + elif isinstance(self.embedding_function, Callable): + pa_table = pa.Table.from_pylist(data) + data = with_embeddings(self.embedding_function, pa_table) + self.table.add(data) + + + def query(self, texts: List[str], top_k: int = 10, filter: str = None): + if self.client is None: + self.init_db() + texts = [texts] if isinstance(texts, str) else texts + results = defaultdict(list) + for text in texts: + query = self.embedding_function(text) if isinstance(self.embedding_function, Callable) else text + print("query: ", query) + result = self.table.search(query).where(f"documents LIKE '%{filter}%'").limit(top_k).to_arrow().to_pydict() + for k, v in result.items(): + results[k].append(v) + + return results + + def _get_schema(self, embedding_function): + if isinstance(embedding_function, EmbeddingFunction): + class Schema(LanceModel): + vector: Vector(embedding_function.ndims()) = embedding_function.VectorField() + documents: str = embedding_function.SourceField() + ids: str + + return Schema + elif isinstance(embedding_function, Callable): + dim = embedding_function("test").shape[0] # TODO: check this + schema = pa.schema( + [ + pa.field("Vector", pa.list_(pa.float32(), dim)), + pa.field("documents", pa.string()), + pa.field("ids", pa.string()), + ] + ) + return schema + else: + raise ValueError( + "embedding_function should be a callable or an EmbeddingFunction instance" + ) + + diff --git a/autogen/retrieve_utils.py b/autogen/retriever/retrieve_utils.py similarity index 100% rename from autogen/retrieve_utils.py rename to autogen/retriever/retrieve_utils.py From 3f547d466b6d3012476345e428762cd8e6a34b23 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Fri, 17 Nov 2023 15:44:23 +0530 Subject: [PATCH 04/52] add tests --- .../contrib/retrieve_user_proxy_agent.py | 2 +- .../contrib}/retriever/__init__.py | 1 + .../{ => agentchat/contrib}/retriever/base.py | 2 +- .../contrib}/retriever/chromadb.py | 4 +- .../contrib}/retriever/lancedb.py | 10 +- .../contrib}/retriever/retrieve_utils.py | 0 .../contrib/retrievers/test_chromadb.py | 33 +++++ .../contrib/retrievers/test_lancedb.py | 33 +++++ test/test_retrieve_utils.py | 133 ++---------------- 9 files changed, 94 insertions(+), 124 deletions(-) rename autogen/{ => agentchat/contrib}/retriever/__init__.py (89%) rename autogen/{ => agentchat/contrib}/retriever/base.py (98%) rename autogen/{ => agentchat/contrib}/retriever/chromadb.py (97%) rename autogen/{ => agentchat/contrib}/retriever/lancedb.py (90%) rename autogen/{ => agentchat/contrib}/retriever/retrieve_utils.py (100%) create mode 100644 test/agentchat/contrib/retrievers/test_chromadb.py create mode 100644 test/agentchat/contrib/retrievers/test_lancedb.py diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 728dc7161b6..df9a2f53c17 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -4,7 +4,7 @@ from autogen.agentchat import UserProxyAgent from autogen.token_count_utils import count_token from autogen.code_utils import extract_code -from autogen.retriever import get_retriever +from autogen.agentchat.contrib.retriever import get_retriever from typing import Callable, Dict, Optional, Union, List, Tuple, Any from IPython import get_ipython diff --git a/autogen/retriever/__init__.py b/autogen/agentchat/contrib/retriever/__init__.py similarity index 89% rename from autogen/retriever/__init__.py rename to autogen/agentchat/contrib/retriever/__init__.py index 4dadba1cd18..e2b55785914 100644 --- a/autogen/retriever/__init__.py +++ b/autogen/agentchat/contrib/retriever/__init__.py @@ -2,6 +2,7 @@ from .chromadb import ChromaDB from .lancedb import LanceDB +AVILABLE_RETRIEVERS = ["lanchedb", "chromadb"] DEFAULT_RETRIEVER = "lancedb" def get_retriever(type: Optional[str]=None): diff --git a/autogen/retriever/base.py b/autogen/agentchat/contrib/retriever/base.py similarity index 98% rename from autogen/retriever/base.py rename to autogen/agentchat/contrib/retriever/base.py index 20133f5ff58..2e886e83f13 100644 --- a/autogen/retriever/base.py +++ b/autogen/agentchat/contrib/retriever/base.py @@ -10,7 +10,7 @@ def __init__(self, path="./db", chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True, custom_text_split_function: Callable = None, - use_existing=True, + use_existing=False, client=None ): """ diff --git a/autogen/retriever/chromadb.py b/autogen/agentchat/contrib/retriever/chromadb.py similarity index 97% rename from autogen/retriever/chromadb.py rename to autogen/agentchat/contrib/retriever/chromadb.py index 83c7283a05d..4ba27be0b95 100644 --- a/autogen/retriever/chromadb.py +++ b/autogen/agentchat/contrib/retriever/chromadb.py @@ -1,6 +1,6 @@ from typing import List from .base import Retriever -from autogen.retriever.retrieve_utils import ( +from .retrieve_utils import ( split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, @@ -28,7 +28,7 @@ def init_db(self): ) self.collection = self.client.create_collection( self.name, - get_or_create=not self.use_existing, + get_or_create=self.use_existing, embedding_function=embedding_function, # https://github.com/nmslib/hnswlib#supported-distances # https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184 diff --git a/autogen/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py similarity index 90% rename from autogen/retriever/lancedb.py rename to autogen/agentchat/contrib/retriever/lancedb.py index 374b57b84ad..31dd846ddbb 100644 --- a/autogen/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -11,7 +11,8 @@ from typing import List from .base import Retriever -from autogen.retriever.retrieve_utils import ( +from autogen import logger +from .retrieve_utils import ( split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, @@ -31,7 +32,9 @@ def init_db(self): ) if self.use_existing and self.name in self.db.table_names(): self.table = self.db.open_table(self.name) + logger.info(f"Reusing existing table {self.name}") else: + logger.info(f"Creating new table {self.name}") schema = self._get_schema(self.embedding_function) self.table = self.db.create_table(self.name, schema=schema) @@ -69,7 +72,10 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): for text in texts: query = self.embedding_function(text) if isinstance(self.embedding_function, Callable) else text print("query: ", query) - result = self.table.search(query).where(f"documents LIKE '%{filter}%'").limit(top_k).to_arrow().to_pydict() + result = self.table.search(query) + if filter is not None: + result = result.where(f"documents LIKE '%{filter}%'") + result = result.limit(top_k).to_arrow().to_pydict() for k, v in result.items(): results[k].append(v) diff --git a/autogen/retriever/retrieve_utils.py b/autogen/agentchat/contrib/retriever/retrieve_utils.py similarity index 100% rename from autogen/retriever/retrieve_utils.py rename to autogen/agentchat/contrib/retriever/retrieve_utils.py diff --git a/test/agentchat/contrib/retrievers/test_chromadb.py b/test/agentchat/contrib/retrievers/test_chromadb.py new file mode 100644 index 00000000000..d6b9dce11b9 --- /dev/null +++ b/test/agentchat/contrib/retrievers/test_chromadb.py @@ -0,0 +1,33 @@ +import os +import pytest +from autogen.agentchat.contrib.retriever.retrieve_utils import ( + split_text_to_chunks, + extract_text_from_pdf, + split_files_to_chunks, + get_files_from_dir, + is_url, +) +from autogen.agentchat.contrib.retriever.chromadb import ChromaDB +try: + import chromadb +except ImportError: + skip = True +else: + skip = False + +test_dir = os.path.join(os.path.dirname(__file__), "test_files") + +@pytest.mark.skipif(skip, reason="chromadb is not installed") +def test_chromadb(): + db_path = "/tmp/test_retrieve_utils_chromadb.db" + client = chromadb.PersistentClient(path=db_path) + if os.path.exists(db_path): + vectorstore = ChromaDB(path=db_path, use_existing=True) + else: + vectorstore = ChromaDB(path=db_path) + vectorstore.ingest_data(test_dir) + + assert client.get_collection("vectorstore") + + results = vectorstore.query(["autogen"]) + assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) diff --git a/test/agentchat/contrib/retrievers/test_lancedb.py b/test/agentchat/contrib/retrievers/test_lancedb.py new file mode 100644 index 00000000000..dee605df5f9 --- /dev/null +++ b/test/agentchat/contrib/retrievers/test_lancedb.py @@ -0,0 +1,33 @@ +import os +import pytest +from autogen.agentchat.contrib.retriever.retrieve_utils import ( + split_text_to_chunks, + extract_text_from_pdf, + split_files_to_chunks, + get_files_from_dir, + is_url, +) +from autogen.agentchat.contrib.retriever.lancedb import LanceDB +try: + import lancedb +except ImportError: + skip = True +else: + skip = False + +test_dir = os.path.join(os.path.dirname(__file__), "test_files") + +@pytest.mark.skipif(skip, reason="lancedb is not installed") +def test_lancedb(): + db_path = "/tmp/test_lancedb_store" + db = lancedb.connect(db_path) + if os.path.exists(db_path): + vectorstore = LanceDB(path=db_path, use_existing=True) + else: + vectorstore = LanceDB(path=db_path) + vectorstore.ingest_data(test_dir) + + assert "vectorstore" in db.table_names() + + results = vectorstore.query(["autogen"]) + assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index b85356ef491..98fc9576746 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -1,25 +1,25 @@ """ Unit test for retrieve_utils.py """ +import os +import sys +import pytest + try: - import chromadb - from autogen.retrieve_utils import ( + from autogen.agentchat.contrib.retriever.retrieve_utils import ( split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, get_files_from_dir, is_url, - create_vector_db_from_dir, - query_vector_db, ) + from autogen.agentchat.contrib.retriever import DEFAULT_RETRIEVER, get_retriever from autogen.token_count_utils import count_token + Retriever = get_retriever(DEFAULT_RETRIEVER) except ImportError: skip = True else: skip = False -import os -import sys -import pytest try: from unstructured.partition.auto import partition @@ -71,127 +71,24 @@ def test_is_url(self): assert is_url("https://www.example.com") assert not is_url("not_a_url") - def test_create_vector_db_from_dir(self): - db_path = "/tmp/test_retrieve_utils_chromadb.db" - if os.path.exists(db_path): - client = chromadb.PersistentClient(path=db_path) - else: - client = chromadb.PersistentClient(path=db_path) - create_vector_db_from_dir(test_dir, client=client) - - assert client.get_collection("all-my-documents") - - def test_query_vector_db(self): - db_path = "/tmp/test_retrieve_utils_chromadb.db" - if os.path.exists(db_path): - client = chromadb.PersistentClient(path=db_path) - else: # If the database does not exist, create it first - client = chromadb.PersistentClient(path=db_path) - create_vector_db_from_dir(test_dir, client=client) - - results = query_vector_db(["autogen"], client=client) - assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) - - def test_custom_vector_db(self): - try: - import lancedb - except ImportError: - return - from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent - - db_path = "/tmp/lancedb" - - def create_lancedb(): - db = lancedb.connect(db_path) - data = [ - {"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"}, - {"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"}, - {"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"}, - {"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"}, - {"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"}, - {"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"}, - ] - try: - db.create_table("my_table", data) - except OSError: - pass - - class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): - def query_vector_db( - self, - query_texts, - n_results=10, - search_string="", - ): - if query_texts: - vector = [0.1, 0.3] - db = lancedb.connect(db_path) - table = db.open_table("my_table") - query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() - return {"ids": [query["id"].tolist()], "documents": [query["documents"].tolist()]} - - def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): - results = self.query_vector_db( - query_texts=[problem], - n_results=n_results, - search_string=search_string, - ) - - self._results = results - print("doc_ids: ", results["ids"]) - - ragragproxyagent = MyRetrieveUserProxyAgent( - name="ragproxyagent", - human_input_mode="NEVER", - max_consecutive_auto_reply=2, - retrieve_config={ - "task": "qa", - "chunk_token_size": 2000, - "client": "__", - "embedding_model": "all-mpnet-base-v2", - }, - ) - - create_lancedb() - ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") - assert ragragproxyagent._results["ids"] == [[3, 1, 5]] def test_custom_text_split_function(self): def custom_text_split_function(text): return [text[: len(text) // 2], text[len(text) // 2 :]] - db_path = "/tmp/test_retrieve_utils_chromadb.db" - client = chromadb.PersistentClient(path=db_path) - create_vector_db_from_dir( - os.path.join(test_dir, "example.txt"), - client=client, - collection_name="mytestcollection", - custom_text_split_function=custom_text_split_function, - get_or_create=True, - ) - results = query_vector_db(["autogen"], client=client, collection_name="mytestcollection", n_results=1) + db_path = "/tmp/test_retrieve_utils" + retriever = Retriever(path=db_path, name="mytestcollection", custom_text_split_function=custom_text_split_function, use_existing=False) + retriever.ingest_data( os.path.join(test_dir, "example.txt")) + results = retriever.query(["autogen"], top_k=1) assert ( "AutoGen is an advanced tool designed to assist developers in harnessing the capabilities" in results.get("documents")[0][0] ) def test_retrieve_utils(self): - client = chromadb.PersistentClient(path="/tmp/chromadb") - create_vector_db_from_dir( - dir_path="./website/docs", - client=client, - collection_name="autogen-docs", - get_or_create=True, - ) - results = query_vector_db( - query_texts=[ - "How can I use AutoGen UserProxyAgent and AssistantAgent to do code generation?", - ], - n_results=4, - client=client, - collection_name="autogen-docs", - search_string="AutoGen", - ) + retriever = Retriever(path="/tmp/chromadb", name="autogen-docs", use_existing=False) + retriever.ingest_data("./website/docs") + results = retriever.query(["autogen"], top_k=4, filter="AutoGen") print(results["ids"][0]) assert len(results["ids"][0]) == 4 @@ -208,7 +105,7 @@ def test_unstructured(self): isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip() for chunk in chunks ) - + if __name__ == "__main__": pytest.main() From 77aa60f754c3a0afaca435ecaf7496355717de01 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Fri, 17 Nov 2023 20:44:36 +0530 Subject: [PATCH 05/52] update tests --- autogen/agentchat/contrib/retriever/lancedb.py | 6 +++--- test/test_retrieve_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index 31dd846ddbb..982309997f0 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -36,7 +36,7 @@ def init_db(self): else: logger.info(f"Creating new table {self.name}") schema = self._get_schema(self.embedding_function) - self.table = self.db.create_table(self.name, schema=schema) + self.table = self.db.create_table(self.name, schema=schema, mode="overwrite") def ingest_data(self, data_dir): """ @@ -44,7 +44,7 @@ def ingest_data(self, data_dir): Args: data_dir: path to the directory containing the text files """ - if self.client is None: + if self.db is None: self.init_db() if self.custom_text_split_function is not None: chunks = split_files_to_chunks( @@ -65,7 +65,7 @@ def ingest_data(self, data_dir): def query(self, texts: List[str], top_k: int = 10, filter: str = None): - if self.client is None: + if self.db is None: self.init_db() texts = [texts] if isinstance(texts, str) else texts results = defaultdict(list) diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index 98fc9576746..39e58e330dd 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -78,7 +78,7 @@ def custom_text_split_function(text): db_path = "/tmp/test_retrieve_utils" retriever = Retriever(path=db_path, name="mytestcollection", custom_text_split_function=custom_text_split_function, use_existing=False) - retriever.ingest_data( os.path.join(test_dir, "example.txt")) + retriever.ingest_data(os.path.join(test_dir, "example.txt")) results = retriever.query(["autogen"], top_k=1) assert ( "AutoGen is an advanced tool designed to assist developers in harnessing the capabilities" From f400fcbcfef2dcf2155498cf4bbf5146e1cf7632 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Fri, 17 Nov 2023 21:06:35 +0530 Subject: [PATCH 06/52] format --- .../contrib/retrieve_user_proxy_agent.py | 22 ++++++------ .../agentchat/contrib/retriever/__init__.py | 5 +-- autogen/agentchat/contrib/retriever/base.py | 34 ++++++++++--------- .../agentchat/contrib/retriever/chromadb.py | 20 +++++------ .../agentchat/contrib/retriever/lancedb.py | 29 ++++++---------- 5 files changed, 53 insertions(+), 57 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index df9a2f53c17..b644473d8f9 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -341,7 +341,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = """Retrieve docs based on the given problem and assign the results to the class property `_results`. Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of - the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. + the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. ids: List[string] documents: List[List[string]] @@ -354,16 +354,16 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = print("Trying to create collection.") retriever_class = get_retriever(self._retriever_type) self.retriever = retriever_class( - name=self._collection_name, - embedding_model_name=self._embedding_model, - embedding_function=self._embedding_function, - max_tokens= self._chunk_token_size, - chunk_mode = self._chunk_mode, - must_break_at_empty_line = self._must_break_at_empty_line, - custom_text_split_function = self.custom_text_split_function, - use_existing=not self._get_or_create, - client=self._client - ) + name=self._collection_name, + embedding_model_name=self._embedding_model, + embedding_function=self._embedding_function, + max_tokens=self._chunk_token_size, + chunk_mode=self._chunk_mode, + must_break_at_empty_line=self._must_break_at_empty_line, + custom_text_split_function=self.custom_text_split_function, + use_existing=not self._get_or_create, + client=self._client, + ) self._collection = True self._get_or_create = False self.retriever.ingest_data(self._docs_path) diff --git a/autogen/agentchat/contrib/retriever/__init__.py b/autogen/agentchat/contrib/retriever/__init__.py index e2b55785914..9bfb09598a6 100644 --- a/autogen/agentchat/contrib/retriever/__init__.py +++ b/autogen/agentchat/contrib/retriever/__init__.py @@ -5,7 +5,8 @@ AVILABLE_RETRIEVERS = ["lanchedb", "chromadb"] DEFAULT_RETRIEVER = "lancedb" -def get_retriever(type: Optional[str]=None): + +def get_retriever(type: Optional[str] = None): """Return a retriever instance.""" type = type or DEFAULT_RETRIEVER if type == "chromadb": @@ -13,4 +14,4 @@ def get_retriever(type: Optional[str]=None): elif type == "lancedb": return LanceDB else: - raise ValueError(f"Unknown retriever type {type}") \ No newline at end of file + raise ValueError(f"Unknown retriever type {type}") diff --git a/autogen/agentchat/contrib/retriever/base.py b/autogen/agentchat/contrib/retriever/base.py index 2e886e83f13..4c5dfd69f44 100644 --- a/autogen/agentchat/contrib/retriever/base.py +++ b/autogen/agentchat/contrib/retriever/base.py @@ -1,18 +1,21 @@ from abc import ABC, abstractmethod from typing import List, Union, Callable, Any + class Retriever(ABC): - def __init__(self, path="./db", - name="vectorstore", - embedding_model_name="all-MiniLM-L6-v2", - embedding_function=None, - max_tokens: int = 4000, - chunk_mode: str = "multi_lines", - must_break_at_empty_line: bool = True, - custom_text_split_function: Callable = None, - use_existing=False, - client=None - ): + def __init__( + self, + path="./db", + name="vectorstore", + embedding_model_name="all-MiniLM-L6-v2", + embedding_function=None, + max_tokens: int = 4000, + chunk_mode: str = "multi_lines", + must_break_at_empty_line: bool = True, + custom_text_split_function: Callable = None, + use_existing=False, + client=None, + ): """ Args: path: path to the folder where the database is stored @@ -37,7 +40,6 @@ def __init__(self, path="./db", self.init_db() - @abstractmethod def ingest_data(self, data_dir): """ @@ -46,9 +48,9 @@ def ingest_data(self, data_dir): data_dir: path to the directory containing the text files """ pass - + @abstractmethod - def query(self, texts: List[str], top_k: int = 10, filter: Any=None): + def query(self, texts: List[str], top_k: int = 10, filter: Any = None): """ Query the database. Args: @@ -56,10 +58,10 @@ def query(self, texts: List[str], top_k: int = 10, filter: Any=None): top_k: number of results to return """ pass - + @abstractmethod def init_db(self): """ Initialize the database. """ - pass \ No newline at end of file + pass diff --git a/autogen/agentchat/contrib/retriever/chromadb.py b/autogen/agentchat/contrib/retriever/chromadb.py index 4ba27be0b95..4be8cdc6ab2 100644 --- a/autogen/agentchat/contrib/retriever/chromadb.py +++ b/autogen/agentchat/contrib/retriever/chromadb.py @@ -1,13 +1,10 @@ from typing import List from .base import Retriever -from .retrieve_utils import ( - split_text_to_chunks, - extract_text_from_pdf, - split_files_to_chunks, - get_files_from_dir -) +from .retrieve_utils import split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, get_files_from_dir + try: import chromadb + if chromadb.__version__ < "0.4.15": from chromadb.api import API else: @@ -17,6 +14,7 @@ except ImportError: raise ImportError("Please install chromadb: pip install chromadb") + class ChromaDB(Retriever): def init_db(self): if self.client is None: @@ -35,7 +33,7 @@ def init_db(self): # https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine ) - + def ingest_data(self, data_dir): """ Create a vector database from a directory of files. @@ -52,14 +50,14 @@ def ingest_data(self, data_dir): chunks = split_files_to_chunks( get_files_from_dir(data_dir), self.max_tokens, self.chunk_mode, self.must_break_at_empty_line ) - print(f"Found {len(chunks)} chunks.") # + print(f"Found {len(chunks)} chunks.") # # Upsert in batch of 40000 or less if the total number of chunks is less than 40000 for i in range(0, len(chunks), min(40000, len(chunks))): end_idx = i + min(40000, len(chunks) - i) self.collection.upsert( documents=chunks[i:end_idx], ids=[f"doc_{j}" for j in range(i, end_idx)], # unique for each doc - ) + ) def query(self, texts: List[str], top_k: int = 10, filter: str = None): if self.client is None: @@ -67,7 +65,9 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): # the collection's embedding function is always the default one, but we want to use the one we used to create the # collection. So we compute the embeddings ourselves and pass it to the query function. embedding_function = ( - ef.SentenceTransformerEmbeddingFunction(self.embedding_model_name) if self.embedding_function is None else self.embedding_function + ef.SentenceTransformerEmbeddingFunction(self.embedding_model_name) + if self.embedding_function is None + else self.embedding_function ) query_embeddings = embedding_function(texts) # Query/search n most similar results. You can also .get by id diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index 982309997f0..cf19dd74f8e 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -1,6 +1,7 @@ from typing import Callable, List from collections import defaultdict from .base import Retriever + try: import lancedb from lancedb.embeddings import get_registry, EmbeddingFunction, with_embeddings @@ -12,16 +13,12 @@ from typing import List from .base import Retriever from autogen import logger -from .retrieve_utils import ( - split_text_to_chunks, - extract_text_from_pdf, - split_files_to_chunks, - get_files_from_dir -) +from .retrieve_utils import split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, get_files_from_dir class LanceDB(Retriever): db = None + def init_db(self): if self.db is None: self.db = lancedb.connect(self.path) @@ -37,7 +34,7 @@ def init_db(self): logger.info(f"Creating new table {self.name}") schema = self._get_schema(self.embedding_function) self.table = self.db.create_table(self.name, schema=schema, mode="overwrite") - + def ingest_data(self, data_dir): """ Create a vector database from a directory of files. @@ -54,16 +51,15 @@ def ingest_data(self, data_dir): chunks = split_files_to_chunks( get_files_from_dir(data_dir), self.max_tokens, self.chunk_mode, self.must_break_at_empty_line ) - print(f"Found {len(chunks)} chunks.") # - data = [ {"documents": docs, "ids": idx } for idx, docs in enumerate(chunks) ] - if isinstance(self.embedding_function, EmbeddingFunction): # this means we are using embedding API + print(f"Found {len(chunks)} chunks.") # + data = [{"documents": docs, "ids": idx} for idx, docs in enumerate(chunks)] + if isinstance(self.embedding_function, EmbeddingFunction): # this means we are using embedding API self.table.add(data) elif isinstance(self.embedding_function, Callable): pa_table = pa.Table.from_pylist(data) data = with_embeddings(self.embedding_function, pa_table) self.table.add(data) - def query(self, texts: List[str], top_k: int = 10, filter: str = None): if self.db is None: self.init_db() @@ -78,11 +74,12 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): result = result.limit(top_k).to_arrow().to_pydict() for k, v in result.items(): results[k].append(v) - + return results def _get_schema(self, embedding_function): if isinstance(embedding_function, EmbeddingFunction): + class Schema(LanceModel): vector: Vector(embedding_function.ndims()) = embedding_function.VectorField() documents: str = embedding_function.SourceField() @@ -90,7 +87,7 @@ class Schema(LanceModel): return Schema elif isinstance(embedding_function, Callable): - dim = embedding_function("test").shape[0] # TODO: check this + dim = embedding_function("test").shape[0] # TODO: check this schema = pa.schema( [ pa.field("Vector", pa.list_(pa.float32(), dim)), @@ -100,8 +97,4 @@ class Schema(LanceModel): ) return schema else: - raise ValueError( - "embedding_function should be a callable or an EmbeddingFunction instance" - ) - - + raise ValueError("embedding_function should be a callable or an EmbeddingFunction instance") From 65ad434b3098532b6f48d25d98e0963f6ae53bed Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 20 Nov 2023 13:34:11 +0530 Subject: [PATCH 07/52] update --- autogen/agentchat/__init__.py | 2 ++ autogen/agentchat/contrib/retriever/lancedb.py | 7 +++---- autogen/retrieve_utils.py | 4 ++++ setup.py | 4 ++-- test/agentchat/contrib/retrievers/test_chromadb.py | 2 +- test/agentchat/contrib/retrievers/test_lancedb.py | 2 +- 6 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 autogen/retrieve_utils.py diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py index 3db1db73a55..7c7a256dca1 100644 --- a/autogen/agentchat/__init__.py +++ b/autogen/agentchat/__init__.py @@ -3,6 +3,7 @@ from .conversable_agent import ConversableAgent from .groupchat import GroupChat, GroupChatManager from .user_proxy_agent import UserProxyAgent +from .contrib.retriever import retrieve_utils __all__ = [ "Agent", @@ -11,4 +12,5 @@ "UserProxyAgent", "GroupChat", "GroupChatManager", + "retrieve_utils" ] diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index cf19dd74f8e..41a2d21b56a 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -12,7 +12,6 @@ from typing import List from .base import Retriever -from autogen import logger from .retrieve_utils import split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, get_files_from_dir @@ -23,15 +22,15 @@ def init_db(self): if self.db is None: self.db = lancedb.connect(self.path) self.embedding_function = ( - get_registry().get("sentence-transformers").create(name=self.embedding_model_name) + get_registry().get("sentence-transformers").create(name=self.embedding_model_name, show_progress_bar=True) if self.embedding_function is None else self.embedding_function ) if self.use_existing and self.name in self.db.table_names(): self.table = self.db.open_table(self.name) - logger.info(f"Reusing existing table {self.name}") + #logger.info(f"Reusing existing table {self.name}") else: - logger.info(f"Creating new table {self.name}") + #logger.info(f"Creating new table {self.name}") schema = self._get_schema(self.embedding_function) self.table = self.db.create_table(self.name, schema=schema, mode="overwrite") diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py new file mode 100644 index 00000000000..3b94a2b2814 --- /dev/null +++ b/autogen/retrieve_utils.py @@ -0,0 +1,4 @@ +from . import logger +from .agentchat.contrib.retriever.retrieve_utils import * + +logger.warning("This module is deprecated. Please use autogen.agentchat.contrib.retriever.retrieve_utils instead.") \ No newline at end of file diff --git a/setup.py b/setup.py index ab7ab28ade1..479dc599c27 100644 --- a/setup.py +++ b/setup.py @@ -50,8 +50,8 @@ ], "blendsearch": ["flaml[blendsearch]"], "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"], - "retrievechat": ["chromadb", "sentence_transformers", "pypdf", "ipython"], - "teachable": ["chromadb"], + "retrievechat": ["chromadb", "lancedb", "sentence_transformers", "pypdf", "ipython"], + "teachable": ["chromadb", "lancedb"], "lmm": ["replicate", "pillow"], }, classifiers=[ diff --git a/test/agentchat/contrib/retrievers/test_chromadb.py b/test/agentchat/contrib/retrievers/test_chromadb.py index d6b9dce11b9..735e38cd74f 100644 --- a/test/agentchat/contrib/retrievers/test_chromadb.py +++ b/test/agentchat/contrib/retrievers/test_chromadb.py @@ -7,8 +7,8 @@ get_files_from_dir, is_url, ) -from autogen.agentchat.contrib.retriever.chromadb import ChromaDB try: + from autogen.agentchat.contrib.retriever.chromadb import ChromaDB import chromadb except ImportError: skip = True diff --git a/test/agentchat/contrib/retrievers/test_lancedb.py b/test/agentchat/contrib/retrievers/test_lancedb.py index dee605df5f9..64f9b1efe07 100644 --- a/test/agentchat/contrib/retrievers/test_lancedb.py +++ b/test/agentchat/contrib/retrievers/test_lancedb.py @@ -7,8 +7,8 @@ get_files_from_dir, is_url, ) -from autogen.agentchat.contrib.retriever.lancedb import LanceDB try: + from autogen.agentchat.contrib.retriever.lancedb import LanceDB import lancedb except ImportError: skip = True From 369de5390e5d99b8e7953ff7a907509ea4152b04 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 20 Nov 2023 13:49:23 +0530 Subject: [PATCH 08/52] update --- autogen/agentchat/contrib/retriever/retrieve_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/retriever/retrieve_utils.py b/autogen/agentchat/contrib/retriever/retrieve_utils.py index 675341411e3..a254d2f0294 100644 --- a/autogen/agentchat/contrib/retriever/retrieve_utils.py +++ b/autogen/agentchat/contrib/retriever/retrieve_utils.py @@ -4,7 +4,6 @@ from urllib.parse import urlparse import glob import logging -import pypdf from autogen.token_count_utils import count_token try: @@ -91,6 +90,8 @@ def split_text_to_chunks( def extract_text_from_pdf(file: str) -> str: """Extract text from PDF files""" + import pypdf # optional dependency + text = "" with open(file, "rb") as f: reader = pypdf.PdfReader(f) From af4b2479af349a37f62b57ab91dee7e7c6d89622 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 20 Nov 2023 13:52:28 +0530 Subject: [PATCH 09/52] update --- autogen/agentchat/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py index 7c7a256dca1..3db1db73a55 100644 --- a/autogen/agentchat/__init__.py +++ b/autogen/agentchat/__init__.py @@ -3,7 +3,6 @@ from .conversable_agent import ConversableAgent from .groupchat import GroupChat, GroupChatManager from .user_proxy_agent import UserProxyAgent -from .contrib.retriever import retrieve_utils __all__ = [ "Agent", @@ -12,5 +11,4 @@ "UserProxyAgent", "GroupChat", "GroupChatManager", - "retrieve_utils" ] From 5b9a43ed79146af1c78209e9801d22b06056656c Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 20 Nov 2023 14:18:41 +0530 Subject: [PATCH 10/52] update --- autogen/agentchat/contrib/retriever/lancedb.py | 4 ++-- autogen/agentchat/contrib/retriever/retrieve_utils.py | 2 +- autogen/retrieve_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index 41a2d21b56a..00417f806b7 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -28,9 +28,9 @@ def init_db(self): ) if self.use_existing and self.name in self.db.table_names(): self.table = self.db.open_table(self.name) - #logger.info(f"Reusing existing table {self.name}") + # logger.info(f"Reusing existing table {self.name}") else: - #logger.info(f"Creating new table {self.name}") + # logger.info(f"Creating new table {self.name}") schema = self._get_schema(self.embedding_function) self.table = self.db.create_table(self.name, schema=schema, mode="overwrite") diff --git a/autogen/agentchat/contrib/retriever/retrieve_utils.py b/autogen/agentchat/contrib/retriever/retrieve_utils.py index a254d2f0294..b02ab66d6a3 100644 --- a/autogen/agentchat/contrib/retriever/retrieve_utils.py +++ b/autogen/agentchat/contrib/retriever/retrieve_utils.py @@ -90,7 +90,7 @@ def split_text_to_chunks( def extract_text_from_pdf(file: str) -> str: """Extract text from PDF files""" - import pypdf # optional dependency + import pypdf # optional dependency text = "" with open(file, "rb") as f: diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index 3b94a2b2814..6e92a938c27 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -1,4 +1,4 @@ from . import logger from .agentchat.contrib.retriever.retrieve_utils import * -logger.warning("This module is deprecated. Please use autogen.agentchat.contrib.retriever.retrieve_utils instead.") \ No newline at end of file +logger.warning("This module is deprecated. Please use autogen.agentchat.contrib.retriever.retrieve_utils instead.") From 9978196e8dc4a891590727d90a05240220264d14 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 20 Nov 2023 14:31:11 +0530 Subject: [PATCH 11/52] update --- autogen/agentchat/contrib/retriever/chromadb.py | 2 +- autogen/agentchat/contrib/retriever/lancedb.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/autogen/agentchat/contrib/retriever/chromadb.py b/autogen/agentchat/contrib/retriever/chromadb.py index 4be8cdc6ab2..05b0951970f 100644 --- a/autogen/agentchat/contrib/retriever/chromadb.py +++ b/autogen/agentchat/contrib/retriever/chromadb.py @@ -22,7 +22,7 @@ def init_db(self): embedding_function = ( ef.SentenceTransformerEmbeddingFunction(self.embedding_model_name) if self.embedding_function is None - else embedding_function + else self.embedding_function ) self.collection = self.client.create_collection( self.name, diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index 00417f806b7..cf7af621f1b 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -1,6 +1,7 @@ from typing import Callable, List from collections import defaultdict from .base import Retriever +from .retrieve_utils import split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, get_files_from_dir try: import lancedb @@ -10,10 +11,6 @@ except ImportError: raise ImportError("Please install lancedb: pip install lancedb") -from typing import List -from .base import Retriever -from .retrieve_utils import split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, get_files_from_dir - class LanceDB(Retriever): db = None From f2739cfb2b05d0b5f4bd2269cae707a39be21a65 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 20 Nov 2023 20:49:15 +0530 Subject: [PATCH 12/52] update --- test/agentchat/contrib/retrievers/test_chromadb.py | 8 +++++--- test/agentchat/contrib/retrievers/test_lancedb.py | 8 +++++--- test/test_retrieve_utils.py | 11 ++++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/test/agentchat/contrib/retrievers/test_chromadb.py b/test/agentchat/contrib/retrievers/test_chromadb.py index 735e38cd74f..f7325051725 100644 --- a/test/agentchat/contrib/retrievers/test_chromadb.py +++ b/test/agentchat/contrib/retrievers/test_chromadb.py @@ -7,6 +7,7 @@ get_files_from_dir, is_url, ) + try: from autogen.agentchat.contrib.retriever.chromadb import ChromaDB import chromadb @@ -14,9 +15,10 @@ skip = True else: skip = False - + test_dir = os.path.join(os.path.dirname(__file__), "test_files") + @pytest.mark.skipif(skip, reason="chromadb is not installed") def test_chromadb(): db_path = "/tmp/test_retrieve_utils_chromadb.db" @@ -26,8 +28,8 @@ def test_chromadb(): else: vectorstore = ChromaDB(path=db_path) vectorstore.ingest_data(test_dir) - + assert client.get_collection("vectorstore") - + results = vectorstore.query(["autogen"]) assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) diff --git a/test/agentchat/contrib/retrievers/test_lancedb.py b/test/agentchat/contrib/retrievers/test_lancedb.py index 64f9b1efe07..5eb82eab041 100644 --- a/test/agentchat/contrib/retrievers/test_lancedb.py +++ b/test/agentchat/contrib/retrievers/test_lancedb.py @@ -7,6 +7,7 @@ get_files_from_dir, is_url, ) + try: from autogen.agentchat.contrib.retriever.lancedb import LanceDB import lancedb @@ -14,9 +15,10 @@ skip = True else: skip = False - + test_dir = os.path.join(os.path.dirname(__file__), "test_files") + @pytest.mark.skipif(skip, reason="lancedb is not installed") def test_lancedb(): db_path = "/tmp/test_lancedb_store" @@ -26,8 +28,8 @@ def test_lancedb(): else: vectorstore = LanceDB(path=db_path) vectorstore.ingest_data(test_dir) - + assert "vectorstore" in db.table_names() - + results = vectorstore.query(["autogen"]) assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index 39e58e330dd..5c4ca131240 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -15,6 +15,7 @@ ) from autogen.agentchat.contrib.retriever import DEFAULT_RETRIEVER, get_retriever from autogen.token_count_utils import count_token + Retriever = get_retriever(DEFAULT_RETRIEVER) except ImportError: skip = True @@ -71,13 +72,17 @@ def test_is_url(self): assert is_url("https://www.example.com") assert not is_url("not_a_url") - def test_custom_text_split_function(self): def custom_text_split_function(text): return [text[: len(text) // 2], text[len(text) // 2 :]] db_path = "/tmp/test_retrieve_utils" - retriever = Retriever(path=db_path, name="mytestcollection", custom_text_split_function=custom_text_split_function, use_existing=False) + retriever = Retriever( + path=db_path, + name="mytestcollection", + custom_text_split_function=custom_text_split_function, + use_existing=False, + ) retriever.ingest_data(os.path.join(test_dir, "example.txt")) results = retriever.query(["autogen"], top_k=1) assert ( @@ -105,7 +110,7 @@ def test_unstructured(self): isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip() for chunk in chunks ) - + if __name__ == "__main__": pytest.main() From 079886aedb1e81bc7857052003810e3c09a2619a Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 21 Nov 2023 22:07:15 +0530 Subject: [PATCH 13/52] update --- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index ad434d80ab6..2ba4b12ad23 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -1,7 +1,7 @@ from typing import Callable, Dict, List, Optional from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent -from autogen.autogen.retriever.retrieve_utils import get_files_from_dir, split_files_to_chunks +from autogen.agentchat.contrib.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS import logging logger = logging.getLogger(__name__) From 69031778a5b0131b80b11f8d7aec2bdf84f017bb Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 21 Nov 2023 22:12:26 +0530 Subject: [PATCH 14/52] update --- test/test_retrieve_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index 789067afd0e..ad6ad3df9e6 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -104,7 +104,6 @@ def custom_text_split_function(text): name="mytestcollection", custom_text_split_function=custom_text_split_function, use_existing=False, - get_or_create=True, recursive=False, ) retriever.ingest_data(os.path.join(test_dir, "example.txt")) From 8cf4cb21cb0b6c2c3cad382063db1f6833ee4f25 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 27 Nov 2023 01:27:34 +0530 Subject: [PATCH 15/52] update --- autogen/__init__.py | 1 - .../contrib/retrieve_user_proxy_agent.py | 23 +- autogen/agentchat/contrib/retriever/base.py | 19 +- .../agentchat/contrib/retriever/chromadb.py | 46 +- .../agentchat/contrib/retriever/lancedb.py | 22 +- notebook/agentchat_RetrieveChat.ipynb | 1588 ++++------------- setup.py | 4 +- .../contrib/retrievers/test_chromadb.py | 23 +- .../contrib/retrievers/test_lancedb.py | 11 +- 9 files changed, 477 insertions(+), 1260 deletions(-) diff --git a/autogen/__init__.py b/autogen/__init__.py index 3002ad5df8e..5d3a8a14b5e 100644 --- a/autogen/__init__.py +++ b/autogen/__init__.py @@ -4,7 +4,6 @@ from .agentchat import * from .code_utils import DEFAULT_MODEL, FAST_MODEL - # Set the root logger. logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index a29f79ff300..267c3dcf6ea 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -2,7 +2,7 @@ from autogen.agentchat.agent import Agent from autogen.agentchat import UserProxyAgent -from autogen.agentchat.contrib.retriever.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS +from autogen.agentchat.contrib.retriever.retrieve_utils import TEXT_FORMATS from autogen.token_count_utils import count_token from autogen.code_utils import extract_code from autogen.agentchat.contrib.retriever import get_retriever @@ -161,9 +161,10 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = human_input_mode=human_input_mode, **kwargs, ) - + self.retriever = None self._retrieve_config = {} if retrieve_config is None else retrieve_config self._retriever_type = self._retrieve_config.get("retriever_type") + self._retriever_path = self._retrieve_config.get("retriever_path", "~/autogen") self._task = self._retrieve_config.get("task", "default") self._client = self._retrieve_config.get("client", None) self._docs_path = self._retrieve_config.get("docs_path", None) @@ -363,10 +364,10 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = n_results (int): the number of results to be retrieved. search_string (str): only docs containing this string will be retrieved. """ - if not self._collection or not self._get_or_create: - print("Trying to create collection.") + if not self.retriever: retriever_class = get_retriever(self._retriever_type) self.retriever = retriever_class( + path=self._retriever_path, name=self._collection_name, embedding_model_name=self._embedding_model, embedding_function=self._embedding_function, @@ -374,14 +375,20 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = chunk_mode=self._chunk_mode, must_break_at_empty_line=self._must_break_at_empty_line, custom_text_split_function=self.custom_text_split_function, - use_existing=not self._get_or_create, client=self._client, custom_text_types=self._custom_text_types, recursive=self._recursive, ) - self._collection = True - self._get_or_create = False - self.retriever.ingest_data(self._docs_path) + if not self.retriever.index_exists() or not self._get_or_create: + print("Trying to create index.") # TODO: logger + self.retriever.ingest_data(self._docs_path) + elif self._get_or_create: + if self.retriever.index_exists(): + print("Trying to use existing collection.") # TODO: logger + self.retriever.use_existing_index() + else: + raise Exception("Requested to use existing index but it is not found!") + results = self.retriever.query( texts=[problem], top_k=n_results, diff --git a/autogen/agentchat/contrib/retriever/base.py b/autogen/agentchat/contrib/retriever/base.py index 4c5dfd69f44..56a0c64f1b3 100644 --- a/autogen/agentchat/contrib/retriever/base.py +++ b/autogen/agentchat/contrib/retriever/base.py @@ -13,8 +13,10 @@ def __init__( chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True, custom_text_split_function: Callable = None, - use_existing=False, client=None, + # TODO: add support for custom text types and recurisive + custom_text_types: str = None, + recursive: bool = True, ): """ Args: @@ -35,7 +37,6 @@ def __init__( self.chunk_mode = chunk_mode self.must_break_at_empty_line = must_break_at_empty_line self.custom_text_split_function = custom_text_split_function - self.use_existing = use_existing self.client = client self.init_db() @@ -49,6 +50,13 @@ def ingest_data(self, data_dir): """ pass + @abstractmethod + def use_existing_index(self): + """ + Open an existing index. + """ + pass + @abstractmethod def query(self, texts: List[str], top_k: int = 10, filter: Any = None): """ @@ -65,3 +73,10 @@ def init_db(self): Initialize the database. """ pass + + @abstractmethod + def index_exists(self): + """ + Check if the index exists in the database. + """ + pass diff --git a/autogen/agentchat/contrib/retriever/chromadb.py b/autogen/agentchat/contrib/retriever/chromadb.py index 05b0951970f..bf1deaf6605 100644 --- a/autogen/agentchat/contrib/retriever/chromadb.py +++ b/autogen/agentchat/contrib/retriever/chromadb.py @@ -17,31 +17,30 @@ class ChromaDB(Retriever): def init_db(self): - if self.client is None: - self.client = chromadb.PersistentClient(path=self.path) - embedding_function = ( + self.client = chromadb.PersistentClient(path=self.path) + self.embedding_function = ( ef.SentenceTransformerEmbeddingFunction(self.embedding_model_name) if self.embedding_function is None else self.embedding_function ) + self.collection = None + + def ingest_data(self, data_dir): + """ + Create a vector database from a directory of files. + Args: + data_dir: path to the directory containing the text files + """ + self.collection = self.client.create_collection( self.name, - get_or_create=self.use_existing, - embedding_function=embedding_function, + embedding_function=self.embedding_function, # https://github.com/nmslib/hnswlib#supported-distances # https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184 # https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine ) - def ingest_data(self, data_dir): - """ - Create a vector database from a directory of files. - Args: - data_dir: path to the directory containing the text files - """ - if self.client is None: - self.init_db() if self.custom_text_split_function is not None: chunks = split_files_to_chunks( get_files_from_dir(data_dir), custom_text_split_function=self.custom_text_split_function @@ -59,17 +58,14 @@ def ingest_data(self, data_dir): ids=[f"doc_{j}" for j in range(i, end_idx)], # unique for each doc ) + def use_existing_index(self): + self.collection = self.client.get_collection(name=self.name, embedding_function=self.embedding_function) + def query(self, texts: List[str], top_k: int = 10, filter: str = None): - if self.client is None: - self.init_db() # the collection's embedding function is always the default one, but we want to use the one we used to create the # collection. So we compute the embeddings ourselves and pass it to the query function. - embedding_function = ( - ef.SentenceTransformerEmbeddingFunction(self.embedding_model_name) - if self.embedding_function is None - else self.embedding_function - ) - query_embeddings = embedding_function(texts) + + query_embeddings = self.embedding_function(texts) # Query/search n most similar results. You can also .get by id results = self.collection.query( query_embeddings=query_embeddings, @@ -77,3 +73,11 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): where_document={"$contains": filter} if filter else None, # optional filter ) return results + + def index_exists(self): + try: + self.client.get_collection(name=self.name, embedding_function=self.embedding_function) + # Not sure if there's an explicit way to check if a collection exists for chromadb + return True + except Exception: + return False diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index cf7af621f1b..b2502ee18b0 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -14,22 +14,15 @@ class LanceDB(Retriever): db = None + table = None def init_db(self): - if self.db is None: - self.db = lancedb.connect(self.path) + self.db = lancedb.connect(self.path) self.embedding_function = ( get_registry().get("sentence-transformers").create(name=self.embedding_model_name, show_progress_bar=True) if self.embedding_function is None else self.embedding_function ) - if self.use_existing and self.name in self.db.table_names(): - self.table = self.db.open_table(self.name) - # logger.info(f"Reusing existing table {self.name}") - else: - # logger.info(f"Creating new table {self.name}") - schema = self._get_schema(self.embedding_function) - self.table = self.db.create_table(self.name, schema=schema, mode="overwrite") def ingest_data(self, data_dir): """ @@ -37,8 +30,9 @@ def ingest_data(self, data_dir): Args: data_dir: path to the directory containing the text files """ - if self.db is None: - self.init_db() + schema = self._get_schema(self.embedding_function) + self.table = self.db.create_table(self.name, schema=schema, mode="overwrite") + if self.custom_text_split_function is not None: chunks = split_files_to_chunks( get_files_from_dir(data_dir), custom_text_split_function=self.custom_text_split_function @@ -56,6 +50,9 @@ def ingest_data(self, data_dir): data = with_embeddings(self.embedding_function, pa_table) self.table.add(data) + def use_existing_index(self): + self.table = self.db.open_table(self.name) + def query(self, texts: List[str], top_k: int = 10, filter: str = None): if self.db is None: self.init_db() @@ -73,6 +70,9 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): return results + def index_exists(self): + return self.name in self.db.table_names() + def _get_schema(self, embedding_function): if isinstance(embedding_function, EmbeddingFunction): diff --git a/notebook/agentchat_RetrieveChat.ipynb b/notebook/agentchat_RetrieveChat.ipynb index 8b81a2ec264..5638b5e59a9 100644 --- a/notebook/agentchat_RetrieveChat.ipynb +++ b/notebook/agentchat_RetrieveChat.ipynb @@ -67,14 +67,14 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "models to use: ['gpt-35-turbo']\n" + "models to use: ['gpt-4']\n" ] } ], @@ -82,7 +82,7 @@ "import autogen\n", "\n", "config_list = autogen.config_list_from_json(\n", - " env_or_file=\"OAI_CONFIG_LIST\",\n", + " env_or_file=\"../OAI_CONFIG_LIST\",\n", " file_location=\".\",\n", " filter_dict={\n", " \"model\": {\n", @@ -148,15 +148,22 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 2, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "This module is deprecated. Please use autogen.agentchat.contrib.retriever.retrieve_utils instead.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "Accepted file formats for `docs_path`:\n", - "['xml', 'htm', 'msg', 'docx', 'org', 'pptx', 'jsonl', 'txt', 'tsv', 'yml', 'json', 'md', 'pdf', 'xlsx', 'csv', 'html', 'log', 'yaml', 'doc', 'odt', 'rtf', 'ppt', 'epub', 'rst']\n" + "['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n" ] } ], @@ -171,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -212,10 +219,10 @@ " \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Research.md\",\n", " os.path.join(os.path.abspath(''), \"..\", \"website\", \"docs\"),\n", " ],\n", + " \"retriever_path\": \"~/test\",\n", " \"custom_text_types\": [\"mdx\"],\n", " \"chunk_token_size\": 2000,\n", " \"model\": config_list[0][\"model\"],\n", - " \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"),\n", " \"embedding_model\": \"all-mpnet-base-v2\",\n", " \"get_or_create\": True, # set to False if you don't want to reuse an existing collection, but you'll need to remove the collection manually\n", " },\n", @@ -240,36 +247,31 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 4, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:autogen.retrieve_utils:Found 2 chunks.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Trying to create collection.\n" + "Trying to use existing collection.\n", + "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 20 is greater than number of elements in index 2, updating n_results = 2\n" + "/Users/ayushchaurasia/Documents/autogen/autogen/env/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", + "doc_ids: [['0']]\n", + "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -409,42 +411,51 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "You can use FLAML's `lgbm_spark` estimator for classification tasks and activate Spark as the parallel backend during training by setting `use_spark` to `True`. Here is an example code snippet:\n", + "You can use the provided FLAML API along with Spark for distributed training and parallel jobs. FLAML integrates Spark ML estimators for AutoML and offers utilities to prepare your data in the required format. This includes the `to_pandas_on_spark` function for converting your data into a pandas-on-spark dataframe, and the `VectorAssembler` for merging all feature columns into a single vector column.\n", + "\n", + "Here's an example of how you can perform a classification task with FLAML and Spark and force cancel jobs if the time limit is reached:\n", "\n", "```python\n", - "import flaml\n", + "import pandas as pd\n", "from flaml.automl.spark.utils import to_pandas_on_spark\n", "from pyspark.ml.feature import VectorAssembler\n", + "import flaml\n", + "\n", + "# Creating a dictionary\n", + "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", + "\n", + "# Creating a pandas DataFrame\n", + "dataframe = pd.DataFrame(data)\n", + "label = \"Price\"\n", "\n", - "# Assuming you have a Spark DataFrame named 'df' that contains your data\n", - "dataframe = df.toPandas()\n", - "label = \"target\"\n", + "# Convert to pandas-on-spark dataframe\n", "psdf = to_pandas_on_spark(dataframe)\n", "\n", + "# Prepare features using VectorAssembler\n", "columns = psdf.columns\n", "feature_cols = [col for col in columns if col != label]\n", "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\", label]\n", "\n", - "# configure and run AutoML\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"accuracy\",\n", - " \"estimator_list\": [\"lgbm_spark\"],\n", - " \"task\": \"classification\",\n", - " \"n_jobs\": -1, # Use all available CPUs\n", - " \"use_spark\": True, # Use Spark as the parallel backend\n", - " \"force_cancel\": True # Halt Spark jobs that run for longer than the time budget\n", + "# Define FLAML settings\n", + "automl_settings = {\n", + " \"time_budget\": 30, # Train for 30 seconds\n", + " \"metric\": \"accuracy\", # Evaluation metric\n", + " \"task\": \"classification\", # Type of task\n", + " \"n_concurrent_trials\": 2, # Number of concurrent trials\n", + " \"use_spark\": True, # Use spark for parallel training\n", + " \"force_cancel\": True, # Force cancel jobs if time limit is reached\n", "}\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings,\n", - ")\n", + "\n", + "automl = flaml.AutoML()\n", + "\n", + "# Train with FLAML and Spark with a classification task\n", + "automl.fit(dataframe=psdf, label=label, **automl_settings)\n", "```\n", "\n", - "Note that you should not use `use_spark` if you are working with Spark data, because SparkML models already run in parallel.\n", + "Please note that this is a basic example. FLAML has many more options available for tuning the models such as the estimator_list option for specifying desired models to try. Also note that the Spark environment needs to be properly set up for running this code.\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", @@ -457,61 +468,24 @@ "UPDATE CONTEXT\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 60 is greater than number of elements in index 2, updating n_results = 2\n", - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 100 is greater than number of elements in index 2, updating n_results = 2\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "doc_ids: [['doc_0']]\n", - "doc_ids: [['doc_0']]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 140 is greater than number of elements in index 2, updating n_results = 2\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "doc_ids: [['doc_0']]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 180 is greater than number of elements in index 2, updating n_results = 2\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "doc_ids: [['doc_0']]\n", + "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n", + "Trying to use existing collection.\n", + "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", + "doc_ids: [['0']]\n", + "Trying to use existing collection.\n", + "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", + "doc_ids: [['0']]\n", + "Trying to use existing collection.\n", + "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", + "doc_ids: [['0']]\n", + "Trying to use existing collection.\n", + "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", + "doc_ids: [['0']]\n", "\u001b[32mNo more context, will terminate.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "TERMINATE\n", "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "TERMINATE\n", - "\n", "--------------------------------------------------------------------------------\n" ] } @@ -545,23 +519,18 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 5, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 20 is greater than number of elements in index 2, updating n_results = 2\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0', 'doc_1']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", + "Trying to use existing collection.\n", + "query: Who is the author of FLAML?\n", + "doc_ids: [['0', '1']]\n", + "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -814,7 +783,51 @@ "\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", + "The primary authors of FLAML, or Fast Lightweight AutoML, are Chi Wang and Qingyun Wu. They developed this library at Microsoft Research. Other contributors include Markus Weimer and Erkang Zhu. They have published several research papers on various aspects of FLAML, which further discuss the technical details and innovative techniques used in this AutoML library.\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + } + ], + "source": [ + "# reset the assistant. Always reset the assistant before starting a new conversation.\n", + "assistant.reset()\n", + "\n", + "qa_problem = \"Who is the author of FLAML?\"\n", + "ragproxyagent.initiate_chat(assistant, problem=qa_problem)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Example 3\n", + "\n", + "[back to top](#toc)\n", + "\n", + "Use RetrieveChat to help generate sample code and ask for human-in-loop feedbacks.\n", + "\n", + "Problem: how to build a time series forecasting model for stock price using FLAML?" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trying to use existing collection.\n", + "query: how to build a time series forecasting model for stock price using FLAML?\n", + "doc_ids: [['0', '1']]\n", + "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -827,7 +840,7 @@ "# your code\n", "```\n", "\n", - "User's question is: Who is the author of FLAML?\n", + "User's question is: how to build a time series forecasting model for stock price using FLAML?\n", "\n", "Context is: # Integrate - Spark\n", "\n", @@ -1069,18 +1082,71 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.\n", + "To build a forecasting model for stock price using FLAML, we first need to ensure we have time-series data related to stock price in a pandas data frame format. Then, we convert it to pandas-on-spark data frame using the function `to_pandas_on_spark` provided by FLAML. And then, we use `VectorAssembler` to merge all feature columns into one vector column.\n", + "\n", + "Following the data preprocessing, we integrate Spark ML with the FLAML AutoML model. The provided FLAML AutoML model using Spark is `lgbm_spark`. Set up the settings such as time budget, metric, and task among the others, and then call the `fit()` method of FLAML AutoML model.\n", + "\n", + "Below is the example in Python:\n", + "\n", + "```python\n", + "import pandas as pd\n", + "from flaml.automl import AutoML\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "from pyspark.ml.feature import VectorAssembler\n", + "\n", + "# Assuming data is a dataframe containing your time series data\n", + "\n", + "# Convert to pandas-on-spark dataframe\n", + "psdf = to_pandas_on_spark(data)\n", + "\n", + "# Assume label is the name of the column in data you want to predict\n", + "label = \"Price\"\n", + "\n", + "columns = psdf.columns\n", + "feature_cols = [col for col in columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "\n", + "# Transform data\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\", label]\n", + "\n", + "# Initialize and setup automl object\n", + "automl = AutoML()\n", + "settings = {\n", + " \"time_budget\": 100, # in seconds\n", + " \"metric\": \"mae\",\n", + " \"task\": \"forecast\",\n", + " \"estimator_list\": [\"lgbm_spark\"],\n", + "}\n", + "\n", + "# Fit automl model\n", + "automl.fit(dataframe=psdf, label=label, **settings)\n", + "```\n", + "\n", + "Replace `data` with the pandas dataframe containing time series information related to the stock price. \"Price\" should be replaced by the column you want to predict in the data frame.\n", + "\n", + "Please note, the SparkML LightGBM model `lgbm_spark` is more suitable for regression tasks. If your stock price dataset is a time series classification task, you might want to choose a different model that supports classification task.\n", + "\n", + "Additionally, FLAML AutoML settings are flexible so you can adjust the settings as needed. For example, you can increase the time budget if you want FLAML to try more configurations, or change the metric to another error metric like \"mse\".\n", "\n", "--------------------------------------------------------------------------------\n" ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "Provide feedback to assistant. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: exit\n" + ] } ], "source": [ "# reset the assistant. Always reset the assistant before starting a new conversation.\n", "assistant.reset()\n", "\n", - "qa_problem = \"Who is the author of FLAML?\"\n", - "ragproxyagent.initiate_chat(assistant, problem=qa_problem)" + "# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n", + "ragproxyagent.human_input_mode = \"ALWAYS\"\n", + "code_problem = \"how to build a time series forecasting model for stock price using FLAML?\"\n", + "ragproxyagent.initiate_chat(assistant, problem=code_problem)" ] }, { @@ -1088,35 +1154,30 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", - "### Example 3\n", + "\n", + "### Example 4\n", "\n", "[back to top](#toc)\n", "\n", - "Use RetrieveChat to help generate sample code and ask for human-in-loop feedbacks.\n", + "Use RetrieveChat to answer a question and ask for human-in-loop feedbacks.\n", "\n", - "Problem: how to build a time series forecasting model for stock price using FLAML?" + "Problem: Is there a function named `tune_automl` in FLAML?" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 7, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 20 is greater than number of elements in index 2, updating n_results = 2\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0', 'doc_1']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", + "Trying to use existing collection.\n", + "query: Is there a function named `tune_automl` in FLAML?\n", + "doc_ids: [['0', '1']]\n", + "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -1129,7 +1190,7 @@ "# your code\n", "```\n", "\n", - "User's question is: how to build a time series forecasting model for stock price using FLAML?\n", + "User's question is: Is there a function named `tune_automl` in FLAML?\n", "\n", "Context is: # Integrate - Spark\n", "\n", @@ -1371,34 +1432,17 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "To build a time series forecasting model for stock price using FLAML, you can use the `lgbm_spark` estimator and organize your data in the required format. First, use `to_pandas_on_spark` function to convert your data into a pandas-on-spark dataframe/series, which Spark estimators require. Next, you should use `VectorAssembler` to merge all feature columns into a single vector column. Finally, use `flaml.AutoML` to try different configurations for the `lgbm_spark` model. Here is an example code snippet: \n", + "From the provided context, it does not appear that there is a function named `tune_automl` in FLAML. Instead, the `AutoML` class is instantiated and its `fit` method is used to conduct the automated machine learning process. This process includes hyperparameter tuning, but it is not conducted with a standalone `tune_automl` function. This is true for both the general use case of FLAML, as well as the specific case of using FLAML with Spark, as outlined in the provided context.\n", + "\n", + "If more context is given or if you are referring to a different version or extension of FLAML, the answer might be different.\n", "\n", "```python\n", "import flaml\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "from pyspark.ml.feature import VectorAssembler\n", - "\n", - "# load your stock price data into a pandas dataframe\n", - "data = pd.read_csv('stock_price.csv')\n", - "\n", - "# specify label column name\n", - "label = 'price'\n", - "\n", - "# convert pandas dataframe to pandas-on-spark dataframe\n", - "psdf = to_pandas_on_spark(data)\n", - "\n", - "# merge feature columns as a single vector column\n", - "feature_cols = [col for col in psdf.columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "\n", - "# start an AutoML experiment with lgbm_spark estimator\n", "automl = flaml.AutoML()\n", "settings = {\n", " \"time_budget\": 30,\n", " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"],\n", + " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", " \"task\": \"regression\",\n", "}\n", "\n", @@ -1413,77 +1457,10 @@ ] }, { - "name": "stdout", + "name": "stdin", "output_type": "stream", "text": [ - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "I want the time_budget to be 10 mins\n", - "\n", - "--------------------------------------------------------------------------------\n", - "I want the time_budget to be 10 mins\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", - "\n", - "You can change the `time_budget` parameter in the `settings` dictionary to 10 minutes (600 seconds) like this:\n", - "\n", - "```python\n", - "import flaml\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "from pyspark.ml.feature import VectorAssembler\n", - "\n", - "# load your stock price data into a pandas dataframe\n", - "data = pd.read_csv('stock_price.csv')\n", - "\n", - "# specify label column name\n", - "label = 'price'\n", - "\n", - "# convert pandas dataframe to pandas-on-spark dataframe\n", - "psdf = to_pandas_on_spark(data)\n", - "\n", - "# merge feature columns as a single vector column\n", - "feature_cols = [col for col in psdf.columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "\n", - "# start an AutoML experiment with lgbm_spark estimator and time_budget of 10 mins\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 600, # time_budget in seconds\n", - " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"],\n", - " \"task\": \"regression\",\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings,\n", - ")\n", - "```\n", - "\n", - "\n", - "In this example, the `time_budget` parameter is set to 600, which represents the number of seconds the FLAML AutoML experiment will run. You can adjust this value to control the total time spent on the experiment.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[31m\n", - ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n", - "\u001b[31m\n", - ">>>>>>>> USING AUTO REPLY...\u001b[0m\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", - "\n", - "Is there anything else I can help you with?\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[31m\n", - ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n" + "Provide feedback to assistant. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: exit\n" ] } ], @@ -1493,567 +1470,8 @@ "\n", "# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n", "ragproxyagent.human_input_mode = \"ALWAYS\"\n", - "code_problem = \"how to build a time series forecasting model for stock price using FLAML?\"\n", - "ragproxyagent.initiate_chat(assistant, problem=code_problem)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "### Example 4\n", - "\n", - "[back to top](#toc)\n", - "\n", - "Use RetrieveChat to answer a question and ask for human-in-loop feedbacks.\n", - "\n", - "Problem: Is there a function named `tune_automl` in FLAML?" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 20 is greater than number of elements in index 2, updating n_results = 2\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "doc_ids: [['doc_0', 'doc_1']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", - "context provided by the user.\n", - "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", - "For code generation, you must obey the following rules:\n", - "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", - "Rule 2. You must follow the formats below to write your code:\n", - "```language\n", - "# your code\n", - "```\n", - "\n", - "User's question is: Is there a function named `tune_automl` in FLAML?\n", - "\n", - "Context is: # Integrate - Spark\n", - "\n", - "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", - "- Use Spark ML estimators for AutoML.\n", - "- Use Spark to run training in parallel spark jobs.\n", - "\n", - "## Spark ML Estimators\n", - "\n", - "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", - "\n", - "### Data\n", - "\n", - "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", - "\n", - "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", - "\n", - "This function also accepts optional arguments `index_col` and `default_index_type`.\n", - "- `index_col` is the column name to use as the index, default is None.\n", - "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", - "\n", - "Here is an example code snippet for Spark Data:\n", - "\n", - "```python\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "# Creating a dictionary\n", - "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", - "\n", - "# Creating a pandas DataFrame\n", - "dataframe = pd.DataFrame(data)\n", - "label = \"Price\"\n", - "\n", - "# Convert to pandas-on-spark dataframe\n", - "psdf = to_pandas_on_spark(dataframe)\n", - "```\n", - "\n", - "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", - "\n", - "Here is an example of how to use it:\n", - "```python\n", - "from pyspark.ml.feature import VectorAssembler\n", - "columns = psdf.columns\n", - "feature_cols = [col for col in columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "```\n", - "\n", - "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", - "\n", - "### Estimators\n", - "#### Model List\n", - "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", - "\n", - "#### Usage\n", - "First, prepare your data in the required format as described in the previous section.\n", - "\n", - "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", - "\n", - "Here is an example code snippet using SparkML models in AutoML:\n", - "\n", - "```python\n", - "import flaml\n", - "# prepare your data in pandas-on-spark format as we previously mentioned\n", - "\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", - " \"task\": \"regression\",\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings,\n", - ")\n", - "```\n", - "\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", - "\n", - "## Parallel Spark Jobs\n", - "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", - "\n", - "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", - "\n", - "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", - "\n", - "\n", - "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", - "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", - "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", - "\n", - "An example code snippet for using parallel Spark jobs:\n", - "```python\n", - "import flaml\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"task\": \"regression\",\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=dataframe,\n", - " label=label,\n", - " **automl_settings,\n", - ")\n", - "```\n", - "\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", - "\n", - "# Research\n", - "\n", - "For technical details, please check our research publications.\n", - "\n", - "* [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021flaml,\n", - " title={FLAML: A Fast and Lightweight AutoML Library},\n", - " author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n", - " year={2021},\n", - " booktitle={MLSys},\n", - "}\n", - "```\n", - "\n", - "* [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021cfo,\n", - " title={Frugal Optimization for Cost-related Hyperparameters},\n", - " author={Qingyun Wu and Chi Wang and Silu Huang},\n", - " year={2021},\n", - " booktitle={AAAI},\n", - "}\n", - "```\n", - "\n", - "* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021blendsearch,\n", - " title={Economical Hyperparameter Optimization With Blended Search Strategy},\n", - " author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n", - " year={2021},\n", - " booktitle={ICLR},\n", - "}\n", - "```\n", - "\n", - "* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{liuwang2021hpolm,\n", - " title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n", - " author={Susan Xueqing Liu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ACL},\n", - "}\n", - "```\n", - "\n", - "* [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021chacha,\n", - " title={ChaCha for Online AutoML},\n", - " author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n", - " year={2021},\n", - " booktitle={ICML},\n", - "}\n", - "```\n", - "\n", - "* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", - "\n", - "```bibtex\n", - "@inproceedings{wuwang2021fairautoml,\n", - " title={Fair AutoML},\n", - " author={Qingyun Wu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ArXiv preprint arXiv:2111.06495},\n", - "}\n", - "```\n", - "\n", - "* [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", - "\n", - "```bibtex\n", - "@inproceedings{kayaliwang2022default,\n", - " title={Mining Robust Default Configurations for Resource-constrained AutoML},\n", - " author={Moe Kayali and Chi Wang},\n", - " year={2022},\n", - " booktitle={ArXiv preprint arXiv:2202.09927},\n", - "}\n", - "```\n", - "\n", - "* [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", - "\n", - "```bibtex\n", - "@inproceedings{zhang2023targeted,\n", - " title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n", - " author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n", - " booktitle={International Conference on Learning Representations},\n", - " year={2023},\n", - " url={https://openreview.net/forum?id=0Ij9_q567Ma},\n", - "}\n", - "```\n", - "\n", - "* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2023EcoOptiGen,\n", - " title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n", - " author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2303.04673},\n", - "}\n", - "```\n", - "\n", - "* [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2023empirical,\n", - " title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n", - " author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2306.01337},\n", - "}\n", - "```\n", - "\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", - "context provided by the user.\n", - "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", - "For code generation, you must obey the following rules:\n", - "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", - "Rule 2. You must follow the formats below to write your code:\n", - "```language\n", - "# your code\n", - "```\n", - "\n", - "User's question is: Is there a function named `tune_automl` in FLAML?\n", - "\n", - "Context is: # Integrate - Spark\n", - "\n", - "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", - "- Use Spark ML estimators for AutoML.\n", - "- Use Spark to run training in parallel spark jobs.\n", - "\n", - "## Spark ML Estimators\n", - "\n", - "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", - "\n", - "### Data\n", - "\n", - "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", - "\n", - "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", - "\n", - "This function also accepts optional arguments `index_col` and `default_index_type`.\n", - "- `index_col` is the column name to use as the index, default is None.\n", - "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", - "\n", - "Here is an example code snippet for Spark Data:\n", - "\n", - "```python\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "# Creating a dictionary\n", - "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", - "\n", - "# Creating a pandas DataFrame\n", - "dataframe = pd.DataFrame(data)\n", - "label = \"Price\"\n", - "\n", - "# Convert to pandas-on-spark dataframe\n", - "psdf = to_pandas_on_spark(dataframe)\n", - "```\n", - "\n", - "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", - "\n", - "Here is an example of how to use it:\n", - "```python\n", - "from pyspark.ml.feature import VectorAssembler\n", - "columns = psdf.columns\n", - "feature_cols = [col for col in columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "```\n", - "\n", - "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", - "\n", - "### Estimators\n", - "#### Model List\n", - "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", - "\n", - "#### Usage\n", - "First, prepare your data in the required format as described in the previous section.\n", - "\n", - "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", - "\n", - "Here is an example code snippet using SparkML models in AutoML:\n", - "\n", - "```python\n", - "import flaml\n", - "# prepare your data in pandas-on-spark format as we previously mentioned\n", - "\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", - " \"task\": \"regression\",\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings,\n", - ")\n", - "```\n", - "\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", - "\n", - "## Parallel Spark Jobs\n", - "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", - "\n", - "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", - "\n", - "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", - "\n", - "\n", - "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", - "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", - "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", - "\n", - "An example code snippet for using parallel Spark jobs:\n", - "```python\n", - "import flaml\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"task\": \"regression\",\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=dataframe,\n", - " label=label,\n", - " **automl_settings,\n", - ")\n", - "```\n", - "\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", - "\n", - "# Research\n", - "\n", - "For technical details, please check our research publications.\n", - "\n", - "* [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021flaml,\n", - " title={FLAML: A Fast and Lightweight AutoML Library},\n", - " author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n", - " year={2021},\n", - " booktitle={MLSys},\n", - "}\n", - "```\n", - "\n", - "* [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021cfo,\n", - " title={Frugal Optimization for Cost-related Hyperparameters},\n", - " author={Qingyun Wu and Chi Wang and Silu Huang},\n", - " year={2021},\n", - " booktitle={AAAI},\n", - "}\n", - "```\n", - "\n", - "* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021blendsearch,\n", - " title={Economical Hyperparameter Optimization With Blended Search Strategy},\n", - " author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n", - " year={2021},\n", - " booktitle={ICLR},\n", - "}\n", - "```\n", - "\n", - "* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{liuwang2021hpolm,\n", - " title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n", - " author={Susan Xueqing Liu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ACL},\n", - "}\n", - "```\n", - "\n", - "* [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021chacha,\n", - " title={ChaCha for Online AutoML},\n", - " author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n", - " year={2021},\n", - " booktitle={ICML},\n", - "}\n", - "```\n", - "\n", - "* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", - "\n", - "```bibtex\n", - "@inproceedings{wuwang2021fairautoml,\n", - " title={Fair AutoML},\n", - " author={Qingyun Wu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ArXiv preprint arXiv:2111.06495},\n", - "}\n", - "```\n", - "\n", - "* [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", - "\n", - "```bibtex\n", - "@inproceedings{kayaliwang2022default,\n", - " title={Mining Robust Default Configurations for Resource-constrained AutoML},\n", - " author={Moe Kayali and Chi Wang},\n", - " year={2022},\n", - " booktitle={ArXiv preprint arXiv:2202.09927},\n", - "}\n", - "```\n", - "\n", - "* [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", - "\n", - "```bibtex\n", - "@inproceedings{zhang2023targeted,\n", - " title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n", - " author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n", - " booktitle={International Conference on Learning Representations},\n", - " year={2023},\n", - " url={https://openreview.net/forum?id=0Ij9_q567Ma},\n", - "}\n", - "```\n", - "\n", - "* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2023EcoOptiGen,\n", - " title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n", - " author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2303.04673},\n", - "}\n", - "```\n", - "\n", - "* [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2023empirical,\n", - " title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n", - " author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2306.01337},\n", - "}\n", - "```\n", - "\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", - "\n", - "There is no function named `tune_automl` in FLAML. However, FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark: \n", - "- Use Spark ML Estimators for AutoML.\n", - "- Use Spark to run training in parallel Spark jobs.\n", - "\n", - "--------------------------------------------------------------------------------\n" - ] - } - ], - "source": [ - "# reset the assistant. Always reset the assistant before starting a new conversation.\n", - "assistant.reset()\n", - "\n", - "# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n", - "ragproxyagent.human_input_mode = \"ALWAYS\"\n", - "qa_problem = \"Is there a function named `tune_automl` in FLAML?\"\n", - "ragproxyagent.initiate_chat(assistant, problem=qa_problem) # type \"exit\" to exit the conversation" + "qa_problem = \"Is there a function named `tune_automl` in FLAML?\"\n", + "ragproxyagent.initiate_chat(assistant, problem=qa_problem) # type \"exit\" to exit the conversation" ] }, { @@ -2073,7 +1491,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -2082,7 +1500,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -2099,17 +1517,17 @@ " \"docs_path\": corpus_file,\n", " \"chunk_token_size\": 2000,\n", " \"model\": config_list[0][\"model\"],\n", - " \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"),\n", " \"collection_name\": \"natural-questions\",\n", " \"chunk_mode\": \"one_line\",\n", " \"embedding_model\": \"all-MiniLM-L6-v2\",\n", + " \"get_or_create\": True\n", " },\n", ")" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -2140,7 +1558,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -2152,130 +1570,28 @@ ">>>>>>>>>>>> Below are outputs of Case 1 <<<<<<<<<<<<\n", "\n", "\n", - "Trying to create collection.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
Film Year Fuck count Minutes Uses / mi ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
Character Ultimate Avengers Ultimate Avengers 2 I ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
Position Country Town / City PM2. 5 PM ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
Rank Country ( or dependent territory ) Population
Rank State Gross collections ( in thousands ) Rev ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t < ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
Date Province Mag . MMI Deaths
City River State
Gangakhed ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
Player Pos . Team Career start Career ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t ABO and Rh blood type distribution by country ( population averages )
Country
Total area Land area Performance in the European Cup and UEFA Champions League by club
  • ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Rank City State Land area ( sq mi ) La ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    # Country Name International goals Cap ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Rank City Image Population Definition ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Rank Team Won Lost Tied Pct ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Territory Rights holder Ref
    Asia
    ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    ( hide ) Rank Nat Name Years Goals
    Total area Land area
    Bids by school Most recent
    Rank Name Nation TP SP
    2014 Rank City 2014 Estimate 2010 Census
    S.No . Year Name
    1961
    Densities of various materials covering a range of values
    Material ρ ( ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Club Season League Nation ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Rank ( 2016 ) Airports ( large hubs ) IATA Code M ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    City Region / State Country Park name ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Year Winner ( nationally ) Votes Percent
    Compound SERT NET DAT 5 - HT
    Rank Name Industry Revenue ( USD millions )
    ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Rank Name Name in Georgian Population 1989
    Country The World Factbook World Res ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Rank Country Area ( km2 ) Notes
    ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Rank Country Area ( km2 ) Notes
    Date State ( s ) Magnitude Fatalities ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t < ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Artist # Gold # Platinum # Multi-Platinum
    Name Number of locations Revenue
    Name Country Region Depth ( meters ) < ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t
    Rank Player ( 2017 HRs ) HR
    ...\n", - "max_tokens is too small to fit a single line of text. Breaking this line:\n", - "\t ...\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "doc_ids: [['doc_0', 'doc_3334', 'doc_720', 'doc_2732', 'doc_2510', 'doc_5084', 'doc_5068', 'doc_3727', 'doc_1938', 'doc_4689', 'doc_5249', 'doc_1751', 'doc_480', 'doc_3989', 'doc_2115', 'doc_1233', 'doc_2264', 'doc_633', 'doc_2376', 'doc_2293', 'doc_5274', 'doc_5213', 'doc_3991', 'doc_2880', 'doc_2737', 'doc_1257', 'doc_1748', 'doc_2038', 'doc_4073', 'doc_2876']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3334 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_720 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2732 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2510 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5084 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5068 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3727 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1938 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4689 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5249 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1751 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_480 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3989 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3334 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_720 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2732 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2510 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5084 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5068 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3727 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1938 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4689 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5249 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1751 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_480 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3989 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2115 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1233 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2264 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_633 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2376 to context.\u001b[0m\n", + "Trying to use existing collection.\n", + "query: what is non controlling interest on balance sheet\n", + "doc_ids: [['0', '3334', '720', '2732', '2510', '5084', '5068', '3727', '1938', '4689', '5249', '1751', '480', '3989', '2115', '1233', '2264', '633', '2376', '2293', '5274', '4842', '5213', '3991', '2880', '2737', '1257', '1748', '2038', '4073']]\n", + "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3334 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 720 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2732 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2510 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 5084 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 5068 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3727 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1938 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4689 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 5249 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1751 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 480 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3989 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2115 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1233 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2264 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 633 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2376 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the\n", @@ -2310,7 +1626,7 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "Non controlling interest on balance sheet refers to the portion of a subsidiary corporation's stock that is not owned by the parent corporation. It represents ownership of less than 50% of the outstanding shares. It is shown as a separate line item in the equity section of the balance sheet.\n", + "Non-controlling interest, also known as minority interest, on a balance sheet is the portion of a subsidiary corporation's stock not owned by the parent corporation.\n", "\n", "--------------------------------------------------------------------------------\n", "\n", @@ -2318,32 +1634,37 @@ ">>>>>>>>>>>> Below are outputs of Case 2 <<<<<<<<<<<<\n", "\n", "\n", - "doc_ids: [['doc_1', 'doc_1097', 'doc_4221', 'doc_4972', 'doc_1352', 'doc_96', 'doc_988', 'doc_2370', 'doc_2414', 'doc_5038', 'doc_302', 'doc_1608', 'doc_980', 'doc_2112', 'doc_562', 'doc_4204', 'doc_3298', 'doc_2995', 'doc_3978', 'doc_1258', 'doc_2971', 'doc_2171', 'doc_1065', 'doc_17', 'doc_2683', 'doc_87', 'doc_1767', 'doc_158', 'doc_482', 'doc_3850']]\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1097 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4221 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4972 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1352 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_96 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_988 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2370 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2414 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5038 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_302 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1608 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_980 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2112 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_562 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4204 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3298 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2995 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3978 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1258 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2971 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2171 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1065 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_17 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2683 to context.\u001b[0m\n", + "Trying to use existing collection.\n", + "query: how many episodes are in chicago fire season 4\n", + "doc_ids: [['1', '1097', '4221', '4972', '1352', '4974', '96', '4301', '988', '2370', '2414', '5038', '302', '1608', '980', '2112', '1699', '562', '4204', '3298', '2995', '3978', '1258', '2971', '2171', '1065', '17', '2683', '87', '1767']]\n", + "\u001b[32mAdding doc_id 1 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1097 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4221 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4972 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1352 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4974 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 96 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4301 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 988 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2370 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2414 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 5038 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 302 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1608 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 980 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2112 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1699 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 562 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4204 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3298 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2995 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3978 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1258 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2971 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2171 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1065 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 17 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2683 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the\n", @@ -2358,7 +1679,9 @@ "

    The fourth season began airing on October 10 , 2017 , on The CW .

    \n", "

    The fifth season of Chicago P.D. , an American police drama television series with executive producer Dick Wolf , and producers Derek Haas , Michael Brandt , and Rick Eid , premiered on September 27 , 2017 . This season featured its 100th episode .

    \n", "

    This was the city of Chicago 's first professional sports championship since the Chicago Fire won MLS Cup ' 98 ( which came four months after the Chicago Bulls ' sixth NBA championship that year ) . The next major Chicago sports championship came in 2010 , when the NHL 's Chicago Blackhawks ended a 49 - year Stanley Cup title drought . With the Chicago Bears ' win in Super Bowl XX and the Chicago Cubs ' own World Series championship in 2016 , all Chicago sports teams have won at least one major championship since 1985 . Meanwhile , the Astros themselves made it back to the World Series in 2017 , but this time as an AL team , where they defeated the Los Angeles Dodgers in seven games , resulting in Houston 's first professional sports championship since the 2006 -- 07 Houston Dynamo won their back - to - back MLS Championships .

    \n", + "
    No . Athlete Nation Sport Years
    Chicago P.D. ( season 5 )
    Chicago P.D. Season 5 poster
    Country of origin United States
    No. of episodes 20
    Release
    Original network NBC
    Original release September 27 , 2017 ( 2017 - 09 - 27 ) -- present
    Season chronology
    ← Previous Season 4
    List of Chicago P.D. episodes
    \n", "

    The season was ordered in May 2017 , and production began the following month . Ben McKenzie stars as Gordon , alongside Donal Logue , David Mazouz , Morena Baccarin , Sean Pertwee , Robin Lord Taylor , Erin Richards , Camren Bicondova , Cory Michael Smith , Jessica Lucas , Chris Chalk , Drew Powell , Crystal Reed and Alexander Siddig . The fourth season premiered on September 21 , 2017 , on Fox , while the second half premiered on March 1 , 2018 .

    \n", + "

    The Eagle Creek Fire was a destructive wildfire in the Columbia River Gorge in the U.S. states of Oregon and Washington . The fire was started on September 2 , 2017 , reportedly caused by teenagers igniting fireworks during a burn ban . In mid-September , highway closures and local evacuations were gradually being lifted . As of September 28 , 2017 , the fire had consumed 48,831 acres ( 19,761 ha ) and was 46 % contained . In late October , fire growth was slowed by rain . On November 30 , 2017 , the fire was declared fully contained but not yet completely out .

    \n", "

    As of May 24 , 2017 , 58 episodes of The 100 have aired , concluding the fourth season . In March 2017 , The CW renewed the series for a fifth season , set to premiere on April 24 , 2018 .

    \n", "

    The fifth book , River of Fire , is scheduled to be released on April 10 , 2018 .

    \n", "

    On September 10 , 2013 , AMC officially cancelled the series after 38 episodes and three seasons . However , on November 15 , 2013 , Netflix ordered a fourth and final season of six episodes , that was released on Netflix on August 1 , 2014 .

    \n", @@ -2367,6 +1690,7 @@ "

    The first season consisted of eight one - hour - long episodes which were released worldwide on Netflix on July 15 , 2016 , in Ultra HD 4K . The second season , consisting of nine episodes , was released on October 27 , 2017 in HDR . A teaser for the second season , which also announced the release date , aired during Super Bowl LI .

    \n", "

    `` Two Days Before the Day After Tomorrow '' is the eighth episode in the ninth season of the American animated television series South Park . The 133rd overall episode overall , it originally aired on Comedy Central in the United States on October 19 , 2005 . In the episode , Stan and Cartman accidentally destroy a dam , causing the town of Beaverton to be destroyed .

    \n", "

    The fourth season consists of a double order of twenty episodes , split into two parts of ten episodes ; the second half premiered on November 30 , 2016 . The season follows the battles between Ragnar and Rollo in Francia , Bjorn 's raid into the Mediterranean , and the Viking invasion of England . It concluded in its entirety on February 1 , 2017 .

    \n", + "
    • Elizabeth Banks as Gail Abernathy - McKadden - Feinberger , an a cappella commentator making an insulting documentary about The Bellas
    • John Michael Higgins as John Smith , an a cappella commentator making an insulting documentary about The Bellas
    • John Lithgow as Fergus Hobart , Fat Amy 's estranged criminal father
    • Matt Lanter as Chicago Walp , a U.S. soldier guiding the Bellas during the tour , and Chloe 's love interest .
    • Guy Burnet as Theo , DJ Khaled 's music producer , who takes a liking to Beca
    • DJ Khaled as himself
    • Troy Ian Hall as Zeke , a U.S. soldier , partners with Chicago
    • Michael Rose as Aubrey 's father
    • Jessica Chaffin as Evan
    • Moises Arias as Pimp - Lo
    • Ruby Rose , Andy Allo , Venzella Joy Williams , and Hannah Fairlight as Calamity , Serenity , Charity , and Veracity , respectively , members of the band Evermoist
    • Whiskey Shivers as Saddle Up , a country - bluegrass - based band competing against the Bellas
    • Trinidad James and D.J. Looney as Young Sparrow and DJ Dragon Nutz , respectively
    \n", "

    This is an episode list for Sabrina the Teenage Witch , an American sitcom that debuted on ABC in 1996 . From Season 5 , the program was aired on The WB . The series ran for seven seasons totaling 163 episodes . It originally premiered on September 27 , 1996 on ABC and ended on April 24 , 2003 on The WB .

    \n", "

    Hart of Dixie was renewed by The CW for 10 episode season on May 8 , 2014 . The show 's fourth and final season premiered on November 15 , 2014 . The series was later cancelled on May 7 , 2015 .

    \n", "

    The Burning Maze is the third book in the series . It is scheduled to be released on May 1 , 2018 .

    \n", @@ -2384,7 +1708,7 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "There are 23 episodes in Chicago Fire season 4.\n", + "The fourth season of Chicago Fire contained 23 episodes.\n", "\n", "--------------------------------------------------------------------------------\n", "\n", @@ -2392,28 +1716,30 @@ ">>>>>>>>>>>> Below are outputs of Case 3 <<<<<<<<<<<<\n", "\n", "\n", - "doc_ids: [['doc_47', 'doc_45', 'doc_2570', 'doc_2851', 'doc_4033', 'doc_5320', 'doc_3849', 'doc_4172', 'doc_3202', 'doc_2282', 'doc_1896', 'doc_949', 'doc_103', 'doc_1552', 'doc_2791', 'doc_392', 'doc_1175', 'doc_5315', 'doc_832', 'doc_3185', 'doc_2532', 'doc_3409', 'doc_824', 'doc_4075', 'doc_1201', 'doc_4116', 'doc_1448', 'doc_2545', 'doc_2251', 'doc_2485']]\n", - "\u001b[32mAdding doc_id doc_47 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_45 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2570 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2851 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4033 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5320 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3849 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4172 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3202 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2282 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1896 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_949 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_103 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1552 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2791 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_392 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1175 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5315 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_832 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3185 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2532 to context.\u001b[0m\n", + "Trying to use existing collection.\n", + "query: what are bulls used for on a farm\n", + "doc_ids: [['47', '45', '2570', '2851', '4033', '5320', '3849', '4172', '3202', '2282', '1896', '949', '103', '1552', '2791', '392', '1175', '5315', '832', '3185', '2532', '3409', '824', '2814', '4075', '2815', '1201', '4116', '1448', '5293']]\n", + "\u001b[32mAdding doc_id 47 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 45 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2570 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2851 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4033 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 5320 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3849 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4172 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3202 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2282 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1896 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 949 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 103 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1552 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2791 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 392 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1175 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 5315 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 832 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3185 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2532 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the\n", @@ -2450,7 +1776,7 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "Bulls are used for breeding and often kept for their semen to sell for AI purposes. Some male cattle are also kept as work oxen for haulage. The vast majority, however, are slaughtered for meat before the age of three years.\n", + "Bulls on a farm are used for breeding purposes, with their semen often used for artificial insemination (AI) to maintain and improve the quality of the herd. Some bulls are kept specifically for semen collection, which can be profitable. Additionally, majority of male cattle, including bulls, are slaughtered for meat before the age of three years, providing leaner muscle compared to castrated males or females.\n", "\n", "--------------------------------------------------------------------------------\n", "\n", @@ -2458,16 +1784,18 @@ ">>>>>>>>>>>> Below are outputs of Case 4 <<<<<<<<<<<<\n", "\n", "\n", - "doc_ids: [['doc_3031', 'doc_819', 'doc_4521', 'doc_3980', 'doc_3423', 'doc_5275', 'doc_745', 'doc_753', 'doc_3562', 'doc_4139', 'doc_3678', 'doc_4931', 'doc_2347', 'doc_1115', 'doc_2806', 'doc_5204', 'doc_2707', 'doc_3653', 'doc_1122', 'doc_2398', 'doc_309', 'doc_3891', 'doc_2087', 'doc_330', 'doc_4844', 'doc_2155', 'doc_2674', 'doc_5357', 'doc_1581', 'doc_9']]\n", - "\u001b[32mAdding doc_id doc_3031 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_819 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4521 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3980 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3423 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5275 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_745 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_753 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3562 to context.\u001b[0m\n", + "Trying to use existing collection.\n", + "query: has been honoured with the wisden leading cricketer in the world award for 2016\n", + "doc_ids: [['3031', '819', '4521', '3980', '3423', '5275', '745', '753', '3562', '4139', '3678', '4931', '2347', '1115', '2806', '5204', '2707', '3653', '1122', '2398', '309', '3891', '2087', '330', '4844', '2155', '2987', '2674', '5357', '1581']]\n", + "\u001b[32mAdding doc_id 3031 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 819 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4521 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3980 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3423 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 5275 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 745 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 753 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3562 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the\n", @@ -2492,19 +1820,19 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "I'm sorry, I couldn't find any information about who has been honoured with the Wisden Leading Cricketer in the World award for 2016. UPDATE CONTEXT.\n", + "UPDATE CONTEXT\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4139 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3678 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4931 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2347 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1115 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2806 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_5204 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2707 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3653 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4139 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3678 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4931 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2347 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1115 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2806 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 5204 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2707 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3653 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the\n", @@ -2529,17 +1857,17 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "UPDATE CONTEXT. The current context does not provide information related to the question.\n", + "UPDATE CONTEXT\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1122 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2398 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_309 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3891 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2087 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_330 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4844 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1122 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2398 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 309 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3891 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2087 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 330 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4844 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the\n", @@ -2562,7 +1890,7 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "I'm sorry, the provided context doesn't contain information about any cricketer being honored with the Wisden Leading Cricketer in the World award for 2016. UPDATE CONTEXT if you have any other query.\n", + "Virat Kohli\n", "\n", "--------------------------------------------------------------------------------\n", "\n", @@ -2570,20 +1898,23 @@ ">>>>>>>>>>>> Below are outputs of Case 5 <<<<<<<<<<<<\n", "\n", "\n", - "doc_ids: [['doc_20', 'doc_2943', 'doc_2059', 'doc_3293', 'doc_4056', 'doc_1914', 'doc_2749', 'doc_1796', 'doc_3468', 'doc_1793', 'doc_876', 'doc_2577', 'doc_27', 'doc_366', 'doc_321', 'doc_3103', 'doc_715', 'doc_3534', 'doc_142', 'doc_5337', 'doc_2426', 'doc_5346', 'doc_3021', 'doc_1596', 'doc_316', 'doc_1103', 'doc_1602', 'doc_1677', 'doc_1670', 'doc_2853']]\n", - "\u001b[32mAdding doc_id doc_20 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2943 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2059 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3293 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_4056 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1914 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2749 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1796 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_3468 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1793 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_876 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_2577 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_27 to context.\u001b[0m\n", + "Trying to use existing collection.\n", + "query: who carried the usa flag in opening ceremony\n", + "doc_ids: [['20', '2943', '2059', '3293', '4056', '1914', '2749', '1796', '3468', '1793', '876', '2577', '27', '2780', '366', '2574', '321', '3103', '715', '3534', '142', '5337', '2426', '5346', '3021', '1596', '316', '2343', '1103', '1602']]\n", + "\u001b[32mAdding doc_id 20 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2943 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2059 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3293 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 4056 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1914 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2749 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1796 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 3468 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 1793 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 876 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2577 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 27 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2780 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the\n", @@ -2606,17 +1937,19 @@ "

    The United States Oath of Allegiance , officially referred to as the `` Oath of Allegiance , '' 8 C.F.R. Part 337 ( 2008 ) , is an allegiance oath that must be taken by all immigrants who wish to become United States citizens .

    \n", "

    During the first half of the 19th century , seven stars were added to the flag to represent the seven signatories to the Venezuelan declaration of independence , being the provinces of Caracas , Cumaná , Barcelona , Barinas , Margarita , Mérida , and Trujillo .

    \n", "

    With the annexation of Hawaii in 1898 and the seizure of Guam and the Philippines during the Spanish -- American War that same year , the United States began to consider unclaimed and uninhabited Wake Island , located approximately halfway between Honolulu and Manila , as a good location for a telegraph cable station and coaling station for refueling warships of the rapidly expanding United States Navy and passing merchant and passenger steamships . On July 4 , 1898 , United States Army Brigadier General Francis V. Greene of the 2nd Brigade , Philippine Expeditionary Force , of the Eighth Army Corps , stopped at Wake Island and raised the American flag while en route to the Philippines on the steamship liner SS China .

    \n", + "

    On Opening Day , April 9 , 1965 , a sold - out crowd of 47,879 watched an exhibition game between the Houston Astros and the New York Yankees . President Lyndon B. Johnson and his wife Lady Bird were in attendance , as well as Texas Governor John Connally and Houston Mayor Louie Welch . Governor Connally tossed out the first ball for the first game ever played indoors . Dick `` Turk '' Farrell of the Astros threw the first pitch . Mickey Mantle had both the first hit ( a single ) and the first home run in the Astrodome . The Astros beat the Yankees that night , 2 - 1 .

    \n", "\n", "\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "I don't have the answer with the provided context. UPDATE CONTEXT.\n", + "UPDATE CONTEXT\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_366 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 366 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 2574 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the\n", @@ -2627,13 +1960,14 @@ "User's question is: who carried the usa flag in opening ceremony\n", "\n", "Context is: \n", + "

    The opening ceremony of the 2018 Winter Olympics was held at the Pyeongchang Olympic Stadium in Pyeongchang , South Korea on 9 February 2018 . It began at 20 : 00 KST and finished at approximately 22 : 20 KST . The Games were officially opened by President of the Republic of Korea Moon Jae - in .

    \n", "\n", "\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "Erin Hamlin carried the USA flag in the opening ceremony.\n", + "Erin Hamlin\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -2675,7 +2009,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -2720,7 +2054,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -2737,19 +2071,19 @@ " \"docs_path\": corpus_file,\n", " \"chunk_token_size\": 2000,\n", " \"model\": config_list[0][\"model\"],\n", - " \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"),\n", " \"collection_name\": \"2wikimultihopqa\",\n", " \"chunk_mode\": \"one_line\",\n", " \"embedding_model\": \"all-MiniLM-L6-v2\",\n", " \"customized_prompt\": PROMPT_MULTIHOP,\n", " \"customized_answer_prefix\": \"the answer is\",\n", + " \"get_or_create\": True\n", " },\n", ")" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -2777,7 +2111,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2789,13 +2123,14 @@ ">>>>>>>>>>>> Below are outputs of Case 1 <<<<<<<<<<<<\n", "\n", "\n", - "Trying to create collection.\n" + "Trying to create index.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ + "[2023-11-26T19:32:19Z WARN lance::dataset] No existing dataset at /Users/ayushchaurasia/autogen/2wikimultihopqa.lance, it will be created\n", "max_tokens is too small to fit a single line of text. Breaking this line:\n", "\tClyde Thompson: Clyde Thompson( 1910 – July 1, 1979) was an American prisoner turned chaplain. He is ...\n", "max_tokens is too small to fit a single line of text. Breaking this line:\n", @@ -2806,77 +2141,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_12', 'doc_11', 'doc_16', 'doc_19', 'doc_13116', 'doc_14', 'doc_13', 'doc_18', 'doc_977', 'doc_10']]\n", - "\u001b[32mAdding doc_id doc_12 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_11 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_16 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_19 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_13116 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_14 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_13 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_18 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_977 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_10 to context.\u001b[0m\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the context provided by the user. You must think step-by-step.\n", - "First, please learn the following examples of context and question pairs and their corresponding answers.\n", - "\n", - "Context:\n", - "Kurram Garhi: Kurram Garhi is a small village located near the city of Bannu, which is the part of Khyber Pakhtunkhwa province of Pakistan. Its population is approximately 35000.\n", - "Trojkrsti: Trojkrsti is a village in Municipality of Prilep, Republic of Macedonia.\n", - "Q: Are both Kurram Garhi and Trojkrsti located in the same country?\n", - "A: Kurram Garhi is located in the country of Pakistan. Trojkrsti is located in the country of Republic of Macedonia. Thus, they are not in the same country. So the answer is: no.\n", - "\n", - "\n", - "Context:\n", - "Early Side of Later: Early Side of Later is the third studio album by English singer- songwriter Matt Goss. It was released on 21 June 2004 by Concept Music and reached No. 78 on the UK Albums Chart.\n", - "What's Inside: What's Inside is the fourteenth studio album by British singer- songwriter Joan Armatrading.\n", - "Q: Which album was released earlier, What'S Inside or Cassandra'S Dream (Album)?\n", - "A: What's Inside was released in the year 1995. Cassandra's Dream (album) was released in the year 2008. Thus, of the two, the album to release earlier is What's Inside. So the answer is: What's Inside.\n", - "\n", - "\n", - "Context:\n", - "Maria Alexandrovna (Marie of Hesse): Maria Alexandrovna , born Princess Marie of Hesse and by Rhine (8 August 1824 – 3 June 1880) was Empress of Russia as the first wife of Emperor Alexander II.\n", - "Grand Duke Alexei Alexandrovich of Russia: Grand Duke Alexei Alexandrovich of Russia,(Russian: Алексей Александрович; 14 January 1850 (2 January O.S.) in St. Petersburg – 14 November 1908 in Paris) was the fifth child and the fourth son of Alexander II of Russia and his first wife Maria Alexandrovna (Marie of Hesse).\n", - "Q: What is the cause of death of Grand Duke Alexei Alexandrovich Of Russia's mother?\n", - "A: The mother of Grand Duke Alexei Alexandrovich of Russia is Maria Alexandrovna. Maria Alexandrovna died from tuberculosis. So the answer is: tuberculosis.\n", - "\n", - "\n", - "Context:\n", - "Laughter in Hell: Laughter in Hell is a 1933 American Pre-Code drama film directed by Edward L. Cahn and starring Pat O'Brien. The film's title was typical of the sensationalistic titles of many Pre-Code films.\n", - "Edward L. Cahn: Edward L. Cahn (February 12, 1899 – August 25, 1963) was an American film director.\n", - "Q: When did the director of film Laughter In Hell die?\n", - "A: The film Laughter In Hell was directed by Edward L. Cahn. Edward L. Cahn died on August 25, 1963. So the answer is: August 25, 1963.\n", - "\n", - "Second, please complete the answer by thinking step-by-step.\n", - "\n", - "Context:\n", - "The Mask of Fu Manchu: The Mask of Fu Manchu is a 1932 pre-Code adventure film directed by Charles Brabin. It was written by Irene Kuhn, Edgar Allan Woolf and John Willard based on the 1932 novel of the same name by Sax Rohmer. Starring Boris Karloff as Fu Manchu, and featuring Myrna Loy as his depraved daughter, the movie revolves around Fu Manchu's quest for the golden sword and mask of Genghis Khan. Lewis Stone plays his nemesis. Dr. Petrie is absent from this film.\n", - "The Mysterious Dr. Fu Manchu: The Mysterious Dr. Fu Manchu is a 1929 American pre-Code drama film directed by Rowland V. Lee and starring Warner Oland as Dr. Fu Manchu. It was the first Fu Manchu film of the talkie era. Since this was during the transition period to sound, a silent version was also released in the United States.\n", - "The Face of Fu Manchu: The Face of Fu Manchu is a 1965 thriller film directed by Don Sharp and based on the characters created by Sax Rohmer. It stars Christopher Lee as the eponymous villain, a Chinese criminal mastermind, and Nigel Green as his pursuing rival Nayland Smith, a Scotland Yard detective. The film was a British- West German co-production, and was the first in a five- part series starring Lee and produced by Harry Alan Towers for Constantin Film, the second of which was\" The Brides of Fu Manchu\" released the next year, with the final entry being\" The Castle of Fu Manchu\" in 1969. It was shot in Technicolor and Techniscope, on- location in County Dublin, Ireland.\n", - "The Return of Dr. Fu Manchu: The Return of Dr. Fu Manchu is a 1930 American pre-Code film directed by Rowland V. Lee. It is the second of three films starring Warner Oland as the fiendish Fu Manchu, who returns from apparent death in the previous film,\" The Mysterious Dr. Fu Manchu\"( 1929), to seek revenge on those he holds responsible for the death of his wife and child.\n", - "The Vengeance of Fu Manchu: The Vengeance of Fu Manchu is a 1967 British film directed by Jeremy Summers and starring Christopher Lee, Horst Frank, Douglas Wilmer and Tsai Chin. It was the third British/ West German Constantin Film co-production of the Dr. Fu Manchu series and the first to be filmed in Hong Kong. It was generally released in the U.K. through Warner- Pathé( as a support feature to the Lindsay Shonteff film\" The Million Eyes of Sumuru\") on 3 December 1967.\n", - "The Brides of Fu Manchu: The Brides of Fu Manchu is a 1966 British/ West German Constantin Film co-production adventure crime film based on the fictional Chinese villain Dr. Fu Manchu, created by Sax Rohmer. It was the second film in a series, and was preceded by\" The Face of Fu ManchuThe Vengeance of Fu Manchu\" followed in 1967,\" The Blood of Fu Manchu\" in 1968, and\" The Castle of Fu Manchu\" in 1969. It was produced by Harry Alan Towers for Hallam Productions. Like the first film, it was directed by Don Sharp, and starred Christopher Lee as Fu Manchu. Nigel Green was replaced by Douglas Wilmer as Scotland Yard detective Nayland Smith. The action takes place mainly in London, where much of the location filming took place.\n", - "The Castle of Fu Manchu: The Castle of Fu Manchu( also known as The Torture Chamber of Dr. Fu Manchu and also known by its German title Die Folterkammer des Dr. Fu Man Chu) is a 1969 film and the fifth and final Dr. Fu Manchu film with Christopher Lee portraying the title character.\n", - "The Blood of Fu Manchu: The Blood of Fu Manchu, also known as Fu Manchu and the Kiss of Death, Kiss of Death, Kiss and Kill( U.S. title) and Against All Odds( original U.S. video title), is a 1968 British adventure crime film directed by Jesús Franco, based on the fictional Asian villain Dr. Fu Manchu created by Sax Rohmer. It was the fourth film in a series, and was preceded by\" The Vengeance of Fu Manchu The Castle of Fu Manchu\" followed in 1969. It was produced by Harry Alan Towers for Udastex Films. It starred Christopher Lee as Dr. Fu Manchu, Richard Greene as Scotland Yard detective Nayland Smith, and Howard Marion- Crawford as Dr. Petrie. The movie was filmed in Spain and Brazil. Shirley Eaton appears in a scene that she claimed she was never paid for; apparently, the director Jesús Franco had inserted some stock footage of her from one of her films(\" The Girl from Rio\"( 1968)) into the film without telling her. She only found out years later that she had been in a Fu Manchu film.\n", - "Don Sharp: Donald Herman Sharp( 19 April 192114 December 2011) was an Australian- born British film director. His best known films were made for Hammer in the 1960s, and included\" The Kiss of the Vampire\"( 1962) and\" Rasputin, the Mad Monk\"( 1966). In 1965 he directed\" The Face of Fu Manchu\", based on the character created by Sax Rohmer, and starring Christopher Lee. Sharp also directed the sequel\" The Brides of Fu Manchu\"( 1966). In the 1980s he was also responsible for several hugely popular miniseries adapted from the novels of Barbara Taylor Bradford.\n", - "Blind Shaft: Blind Shaft is a 2003 film about a pair of brutal con artists operating in the illegal coal mines of present- day northern China. The film was written and directed by Li Yang( 李杨), and is based on Chinese writer Liu Qingbang's short novel\" Shen MuSacred Wood\").\n", - "\n", - "Q: Which film came out first, Blind Shaft or The Mask Of Fu Manchu?\n", - "A:\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[32mAdding doc_id doc_11 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_16 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_19 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_13116 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_14 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_13 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_18 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_977 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_10 to context.\u001b[0m\n", + "Found 57090 chunks.\n", + "query: Which film came out first, Blind Shaft or The Mask Of Fu Manchu?\n", + "doc_ids: [['12', '11', '16', '19', '13116', '14', '13', '18', '977', '10']]\n", + "\u001b[32mAdding doc_id 12 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 11 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 16 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 19 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 13116 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 14 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 13 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 18 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 977 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 10 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the context provided by the user. You must think step-by-step.\n", @@ -2930,7 +2207,7 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "Blind Shaft is a 2003 film while The Mask of Fu Manchu is a 1932 pre-Code adventure film. Thus, The Mask of Fu Manchu came out earlier than Blind Shaft. So the answer is: The Mask of Fu Manchu.\n", + "The film The Mask of Fu Manchu was released in the year 1932. Blind Shaft was released in the year 2003. Thus, of the two, the film to release earlier is The Mask of Fu Manchu. So the answer is: The Mask of Fu Manchu.\n", "\n", "--------------------------------------------------------------------------------\n", "\n", @@ -2938,134 +2215,24 @@ ">>>>>>>>>>>> Below are outputs of Case 2 <<<<<<<<<<<<\n", "\n", "\n", - "doc_ids: [['doc_74', 'doc_76', 'doc_68', 'doc_42890', 'doc_75', 'doc_19596', 'doc_45135', 'doc_995', 'doc_7274', 'doc_23187']]\n", - "\u001b[32mAdding doc_id doc_74 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_76 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_68 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_42890 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_75 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_19596 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_45135 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_995 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_7274 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_23187 to context.\u001b[0m\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the context provided by the user. You must think step-by-step.\n", - "First, please learn the following examples of context and question pairs and their corresponding answers.\n", - "\n", - "Context:\n", - "Kurram Garhi: Kurram Garhi is a small village located near the city of Bannu, which is the part of Khyber Pakhtunkhwa province of Pakistan. Its population is approximately 35000.\n", - "Trojkrsti: Trojkrsti is a village in Municipality of Prilep, Republic of Macedonia.\n", - "Q: Are both Kurram Garhi and Trojkrsti located in the same country?\n", - "A: Kurram Garhi is located in the country of Pakistan. Trojkrsti is located in the country of Republic of Macedonia. Thus, they are not in the same country. So the answer is: no.\n", - "\n", - "\n", - "Context:\n", - "Early Side of Later: Early Side of Later is the third studio album by English singer- songwriter Matt Goss. It was released on 21 June 2004 by Concept Music and reached No. 78 on the UK Albums Chart.\n", - "What's Inside: What's Inside is the fourteenth studio album by British singer- songwriter Joan Armatrading.\n", - "Q: Which album was released earlier, What'S Inside or Cassandra'S Dream (Album)?\n", - "A: What's Inside was released in the year 1995. Cassandra's Dream (album) was released in the year 2008. Thus, of the two, the album to release earlier is What's Inside. So the answer is: What's Inside.\n", - "\n", - "\n", - "Context:\n", - "Maria Alexandrovna (Marie of Hesse): Maria Alexandrovna , born Princess Marie of Hesse and by Rhine (8 August 1824 – 3 June 1880) was Empress of Russia as the first wife of Emperor Alexander II.\n", - "Grand Duke Alexei Alexandrovich of Russia: Grand Duke Alexei Alexandrovich of Russia,(Russian: Алексей Александрович; 14 January 1850 (2 January O.S.) in St. Petersburg – 14 November 1908 in Paris) was the fifth child and the fourth son of Alexander II of Russia and his first wife Maria Alexandrovna (Marie of Hesse).\n", - "Q: What is the cause of death of Grand Duke Alexei Alexandrovich Of Russia's mother?\n", - "A: The mother of Grand Duke Alexei Alexandrovich of Russia is Maria Alexandrovna. Maria Alexandrovna died from tuberculosis. So the answer is: tuberculosis.\n", - "\n", - "\n", - "Context:\n", - "Laughter in Hell: Laughter in Hell is a 1933 American Pre-Code drama film directed by Edward L. Cahn and starring Pat O'Brien. The film's title was typical of the sensationalistic titles of many Pre-Code films.\n", - "Edward L. Cahn: Edward L. Cahn (February 12, 1899 – August 25, 1963) was an American film director.\n", - "Q: When did the director of film Laughter In Hell die?\n", - "A: The film Laughter In Hell was directed by Edward L. Cahn. Edward L. Cahn died on August 25, 1963. So the answer is: August 25, 1963.\n", - "\n", - "Second, please complete the answer by thinking step-by-step.\n", - "\n", - "Context:\n", - "Seoul High School: Seoul High School( Hangul: 서울고등학교) is a public high school located in the heart of Seoul, South Korea.\n", - "North Marion High School (Oregon): North Marion High School is a public high school in Aurora, Oregon, United States. The school is part of the North Marion School District with all four schools being located on the same campus. The school draws students from the cities of Aurora, Hubbard, and Donald as well as the communities of Broadacres and Butteville.\n", - "Marion High School (Kansas): Marion High School is a public high school in Marion, Kansas, USA. It is one of three schools operated by Marion USD 408, and is the sole high school in the district.\n", - "Northwest High School: Northwest High School or North West High School may refer to:\n", - "Marion High School (Indiana): Marion High School is a high school in Marion, Indiana with more than 1,000 students.\n", - "Macon County High School: Macon County High School is located in Montezuma, Georgia, United States, which is a part of Macon County. Enrollment as of the 2017- 2018 school year is 491.\n", - "Canyon High School (Ogden, Utah): Canyon High School was a high school in Ogden, Utah.\n", - "Northside High School: Northside High School or North Side High School or Northside Christian School or similar can refer to:\n", - "Springs Boys' High School: Springs Boys' High School is a high school in Springs, Gauteng, South Africa.\n", - "International School of Koje: International School of Koje( ISK) is a privately funded international school located in Geoje, South Korea.\n", - "\n", - "Q: Are North Marion High School (Oregon) and Seoul High School both located in the same country?\n", - "A:\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", - "\n", - "No, North Marion High School (Oregon) is located in the United States, specifically in the state of Oregon, while Seoul High School is located in South Korea. So they are not in the same country.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n", - "doc_ids: [['doc_76', 'doc_68', 'doc_74', 'doc_75', 'doc_19596', 'doc_42890', 'doc_24819', 'doc_69', 'doc_995', 'doc_7274']]\n", - "\u001b[32mAdding doc_id doc_24819 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_69 to context.\u001b[0m\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the context provided by the user. You must think step-by-step.\n", - "First, please learn the following examples of context and question pairs and their corresponding answers.\n", - "\n", - "Context:\n", - "Kurram Garhi: Kurram Garhi is a small village located near the city of Bannu, which is the part of Khyber Pakhtunkhwa province of Pakistan. Its population is approximately 35000.\n", - "Trojkrsti: Trojkrsti is a village in Municipality of Prilep, Republic of Macedonia.\n", - "Q: Are both Kurram Garhi and Trojkrsti located in the same country?\n", - "A: Kurram Garhi is located in the country of Pakistan. Trojkrsti is located in the country of Republic of Macedonia. Thus, they are not in the same country. So the answer is: no.\n", - "\n", - "\n", - "Context:\n", - "Early Side of Later: Early Side of Later is the third studio album by English singer- songwriter Matt Goss. It was released on 21 June 2004 by Concept Music and reached No. 78 on the UK Albums Chart.\n", - "What's Inside: What's Inside is the fourteenth studio album by British singer- songwriter Joan Armatrading.\n", - "Q: Which album was released earlier, What'S Inside or Cassandra'S Dream (Album)?\n", - "A: What's Inside was released in the year 1995. Cassandra's Dream (album) was released in the year 2008. Thus, of the two, the album to release earlier is What's Inside. So the answer is: What's Inside.\n", - "\n", - "\n", - "Context:\n", - "Maria Alexandrovna (Marie of Hesse): Maria Alexandrovna , born Princess Marie of Hesse and by Rhine (8 August 1824 – 3 June 1880) was Empress of Russia as the first wife of Emperor Alexander II.\n", - "Grand Duke Alexei Alexandrovich of Russia: Grand Duke Alexei Alexandrovich of Russia,(Russian: Алексей Александрович; 14 January 1850 (2 January O.S.) in St. Petersburg – 14 November 1908 in Paris) was the fifth child and the fourth son of Alexander II of Russia and his first wife Maria Alexandrovna (Marie of Hesse).\n", - "Q: What is the cause of death of Grand Duke Alexei Alexandrovich Of Russia's mother?\n", - "A: The mother of Grand Duke Alexei Alexandrovich of Russia is Maria Alexandrovna. Maria Alexandrovna died from tuberculosis. So the answer is: tuberculosis.\n", - "\n", - "\n", - "Context:\n", - "Laughter in Hell: Laughter in Hell is a 1933 American Pre-Code drama film directed by Edward L. Cahn and starring Pat O'Brien. The film's title was typical of the sensationalistic titles of many Pre-Code films.\n", - "Edward L. Cahn: Edward L. Cahn (February 12, 1899 – August 25, 1963) was an American film director.\n", - "Q: When did the director of film Laughter In Hell die?\n", - "A: The film Laughter In Hell was directed by Edward L. Cahn. Edward L. Cahn died on August 25, 1963. So the answer is: August 25, 1963.\n", - "\n", - "Second, please complete the answer by thinking step-by-step.\n", - "\n", - "Context:\n", - "Seoul High School: Seoul High School( Hangul: 서울고등학교) is a public high school located in the heart of Seoul, South Korea.\n", - "North Marion High School (Oregon): North Marion High School is a public high school in Aurora, Oregon, United States. The school is part of the North Marion School District with all four schools being located on the same campus. The school draws students from the cities of Aurora, Hubbard, and Donald as well as the communities of Broadacres and Butteville.\n", - "Marion High School (Kansas): Marion High School is a public high school in Marion, Kansas, USA. It is one of three schools operated by Marion USD 408, and is the sole high school in the district.\n", - "Northwest High School: Northwest High School or North West High School may refer to:\n", - "Marion High School (Indiana): Marion High School is a high school in Marion, Indiana with more than 1,000 students.\n", - "Macon County High School: Macon County High School is located in Montezuma, Georgia, United States, which is a part of Macon County. Enrollment as of the 2017- 2018 school year is 491.\n", - "Canyon High School (Ogden, Utah): Canyon High School was a high school in Ogden, Utah.\n", - "Northside High School: Northside High School or North Side High School or Northside Christian School or similar can refer to:\n", - "Springs Boys' High School: Springs Boys' High School is a high school in Springs, Gauteng, South Africa.\n", - "International School of Koje: International School of Koje( ISK) is a privately funded international school located in Geoje, South Korea.\n", - "Anderson High School (Anderson, Indiana): Anderson High School is a public high school located in Anderson, Indiana.\n", - "North Marion High School (West Virginia): North Marion High School is a public Double A (\"AA\") high school in the U.S. state of West Virginia, with a current enrollment of 851 students. North Marion High School is located approximately 4 miles from Farmington, West Virginia on US Route 250 north. While it is closer to the city of Mannington, West Virginia, and is often considered to be located in Rachel, West Virginia, the school mailing address is Farmington. Rachel is a small coal mining community located adjacent to the school, and is an unincorporated municipality. North Marion High School is represented as \"Grantville High School\" in the popular alternative history novel \"1632\" by writer Eric Flint. The novel is set in the fictional town of Grantville, which is based on the real town and surroundings of Mannington.\n", - "Q: Are North Marion High School (Oregon) and Seoul High School both located in the same country?\n", - "A:\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", - "\n", - "North Marion High School (Oregon) is located in the country of United States. Seoul High School is located in the country of South Korea. Thus, they are not in the same country. So the answer is: no.\n", - "\n", - "--------------------------------------------------------------------------------\n" + "Trying to create index.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_tokens is too small to fit a single line of text. Breaking this line:\n", + "\tClyde Thompson: Clyde Thompson( 1910 – July 1, 1979) was an American prisoner turned chaplain. He is ...\n", + "max_tokens is too small to fit a single line of text. Breaking this line:\n", + "\tAustralian Historical Monographs: The Australian Historical Monographs are a series of Historical st ...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 57090 chunks.\n" ] } ], @@ -3079,6 +2246,13 @@ " qa_problem = questions[i]\n", " ragproxyagent.initiate_chat(assistant, problem=qa_problem, n_results=10)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -3097,7 +2271,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/setup.py b/setup.py index 042e4080da9..465d32e8268 100644 --- a/setup.py +++ b/setup.py @@ -49,8 +49,8 @@ ], "blendsearch": ["flaml[blendsearch]"], "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"], - "retrievechat": ["chromadb", "lancedb", "sentence_transformers", "pypdf", "ipython"], - "teachable": ["chromadb", "lancedb"], + "retrievechat": ["lancedb", "sentence_transformers", "pypdf", "ipython"], + "teachable": ["lancedb"], "lmm": ["replicate", "pillow"], "graphs": ["networkx~=3.2.1", "matplotlib~=3.8.1"], }, diff --git a/test/agentchat/contrib/retrievers/test_chromadb.py b/test/agentchat/contrib/retrievers/test_chromadb.py index f7325051725..e194ffd879f 100644 --- a/test/agentchat/contrib/retrievers/test_chromadb.py +++ b/test/agentchat/contrib/retrievers/test_chromadb.py @@ -1,5 +1,6 @@ import os import pytest +from pathlib import Path from autogen.agentchat.contrib.retriever.retrieve_utils import ( split_text_to_chunks, extract_text_from_pdf, @@ -20,16 +21,24 @@ @pytest.mark.skipif(skip, reason="chromadb is not installed") -def test_chromadb(): - db_path = "/tmp/test_retrieve_utils_chromadb.db" - client = chromadb.PersistentClient(path=db_path) - if os.path.exists(db_path): - vectorstore = ChromaDB(path=db_path, use_existing=True) - else: - vectorstore = ChromaDB(path=db_path) +def test_chromadb(tmpdir): + # Test index creation and querying + client = chromadb.PersistentClient(path=tmpdir) + vectorstore = ChromaDB(path=tmpdir) + vectorstore.ingest_data(test_dir) assert client.get_collection("vectorstore") results = vectorstore.query(["autogen"]) assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) + + # Test index_exists() + db_path = "/tmp/test_retrieve_utils_chromadb.db" + vectorstore = ChromaDB(path=db_path) + assert vectorstore.index_exists() + + # Test use_existing_index() + assert vectorstore.collection is None + vectorstore.use_existing_index() + assert vectorstore.collection is not None diff --git a/test/agentchat/contrib/retrievers/test_lancedb.py b/test/agentchat/contrib/retrievers/test_lancedb.py index 5eb82eab041..f80049741f8 100644 --- a/test/agentchat/contrib/retrievers/test_lancedb.py +++ b/test/agentchat/contrib/retrievers/test_lancedb.py @@ -24,7 +24,7 @@ def test_lancedb(): db_path = "/tmp/test_lancedb_store" db = lancedb.connect(db_path) if os.path.exists(db_path): - vectorstore = LanceDB(path=db_path, use_existing=True) + vectorstore = LanceDB(path=db_path) else: vectorstore = LanceDB(path=db_path) vectorstore.ingest_data(test_dir) @@ -33,3 +33,12 @@ def test_lancedb(): results = vectorstore.query(["autogen"]) assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) + + # Test index_exists() + vectorstore = LanceDB(path=db_path) + assert vectorstore.index_exists() + + # Test use_existing_index() + assert vectorstore.table is None + vectorstore.use_existing_index() + assert vectorstore.table is not None From 055e7d7716121dfac5c053460dee6254efc4ad94 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 27 Nov 2023 01:41:48 +0530 Subject: [PATCH 16/52] update --- autogen/agentchat/contrib/retriever/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autogen/agentchat/contrib/retriever/__init__.py b/autogen/agentchat/contrib/retriever/__init__.py index 9bfb09598a6..f4e40217f26 100644 --- a/autogen/agentchat/contrib/retriever/__init__.py +++ b/autogen/agentchat/contrib/retriever/__init__.py @@ -1,6 +1,4 @@ from typing import Optional -from .chromadb import ChromaDB -from .lancedb import LanceDB AVILABLE_RETRIEVERS = ["lanchedb", "chromadb"] DEFAULT_RETRIEVER = "lancedb" @@ -10,8 +8,12 @@ def get_retriever(type: Optional[str] = None): """Return a retriever instance.""" type = type or DEFAULT_RETRIEVER if type == "chromadb": + from .chromadb import ChromaDB + return ChromaDB elif type == "lancedb": + from .lancedb import LanceDB + return LanceDB else: raise ValueError(f"Unknown retriever type {type}") From efc64f855f3a30b27a665e6c3893bca7e0e9b167 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 28 Nov 2023 23:43:48 +0530 Subject: [PATCH 17/52] Update autogen/agentchat/contrib/retriever/base.py Co-authored-by: Li Jiang --- autogen/agentchat/contrib/retriever/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/retriever/base.py b/autogen/agentchat/contrib/retriever/base.py index 56a0c64f1b3..6cbe9abf4fd 100644 --- a/autogen/agentchat/contrib/retriever/base.py +++ b/autogen/agentchat/contrib/retriever/base.py @@ -26,7 +26,8 @@ def __init__( embedding_function: function to use to embed the text max_tokens: maximum number of tokens to embed chunk_mode: mode to chunk the text. Can be "multi_lines" or "single_line" - must_break_at_empty_line: whether to break the text at empty lines when chunking + must_break_at_empty_line: chunk will only break at empty line if True. Default is True. + If chunk_mode is "one_line", this parameter will be ignored. custom_text_split_function: custom function to split the text into chunks """ self.path = path From 4f90d318bb290e4d5c3990f72862b6c4e1ed9bf4 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 28 Nov 2023 23:45:42 +0530 Subject: [PATCH 18/52] Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py Co-authored-by: Li Jiang --- autogen/agentchat/contrib/retrieve_user_proxy_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 267c3dcf6ea..34f38388261 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -120,7 +120,7 @@ def __init__( - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. - get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat. This is the same as that used in retriever. Default is False. Will be set to False if docs_path is None. - - custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string. + - custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string. The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function. Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models. - custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings. From 10f1b233486741a5fb3e63390c92443dcd76c14d Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 28 Nov 2023 23:46:30 +0530 Subject: [PATCH 19/52] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 465d32e8268..3e4e92323b9 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ "blendsearch": ["flaml[blendsearch]"], "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"], "retrievechat": ["lancedb", "sentence_transformers", "pypdf", "ipython"], - "teachable": ["lancedb"], + "teachable": ["chromadb"], "lmm": ["replicate", "pillow"], "graphs": ["networkx~=3.2.1", "matplotlib~=3.8.1"], }, From 8a686407307732730bbeadb3bf418842f9334ffc Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 28 Nov 2023 23:47:14 +0530 Subject: [PATCH 20/52] Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py Co-authored-by: Li Jiang --- .../contrib/retrieve_user_proxy_agent.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 34f38388261..4ec58274f7e 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -379,15 +379,16 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = custom_text_types=self._custom_text_types, recursive=self._recursive, ) - if not self.retriever.index_exists() or not self._get_or_create: - print("Trying to create index.") # TODO: logger - self.retriever.ingest_data(self._docs_path) - elif self._get_or_create: - if self.retriever.index_exists(): - print("Trying to use existing collection.") # TODO: logger - self.retriever.use_existing_index() + if not self.retriever.index_exists() or self._get_or_create: + if not self.retriever.index_exists(): + print("Trying to create index.") # TODO: logger + self.retriever.ingest_data(self._docs_path, overwrite=False) else: - raise Exception("Requested to use existing index but it is not found!") + print("Trying to recreate index.") # TODO: logger + self.retriever.ingest_data(self._docs_path, overwrite=True) + else: + print("Trying to use existing collection.") # TODO: logger + self.retriever.use_existing_index() results = self.retriever.query( texts=[problem], From 1d0695553bff62c457e071fceafcba7820d9c406 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Thu, 30 Nov 2023 21:01:42 +0530 Subject: [PATCH 21/52] update --- .../contrib/retrieve_user_proxy_agent.py | 64 ++++++++-- autogen/agentchat/contrib/retriever/base.py | 4 +- .../agentchat/contrib/retriever/chromadb.py | 4 +- .../agentchat/contrib/retriever/lancedb.py | 5 +- notebook/agentchat_RetrieveChat.ipynb | 115 +++++++++++++++--- 5 files changed, 162 insertions(+), 30 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 4ec58274f7e..3ad5c336785 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -6,6 +6,7 @@ from autogen.token_count_utils import count_token from autogen.code_utils import extract_code from autogen.agentchat.contrib.retriever import get_retriever +from autogen import logger from typing import Callable, Dict, Optional, Union, List, Tuple, Any from IPython import get_ipython @@ -118,7 +119,10 @@ def __init__( - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. - - get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat. + - db_mode (Optional, str): the mode to create the vector db. Possible values are "get", "recreate", "create". Default is "recreate" to + keep the workflow less error-prone. If "get", will try to get an existing collection. If "recreate", will recreate a collection + if the collection already exists. If "create", will create a collection if the collection doesn't exist. + - get_or_create (Optional, bool): [Depricated]if True, will create/recreate a collection for the retrieve chat. This is the same as that used in retriever. Default is False. Will be set to False if docs_path is None. - custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string. The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function. @@ -179,7 +183,6 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self.customized_prompt = self._retrieve_config.get("customized_prompt", None) self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() self.update_context = self._retrieve_config.get("update_context", True) - self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token) self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None) self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS) @@ -193,6 +196,26 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc self._search_string = "" # the search string used in the current query + self._db_mode = self._retrieve_config.get("db_mode") + self._get_or_create = self._retrieve_config.get("get_or_create") + if self._db_mode and self._get_or_create: + logger.warning( + colored( + "Warning: db_mode and get_or_create are both set. get_or_create will be ignored. get_or_create is depricated", + "yellow", + ) + ) + self._get_or_create = None + elif self._db_mode is None and self._get_or_create is None: # if both not set, set db_mode's default value + self._db_mode = "recreate" + elif self._get_or_create: + logger.warning( + colored( + "Warning: get_or_create is depricated and will be removed from future versions. Use `db_mode` instead", + "yellow", + ) + ) + # update the termination message function self._is_termination_msg = ( self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg @@ -379,16 +402,37 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = custom_text_types=self._custom_text_types, recursive=self._recursive, ) - if not self.retriever.index_exists() or self._get_or_create: - if not self.retriever.index_exists(): - print("Trying to create index.") # TODO: logger + if self._db_mode: + if self._db_mode not in ["get", "recreate", "create"]: + raise ValueError( + f"db_mode {self._db_mode} is not supported. Possible values are 'get', 'recreate', 'create'." + ) + if self._db_mode == "get": + if ( + not self.retriever.index_exists + ): # warn users if the index doesn't exist. Maybe we can even raise here + raise ValueError("The index doesn't exist. Please set db_mode to 'recreate' or 'create'.") + self.retriever.use_existing_index() + elif self._db_mode == "recreate": + logger.info("Trying to create index. If the index already exists, it will be recreated.") + self.retriever.ingest_data(self._docs_path, overwrite=True) + elif self._db_mode == "create": + logger.info("Trying to create index.") + if self.retriever.index_exists: + raise ValueError("The index already exists. Please set db_mode to 'get' or 'recreate'.") self.retriever.ingest_data(self._docs_path, overwrite=False) + + elif self._get_or_create: + if not self.retriever.index_exists or self._get_or_create: + if not self.retriever.index_exists: + logger.info("Trying to create index.") + self.retriever.ingest_data(self._docs_path, overwrite=False) + else: + logger.info("Trying to recreate index.") + self.retriever.ingest_data(self._docs_path, overwrite=True) else: - print("Trying to recreate index.") # TODO: logger - self.retriever.ingest_data(self._docs_path, overwrite=True) - else: - print("Trying to use existing collection.") # TODO: logger - self.retriever.use_existing_index() + logger.info("Trying to use existing collection.") + self.retriever.use_existing_index() results = self.retriever.query( texts=[problem], diff --git a/autogen/agentchat/contrib/retriever/base.py b/autogen/agentchat/contrib/retriever/base.py index 6cbe9abf4fd..63d28ae8996 100644 --- a/autogen/agentchat/contrib/retriever/base.py +++ b/autogen/agentchat/contrib/retriever/base.py @@ -43,11 +43,12 @@ def __init__( self.init_db() @abstractmethod - def ingest_data(self, data_dir): + def ingest_data(self, data_dir, overwrite: bool = False): """ Create a vector database from a directory of files. Args: data_dir: path to the directory containing the text files + overwrite: overwrite the existing database if True """ pass @@ -75,6 +76,7 @@ def init_db(self): """ pass + @property @abstractmethod def index_exists(self): """ diff --git a/autogen/agentchat/contrib/retriever/chromadb.py b/autogen/agentchat/contrib/retriever/chromadb.py index bf1deaf6605..e0c0a388aa3 100644 --- a/autogen/agentchat/contrib/retriever/chromadb.py +++ b/autogen/agentchat/contrib/retriever/chromadb.py @@ -25,7 +25,7 @@ def init_db(self): ) self.collection = None - def ingest_data(self, data_dir): + def ingest_data(self, data_dir, overwrite: bool = False): """ Create a vector database from a directory of files. Args: @@ -35,6 +35,7 @@ def ingest_data(self, data_dir): self.collection = self.client.create_collection( self.name, embedding_function=self.embedding_function, + get_or_create=overwrite, # https://github.com/nmslib/hnswlib#supported-distances # https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184 # https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md @@ -74,6 +75,7 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): ) return results + @property def index_exists(self): try: self.client.get_collection(name=self.name, embedding_function=self.embedding_function) diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index b2502ee18b0..53724ca840d 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -24,14 +24,14 @@ def init_db(self): else self.embedding_function ) - def ingest_data(self, data_dir): + def ingest_data(self, data_dir, overwrite: bool = False): """ Create a vector database from a directory of files. Args: data_dir: path to the directory containing the text files """ schema = self._get_schema(self.embedding_function) - self.table = self.db.create_table(self.name, schema=schema, mode="overwrite") + self.table = self.db.create_table(self.name, schema=schema, mode="overwrite" if overwrite else "create") if self.custom_text_split_function is not None: chunks = split_files_to_chunks( @@ -70,6 +70,7 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): return results + @property def index_exists(self): return self.name in self.db.table_names() diff --git a/notebook/agentchat_RetrieveChat.ipynb b/notebook/agentchat_RetrieveChat.ipynb index 5638b5e59a9..34eecfe7b71 100644 --- a/notebook/agentchat_RetrieveChat.ipynb +++ b/notebook/agentchat_RetrieveChat.ipynb @@ -178,13 +178,12 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent\n", "from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent\n", - "import chromadb\n", "import os\n", "\n", "# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n", @@ -224,7 +223,7 @@ " \"chunk_token_size\": 2000,\n", " \"model\": config_list[0][\"model\"],\n", " \"embedding_model\": \"all-mpnet-base-v2\",\n", - " \"get_or_create\": True, # set to False if you don't want to reuse an existing collection, but you'll need to remove the collection manually\n", + " \"db_mode\": \"recreate\", # \"get\", \"create\", \"recreate\".\n", " },\n", " code_execution_config=False, # set to False if you don't want to execute the code\n", ")" @@ -247,29 +246,29 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Trying to use existing collection.\n", - "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n" + "Trying to create index. If the index already exists, it will be recreated.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "/Users/ayushchaurasia/Documents/autogen/autogen/env/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + "File /Users/ayushchaurasia/Documents/autogen/autogen/notebook/../website/docs does not exist. Skipping.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "Found 2 chunks.\n", + "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", "doc_ids: [['0']]\n", "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", @@ -469,16 +468,72 @@ "\n", "--------------------------------------------------------------------------------\n", "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n", - "Trying to use existing collection.\n", + "Trying to create index. If the index already exists, it will be recreated.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "File /Users/ayushchaurasia/Documents/autogen/autogen/notebook/../website/docs does not exist. Skipping.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 chunks.\n", "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", "doc_ids: [['0']]\n", - "Trying to use existing collection.\n", + "Trying to create index. If the index already exists, it will be recreated.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "File /Users/ayushchaurasia/Documents/autogen/autogen/notebook/../website/docs does not exist. Skipping.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 chunks.\n", "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", "doc_ids: [['0']]\n", - "Trying to use existing collection.\n", + "Trying to create index. If the index already exists, it will be recreated.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "File /Users/ayushchaurasia/Documents/autogen/autogen/notebook/../website/docs does not exist. Skipping.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 chunks.\n", "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", "doc_ids: [['0']]\n", - "Trying to use existing collection.\n", + "Trying to create index. If the index already exists, it will be recreated.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "File /Users/ayushchaurasia/Documents/autogen/autogen/notebook/../website/docs does not exist. Skipping.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 chunks.\n", "query: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", "doc_ids: [['0']]\n", "\u001b[32mNo more context, will terminate.\u001b[0m\n", @@ -519,14 +574,28 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Trying to use existing collection.\n", + "Trying to create index. If the index already exists, it will be recreated.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "File /Users/ayushchaurasia/Documents/autogen/autogen/notebook/../website/docs does not exist. Skipping.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 chunks.\n", "query: Who is the author of FLAML?\n", "doc_ids: [['0', '1']]\n", "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", @@ -816,14 +885,28 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Trying to use existing collection.\n", + "Trying to create index. If the index already exists, it will be recreated.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "File /Users/ayushchaurasia/Documents/autogen/autogen/notebook/../website/docs does not exist. Skipping.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 chunks.\n", "query: how to build a time series forecasting model for stock price using FLAML?\n", "doc_ids: [['0', '1']]\n", "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", From 01d305fb869ff2d0f36beb0c546280fc617000aa Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Thu, 30 Nov 2023 21:20:16 +0530 Subject: [PATCH 22/52] update tests --- test/agentchat/contrib/retrievers/test_chromadb.py | 5 ++--- test/agentchat/contrib/retrievers/test_lancedb.py | 14 +++++--------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/test/agentchat/contrib/retrievers/test_chromadb.py b/test/agentchat/contrib/retrievers/test_chromadb.py index e194ffd879f..b2fbcbf5bd8 100644 --- a/test/agentchat/contrib/retrievers/test_chromadb.py +++ b/test/agentchat/contrib/retrievers/test_chromadb.py @@ -34,9 +34,8 @@ def test_chromadb(tmpdir): assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) # Test index_exists() - db_path = "/tmp/test_retrieve_utils_chromadb.db" - vectorstore = ChromaDB(path=db_path) - assert vectorstore.index_exists() + vectorstore = ChromaDB(path=tmpdir) + assert vectorstore.index_exists # Test use_existing_index() assert vectorstore.collection is None diff --git a/test/agentchat/contrib/retrievers/test_lancedb.py b/test/agentchat/contrib/retrievers/test_lancedb.py index f80049741f8..eefdd7cd1f4 100644 --- a/test/agentchat/contrib/retrievers/test_lancedb.py +++ b/test/agentchat/contrib/retrievers/test_lancedb.py @@ -20,13 +20,9 @@ @pytest.mark.skipif(skip, reason="lancedb is not installed") -def test_lancedb(): - db_path = "/tmp/test_lancedb_store" - db = lancedb.connect(db_path) - if os.path.exists(db_path): - vectorstore = LanceDB(path=db_path) - else: - vectorstore = LanceDB(path=db_path) +def test_lancedb(tmpdir): + db = lancedb.connect(str(tmpdir)) + vectorstore = LanceDB(path=str(tmpdir)) vectorstore.ingest_data(test_dir) assert "vectorstore" in db.table_names() @@ -35,8 +31,8 @@ def test_lancedb(): assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) # Test index_exists() - vectorstore = LanceDB(path=db_path) - assert vectorstore.index_exists() + vectorstore = LanceDB(path=str(tmpdir)) + assert vectorstore.index_exists # Test use_existing_index() assert vectorstore.table is None From b4cd6c41bb610c28f9f999a8de68a872b039c493 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Thu, 30 Nov 2023 21:30:09 +0530 Subject: [PATCH 23/52] move retrieve utils --- .../agentchat/contrib/retriever/__init__.py | 20 +----------------- .../contrib/retriever/retrieve_utils.py | 21 ++++++++++++++++++- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/autogen/agentchat/contrib/retriever/__init__.py b/autogen/agentchat/contrib/retriever/__init__.py index f4e40217f26..389bed28dd0 100644 --- a/autogen/agentchat/contrib/retriever/__init__.py +++ b/autogen/agentchat/contrib/retriever/__init__.py @@ -1,19 +1 @@ -from typing import Optional - -AVILABLE_RETRIEVERS = ["lanchedb", "chromadb"] -DEFAULT_RETRIEVER = "lancedb" - - -def get_retriever(type: Optional[str] = None): - """Return a retriever instance.""" - type = type or DEFAULT_RETRIEVER - if type == "chromadb": - from .chromadb import ChromaDB - - return ChromaDB - elif type == "lancedb": - from .lancedb import LanceDB - - return LanceDB - else: - raise ValueError(f"Unknown retriever type {type}") +from .retrieve_utils import get_retriever diff --git a/autogen/agentchat/contrib/retriever/retrieve_utils.py b/autogen/agentchat/contrib/retriever/retrieve_utils.py index b02ab66d6a3..5ed2304fc82 100644 --- a/autogen/agentchat/contrib/retriever/retrieve_utils.py +++ b/autogen/agentchat/contrib/retriever/retrieve_utils.py @@ -1,4 +1,4 @@ -from typing import List, Union, Callable +from typing import List, Union, Callable, Optional import os import requests from urllib.parse import urlparse @@ -209,3 +209,22 @@ def is_url(string: str): return all([result.scheme, result.netloc]) except ValueError: return False + + +AVILABLE_RETRIEVERS = ["lanchedb", "chromadb"] +DEFAULT_RETRIEVER = "lancedb" + + +def get_retriever(type: Optional[str] = None): + """Return a retriever instance.""" + type = type or DEFAULT_RETRIEVER + if type == "chromadb": + from .chromadb import ChromaDB + + return ChromaDB + elif type == "lancedb": + from .lancedb import LanceDB + + return LanceDB + else: + raise ValueError(f"Unknown retriever type {type}") From 9beb1be52d22da34068dc19052a071ff33b7fbed Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Thu, 30 Nov 2023 21:34:37 +0530 Subject: [PATCH 24/52] make qdrant work --- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index e45d1fa6aa1..9a8b141cd30 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -1,7 +1,7 @@ from typing import Callable, Dict, List, Optional from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent -from autogen.agentchat.contrib.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS +from autogen.agentchat.contrib.retriever.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS import logging logger = logging.getLogger(__name__) From f1ccc4be4f5a7e2c7499100b478e845fa03183a9 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Thu, 30 Nov 2023 21:59:23 +0530 Subject: [PATCH 25/52] update test dir --- test/agentchat/contrib/retrievers/test_chromadb.py | 2 +- test/agentchat/contrib/retrievers/test_lancedb.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/agentchat/contrib/retrievers/test_chromadb.py b/test/agentchat/contrib/retrievers/test_chromadb.py index b2fbcbf5bd8..3888f115d8d 100644 --- a/test/agentchat/contrib/retrievers/test_chromadb.py +++ b/test/agentchat/contrib/retrievers/test_chromadb.py @@ -17,7 +17,7 @@ else: skip = False -test_dir = os.path.join(os.path.dirname(__file__), "test_files") +test_dir = Path(__file__).parent.parent.parent.parent / "test_files" @pytest.mark.skipif(skip, reason="chromadb is not installed") diff --git a/test/agentchat/contrib/retrievers/test_lancedb.py b/test/agentchat/contrib/retrievers/test_lancedb.py index eefdd7cd1f4..bac3031fd8c 100644 --- a/test/agentchat/contrib/retrievers/test_lancedb.py +++ b/test/agentchat/contrib/retrievers/test_lancedb.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import pytest from autogen.agentchat.contrib.retriever.retrieve_utils import ( split_text_to_chunks, @@ -16,14 +17,15 @@ else: skip = False -test_dir = os.path.join(os.path.dirname(__file__), "test_files") +# test_dir is 2 directories above this file +test_dir = Path(__file__).parent.parent.parent.parent / "test_files" @pytest.mark.skipif(skip, reason="lancedb is not installed") def test_lancedb(tmpdir): db = lancedb.connect(str(tmpdir)) vectorstore = LanceDB(path=str(tmpdir)) - vectorstore.ingest_data(test_dir) + vectorstore.ingest_data(str(test_dir)) assert "vectorstore" in db.table_names() From 3df587374508f4d7dc694a42eecce64138fd6d2a Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 11 Dec 2023 17:59:41 +0530 Subject: [PATCH 26/52] Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py Co-authored-by: Li Jiang --- autogen/agentchat/contrib/retrieve_user_proxy_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index e0d61e4798c..c66e68fac1a 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -123,7 +123,7 @@ def __init__( - db_mode (Optional, str): the mode to create the vector db. Possible values are "get", "recreate", "create". Default is "recreate" to keep the workflow less error-prone. If "get", will try to get an existing collection. If "recreate", will recreate a collection if the collection already exists. If "create", will create a collection if the collection doesn't exist. - - get_or_create (Optional, bool): [Depricated]if True, will create/recreate a collection for the retrieve chat. + - get_or_create (Optional, bool): [Depricated] if True, will create/recreate a collection for the retrieve chat. This is the same as that used in retriever. Default is False. Will be set to False if docs_path is None. - custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string. The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function. From dad693bf60d39822f01652e3b3f713ae1a51d312 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 11 Dec 2023 18:00:04 +0530 Subject: [PATCH 27/52] Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py Co-authored-by: Li Jiang --- .../agentchat/contrib/retrieve_user_proxy_agent.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index c66e68fac1a..892f914ab87 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -429,17 +429,13 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = raise ValueError("The index already exists. Please set db_mode to 'get' or 'recreate'.") self.retriever.ingest_data(self._docs_path, overwrite=False) - elif self._get_or_create: - if not self.retriever.index_exists or self._get_or_create: - if not self.retriever.index_exists: - logger.info("Trying to create index.") - self.retriever.ingest_data(self._docs_path, overwrite=False) - else: - logger.info("Trying to recreate index.") - self.retriever.ingest_data(self._docs_path, overwrite=True) - else: + elif self._get_or_create is not None: + if self._get_or_create and self.retriever.index_exists: logger.info("Trying to use existing collection.") self.retriever.use_existing_index() + else: + logger.info("Trying to create index.") + self.retriever.ingest_data(self._docs_path, overwrite=False) results = self.retriever.query( texts=[problem], From b3322fd310cdfa6508ada5bd6ad9fb74de2a7833 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 11 Dec 2023 18:02:50 +0530 Subject: [PATCH 28/52] Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py Co-authored-by: Li Jiang --- autogen/agentchat/contrib/retrieve_user_proxy_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 892f914ab87..42ec21b4f73 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -205,7 +205,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self._search_string = "" # the search string used in the current query self._db_mode = self._retrieve_config.get("db_mode") self._get_or_create = self._retrieve_config.get("get_or_create") - if self._db_mode and self._get_or_create: + if self._db_mode is not None and self._get_or_create is not None: logger.warning( colored( "Warning: db_mode and get_or_create are both set. get_or_create will be ignored. get_or_create is depricated", From a1a785747bbd63614eca416b5934fe752fc39760 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 08:59:29 +0530 Subject: [PATCH 29/52] upadte testing --- .../contrib/retrieve_user_proxy_agent.py | 6 +-- .../agentchat/contrib/retriever/chromadb.py | 2 + .../agentchat/contrib/retriever/lancedb.py | 6 +-- .../contrib/retrievers/test_chromadb.py | 5 ++- .../contrib/retrievers/test_lancedb.py | 12 +++++- .../contrib/retrievers/test_utils.py | 41 +++++++++++++++++++ 6 files changed, 64 insertions(+), 8 deletions(-) create mode 100644 test/agentchat/contrib/retrievers/test_utils.py diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 42ec21b4f73..3eda225fb63 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -92,6 +92,8 @@ def __init__( The dict can contain the following keys: "content", "role", "name", "function_call". retrieve_config (dict or None): config for the retrieve agent. To use default config, set to None. Otherwise, set to a dictionary with the following keys: + - retriever_type (Optional, str): the type of the retriever. + - retriever_path (Optional, str): the path to use for retriever-realted operations. Default is `~/autogen`. - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System prompt will be different for different tasks. The default value is `default`, which supports both code and qa. - client (Optional, Any): the vectordb client/connection. If key not provided, the Retreiver class should handle it. @@ -415,9 +417,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = f"db_mode {self._db_mode} is not supported. Possible values are 'get', 'recreate', 'create'." ) if self._db_mode == "get": - if ( - not self.retriever.index_exists - ): # warn users if the index doesn't exist. Maybe we can even raise here + if not self.retriever.index_exists: raise ValueError("The index doesn't exist. Please set db_mode to 'recreate' or 'create'.") self.retriever.use_existing_index() elif self._db_mode == "recreate": diff --git a/autogen/agentchat/contrib/retriever/chromadb.py b/autogen/agentchat/contrib/retriever/chromadb.py index e0c0a388aa3..59295088fe3 100644 --- a/autogen/agentchat/contrib/retriever/chromadb.py +++ b/autogen/agentchat/contrib/retriever/chromadb.py @@ -31,6 +31,8 @@ def ingest_data(self, data_dir, overwrite: bool = False): Args: data_dir: path to the directory containing the text files """ + if overwrite is True and self.index_exists: + self.client.delete_collection(name=self.name) self.collection = self.client.create_collection( self.name, diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index 53724ca840d..e63a6ee5db5 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -47,7 +47,7 @@ def ingest_data(self, data_dir, overwrite: bool = False): self.table.add(data) elif isinstance(self.embedding_function, Callable): pa_table = pa.Table.from_pylist(data) - data = with_embeddings(self.embedding_function, pa_table) + data = with_embeddings(self.embedding_function, pa_table, column="documents") self.table.add(data) def use_existing_index(self): @@ -84,10 +84,10 @@ class Schema(LanceModel): return Schema elif isinstance(embedding_function, Callable): - dim = embedding_function("test").shape[0] # TODO: check this + dim = embedding_function("test")[0].shape[0] # TODO: check this schema = pa.schema( [ - pa.field("Vector", pa.list_(pa.float32(), dim)), + pa.field("vector", pa.list_(pa.float32(), dim)), pa.field("documents", pa.string()), pa.field("ids", pa.string()), ] diff --git a/test/agentchat/contrib/retrievers/test_chromadb.py b/test/agentchat/contrib/retrievers/test_chromadb.py index 3888f115d8d..6e8b37a8c8d 100644 --- a/test/agentchat/contrib/retrievers/test_chromadb.py +++ b/test/agentchat/contrib/retrievers/test_chromadb.py @@ -26,7 +26,7 @@ def test_chromadb(tmpdir): client = chromadb.PersistentClient(path=tmpdir) vectorstore = ChromaDB(path=tmpdir) - vectorstore.ingest_data(test_dir) + vectorstore.ingest_data(str(test_dir)) assert client.get_collection("vectorstore") @@ -41,3 +41,6 @@ def test_chromadb(tmpdir): assert vectorstore.collection is None vectorstore.use_existing_index() assert vectorstore.collection is not None + + vectorstore.ingest_data(str(test_dir), overwrite=True) + vectorstore.query(["hello"]) diff --git a/test/agentchat/contrib/retrievers/test_lancedb.py b/test/agentchat/contrib/retrievers/test_lancedb.py index bac3031fd8c..66868f4b05f 100644 --- a/test/agentchat/contrib/retrievers/test_lancedb.py +++ b/test/agentchat/contrib/retrievers/test_lancedb.py @@ -1,4 +1,4 @@ -import os +import numpy as np from pathlib import Path import pytest from autogen.agentchat.contrib.retriever.retrieve_utils import ( @@ -21,6 +21,10 @@ test_dir = Path(__file__).parent.parent.parent.parent / "test_files" +def embedding_fcn(texts): + return [np.array([0, 0]) for _ in texts] + + @pytest.mark.skipif(skip, reason="lancedb is not installed") def test_lancedb(tmpdir): db = lancedb.connect(str(tmpdir)) @@ -40,3 +44,9 @@ def test_lancedb(tmpdir): assert vectorstore.table is None vectorstore.use_existing_index() assert vectorstore.table is not None + + vectorstore.ingest_data(str(test_dir), overwrite=True) + vectorstore.query(["hello"]) + + vectorstore = LanceDB(path=str(tmpdir), embedding_function=embedding_fcn) + vectorstore.ingest_data(str(test_dir), overwrite=True) diff --git a/test/agentchat/contrib/retrievers/test_utils.py b/test/agentchat/contrib/retrievers/test_utils.py new file mode 100644 index 00000000000..9bcd929716b --- /dev/null +++ b/test/agentchat/contrib/retrievers/test_utils.py @@ -0,0 +1,41 @@ +from pathlib import Path +from autogen.agentchat.contrib.retriever.retrieve_utils import ( + split_text_to_chunks, + split_files_to_chunks, + get_file_from_url, + get_files_from_dir, + is_url, +) + +test_dir = Path(__file__).parent.parent.parent.parent / "test_files" + + +def test_split_text_to_chunks(): + text = "Hello, World! This is a test of the split_text_to_chunks() function." + chunks = split_text_to_chunks(text) + assert len(chunks) == 1 + chunks = split_text_to_chunks(text, max_tokens=10) + assert len(chunks) == 2 + + +def test_split_files_to_chunks(): + files = [test_dir / "example.txt"] + chunks = split_files_to_chunks(files) + assert len(chunks) == 1 + chunks = split_files_to_chunks(files, max_tokens=50) + assert len(chunks) == 2 + + +def test_get_files_from_dir(): + files = get_files_from_dir(str(test_dir)) + assert len(files) == 8 + + +def test_is_url(): + assert is_url("https://google.com") + assert not is_url("google") + + +def test_get_file_from_url(): + file = get_file_from_url("https://google.com") + assert file is not None From b9deeafb20542f80caf372f5c295c842a5166eec Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 09:09:12 +0530 Subject: [PATCH 30/52] rename --- .../contrib/retrievers/{test_utils.py => test_retriever_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/agentchat/contrib/retrievers/{test_utils.py => test_retriever_utils.py} (100%) diff --git a/test/agentchat/contrib/retrievers/test_utils.py b/test/agentchat/contrib/retrievers/test_retriever_utils.py similarity index 100% rename from test/agentchat/contrib/retrievers/test_utils.py rename to test/agentchat/contrib/retrievers/test_retriever_utils.py From b669aee7ff5753c934d8a98f14d737c162788534 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 09:29:43 +0530 Subject: [PATCH 31/52] improve coverage --- .../retrievers/test_retriever_utils.py | 41 ------------------- test/agentchat/contrib/test_retrievechat.py | 25 +++++++---- 2 files changed, 18 insertions(+), 48 deletions(-) delete mode 100644 test/agentchat/contrib/retrievers/test_retriever_utils.py diff --git a/test/agentchat/contrib/retrievers/test_retriever_utils.py b/test/agentchat/contrib/retrievers/test_retriever_utils.py deleted file mode 100644 index 9bcd929716b..00000000000 --- a/test/agentchat/contrib/retrievers/test_retriever_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -from pathlib import Path -from autogen.agentchat.contrib.retriever.retrieve_utils import ( - split_text_to_chunks, - split_files_to_chunks, - get_file_from_url, - get_files_from_dir, - is_url, -) - -test_dir = Path(__file__).parent.parent.parent.parent / "test_files" - - -def test_split_text_to_chunks(): - text = "Hello, World! This is a test of the split_text_to_chunks() function." - chunks = split_text_to_chunks(text) - assert len(chunks) == 1 - chunks = split_text_to_chunks(text, max_tokens=10) - assert len(chunks) == 2 - - -def test_split_files_to_chunks(): - files = [test_dir / "example.txt"] - chunks = split_files_to_chunks(files) - assert len(chunks) == 1 - chunks = split_files_to_chunks(files, max_tokens=50) - assert len(chunks) == 2 - - -def test_get_files_from_dir(): - files = get_files_from_dir(str(test_dir)) - assert len(files) == 8 - - -def test_is_url(): - assert is_url("https://google.com") - assert not is_url("google") - - -def test_get_file_from_url(): - file = get_file_from_url("https://google.com") - assert file is not None diff --git a/test/agentchat/contrib/test_retrievechat.py b/test/agentchat/contrib/test_retrievechat.py index 574e3571b62..7a14e27bece 100644 --- a/test/agentchat/contrib/test_retrievechat.py +++ b/test/agentchat/contrib/test_retrievechat.py @@ -14,10 +14,6 @@ from autogen.agentchat.contrib.retrieve_user_proxy_agent import ( RetrieveUserProxyAgent, ) - import chromadb - from chromadb.utils import embedding_functions as ef - - skip_test = False except ImportError: skip_test = True @@ -45,7 +41,6 @@ def test_retrievechat(): }, ) - sentence_transformer_ef = ef.SentenceTransformerEmbeddingFunction() ragproxyagent = RetrieveUserProxyAgent( name="ragproxyagent", human_input_mode="NEVER", @@ -54,8 +49,6 @@ def test_retrievechat(): "docs_path": "./website/docs", "chunk_token_size": 2000, "model": config_list[0]["model"], - "client": chromadb.PersistentClient(path="/tmp/chromadb"), - "embedding_function": sentence_transformer_ef, "get_or_create": True, }, ) @@ -67,6 +60,24 @@ def test_retrievechat(): print(conversations) + # db_mode + ragproxyagent = RetrieveUserProxyAgent( + name="ragproxyagent", + human_input_mode="NEVER", + max_consecutive_auto_reply=2, + retrieve_config={ + "docs_path": "./website/docs", + "chunk_token_size": 2000, + "model": config_list[0]["model"], + "db_mode": "recreate", + }, + ) + + assistant.reset() + + code_problem = "How can I use FLAML to perform a classification task, set use_spark=True, train 30 seconds and force cancel jobs if time limit is reached." + ragproxyagent.initiate_chat(assistant, problem=code_problem, search_string="spark", silent=True) + @pytest.mark.skipif( sys.platform in ["darwin", "win32"] or skip_test, From 154fa3dafb201c69d9a69a9acf7cc7f169e15e58 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 09:39:58 +0530 Subject: [PATCH 32/52] update --- test/agentchat/contrib/test_retrievechat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/agentchat/contrib/test_retrievechat.py b/test/agentchat/contrib/test_retrievechat.py index 7a14e27bece..66393d6dd27 100644 --- a/test/agentchat/contrib/test_retrievechat.py +++ b/test/agentchat/contrib/test_retrievechat.py @@ -91,7 +91,7 @@ def test_retrieve_config(caplog): max_consecutive_auto_reply=2, retrieve_config={ "chunk_token_size": 2000, - "get_or_create": True, + "db_mode": "recreate", }, ) From c1c6532d40d8e1e4867643c7358c53dc732a5429 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 10:08:29 +0530 Subject: [PATCH 33/52] update notebook --- notebook/agentchat_groupchat_RAG.ipynb | 500 +++++++++++-------------- 1 file changed, 221 insertions(+), 279 deletions(-) diff --git a/notebook/agentchat_groupchat_RAG.ipynb b/notebook/agentchat_groupchat_RAG.ipynb index c68b3181950..87905ff2dcd 100644 --- a/notebook/agentchat_groupchat_RAG.ipynb +++ b/notebook/agentchat_groupchat_RAG.ipynb @@ -55,7 +55,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "LLM models: ['gpt-35-turbo', 'gpt-35-turbo-0613']\n" + "LLM models: ['gpt-4']\n" ] } ], @@ -110,20 +110,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/torch/cuda/__init__.py:138: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11060). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n", - " return torch._C._cuda_getDeviceCount() > 0\n" - ] - } - ], + "outputs": [], "source": [ "from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent\n", "from autogen import AssistantAgent\n", @@ -159,9 +148,8 @@ " \"docs_path\": \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md\",\n", " \"chunk_token_size\": 1000,\n", " \"model\": config_list[0][\"model\"],\n", - " \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"),\n", " \"collection_name\": \"groupchat\",\n", - " \"get_or_create\": True,\n", + " \"db_mode\": \"recreate\",\n", " },\n", " code_execution_config=False, # we don't want to execute code in this case.\n", ")\n", @@ -317,7 +305,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -331,88 +319,67 @@ "--------------------------------------------------------------------------------\n", "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n", "\n", - "To use Spark for parallel training in FLAML, you can use the `SparkTrials` class provided by FLAML. Here is a sample code:\n", + "Sure, you can use PySpark to parallelize the training process in FLAML. Here's a simple example of how you can do it:\n", "\n", "```python\n", "from flaml import AutoML\n", - "from flaml.data import load_credit\n", - "from flaml.model import SparkTrials\n", + "from pyspark.sql import SparkSession\n", + "from pyspark.sql.functions import col\n", + "from sklearn.datasets import load_boston\n", + "import pandas as pd\n", "\n", - "# Load data\n", - "X_train, y_train, X_test, y_test = load_credit()\n", + "# Initialize SparkSession\n", + "spark = SparkSession.builder \\\n", + " .appName(\"Parallel Training with FLAML\") \\\n", + " .getOrCreate()\n", "\n", - "# Define the search space\n", - "search_space = {\n", - " \"n_estimators\": {\"domain\": range(10, 100)},\n", - " \"max_depth\": {\"domain\": range(6, 10)},\n", - " \"learning_rate\": {\"domain\": (0.01, 0.1, 1)},\n", - "}\n", + "# Load the dataset\n", + "boston = load_boston()\n", + "df = pd.DataFrame(boston.data, columns=boston.feature_names)\n", + "df['target'] = pd.Series(boston.target)\n", "\n", - "# Create an AutoML instance with SparkTrials\n", - "automl = AutoML(\n", - " search_space=search_space,\n", - " task=\"classification\",\n", - " n_jobs=1,\n", - " ensemble_size=0,\n", - " max_trials=10,\n", - " trials=SparkTrials(parallelism=2),\n", - ")\n", + "# Convert the pandas dataframe to spark dataframe\n", + "sdf = spark.createDataFrame(df)\n", "\n", - "# Train the model\n", - "automl.fit(X_train=X_train, y_train=y_train)\n", + "# Split the data into training and test sets\n", + "train, test = sdf.randomSplit([0.8, 0.2])\n", "\n", - "# Evaluate the model\n", - "print(\"Best model:\", automl.best_model)\n", - "print(\"Best hyperparameters:\", automl.best_config)\n", - "print(\"Test accuracy:\", automl.score(X_test=X_test, y_test=y_test))\n", + "# Convert the spark dataframes back to pandas dataframes\n", + "train = train.select(\"*\").toPandas()\n", + "test = test.select(\"*\").toPandas()\n", "\n", - "# Terminate\n", - "TERMINATE\n", - "```\n", - "\n", - "In this code, we first load the credit dataset. Then, we define the search space for the hyperparameters. We create an `AutoML` instance with `SparkTrials` as the `trials` parameter. We set the `parallelism` parameter to 2, which means that FLAML will use 2 Spark workers to run the trials in parallel. Finally, we fit the model and evaluate it.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mCode_Reviewer\u001b[0m (to chat_manager):\n", + "# Initialize the AutoML instance\n", + "automl = AutoML()\n", "\n", - "Great! That's a clear and concise example. No further questions from my side.\n", + "# Specify the task as regression and the metric to optimize as rmse\n", + "automl_settings = {\n", + " \"time_budget\": 120, # in seconds\n", + " \"metric\": 'rmse',\n", + " \"task\": 'regression',\n", + " \"log_file_name\": \"boston.log\",\n", + "}\n", "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n", + "# Train the model\n", + "automl.fit(train.drop('target', axis=1), train['target'], **automl_settings)\n", "\n", - "Thank you! Let me know if you have any other questions.\n", + "# Predict on the test data\n", + "preds = automl.predict(test.drop('target', axis=1))\n", "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mBoss\u001b[0m (to chat_manager):\n", + "# Print the best model and metric\n", + "print('Best ML model:', automl.best_estimator)\n", + "print('Best metric:', automl.best_loss)\n", + "```\n", "\n", - "Reply `TERMINATE` if the task is done.\n", + "This script loads the Boston housing dataset, splits it into training and test sets, and then uses FLAML's AutoML to find the best model and hyperparameters. The training process is parallelized using PySpark.\n", "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GroupChat select_speaker failed to resolve the next speaker's name. Speaker selection will default to the next speaker in the list. This is because the speaker selection OAI call returned:\n", - "The next role to play is not specified in the conversation. Please provide more information.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n", + "Please note that this is a simple example and might not fully utilize the capabilities of Spark. For larger datasets and more complex scenarios, you might need to customize this script to better suit your needs.\n", "\n", - "TERMINATE\n", + "Also, please ensure that you have the necessary packages installed in your environment. You can install them using pip:\n", "\n", - "--------------------------------------------------------------------------------\n", + "```bash\n", + "pip install flaml[forecast]\n", + "pip install pyspark\n", + "```\n", "TERMINATE\n", "\n", "--------------------------------------------------------------------------------\n" @@ -434,30 +401,25 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 20, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Trying to create collection.\n" - ] - }, { "name": "stderr", "output_type": "stream", "text": [ - "Number of requested results 3 is greater than number of elements in index 2, updating n_results = 2\n" + "INFO:autogen:Trying to create index. If the index already exists, it will be recreated.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0', 'doc_1']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", + "Found 2 chunks.\n", + "query: How to use spark for parallel training in FLAML? Give me sample code.\n", + "doc_ids: [['1', '0']]\n", + "\u001b[32mAdding doc_id 1 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", "\u001b[33mBoss_Assistant\u001b[0m (to chat_manager):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -472,29 +434,47 @@ "\n", "User's question is: How to use spark for parallel training in FLAML? Give me sample code.\n", "\n", - "Context is: # Integrate - Spark\n", + "Context is: \n", + "use_spark: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable FLAML_MAX_CONCURRENT to override the detected num_executors. The final number of concurrent trials will be the minimum of n_concurrent_trials and num_executors.\n", + "n_concurrent_trials: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", + "force_cancel: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", + "An example code snippet for using parallel Spark jobs:\n", "\n", - "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", - "- Use Spark ML estimators for AutoML.\n", - "- Use Spark to run training in parallel spark jobs.\n", + "import flaml\n", + "automl_experiment = flaml.AutoML()\n", + "automl_settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"r2\",\n", + " \"task\": \"regression\",\n", + " \"n_concurrent_trials\": 2,\n", + " \"use_spark\": True,\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + "}\n", "\n", - "## Spark ML Estimators\n", + "automl.fit(\n", + " dataframe=dataframe,\n", + " label=label,\n", + " **automl_settings,\n", + ")\n", + "Integrate - Spark\n", + "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", "\n", + "Use Spark ML estimators for AutoML.\n", + "Use Spark to run training in parallel spark jobs.\n", + "Spark ML Estimators\n", "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", "\n", - "### Data\n", + "Data\n", + "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function to_pandas_on_spark in the flaml.automl.spark.utils module to convert your data into a pandas-on-spark (pyspark.pandas) dataframe/series, which Spark estimators require.\n", "\n", - "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", + "This utility function takes data in the form of a pandas.Dataframe or pyspark.sql.Dataframe and converts it into a pandas-on-spark dataframe. It also takes pandas.Series or pyspark.sql.Dataframe and converts it into a pandas-on-spark series. If you pass in a pyspark.pandas.Dataframe, it will not make any changes.\n", "\n", - "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", - "\n", - "This function also accepts optional arguments `index_col` and `default_index_type`.\n", - "- `index_col` is the column name to use as the index, default is None.\n", - "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", + "This function also accepts optional arguments index_col and default_index_type.\n", "\n", + "index_col is the column name to use as the index, default is None.\n", + "default_index_type is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official documentation\n", "Here is an example code snippet for Spark Data:\n", "\n", - "```python\n", "import pandas as pd\n", "from flaml.automl.spark.utils import to_pandas_on_spark\n", "# Creating a dictionary\n", @@ -508,33 +488,27 @@ "\n", "# Convert to pandas-on-spark dataframe\n", "psdf = to_pandas_on_spark(dataframe)\n", - "```\n", - "\n", - "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", + "To use Spark ML models you need to format your data appropriately. Specifically, use VectorAssembler to merge all feature columns into a single vector column.\n", "\n", "Here is an example of how to use it:\n", - "```python\n", + "\n", "from pyspark.ml.feature import VectorAssembler\n", "columns = psdf.columns\n", "feature_cols = [col for col in columns if col != label]\n", "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "```\n", - "\n", - "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", + "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using X_train, y_train or dataframe, label.\n", "\n", - "### Estimators\n", - "#### Model List\n", - "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", - "\n", - "#### Usage\n", + "Estimators\n", + "Model List\n", + "lgbm_spark: The class for fine-tuning Spark version LightGBM models, using SynapseML API.\n", + "Usage\n", "First, prepare your data in the required format as described in the previous section.\n", "\n", - "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", + "By including the models you intend to try in the estimators_list argument to flaml.automl, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the _spark postfix by default, even if you haven't specified them.\n", "\n", "Here is an example code snippet using SparkML models in AutoML:\n", "\n", - "```python\n", "import flaml\n", "# prepare your data in pandas-on-spark format as we previously mentioned\n", "\n", @@ -551,78 +525,105 @@ " label=label,\n", " **settings,\n", ")\n", - "```\n", - "\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", + "Link to notebook | Open in colab\n", "\n", - "## Parallel Spark Jobs\n", - "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", + "Parallel Spark Jobs\n", + "You can activate Spark as the parallel backend during parallel tuning in both AutoML and Hyperparameter Tuning, by setting the use_spark to true. FLAML will dispatch your job to the distributed Spark backend using joblib-spark.\n", "\n", - "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", + "Please note that you should not set use_spark to true when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with use_spark again.\n", "\n", "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", "\n", "\n", - "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", - "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", - "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", - "\n", - "An example code snippet for using parallel Spark jobs:\n", - "```python\n", - "import flaml\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"task\": \"regression\",\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", - "}\n", "\n", - "automl.fit(\n", - " dataframe=dataframe,\n", - " label=label,\n", - " **automl_settings,\n", - ")\n", - "```\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n", "\n", + "Based on the context provided, here is a sample code on how to use Spark for parallel training in FLAML:\n", "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", + "```python\n", + "# Import necessary libraries\n", + "import pandas as pd\n", + "from flaml.automl import AutoML\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "from pyspark.ml.feature import VectorAssembler\n", "\n", + "# Creating a dictionary\n", + "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", "\n", + "# Creating a pandas DataFrame\n", + "dataframe = pd.DataFrame(data)\n", + "label = \"Price\"\n", "\n", + "# Convert to pandas-on-spark dataframe\n", + "psdf = to_pandas_on_spark(dataframe)\n", "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n", + "# Use VectorAssembler to merge all feature columns into a single vector column\n", + "columns = psdf.columns\n", + "feature_cols = [col for col in columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", "\n", - "To use Spark for parallel training in FLAML, you can activate Spark as the parallel backend during parallel tuning in both AutoML and Hyperparameter Tuning, by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using `joblib-spark`. Here is an example code snippet for using parallel Spark jobs:\n", + "# Initialize AutoML\n", + "automl = AutoML()\n", "\n", - "```python\n", - "import flaml\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", + "# Define settings\n", + "settings = {\n", " \"time_budget\": 30,\n", " \"metric\": \"r2\",\n", " \"task\": \"regression\",\n", " \"n_concurrent_trials\": 2,\n", " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + " \"force_cancel\": True,\n", "}\n", "\n", + "# Fit the model\n", "automl.fit(\n", - " dataframe=dataframe,\n", + " dataframe=psdf,\n", " label=label,\n", - " **automl_settings,\n", + " **settings,\n", ")\n", "```\n", "\n", - "Note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", + "This code first creates a pandas DataFrame and converts it to a pandas-on-spark DataFrame. Then, it uses VectorAssembler to merge all feature columns into a single vector column. After that, it initializes AutoML and defines the settings for the model. Finally, it fits the model using the defined settings. \n", "\n", - "I hope this helps! Let me know if you have any further questions.\n", + "Please note that you need to have a running Spark session to use Spark for parallel training in FLAML.\n", "\n", - "--------------------------------------------------------------------------------\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n", "\n", "TERMINATE\n", @@ -647,7 +648,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -658,136 +659,70 @@ "\n", "How to use spark for parallel training in FLAML? Give me sample code.\n", "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n", - "\n", - "To use Spark for parallel training in FLAML, you can follow these steps:\n", - "\n", - "1. Install PySpark and FLAML on your machine.\n", - "2. Start a Spark cluster using the `pyspark` command.\n", - "3. Import the necessary libraries and initialize a SparkSession object.\n", - "4. Load your data into a Spark DataFrame.\n", - "5. Define your search space and search strategy using FLAML's API.\n", - "6. Create a SparkEstimator object and pass it to FLAML's `fit()` method.\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n", "\n", - "Here's some sample code to get you started:\n", + "Sure, you can use PySpark to parallelize the training process in FLAML. Here's a simple example of how you can do it:\n", "\n", "```python\n", - "from pyspark.sql import SparkSession\n", "from flaml import AutoML\n", - "from flaml.data import get_output_from_log\n", - "\n", - "# Initialize a SparkSession object\n", - "spark = SparkSession.builder.appName(\"FLAML-Spark\").getOrCreate()\n", - "\n", - "# Load your data into a Spark DataFrame\n", - "data = spark.read.format(\"csv\").option(\"header\", \"true\").load(\"path/to/data.csv\")\n", - "\n", - "# Define your search space and search strategy\n", - "search_space = {\n", - " \"n_estimators\": {\"domain\": range(10, 100)},\n", - " \"max_depth\": {\"domain\": range(1, 10)},\n", - " \"learning_rate\": {\"domain\": [0.001, 0.01, 0.1]},\n", - "}\n", - "search_strategy = \"skopt\"\n", - "\n", - "# Create a SparkEstimator object\n", - "from pyspark.ml.classification import GBTClassifier\n", - "estimator = GBTClassifier()\n", - "\n", - "# Pass the SparkEstimator object to FLAML's fit() method\n", - "automl = AutoML()\n", - "automl.fit(\n", - " X_train=data,\n", - " estimator=estimator,\n", - " task=\"classification\",\n", - " search_space=search_space,\n", - " search_alg=search_strategy,\n", - " n_jobs=-1,\n", - ")\n", - "\n", - "# Get the best model and its hyperparameters\n", - "best_model = automl.model\n", - "best_params = automl.best_config\n", - "\n", - "# Print the results\n", - "print(f\"Best model: {best_model}\")\n", - "print(f\"Best hyperparameters: {best_params}\")\n", - "\n", - "# Stop the SparkSession object\n", - "spark.stop()\n", - "```\n", - "\n", - "Note that the `n_jobs` parameter is set to `-1` to use all available cores on the Spark cluster. You can adjust this value to control the level of parallelism. Also, the `get_output_from_log()` function can be used to extract the results from the FLAML log file. \n", - "\n", - "TERMINATE\n", + "from pyspark.sql import SparkSession\n", + "from pyspark.sql.functions import col\n", + "from sklearn.datasets import load_boston\n", + "import pandas as pd\n", "\n", - "--------------------------------------------------------------------------------\n", - "To use Spark for parallel training in FLAML, you can follow these steps:\n", + "# Initialize SparkSession\n", + "spark = SparkSession.builder \\\n", + " .appName(\"Parallel Training with FLAML\") \\\n", + " .getOrCreate()\n", "\n", - "1. Install PySpark and FLAML on your machine.\n", - "2. Start a Spark cluster using the `pyspark` command.\n", - "3. Import the necessary libraries and initialize a SparkSession object.\n", - "4. Load your data into a Spark DataFrame.\n", - "5. Define your search space and search strategy using FLAML's API.\n", - "6. Create a SparkEstimator object and pass it to FLAML's `fit()` method.\n", + "# Load the dataset\n", + "boston = load_boston()\n", + "df = pd.DataFrame(boston.data, columns=boston.feature_names)\n", + "df['target'] = pd.Series(boston.target)\n", "\n", - "Here's some sample code to get you started:\n", + "# Convert the pandas dataframe to spark dataframe\n", + "sdf = spark.createDataFrame(df)\n", "\n", - "```python\n", - "from pyspark.sql import SparkSession\n", - "from flaml import AutoML\n", - "from flaml.data import get_output_from_log\n", + "# Split the data into training and test sets\n", + "train, test = sdf.randomSplit([0.8, 0.2])\n", "\n", - "# Initialize a SparkSession object\n", - "spark = SparkSession.builder.appName(\"FLAML-Spark\").getOrCreate()\n", + "# Convert the spark dataframes back to pandas dataframes\n", + "train = train.select(\"*\").toPandas()\n", + "test = test.select(\"*\").toPandas()\n", "\n", - "# Load your data into a Spark DataFrame\n", - "data = spark.read.format(\"csv\").option(\"header\", \"true\").load(\"path/to/data.csv\")\n", + "# Initialize the AutoML instance\n", + "automl = AutoML()\n", "\n", - "# Define your search space and search strategy\n", - "search_space = {\n", - " \"n_estimators\": {\"domain\": range(10, 100)},\n", - " \"max_depth\": {\"domain\": range(1, 10)},\n", - " \"learning_rate\": {\"domain\": [0.001, 0.01, 0.1]},\n", + "# Specify the task as regression and the metric to optimize as rmse\n", + "automl_settings = {\n", + " \"time_budget\": 120, # in seconds\n", + " \"metric\": 'rmse',\n", + " \"task\": 'regression',\n", + " \"log_file_name\": \"boston.log\",\n", "}\n", - "search_strategy = \"skopt\"\n", "\n", - "# Create a SparkEstimator object\n", - "from pyspark.ml.classification import GBTClassifier\n", - "estimator = GBTClassifier()\n", + "# Train the model\n", + "automl.fit(train.drop('target', axis=1), train['target'], **automl_settings)\n", "\n", - "# Pass the SparkEstimator object to FLAML's fit() method\n", - "automl = AutoML()\n", - "automl.fit(\n", - " X_train=data,\n", - " estimator=estimator,\n", - " task=\"classification\",\n", - " search_space=search_space,\n", - " search_alg=search_strategy,\n", - " n_jobs=-1,\n", - ")\n", + "# Predict on the test data\n", + "preds = automl.predict(test.drop('target', axis=1))\n", "\n", - "# Get the best model and its hyperparameters\n", - "best_model = automl.model\n", - "best_params = automl.best_config\n", + "# Print the best model and metric\n", + "print('Best ML model:', automl.best_estimator)\n", + "print('Best metric:', automl.best_loss)\n", + "```\n", "\n", - "# Print the results\n", - "print(f\"Best model: {best_model}\")\n", - "print(f\"Best hyperparameters: {best_params}\")\n", + "This script loads the Boston housing dataset, splits it into training and test sets, and then uses FLAML's AutoML to find the best model and hyperparameters. The training process is parallelized using PySpark.\n", "\n", - "# Stop the SparkSession object\n", - "spark.stop()\n", - "```\n", + "Please note that this is a simple example and might not fully utilize the capabilities of Spark. For larger datasets and more complex scenarios, you might need to customize this script to better suit your needs.\n", "\n", - "Note that the `n_jobs` parameter is set to `-1` to use all available cores on the Spark cluster. You can adjust this value to control the level of parallelism. Also, the `get_output_from_log()` function can be used to extract the results from the FLAML log file. \n", + "Also, please ensure that you have the necessary packages installed in your environment. You can install them using pip:\n", "\n", + "```bash\n", + "pip install flaml[forecast]\n", + "pip install pyspark\n", + "```\n", "TERMINATE\n", "\n", "--------------------------------------------------------------------------------\n" @@ -797,11 +732,18 @@ "source": [ "call_rag_chat()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "flaml", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -815,9 +757,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.6" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } From 1cb69318a00e93b501bf0805677109c60813cf67 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 10:14:13 +0530 Subject: [PATCH 34/52] update --- .github/workflows/contrib-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 7e1eaa7d85b..7406f8adc42 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -45,6 +45,7 @@ jobs: - name: Install packages and dependencies for RetrieveChat run: | pip install -e .[retrievechat] + pip install chromadb pip uninstall -y openai - name: Test RetrieveChat run: | From 1f02f65299d58745b6c322ed894d2cdefa1546ab Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 16:25:56 +0530 Subject: [PATCH 35/52] update docstring --- autogen/agentchat/contrib/retrieve_user_proxy_agent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 3eda225fb63..3681ac75820 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -125,6 +125,10 @@ def __init__( - db_mode (Optional, str): the mode to create the vector db. Possible values are "get", "recreate", "create". Default is "recreate" to keep the workflow less error-prone. If "get", will try to get an existing collection. If "recreate", will recreate a collection if the collection already exists. If "create", will create a collection if the collection doesn't exist. + Raises ValueError if: + * the collection doesn't exist and "get" is used. + * the collection already exists and "create" is used. + - get_or_create (Optional, bool): [Depricated] if True, will create/recreate a collection for the retrieve chat. This is the same as that used in retriever. Default is False. Will be set to False if docs_path is None. - custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string. From 96a4136735b06e96e8b0b515dd744797d8e68e27 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 19:04:31 +0800 Subject: [PATCH 36/52] Test hide bot comments --- .github/workflows/contrib-openai.yml | 4 ++++ .github/workflows/openai.yml | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml index 467d5270c8e..6f03cfa1f65 100644 --- a/.github/workflows/contrib-openai.yml +++ b/.github/workflows/contrib-openai.yml @@ -26,6 +26,10 @@ jobs: uses: actions/checkout@v3 with: ref: ${{ github.event.pull_request.head.sha }} + - name: Hide bot comments + uses: kanga333/comment-hider@master + with: + github_token: ${{ secrets.GITHUB_TOKEN }} - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/openai.yml b/.github/workflows/openai.yml index de5f743855f..935cff50d8d 100644 --- a/.github/workflows/openai.yml +++ b/.github/workflows/openai.yml @@ -27,6 +27,10 @@ jobs: uses: actions/checkout@v3 with: ref: ${{ github.event.pull_request.head.sha }} + - name: Hide bot comments + uses: kanga333/comment-hider@master + with: + github_token: ${{ secrets.GITHUB_TOKEN }} - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: From 274950ee3079a2389aff6df3fd9453e97d68c252 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 19:16:45 +0800 Subject: [PATCH 37/52] Revert "Test hide bot comments" This reverts commit 96a4136735b06e96e8b0b515dd744797d8e68e27. --- .github/workflows/contrib-openai.yml | 4 ---- .github/workflows/openai.yml | 4 ---- 2 files changed, 8 deletions(-) diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml index 6f03cfa1f65..467d5270c8e 100644 --- a/.github/workflows/contrib-openai.yml +++ b/.github/workflows/contrib-openai.yml @@ -26,10 +26,6 @@ jobs: uses: actions/checkout@v3 with: ref: ${{ github.event.pull_request.head.sha }} - - name: Hide bot comments - uses: kanga333/comment-hider@master - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/openai.yml b/.github/workflows/openai.yml index 935cff50d8d..de5f743855f 100644 --- a/.github/workflows/openai.yml +++ b/.github/workflows/openai.yml @@ -27,10 +27,6 @@ jobs: uses: actions/checkout@v3 with: ref: ${{ github.event.pull_request.head.sha }} - - name: Hide bot comments - uses: kanga333/comment-hider@master - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: From 795dfbd8515df3be5f077b1ef8a98e5171e2cac6 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 19:17:15 +0800 Subject: [PATCH 38/52] Add hide bot comments --- .github/workflows/build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8d4b84a301a..f2830d090fe 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,6 +30,8 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.8", "3.9", "3.10", "3.11"] steps: + - name: Hide comment + uses: int128/hide-comment-action@v1 - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 From 0de53a3e852303108f55124fb0cfae857da6c1a0 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 19:20:34 +0800 Subject: [PATCH 39/52] Revert "Add hide bot comments" This reverts commit 795dfbd8515df3be5f077b1ef8a98e5171e2cac6. --- .github/workflows/build.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f2830d090fe..8d4b84a301a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,8 +30,6 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - name: Hide comment - uses: int128/hide-comment-action@v1 - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 From bdf45cff0c04665f6109ec51b43a52af00e8fc91 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 19:21:37 +0800 Subject: [PATCH 40/52] Add comment-hider --- .github/workflows/build.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8d4b84a301a..3846b27d8a2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,10 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 + - uses: kanga333/comment-hider@master + name: Hide bot comments + with: + github_token: ${{ secrets.GITHUB_TOKEN }} - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: From 40edc959760016071f5d57f1f337dfcd560cedac Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 19:26:52 +0800 Subject: [PATCH 41/52] Revert "Add comment-hider" This reverts commit bdf45cff0c04665f6109ec51b43a52af00e8fc91. --- .github/workflows/build.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3846b27d8a2..8d4b84a301a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,10 +31,6 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - - uses: kanga333/comment-hider@master - name: Hide bot comments - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: From 9648e4b9c1707050cae90a01af410dc5bf6a1e8a Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 19:28:49 +0800 Subject: [PATCH 42/52] Add hide-comment-action --- .github/workflows/build.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8d4b84a301a..64954856dfc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,6 +30,12 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.8", "3.9", "3.10", "3.11"] steps: + - uses: int128/hide-comment-action@v1 + with: + authors: | + github-actions + github-actions[bot] + token: ${{ secrets.GITHUB_TOKEN }} - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 From 27d8d30d023194df61fd96d175bed4772c9be409 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 19:49:09 +0800 Subject: [PATCH 43/52] Revert "Add hide-comment-action" This reverts commit 9648e4b9c1707050cae90a01af410dc5bf6a1e8a. --- .github/workflows/build.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 64954856dfc..8d4b84a301a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,12 +30,6 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: int128/hide-comment-action@v1 - with: - authors: | - github-actions - github-actions[bot] - token: ${{ secrets.GITHUB_TOKEN }} - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 From 5f341a62e44bab0c3448e04c02116b4932331a4c Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 20:09:41 +0800 Subject: [PATCH 44/52] Update coverage for retrievechat --- .github/workflows/contrib-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 7406f8adc42..c6e4b37d808 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -54,7 +54,7 @@ jobs: if: matrix.python-version == '3.10' run: | pip install coverage>=5.3 - coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib + coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py coverage xml - name: Upload coverage to Codecov if: matrix.python-version == '3.10' From 38eac1427c91caf7bc154da2905687b2bc698987 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 20:40:24 +0800 Subject: [PATCH 45/52] Move retrieve tests to the same folder --- .github/workflows/contrib-openai.yml | 3 ++- .github/workflows/contrib-tests.yml | 4 ++-- .../contrib/{ => retrievers}/test_qdrant_retrievechat.py | 3 ++- .../contrib/retrievers}/test_retrieve_utils.py | 3 ++- .../contrib/{ => retrievers}/test_retrievechat.py | 9 ++++++--- 5 files changed, 14 insertions(+), 8 deletions(-) rename test/agentchat/contrib/{ => retrievers}/test_qdrant_retrievechat.py (97%) rename test/{ => agentchat/contrib/retrievers}/test_retrieve_utils.py (98%) rename test/agentchat/contrib/{ => retrievers}/test_retrievechat.py (96%) diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml index 467d5270c8e..78770792fb4 100644 --- a/.github/workflows/contrib-openai.yml +++ b/.github/workflows/contrib-openai.yml @@ -42,6 +42,7 @@ jobs: pip install docker pip install qdrant_client[fastembed] pip install -e .[retrievechat] + pip install chromadb - name: Coverage env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -49,7 +50,7 @@ jobs: AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }} OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} run: | - coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py + coverage run -a -m pytest test/agentchat/contrib/retrievers coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index c6e4b37d808..a042b93dec2 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -49,12 +49,12 @@ jobs: pip uninstall -y openai - name: Test RetrieveChat run: | - pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py + pytest test/agentchat/contrib/retrievers - name: Coverage if: matrix.python-version == '3.10' run: | pip install coverage>=5.3 - coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py + coverage run -a -m pytest test/agentchat/contrib/retrievers coverage xml - name: Upload coverage to Codecov if: matrix.python-version == '3.10' diff --git a/test/agentchat/contrib/test_qdrant_retrievechat.py b/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py similarity index 97% rename from test/agentchat/contrib/test_qdrant_retrievechat.py rename to test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py index 1d3c5afd6af..97ec4454c3b 100644 --- a/test/agentchat/contrib/test_qdrant_retrievechat.py +++ b/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py @@ -1,5 +1,6 @@ import os import sys +from pathlib import Path import pytest from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent from autogen import config_list_from_json @@ -27,7 +28,7 @@ except ImportError: OPENAI_INSTALLED = False -test_dir = os.path.join(os.path.dirname(__file__), "../..", "test_files") +test_dir = Path(__file__).parent.parent.parent.parent / "test_files" @pytest.mark.skipif( diff --git a/test/test_retrieve_utils.py b/test/agentchat/contrib/retrievers/test_retrieve_utils.py similarity index 98% rename from test/test_retrieve_utils.py rename to test/agentchat/contrib/retrievers/test_retrieve_utils.py index ad6ad3df9e6..bbea7244257 100644 --- a/test/test_retrieve_utils.py +++ b/test/agentchat/contrib/retrievers/test_retrieve_utils.py @@ -3,6 +3,7 @@ """ import os import sys +from pathlib import Path import pytest try: @@ -29,7 +30,7 @@ except ImportError: HAS_UNSTRUCTURED = False -test_dir = os.path.join(os.path.dirname(__file__), "test_files") +test_dir = Path(__file__).parent.parent.parent.parent / "test_files" expected_text = """AutoGen is an advanced tool designed to assist developers in harnessing the capabilities of Large Language Models (LLMs) for various applications. The primary purpose of AutoGen is to automate and simplify the process of building applications that leverage the power of LLMs, allowing for seamless diff --git a/test/agentchat/contrib/test_retrievechat.py b/test/agentchat/contrib/retrievers/test_retrievechat.py similarity index 96% rename from test/agentchat/contrib/test_retrievechat.py rename to test/agentchat/contrib/retrievers/test_retrievechat.py index 66393d6dd27..cda7417a304 100644 --- a/test/agentchat/contrib/test_retrievechat.py +++ b/test/agentchat/contrib/retrievers/test_retrievechat.py @@ -1,11 +1,9 @@ import pytest import os import sys +from pathlib import Path import autogen -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 - try: import openai from autogen.agentchat.contrib.retrieve_assistant_agent import ( @@ -14,9 +12,14 @@ from autogen.agentchat.contrib.retrieve_user_proxy_agent import ( RetrieveUserProxyAgent, ) + + skip_test = False except ImportError: skip_test = True +KEY_LOC = "notebook" +OAI_CONFIG_LIST = "OAI_CONFIG_LIST" + @pytest.mark.skipif( sys.platform in ["darwin", "win32"] or skip_test, From d1e60788eddba2fdd9937b7bd484e5a704d20d15 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 21:01:06 +0800 Subject: [PATCH 46/52] Sync changes in main branch --- .../agentchat/contrib/retrieve_user_proxy_agent.py | 6 +++--- .../agentchat/contrib/retriever/retrieve_utils.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 3681ac75820..80a1c03b902 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -97,8 +97,9 @@ def __init__( - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System prompt will be different for different tasks. The default value is `default`, which supports both code and qa. - client (Optional, Any): the vectordb client/connection. If key not provided, the Retreiver class should handle it. - - docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file, - or the url to a single file. Default is None, which works only if the collection is already created. + - docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file, + the url to a single file or a list of directories, files and urls. + Default is None, which works only if the collection is already created. - collection_name (Optional, str): the name of the collection. If key not provided, a default name `autogen-docs` will be used. - model (Optional, str): the model to use for the retrieve chat. @@ -128,7 +129,6 @@ def __init__( Raises ValueError if: * the collection doesn't exist and "get" is used. * the collection already exists and "create" is used. - - get_or_create (Optional, bool): [Depricated] if True, will create/recreate a collection for the retrieve chat. This is the same as that used in retriever. Default is False. Will be set to False if docs_path is None. - custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string. diff --git a/autogen/agentchat/contrib/retriever/retrieve_utils.py b/autogen/agentchat/contrib/retriever/retrieve_utils.py index 5ed2304fc82..3b68fa84758 100644 --- a/autogen/agentchat/contrib/retriever/retrieve_utils.py +++ b/autogen/agentchat/contrib/retriever/retrieve_utils.py @@ -31,7 +31,17 @@ "yml", "pdf", ] -UNSTRUCTURED_FORMATS = ["docx", "doc", "odt", "pptx", "ppt", "xlsx", "eml", "msg", "epub"] +UNSTRUCTURED_FORMATS = [ + "docx", + "doc", + "odt", + "pptx", + "ppt", + "xlsx", + "eml", + "msg", + "epub", +] # These formats will be parsed by the 'unstructured' library, if installed. if HAS_UNSTRUCTURED: TEXT_FORMATS += UNSTRUCTURED_FORMATS TEXT_FORMATS = list(set(TEXT_FORMATS)) From 9397974e4e3fe90aa488a444718aad559c950803 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 12 Dec 2023 21:24:17 +0800 Subject: [PATCH 47/52] Fix import error in tests --- .../contrib/retrievers/test_qdrant_retrievechat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py b/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py index 97ec4454c3b..18b62447684 100644 --- a/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py +++ b/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py @@ -5,9 +5,6 @@ from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent from autogen import config_list_from_json -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 - try: from qdrant_client import QdrantClient from autogen.agentchat.contrib.qdrant_retrieve_user_proxy_agent import ( @@ -28,6 +25,9 @@ except ImportError: OPENAI_INSTALLED = False + +KEY_LOC = "notebook" +OAI_CONFIG_LIST = "OAI_CONFIG_LIST" test_dir = Path(__file__).parent.parent.parent.parent / "test_files" From 6a5ea67b5b06c211650b79c49a8b460ddae8d7e1 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 20:06:09 +0530 Subject: [PATCH 48/52] update --- .../contrib/retrieve_user_proxy_agent.py | 2 +- autogen/agentchat/contrib/retriever/base.py | 7 +- .../agentchat/contrib/retriever/chromadb.py | 4 +- .../agentchat/contrib/retriever/lancedb.py | 6 +- test/agentchat/contrib/test_retrievechat.py | 27 +++----- test/test_retrieve_utils.py | 65 +++++++++++++++++++ 6 files changed, 84 insertions(+), 27 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 3681ac75820..6d11612c742 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -142,7 +142,7 @@ def __init__( **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). Example of overriding retrieve_docs: - If you have set up a customized vector db, and it's not compatible with retriever, you can easily plug in it with below code. + If you want to set up a customized vector db, and it's not compatible with retriever, you can easily plug in it with below code. ```python class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): def query_vector_db( diff --git a/autogen/agentchat/contrib/retriever/base.py b/autogen/agentchat/contrib/retriever/base.py index 63d28ae8996..afb33e68d17 100644 --- a/autogen/agentchat/contrib/retriever/base.py +++ b/autogen/agentchat/contrib/retriever/base.py @@ -39,6 +39,8 @@ def __init__( self.must_break_at_empty_line = must_break_at_empty_line self.custom_text_split_function = custom_text_split_function self.client = client + self.custom_text_types = custom_text_types + self.recursive = recursive self.init_db() @@ -60,12 +62,13 @@ def use_existing_index(self): pass @abstractmethod - def query(self, texts: List[str], top_k: int = 10, filter: Any = None): + def query(self, texts: List[str], top_k: int = 10, search_string: Any = None): """ Query the database. Args: - query: query string or list of query strings + texts: list of texts to query top_k: number of results to return + search_string: string to filter the results """ pass diff --git a/autogen/agentchat/contrib/retriever/chromadb.py b/autogen/agentchat/contrib/retriever/chromadb.py index 59295088fe3..dc1e81db028 100644 --- a/autogen/agentchat/contrib/retriever/chromadb.py +++ b/autogen/agentchat/contrib/retriever/chromadb.py @@ -64,7 +64,7 @@ def ingest_data(self, data_dir, overwrite: bool = False): def use_existing_index(self): self.collection = self.client.get_collection(name=self.name, embedding_function=self.embedding_function) - def query(self, texts: List[str], top_k: int = 10, filter: str = None): + def query(self, texts: List[str], top_k: int = 10, search_string: str = None): # the collection's embedding function is always the default one, but we want to use the one we used to create the # collection. So we compute the embeddings ourselves and pass it to the query function. @@ -73,7 +73,7 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): results = self.collection.query( query_embeddings=query_embeddings, n_results=top_k, - where_document={"$contains": filter} if filter else None, # optional filter + where_document={"$contains": search_string} if search_string else None, # optional filter ) return results diff --git a/autogen/agentchat/contrib/retriever/lancedb.py b/autogen/agentchat/contrib/retriever/lancedb.py index e63a6ee5db5..88dd3ccba6d 100644 --- a/autogen/agentchat/contrib/retriever/lancedb.py +++ b/autogen/agentchat/contrib/retriever/lancedb.py @@ -53,7 +53,7 @@ def ingest_data(self, data_dir, overwrite: bool = False): def use_existing_index(self): self.table = self.db.open_table(self.name) - def query(self, texts: List[str], top_k: int = 10, filter: str = None): + def query(self, texts: List[str], top_k: int = 10, search_string: str = None): if self.db is None: self.init_db() texts = [texts] if isinstance(texts, str) else texts @@ -62,8 +62,8 @@ def query(self, texts: List[str], top_k: int = 10, filter: str = None): query = self.embedding_function(text) if isinstance(self.embedding_function, Callable) else text print("query: ", query) result = self.table.search(query) - if filter is not None: - result = result.where(f"documents LIKE '%{filter}%'") + if search_string is not None: + result = result.where(f"documents LIKE '%{search_string}%'") result = result.limit(top_k).to_arrow().to_pydict() for k, v in result.items(): results[k].append(v) diff --git a/test/agentchat/contrib/test_retrievechat.py b/test/agentchat/contrib/test_retrievechat.py index 66393d6dd27..e602aeaff0d 100644 --- a/test/agentchat/contrib/test_retrievechat.py +++ b/test/agentchat/contrib/test_retrievechat.py @@ -14,6 +14,10 @@ from autogen.agentchat.contrib.retrieve_user_proxy_agent import ( RetrieveUserProxyAgent, ) + import chromadb + from chromadb.utils import embedding_functions as ef + + skip_test = False except ImportError: skip_test = True @@ -41,14 +45,17 @@ def test_retrievechat(): }, ) + sentence_transformer_ef = ef.SentenceTransformerEmbeddingFunction() ragproxyagent = RetrieveUserProxyAgent( name="ragproxyagent", human_input_mode="NEVER", max_consecutive_auto_reply=2, retrieve_config={ + "client": chromadb.PersistentClient(path="/tmp/chromadb"), "docs_path": "./website/docs", "chunk_token_size": 2000, "model": config_list[0]["model"], + "embedding_function": sentence_transformer_ef, "get_or_create": True, }, ) @@ -60,24 +67,6 @@ def test_retrievechat(): print(conversations) - # db_mode - ragproxyagent = RetrieveUserProxyAgent( - name="ragproxyagent", - human_input_mode="NEVER", - max_consecutive_auto_reply=2, - retrieve_config={ - "docs_path": "./website/docs", - "chunk_token_size": 2000, - "model": config_list[0]["model"], - "db_mode": "recreate", - }, - ) - - assistant.reset() - - code_problem = "How can I use FLAML to perform a classification task, set use_spark=True, train 30 seconds and force cancel jobs if time limit is reached." - ragproxyagent.initiate_chat(assistant, problem=code_problem, search_string="spark", silent=True) - @pytest.mark.skipif( sys.platform in ["darwin", "win32"] or skip_test, @@ -91,7 +80,7 @@ def test_retrieve_config(caplog): max_consecutive_auto_reply=2, retrieve_config={ "chunk_token_size": 2000, - "db_mode": "recreate", + "get_or_create": True, }, ) diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index ad6ad3df9e6..8571de67cce 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -136,6 +136,71 @@ def test_unstructured(self): ) +def test_custom_vector_db(self): + try: + import lancedb + except ImportError: + return + from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent + + db_path = "/tmp/lancedb" + + def create_lancedb(): + db = lancedb.connect(db_path) + data = [ + {"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"}, + {"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"}, + {"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"}, + {"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"}, + {"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"}, + {"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"}, + ] + try: + db.create_table("my_table", data) + except OSError: + pass + + class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): + def query_vector_db( + self, + query_texts, + n_results=10, + search_string="", + ): + if query_texts: + vector = [0.1, 0.3] + db = lancedb.connect(db_path) + table = db.open_table("my_table") + query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() + return {"ids": [query["id"].tolist()], "documents": [query["documents"].tolist()]} + + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): + results = self.query_vector_db( + query_texts=[problem], + n_results=n_results, + search_string=search_string, + ) + + self._results = results + print("doc_ids: ", results["ids"]) + + ragragproxyagent = MyRetrieveUserProxyAgent( + name="ragproxyagent", + human_input_mode="NEVER", + max_consecutive_auto_reply=2, + retrieve_config={ + "task": "qa", + "chunk_token_size": 2000, + "client": "__", + "embedding_model": "all-mpnet-base-v2", + }, + ) + + create_lancedb() + ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") + assert ragragproxyagent._results["ids"] == [[3, 1, 5]] + + if __name__ == "__main__": pytest.main() From 232eabad5e113f61e63018bef3816b6c12c92bed Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 20:10:16 +0530 Subject: [PATCH 49/52] add custom_vectordb test --- .../contrib/retrievers/test_retrieve_utils.py | 121 +++++++++--------- 1 file changed, 60 insertions(+), 61 deletions(-) diff --git a/test/agentchat/contrib/retrievers/test_retrieve_utils.py b/test/agentchat/contrib/retrievers/test_retrieve_utils.py index bb2c9a8ce8c..8dc657cb986 100644 --- a/test/agentchat/contrib/retrievers/test_retrieve_utils.py +++ b/test/agentchat/contrib/retrievers/test_retrieve_utils.py @@ -136,70 +136,69 @@ def test_unstructured(self): for chunk in chunks ) - -def test_custom_vector_db(self): - try: - import lancedb - except ImportError: - return - from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent - - db_path = "/tmp/lancedb" - - def create_lancedb(): - db = lancedb.connect(db_path) - data = [ - {"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"}, - {"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"}, - {"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"}, - {"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"}, - {"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"}, - {"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"}, - ] + def test_custom_vector_db(self): try: - db.create_table("my_table", data) - except OSError: - pass - - class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): - def query_vector_db( - self, - query_texts, - n_results=10, - search_string="", - ): - if query_texts: - vector = [0.1, 0.3] + import lancedb + except ImportError: + return + from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent + + db_path = "/tmp/lancedb" + + def create_lancedb(): db = lancedb.connect(db_path) - table = db.open_table("my_table") - query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() - return {"ids": [query["id"].tolist()], "documents": [query["documents"].tolist()]} - - def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): - results = self.query_vector_db( - query_texts=[problem], - n_results=n_results, - search_string=search_string, - ) - - self._results = results - print("doc_ids: ", results["ids"]) - - ragragproxyagent = MyRetrieveUserProxyAgent( - name="ragproxyagent", - human_input_mode="NEVER", - max_consecutive_auto_reply=2, - retrieve_config={ - "task": "qa", - "chunk_token_size": 2000, - "client": "__", - "embedding_model": "all-mpnet-base-v2", - }, - ) + data = [ + {"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"}, + {"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"}, + {"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"}, + {"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"}, + {"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"}, + {"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"}, + ] + try: + db.create_table("my_table", data) + except OSError: + pass + + class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): + def query_vector_db( + self, + query_texts, + n_results=10, + search_string="", + ): + if query_texts: + vector = [0.1, 0.3] + db = lancedb.connect(db_path) + table = db.open_table("my_table") + query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() + return {"ids": [query["id"].tolist()], "documents": [query["documents"].tolist()]} + + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): + results = self.query_vector_db( + query_texts=[problem], + n_results=n_results, + search_string=search_string, + ) + + self._results = results + print("doc_ids: ", results["ids"]) + + ragragproxyagent = MyRetrieveUserProxyAgent( + name="ragproxyagent", + human_input_mode="NEVER", + max_consecutive_auto_reply=2, + retrieve_config={ + "task": "qa", + "chunk_token_size": 2000, + "client": "__", + "embedding_model": "all-mpnet-base-v2", + }, + ) - create_lancedb() - ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") - assert ragragproxyagent._results["ids"] == [[3, 1, 5]] + create_lancedb() + ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") + assert ragragproxyagent._results["ids"] == [[3, 1, 5]] if __name__ == "__main__": From b37e03fc6839c8a97377e3d5963655aa3e752afe Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 20:21:17 +0530 Subject: [PATCH 50/52] update test --- test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py b/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py index 18b62447684..e0dee0b21a4 100644 --- a/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py +++ b/test/agentchat/contrib/retrievers/test_qdrant_retrievechat.py @@ -91,7 +91,7 @@ def test_qdrant_filter(): @pytest.mark.skipif(not QDRANT_INSTALLED, reason="qdrant_client is not installed") def test_qdrant_search(): client = QdrantClient(":memory:") - create_qdrant_from_dir(test_dir, client=client) + create_qdrant_from_dir(str(test_dir), client=client) assert client.get_collection("all-my-documents") From 680e37d0b46e2d5e4a05f7dde27ba3723e3af827 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 20:33:39 +0530 Subject: [PATCH 51/52] update dosctring --- autogen/agentchat/contrib/retriever/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/autogen/agentchat/contrib/retriever/base.py b/autogen/agentchat/contrib/retriever/base.py index afb33e68d17..f9a8331b145 100644 --- a/autogen/agentchat/contrib/retriever/base.py +++ b/autogen/agentchat/contrib/retriever/base.py @@ -29,6 +29,9 @@ def __init__( must_break_at_empty_line: chunk will only break at empty line if True. Default is True. If chunk_mode is "one_line", this parameter will be ignored. custom_text_split_function: custom function to split the text into chunks + client: client to use to connect to the database + custom_text_types: custom text types to ingest + recursive: whether to recursively ingest the files in the directory """ self.path = path self.name = name From ccef9ca39b1327d1a3866462a9d948e48a4abfa9 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Tue, 12 Dec 2023 20:47:55 +0530 Subject: [PATCH 52/52] update test --- test/agentchat/contrib/retrievers/test_lancedb.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/agentchat/contrib/retrievers/test_lancedb.py b/test/agentchat/contrib/retrievers/test_lancedb.py index 66868f4b05f..c8450926b7d 100644 --- a/test/agentchat/contrib/retrievers/test_lancedb.py +++ b/test/agentchat/contrib/retrievers/test_lancedb.py @@ -1,13 +1,6 @@ import numpy as np from pathlib import Path import pytest -from autogen.agentchat.contrib.retriever.retrieve_utils import ( - split_text_to_chunks, - extract_text_from_pdf, - split_files_to_chunks, - get_files_from_dir, - is_url, -) try: from autogen.agentchat.contrib.retriever.lancedb import LanceDB @@ -17,7 +10,7 @@ else: skip = False -# test_dir is 2 directories above this file + test_dir = Path(__file__).parent.parent.parent.parent / "test_files"
    # Event year Season Ceremony Flag bearer Sex State / Country Sport
    62 2018 Winter Closing Diggins , Jessica Jessica Diggins Minnesota Cross-country skiing
    61 2018 Winter Opening Hamlin , Erin Erin Hamlin New York Luge
    60 2016 Summer Closing Biles , Simone Simone Biles Texas Gymnastics
    59 2016 Summer Opening Phelps , Michael Michael Phelps Maryland Swimming
    58 2014 Winter Closing Chu , Julie Julie Chu Connecticut Hockey
    57 2014 Winter Opening Lodwick , Todd Todd Lodwick Colorado Nordic combined
    56 2012 Summer Closing Nellum , Bryshon Bryshon Nellum California Athletics
    55 2012 Summer Opening Zagunis , Mariel Mariel Zagunis Oregon Fencing
    54 Winter Closing Demong , Bill Bill Demong New York Nordic combined
    53 Winter Opening Grimmette , Mark Mark Grimmette Michigan Luge
    52 2008 Summer Closing Lorig , Khatuna Khatuna Lorig Georgia ( country ) Archery
    51 2008 Summer Opening Lomong , Lopez Lopez Lomong Sudan ( now South Sudan ) Athletics
    50 2006 Winter Closing Cheek , Joey Joey Cheek North Carolina Speed skating
    49 2006 Winter Opening Witty , Chris Chris Witty Wisconsin Speed skating
    48 Summer Closing Hamm , Mia Mia Hamm Texas Women 's soccer
    47 Summer Opening Staley , Dawn Dawn Staley Pennsylvania Basketball
    46 2002 Winter Closing Shimer , Brian Brian Shimer Florida Bobsleigh
    45 2002 Winter Opening Peterson , Amy Amy Peterson Minnesota Short track speed skating
    44 2000 Summer Closing Gardner , Rulon Rulon Gardner Wyoming Wrestling
    43 2000 Summer Opening Meidl , Cliff Cliff Meidl California Canoeing
    42 1998 Winter Closing Granato , Cammi Cammi Granato Illinois Hockey
    41 1998 Winter Opening Flaim , Eric Eric Flaim Massachusetts Speed skating
    40 Summer Closing Matz , Michael Michael Matz Pennsylvania Equestrian
    39 Summer Opening Baumgartner , Bruce Bruce Baumgartner New Jersey Wrestling
    38 1994 Winter Closing Jansen , Dan Dan Jansen Wisconsin Speed skating
    37 1994 Winter Opening Myler , Cammy Cammy Myler New York