diff --git a/memgpt/config.py b/memgpt/config.py index e0987f9f5c..0f0bc41095 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -33,7 +33,7 @@ def set_field(config, section, field, value): @dataclass class MemGPTConfig: - config_path: str = os.getenv("MEMGPT_CONFIG_PATH") if os.getenv("MEMGPT_CONFIG_PATH") else os.path.join(MEMGPT_DIR, "config") + config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config") anon_clientid: str = str(uuid.UUID(int=0)) # preset @@ -196,16 +196,51 @@ def save(self): # model defaults set_field(config, "model", "model", self.default_llm_config.model) set_field(config, "model", "model_endpoint", self.default_llm_config.model_endpoint) - set_field(config, "model", "model_endpoint_type", self.default_llm_config.model_endpoint_type) + set_field( + config, + "model", + "model_endpoint_type", + self.default_llm_config.model_endpoint_type, + ) set_field(config, "model", "model_wrapper", self.default_llm_config.model_wrapper) - set_field(config, "model", "context_window", str(self.default_llm_config.context_window)) + set_field( + config, + "model", + "context_window", + str(self.default_llm_config.context_window), + ) # embeddings - set_field(config, "embedding", "embedding_endpoint_type", self.default_embedding_config.embedding_endpoint_type) - set_field(config, "embedding", "embedding_endpoint", self.default_embedding_config.embedding_endpoint) - set_field(config, "embedding", "embedding_model", self.default_embedding_config.embedding_model) - set_field(config, "embedding", "embedding_dim", str(self.default_embedding_config.embedding_dim)) - set_field(config, "embedding", "embedding_chunk_size", str(self.default_embedding_config.embedding_chunk_size)) + set_field( + config, + "embedding", + "embedding_endpoint_type", + self.default_embedding_config.embedding_endpoint_type, + ) + set_field( + config, + "embedding", + "embedding_endpoint", + self.default_embedding_config.embedding_endpoint, + ) + set_field( + config, + "embedding", + "embedding_model", + self.default_embedding_config.embedding_model, + ) + set_field( + config, + "embedding", + "embedding_dim", + str(self.default_embedding_config.embedding_dim), + ) + set_field( + config, + "embedding", + "embedding_chunk_size", + str(self.default_embedding_config.embedding_chunk_size), + ) # archival storage set_field(config, "archival_storage", "type", self.archival_storage_type) @@ -253,7 +288,16 @@ def create_config_dir(): if not os.path.exists(MEMGPT_DIR): os.makedirs(MEMGPT_DIR, exist_ok=True) - folders = ["personas", "humans", "archival", "agents", "functions", "system_prompts", "presets", "settings"] + folders = [ + "personas", + "humans", + "archival", + "agents", + "functions", + "system_prompts", + "presets", + "settings", + ] for folder in folders: if not os.path.exists(os.path.join(MEMGPT_DIR, folder)): diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb2..9dfe7151fa 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +from tests.config import TestMGPTConfig + + +TEST_MEMGPT_CONFIG = TestMGPTConfig() diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 0000000000..5843ebf80e --- /dev/null +++ b/tests/config.py @@ -0,0 +1,7 @@ +import os +from memgpt.config import MemGPTConfig +from memgpt.constants import MEMGPT_DIR + + +class TestMGPTConfig(MemGPTConfig): + config_path: str = os.getenv("TEST_MEMGPT_CONFIG_PATH") or os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config") diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py index 45fd3bd20a..ec0f1e9827 100644 --- a/tests/test_agent_function_update.py +++ b/tests/test_agent_function_update.py @@ -1,16 +1,14 @@ -from collections import UserDict import json import os import inspect import uuid -from memgpt.config import MemGPTConfig from memgpt import create_client from memgpt import constants -import memgpt.functions.function_sets.base as base_functions from memgpt.functions.functions import USER_FUNCTIONS_DIR from memgpt.utils import assistant_function_to_tool from memgpt.models import chat_completion_response +from tests import TEST_MEMGPT_CONFIG from tests.utils import wipe_config, create_config @@ -39,10 +37,8 @@ def agent(): # create memgpt client client = create_client() - config = MemGPTConfig.load() - # ensure user exists - user_id = uuid.UUID(config.anon_clientid) + user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) if not client.server.get_user(user_id=user_id): client.server.create_user({"id": user_id}) diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index c4280d51e1..a9ab69b254 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -2,9 +2,9 @@ import uuid from memgpt import create_client -from memgpt.config import MemGPTConfig from memgpt import constants import memgpt.functions.function_sets.base as base_functions +from tests import TEST_MEMGPT_CONFIG from .utils import wipe_config, create_config @@ -30,8 +30,7 @@ def create_test_agent(): ) global agent_obj - config = MemGPTConfig.load() - user_id = uuid.UUID(config.anon_clientid) + user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) diff --git a/tests/test_client.py b/tests/test_client.py index 39c3d33f3f..6d551f3f91 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,20 +1,13 @@ import uuid import time -import os import threading from memgpt import Admin, create_client -from memgpt.config import MemGPTConfig -from memgpt import constants -from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset -from memgpt.functions.functions import load_all_function_sets -from memgpt.prompts import gpt_system from memgpt.constants import DEFAULT_PRESET import pytest -from .utils import wipe_config import uuid @@ -116,9 +109,3 @@ def test_user_message(client): # print( # f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}" # ) - - -if __name__ == "__main__": - # test_create_preset() - test_create_agent() - test_user_message() diff --git a/tests/test_different_embedding_size.py b/tests/test_different_embedding_size.py index 6d36942410..ffeb36242c 100644 --- a/tests/test_different_embedding_size.py +++ b/tests/test_different_embedding_size.py @@ -2,11 +2,10 @@ import os from memgpt import create_client -from memgpt.config import MemGPTConfig -from memgpt import constants -from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage +from memgpt.data_types import EmbeddingConfig, Passage from memgpt.embeddings import embedding_model from memgpt.agent_store.storage import StorageConnector, TableType +from tests import TEST_MEMGPT_CONFIG from .utils import wipe_config, create_config import uuid @@ -21,7 +20,11 @@ def generate_passages(user, agent): # Note: the database will filter out rows that do not correspond to agent1 and test_user by default. - texts = ["This is a test passage", "This is another test passage", "Cinderella wept"] + texts = [ + "This is a test passage", + "This is another test passage", + "Cinderella wept", + ] embed_model = embedding_model(agent.embedding_config) orig_embeddings = [] passages = [] @@ -86,8 +89,7 @@ def test_create_user(): hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) # test passage dimentionality - config = MemGPTConfig.load() - storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, client.user.id) + storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id) storage.filters = {} # clear filters to be able to get all passages passages = storage.get_all() for passage in passages: diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 786976c39b..4cf33c9f87 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -9,12 +9,11 @@ from memgpt.cli.cli_load import load_directory # from memgpt.data_sources.connectors import DirectoryConnector, load_data -from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials from memgpt.metadata import MetadataStore from memgpt.data_types import User, AgentState, EmbeddingConfig -from memgpt import create_client -from .utils import wipe_config, create_config +from tests import TEST_MEMGPT_CONFIG +from .utils import wipe_config @pytest.fixture(autouse=True) @@ -37,16 +36,20 @@ def recreate_declarative_base(): @pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"]) @pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"]) -def test_load_directory(metadata_storage_connector, passage_storage_connector, clear_dynamically_created_models, recreate_declarative_base): +def test_load_directory( + metadata_storage_connector, + passage_storage_connector, + clear_dynamically_created_models, + recreate_declarative_base, +): wipe_config() # setup config - config = MemGPTConfig() if metadata_storage_connector == "postgres": if not os.getenv("PGVECTOR_TEST_DB_URL"): print("Skipping test, missing PG URI") return - config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.metadata_storage_type = "postgres" + TEST_MEMGPT_CONFIG.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + TEST_MEMGPT_CONFIG.metadata_storage_type = "postgres" elif metadata_storage_connector == "sqlite": print("testing sqlite metadata") # nothing to do (should be config defaults) @@ -56,18 +59,18 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c if not os.getenv("PGVECTOR_TEST_DB_URL"): print("Skipping test, missing PG URI") return - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.archival_storage_type = "postgres" + TEST_MEMGPT_CONFIG.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" elif passage_storage_connector == "chroma": print("testing chroma passage storage") # nothing to do (should be config defaults) else: raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented") - config.save() + TEST_MEMGPT_CONFIG.save() # create metadata store - ms = MetadataStore(config) - user = User(id=uuid.UUID(config.anon_clientid)) + ms = MetadataStore(TEST_MEMGPT_CONFIG) + user = User(id=uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)) # embedding config if os.getenv("OPENAI_API_KEY"): @@ -93,10 +96,10 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c agent = AgentState( user_id=user.id, name="test_agent", - preset=config.preset, - persona=config.persona, - human=config.human, - llm_config=config.default_llm_config, + preset=TEST_MEMGPT_CONFIG.preset, + persona=TEST_MEMGPT_CONFIG.persona, + human=TEST_MEMGPT_CONFIG.human, + llm_config=TEST_MEMGPT_CONFIG.default_llm_config, embedding_config=embedding_config, ) ms.delete_user(user.id) @@ -109,7 +112,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c print("Creating storage connectors...") user_id = user.id print("User ID", user_id) - passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) + passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id) # load data name = "test_dataset" @@ -121,7 +124,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c print("Resetting tables with delete_table...") passages_conn.delete_table() print("Re-creating tables...") - passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) + passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id) assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}" # test: load directory diff --git a/tests/test_metadata_store.py b/tests/test_metadata_store.py index 852a727e88..e3a630f952 100644 --- a/tests/test_metadata_store.py +++ b/tests/test_metadata_store.py @@ -4,27 +4,26 @@ from memgpt.agent import Agent, save_agent from memgpt.metadata import MetadataStore -from memgpt.config import MemGPTConfig -from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig +from memgpt.data_types import User, AgentState, Source, LLMConfig from memgpt.utils import get_human_text, get_persona_text +from tests import TEST_MEMGPT_CONFIG # @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"]) @pytest.mark.parametrize("storage_connector", ["sqlite"]) def test_storage(storage_connector): - config = MemGPTConfig() if storage_connector == "postgres": if not os.getenv("PGVECTOR_TEST_DB_URL"): print("Skipping test, missing PG URI") return - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.archival_storage_type = "postgres" - config.recall_storage_type = "postgres" + TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"] + TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"] + TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" + TEST_MEMGPT_CONFIG.recall_storage_type = "postgres" if storage_connector == "sqlite": - config.recall_storage_type = "local" + TEST_MEMGPT_CONFIG.recall_storage_type = "local" - ms = MetadataStore(config) + ms = MetadataStore(TEST_MEMGPT_CONFIG) # generate data user_1 = User() @@ -35,8 +34,8 @@ def test_storage(storage_connector): preset=DEFAULT_PRESET, persona=DEFAULT_PERSONA, human=DEFAULT_HUMAN, - llm_config=config.default_llm_config, - embedding_config=config.default_embedding_config, + llm_config=TEST_MEMGPT_CONFIG.default_llm_config, + embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, ) source_1 = Source(user_id=user_1.id, name="source_1") @@ -88,7 +87,7 @@ def test_storage(storage_connector): # test: updating # test: update JSON-stored LLMConfig class - print(agent_1.llm_config, config.default_llm_config) + print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config) llm_config = ms.get_agent(agent_1.id).llm_config assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}" assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}" diff --git a/tests/test_migrate.py b/tests/test_migrate.py index fa9dec66c9..955f093853 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -1,6 +1,4 @@ -import os -from memgpt.migrate import migrate_all_agents, migrate_all_sources -from memgpt.config import MemGPTConfig +from memgpt.migrate import migrate_all_agents from .utils import wipe_config from memgpt.server.server import SyncServer import shutil @@ -38,7 +36,12 @@ def test_migrate_0211(): assert len(message_ids) > 0 # assert recall memories exist - messages = server.get_agent_messages(user_id=agent_state.user_id, agent_id=agent_state.id, start=0, count=1000) + messages = server.get_agent_messages( + user_id=agent_state.user_id, + agent_id=agent_state.id, + start=0, + count=1000, + ) assert len(messages) > 0 # for source_name in source_res["migration_candidates"]: diff --git a/tests/test_openai_assistant_api.py b/tests/test_openai_assistant_api.py index 4fc02bf945..71d6d1a433 100644 --- a/tests/test_openai_assistant_api.py +++ b/tests/test_openai_assistant_api.py @@ -5,7 +5,6 @@ from memgpt.server.server import SyncServer from memgpt.server.rest_api.server import app from memgpt.constants import DEFAULT_PRESET -from memgpt.config import MemGPTConfig # TODO: modify this to run against an actual running server # def test_list_messages(): diff --git a/tests/test_server.py b/tests/test_server.py index 0272560438..5dc0da0762 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,13 +4,12 @@ import memgpt.utils as utils from dotenv import load_dotenv +from tests.config import TestMGPTConfig + utils.DEBUG = True -from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials from memgpt.server.server import SyncServer -from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User -from memgpt.embeddings import embedding_model -from memgpt.presets.presets import add_default_presets +from memgpt.data_types import EmbeddingConfig, LLMConfig from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector @@ -22,9 +21,10 @@ def server(): # Use os.getenv with a fallback to os.environ.get db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL") + assert db_url, "Missing PGVECTOR_TEST_DB_URL" if os.getenv("OPENAI_API_KEY"): - config = MemGPTConfig( + config = TestMGPTConfig( archival_storage_uri=db_url, recall_storage_uri=db_url, metadata_storage_uri=db_url, @@ -48,7 +48,7 @@ def server(): openai_key=os.getenv("OPENAI_API_KEY"), ) else: # hosted - config = MemGPTConfig( + config = TestMGPTConfig( archival_storage_uri=db_url, recall_storage_uri=db_url, metadata_storage_uri=db_url, @@ -141,7 +141,13 @@ def test_load_data(server, user_id, agent_id): source = server.create_source("test_source", user_id) # load data - archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"] + archival_memories = [ + "alpha", + "Cinderella wore a blue dress", + "Dog eat dog", + "ZZZ", + "Shishir loves indian food", + ] connector = DummyDataConnector(archival_memories) server.load_data(user_id, connector, source.name) @@ -215,10 +221,19 @@ def test_get_archival_memory(server, user_id, agent_id): # test archival memory cursor pagination cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text") cursor2, passages_2 = server.get_agent_archival_cursor( - user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text" + user_id=user_id, + agent_id=agent_id, + reverse=False, + after=cursor1, + order_by="text", ) cursor3, passages_3 = server.get_agent_archival_cursor( - user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text" + user_id=user_id, + agent_id=agent_id, + reverse=False, + before=cursor2, + limit=1000, + order_by="text", ) print("p1", [p["text"] for p in passages_1]) print("p2", [p["text"] for p in passages_2]) diff --git a/tests/test_storage.py b/tests/test_storage.py index 96d1e56495..5950d2be67 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -5,8 +5,7 @@ from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.embeddings import embedding_model, query_embedding -from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig -from memgpt.config import MemGPTConfig +from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState from memgpt.credentials import MemGPTCredentials from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.metadata import MetadataStore @@ -15,6 +14,8 @@ from datetime import datetime, timedelta +from tests import TEST_MEMGPT_CONFIG + # Note: the database will filter out rows that do not correspond to agent1 and test_user by default. texts = ["This is a test passage", "This is another test passage", "Cinderella wept"] @@ -104,7 +105,12 @@ def recreate_declarative_base(): # @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"]) # @pytest.mark.parametrize("storage_connector", ["postgres"]) @pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY]) -def test_storage(storage_connector, table_type, clear_dynamically_created_models, recreate_declarative_base): +def test_storage( + storage_connector, + table_type, + clear_dynamically_created_models, + recreate_declarative_base, +): # setup memgpt config # TODO: set env for different config path @@ -114,35 +120,34 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models # print("Removing messages", globals()['Message']) # del globals()['Message'] - config = MemGPTConfig() if storage_connector == "postgres": if not os.getenv("PGVECTOR_TEST_DB_URL"): print("Skipping test, missing PG URI") return - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.archival_storage_type = "postgres" - config.recall_storage_type = "postgres" + TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"] + TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"] + TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" + TEST_MEMGPT_CONFIG.recall_storage_type = "postgres" if storage_connector == "lancedb": # TODO: complete lancedb implementation if not os.getenv("LANCEDB_TEST_URL"): print("Skipping test, missing LanceDB URI") return - config.archival_storage_uri = os.getenv("LANCEDB_TEST_URL") - config.recall_storage_uri = os.getenv("LANCEDB_TEST_URL") - config.archival_storage_type = "lancedb" - config.recall_storage_type = "lancedb" + TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["LANCEDB_TEST_URL"] + TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["LANCEDB_TEST_URL"] + TEST_MEMGPT_CONFIG.archival_storage_type = "lancedb" + TEST_MEMGPT_CONFIG.recall_storage_type = "lancedb" if storage_connector == "chroma": if table_type == TableType.RECALL_MEMORY: print("Skipping test, chroma only supported for archival memory") return - config.archival_storage_type = "chroma" - config.archival_storage_path = "./test_chroma" + TEST_MEMGPT_CONFIG.archival_storage_type = "chroma" + TEST_MEMGPT_CONFIG.archival_storage_path = "./test_chroma" if storage_connector == "sqlite": if table_type == TableType.ARCHIVAL_MEMORY: print("Skipping test, sqlite only supported for recall memory") return - config.recall_storage_type = "sqlite" + TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite" # get embedding model embed_model = None @@ -162,27 +167,27 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models embed_model = embedding_model(embedding_config) # create user - ms = MetadataStore(config) + ms = MetadataStore(TEST_MEMGPT_CONFIG) ms.delete_user(user_id) user = User(id=user_id) agent = AgentState( user_id=user_id, name="agent_1", id=agent_1_id, - preset=config.preset, - persona=config.persona, - human=config.human, - llm_config=config.default_llm_config, - embedding_config=config.default_embedding_config, + preset=TEST_MEMGPT_CONFIG.preset, + persona=TEST_MEMGPT_CONFIG.persona, + human=TEST_MEMGPT_CONFIG.human, + llm_config=TEST_MEMGPT_CONFIG.default_llm_config, + embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, ) ms.create_user(user) ms.create_agent(agent) # create storage connector - conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id) + conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id) # conn.client.delete_collection(conn.collection.name) # clear out data conn.delete_table() - conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id) + conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id) # generate data if table_type == TableType.ARCHIVAL_MEMORY: diff --git a/tests/test_summarize.py b/tests/test_summarize.py index 82f617bc88..83a97664c9 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -2,9 +2,7 @@ import uuid from memgpt import create_client -from memgpt.config import MemGPTConfig from memgpt import constants -import memgpt.functions.function_sets.base as base_functions from .utils import wipe_config, create_config @@ -31,7 +29,6 @@ def create_test_agent(): ) global agent_obj - config = MemGPTConfig.load() agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id) @@ -48,12 +45,16 @@ def test_summarize(): # First send a few messages (5) response = client.user_message( - agent_id=agent_obj.agent_state.id, message="Hey, how's it going? What do you think about this whole shindig" + agent_id=agent_obj.agent_state.id, + message="Hey, how's it going? What do you think about this whole shindig", ) assert response is not None and len(response) > 0 print(f"test_summarize: response={response}") - response = client.user_message(agent_id=agent_obj.agent_state.id, message="Any thoughts on the meaning of life?") + response = client.user_message( + agent_id=agent_obj.agent_state.id, + message="Any thoughts on the meaning of life?", + ) assert response is not None and len(response) > 0 print(f"test_summarize: response={response}") @@ -62,7 +63,8 @@ def test_summarize(): print(f"test_summarize: response={response}") response = client.user_message( - agent_id=agent_obj.agent_state.id, message="Would you be surprised to learn that you're actually conversing with an AI right now?" + agent_id=agent_obj.agent_state.id, + message="Would you be surprised to learn that you're actually conversing with an AI right now?", ) assert response is not None and len(response) > 0 print(f"test_summarize: response={response}") diff --git a/tests/test_websocket_interface.py b/tests/test_websocket_interface.py index 79b445a6b7..5d0671ef52 100644 --- a/tests/test_websocket_interface.py +++ b/tests/test_websocket_interface.py @@ -1,13 +1,11 @@ import os import pytest -from unittest.mock import Mock, AsyncMock, MagicMock +from unittest.mock import AsyncMock -from memgpt.config import MemGPTConfig, AgentConfig +from memgpt.credentials import MemGPTCredentials from memgpt.server.ws_api.interface import SyncWebSocketInterface import memgpt.presets.presets as presets -import memgpt.utils as utils import memgpt.system as system -from memgpt.persistence_manager import LocalStateManager from memgpt.data_types import AgentState @@ -62,10 +60,10 @@ async def test_websockets(): if api_key is None: ws_interface.close() return - config = MemGPTConfig.load() - if config.openai_key is None: - config.openai_key = api_key - config.save() + credentials = MemGPTCredentials.load() + if credentials.openai_key is None: + credentials.openai_key = api_key + credentials.save() # Mock the persistence manager # create agents with defaults diff --git a/tests/utils.py b/tests/utils.py index f108d272f0..c0d2381ec0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,11 +2,10 @@ from typing import Dict, List, Tuple, Iterator import os -from memgpt.config import MemGPTConfig from memgpt.cli.cli import quickstart, QuickstartChoice from memgpt.data_sources.connectors import DataConnector -from memgpt import Admin from memgpt.data_types import Document +from tests import TEST_MEMGPT_CONFIG from .constants import TIMEOUT @@ -37,17 +36,17 @@ def create_config(endpoint="openai"): def wipe_config(): - if MemGPTConfig.exists(): + if TEST_MEMGPT_CONFIG.exists(): # delete if os.getenv("MEMGPT_CONFIG_PATH"): config_path = os.getenv("MEMGPT_CONFIG_PATH") else: - config_path = MemGPTConfig.config_path + config_path = TEST_MEMGPT_CONFIG.config_path # TODO delete file config_path os.remove(config_path) - assert not MemGPTConfig.exists(), "Config should not exist after deletion" + assert not TEST_MEMGPT_CONFIG.exists(), "Config should not exist after deletion" else: - print("No config to wipe", MemGPTConfig.config_path) + print("No config to wipe", TEST_MEMGPT_CONFIG.config_path) def wipe_memgpt_home(): @@ -63,7 +62,7 @@ def wipe_memgpt_home(): os.system(f"mv ~/.memgpt {backup_dir}") # Setup the initial directory - MemGPTConfig.create_config_dir() + TEST_MEMGPT_CONFIG.create_config_dir() def configure_memgpt_localllm():