diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index c660fa85d6d..c29ced376a1 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -254,7 +254,10 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA def get_file_from_url(url: str, save_path: str = None): """Download a file from a URL.""" if save_path is None: + os.makedirs("/tmp/chromadb", exist_ok=True) save_path = os.path.join("/tmp/chromadb", os.path.basename(url)) + else: + os.makedirs(os.path.dirname(save_path), exist_ok=True) with requests.get(url, stream=True) as r: r.raise_for_status() with open(save_path, "wb") as f: diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index a1c70d9cf28..81fb1a0969a 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -140,7 +140,7 @@ def query_vector_db( db = lancedb.connect(db_path) table = db.open_table("my_table") query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() - return {"ids": query["id"].tolist(), "documents": query["documents"].tolist()} + return {"ids": [query["id"].tolist()], "documents": [query["documents"].tolist()]} def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): results = self.query_vector_db( @@ -166,7 +166,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = create_lancedb() ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") - assert ragragproxyagent._results["ids"] == [3, 1, 5] + assert ragragproxyagent._results["ids"] == [[3, 1, 5]] def test_custom_text_split_function(self): def custom_text_split_function(text):