diff --git a/memgpt/config.py b/memgpt/config.py index 9886d90f74..a75a4a432d 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -5,17 +5,12 @@ import uuid from dataclasses import dataclass, field import configparser -import typer -import questionary -from typing import Optional import memgpt import memgpt.utils as utils -from memgpt.utils import printd, get_schema_diff -from memgpt.functions.functions import load_all_function_sets -from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET -from memgpt.data_types import AgentState, User, LLMConfig, EmbeddingConfig +from memgpt.constants import MEMGPT_DIR, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET +from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig # helper functions for writing to configs @@ -38,7 +33,7 @@ def set_field(config, section, field, value): @dataclass class MemGPTConfig: - config_path: str = os.getenv("MEMGPT_CONFIG_PATH") if os.getenv("MEMGPT_CONFIG_PATH") else os.path.join(MEMGPT_DIR, "config") + config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config") anon_clientid: str = None # preset diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb2..b077b4e3bf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +from tests.config import TestMGPTConfig + + +TEST_MEMGPT_CONFIG = TestMGPTConfig() \ No newline at end of file diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 0000000000..d5d6aa2f99 --- /dev/null +++ b/tests/config.py @@ -0,0 +1,8 @@ +import os +from memgpt.config import MemGPTConfig +from memgpt.constants import MEMGPT_DIR + + +class TestMGPTConfig(MemGPTConfig): + config_path: str = os.getenv("TEST_MEMGPT_CONFIG_PATH") or os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config") + \ No newline at end of file diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py index 45fd3bd20a..ec0f1e9827 100644 --- a/tests/test_agent_function_update.py +++ b/tests/test_agent_function_update.py @@ -1,16 +1,14 @@ -from collections import UserDict import json import os import inspect import uuid -from memgpt.config import MemGPTConfig from memgpt import create_client from memgpt import constants -import memgpt.functions.function_sets.base as base_functions from memgpt.functions.functions import USER_FUNCTIONS_DIR from memgpt.utils import assistant_function_to_tool from memgpt.models import chat_completion_response +from tests import TEST_MEMGPT_CONFIG from tests.utils import wipe_config, create_config @@ -39,10 +37,8 @@ def agent(): # create memgpt client client = create_client() - config = MemGPTConfig.load() - # ensure user exists - user_id = uuid.UUID(config.anon_clientid) + user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) if not client.server.get_user(user_id=user_id): client.server.create_user({"id": user_id}) diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index c4280d51e1..a9ab69b254 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -2,9 +2,9 @@ import uuid from memgpt import create_client -from memgpt.config import MemGPTConfig from memgpt import constants import memgpt.functions.function_sets.base as base_functions +from tests import TEST_MEMGPT_CONFIG from .utils import wipe_config, create_config @@ -30,8 +30,7 @@ def create_test_agent(): ) global agent_obj - config = MemGPTConfig.load() - user_id = uuid.UUID(config.anon_clientid) + user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) diff --git a/tests/test_client.py b/tests/test_client.py index 39c3d33f3f..088e17041f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,20 +1,13 @@ import uuid import time -import os import threading from memgpt import Admin, create_client -from memgpt.config import MemGPTConfig -from memgpt import constants -from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset -from memgpt.functions.functions import load_all_function_sets -from memgpt.prompts import gpt_system from memgpt.constants import DEFAULT_PRESET import pytest -from .utils import wipe_config import uuid diff --git a/tests/test_different_embedding_size.py b/tests/test_different_embedding_size.py index 6d36942410..bfbbf875af 100644 --- a/tests/test_different_embedding_size.py +++ b/tests/test_different_embedding_size.py @@ -2,11 +2,10 @@ import os from memgpt import create_client -from memgpt.config import MemGPTConfig -from memgpt import constants -from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage +from memgpt.data_types import EmbeddingConfig, Passage from memgpt.embeddings import embedding_model from memgpt.agent_store.storage import StorageConnector, TableType +from tests import TEST_MEMGPT_CONFIG from .utils import wipe_config, create_config import uuid @@ -86,8 +85,7 @@ def test_create_user(): hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) # test passage dimentionality - config = MemGPTConfig.load() - storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, client.user.id) + storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id) storage.filters = {} # clear filters to be able to get all passages passages = storage.get_all() for passage in passages: diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 786976c39b..60f8f03a4f 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -9,12 +9,11 @@ from memgpt.cli.cli_load import load_directory # from memgpt.data_sources.connectors import DirectoryConnector, load_data -from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials from memgpt.metadata import MetadataStore from memgpt.data_types import User, AgentState, EmbeddingConfig -from memgpt import create_client -from .utils import wipe_config, create_config +from tests import TEST_MEMGPT_CONFIG +from .utils import wipe_config @pytest.fixture(autouse=True) @@ -40,13 +39,12 @@ def recreate_declarative_base(): def test_load_directory(metadata_storage_connector, passage_storage_connector, clear_dynamically_created_models, recreate_declarative_base): wipe_config() # setup config - config = MemGPTConfig() if metadata_storage_connector == "postgres": if not os.getenv("PGVECTOR_TEST_DB_URL"): print("Skipping test, missing PG URI") return - config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.metadata_storage_type = "postgres" + TEST_MEMGPT_CONFIG.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + TEST_MEMGPT_CONFIG.metadata_storage_type = "postgres" elif metadata_storage_connector == "sqlite": print("testing sqlite metadata") # nothing to do (should be config defaults) @@ -56,18 +54,18 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c if not os.getenv("PGVECTOR_TEST_DB_URL"): print("Skipping test, missing PG URI") return - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.archival_storage_type = "postgres" + TEST_MEMGPT_CONFIG.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" elif passage_storage_connector == "chroma": print("testing chroma passage storage") # nothing to do (should be config defaults) else: raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented") - config.save() + TEST_MEMGPT_CONFIG.save() # create metadata store - ms = MetadataStore(config) - user = User(id=uuid.UUID(config.anon_clientid)) + ms = MetadataStore(TEST_MEMGPT_CONFIG) + user = User(id=uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)) # embedding config if os.getenv("OPENAI_API_KEY"): @@ -93,10 +91,10 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c agent = AgentState( user_id=user.id, name="test_agent", - preset=config.preset, - persona=config.persona, - human=config.human, - llm_config=config.default_llm_config, + preset=TEST_MEMGPT_CONFIG.preset, + persona=TEST_MEMGPT_CONFIG.persona, + human=TEST_MEMGPT_CONFIG.human, + llm_config=TEST_MEMGPT_CONFIG.default_llm_config, embedding_config=embedding_config, ) ms.delete_user(user.id) @@ -109,7 +107,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c print("Creating storage connectors...") user_id = user.id print("User ID", user_id) - passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) + passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id) # load data name = "test_dataset" @@ -121,7 +119,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c print("Resetting tables with delete_table...") passages_conn.delete_table() print("Re-creating tables...") - passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) + passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id) assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}" # test: load directory diff --git a/tests/test_metadata_store.py b/tests/test_metadata_store.py index cac9c0cb17..c7aa40e993 100644 --- a/tests/test_metadata_store.py +++ b/tests/test_metadata_store.py @@ -3,26 +3,25 @@ import pytest from memgpt.metadata import MetadataStore -from memgpt.config import MemGPTConfig -from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig +from memgpt.data_types import User, AgentState, Source, LLMConfig +from tests import TEST_MEMGPT_CONFIG # @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"]) @pytest.mark.parametrize("storage_connector", ["sqlite"]) def test_storage(storage_connector): - config = MemGPTConfig() if storage_connector == "postgres": if not os.getenv("PGVECTOR_TEST_DB_URL"): print("Skipping test, missing PG URI") return - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.archival_storage_type = "postgres" - config.recall_storage_type = "postgres" + TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"] + TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"] + TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" + TEST_MEMGPT_CONFIG.recall_storage_type = "postgres" if storage_connector == "sqlite": - config.recall_storage_type = "local" + TEST_MEMGPT_CONFIG.recall_storage_type = "local" - ms = MetadataStore(config) + ms = MetadataStore(TEST_MEMGPT_CONFIG) # generate data user_1 = User() @@ -33,8 +32,8 @@ def test_storage(storage_connector): preset=DEFAULT_PRESET, persona=DEFAULT_PERSONA, human=DEFAULT_HUMAN, - llm_config=config.default_llm_config, - embedding_config=config.default_embedding_config, + llm_config=TEST_MEMGPT_CONFIG.default_llm_config, + embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, ) source_1 = Source(user_id=user_1.id, name="source_1") @@ -53,7 +52,7 @@ def test_storage(storage_connector): # test: updating # test: update JSON-stored LLMConfig class - print(agent_1.llm_config, config.default_llm_config) + print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config) llm_config = ms.get_agent(agent_1.id).llm_config assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}" assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}" diff --git a/tests/test_migrate.py b/tests/test_migrate.py index fa9dec66c9..708b1a5459 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -1,6 +1,4 @@ -import os -from memgpt.migrate import migrate_all_agents, migrate_all_sources -from memgpt.config import MemGPTConfig +from memgpt.migrate import migrate_all_agents from .utils import wipe_config from memgpt.server.server import SyncServer import shutil diff --git a/tests/test_openai_assistant_api.py b/tests/test_openai_assistant_api.py index 4fc02bf945..71d6d1a433 100644 --- a/tests/test_openai_assistant_api.py +++ b/tests/test_openai_assistant_api.py @@ -5,7 +5,6 @@ from memgpt.server.server import SyncServer from memgpt.server.rest_api.server import app from memgpt.constants import DEFAULT_PRESET -from memgpt.config import MemGPTConfig # TODO: modify this to run against an actual running server # def test_list_messages(): diff --git a/tests/test_server.py b/tests/test_server.py index 0272560438..d0048a80c9 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,13 +4,12 @@ import memgpt.utils as utils from dotenv import load_dotenv +from tests.config import TestMGPTConfig + utils.DEBUG = True -from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials from memgpt.server.server import SyncServer -from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User -from memgpt.embeddings import embedding_model -from memgpt.presets.presets import add_default_presets +from memgpt.data_types import EmbeddingConfig, LLMConfig from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector @@ -22,9 +21,10 @@ def server(): # Use os.getenv with a fallback to os.environ.get db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL") + assert db_url, "Missing PGVECTOR_TEST_DB_URL" if os.getenv("OPENAI_API_KEY"): - config = MemGPTConfig( + config = TestMGPTConfig( archival_storage_uri=db_url, recall_storage_uri=db_url, metadata_storage_uri=db_url, @@ -48,7 +48,7 @@ def server(): openai_key=os.getenv("OPENAI_API_KEY"), ) else: # hosted - config = MemGPTConfig( + config = TestMGPTConfig( archival_storage_uri=db_url, recall_storage_uri=db_url, metadata_storage_uri=db_url, diff --git a/tests/test_storage.py b/tests/test_storage.py index 96d1e56495..fc7013d354 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -5,8 +5,7 @@ from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.embeddings import embedding_model, query_embedding -from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig -from memgpt.config import MemGPTConfig +from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState from memgpt.credentials import MemGPTCredentials from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.metadata import MetadataStore @@ -15,6 +14,8 @@ from datetime import datetime, timedelta +from tests import TEST_MEMGPT_CONFIG + # Note: the database will filter out rows that do not correspond to agent1 and test_user by default. texts = ["This is a test passage", "This is another test passage", "Cinderella wept"] @@ -114,35 +115,34 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models # print("Removing messages", globals()['Message']) # del globals()['Message'] - config = MemGPTConfig() if storage_connector == "postgres": if not os.getenv("PGVECTOR_TEST_DB_URL"): print("Skipping test, missing PG URI") return - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.archival_storage_type = "postgres" - config.recall_storage_type = "postgres" + TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"] + TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"] + TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" + TEST_MEMGPT_CONFIG.recall_storage_type = "postgres" if storage_connector == "lancedb": # TODO: complete lancedb implementation if not os.getenv("LANCEDB_TEST_URL"): print("Skipping test, missing LanceDB URI") return - config.archival_storage_uri = os.getenv("LANCEDB_TEST_URL") - config.recall_storage_uri = os.getenv("LANCEDB_TEST_URL") - config.archival_storage_type = "lancedb" - config.recall_storage_type = "lancedb" + TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["LANCEDB_TEST_URL"] + TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["LANCEDB_TEST_URL"] + TEST_MEMGPT_CONFIG.archival_storage_type = "lancedb" + TEST_MEMGPT_CONFIG.recall_storage_type = "lancedb" if storage_connector == "chroma": if table_type == TableType.RECALL_MEMORY: print("Skipping test, chroma only supported for archival memory") return - config.archival_storage_type = "chroma" - config.archival_storage_path = "./test_chroma" + TEST_MEMGPT_CONFIG.archival_storage_type = "chroma" + TEST_MEMGPT_CONFIG.archival_storage_path = "./test_chroma" if storage_connector == "sqlite": if table_type == TableType.ARCHIVAL_MEMORY: print("Skipping test, sqlite only supported for recall memory") return - config.recall_storage_type = "sqlite" + TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite" # get embedding model embed_model = None @@ -162,27 +162,27 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models embed_model = embedding_model(embedding_config) # create user - ms = MetadataStore(config) + ms = MetadataStore(TEST_MEMGPT_CONFIG) ms.delete_user(user_id) user = User(id=user_id) agent = AgentState( user_id=user_id, name="agent_1", id=agent_1_id, - preset=config.preset, - persona=config.persona, - human=config.human, - llm_config=config.default_llm_config, - embedding_config=config.default_embedding_config, + preset=TEST_MEMGPT_CONFIG.preset, + persona=TEST_MEMGPT_CONFIG.persona, + human=TEST_MEMGPT_CONFIG.human, + llm_config=TEST_MEMGPT_CONFIG.default_llm_config, + embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, ) ms.create_user(user) ms.create_agent(agent) # create storage connector - conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id) + conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id) # conn.client.delete_collection(conn.collection.name) # clear out data conn.delete_table() - conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id) + conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id) # generate data if table_type == TableType.ARCHIVAL_MEMORY: diff --git a/tests/test_summarize.py b/tests/test_summarize.py index 82f617bc88..4f8f0657a5 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -2,9 +2,7 @@ import uuid from memgpt import create_client -from memgpt.config import MemGPTConfig from memgpt import constants -import memgpt.functions.function_sets.base as base_functions from .utils import wipe_config, create_config @@ -31,7 +29,6 @@ def create_test_agent(): ) global agent_obj - config = MemGPTConfig.load() agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id) diff --git a/tests/test_websocket_interface.py b/tests/test_websocket_interface.py index 79b445a6b7..5d0671ef52 100644 --- a/tests/test_websocket_interface.py +++ b/tests/test_websocket_interface.py @@ -1,13 +1,11 @@ import os import pytest -from unittest.mock import Mock, AsyncMock, MagicMock +from unittest.mock import AsyncMock -from memgpt.config import MemGPTConfig, AgentConfig +from memgpt.credentials import MemGPTCredentials from memgpt.server.ws_api.interface import SyncWebSocketInterface import memgpt.presets.presets as presets -import memgpt.utils as utils import memgpt.system as system -from memgpt.persistence_manager import LocalStateManager from memgpt.data_types import AgentState @@ -62,10 +60,10 @@ async def test_websockets(): if api_key is None: ws_interface.close() return - config = MemGPTConfig.load() - if config.openai_key is None: - config.openai_key = api_key - config.save() + credentials = MemGPTCredentials.load() + if credentials.openai_key is None: + credentials.openai_key = api_key + credentials.save() # Mock the persistence manager # create agents with defaults diff --git a/tests/utils.py b/tests/utils.py index 973c420e08..16e3b88b49 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,11 +2,10 @@ from typing import Dict, List, Tuple, Iterator import os -from memgpt.config import MemGPTConfig from memgpt.cli.cli import quickstart, QuickstartChoice from memgpt.data_sources.connectors import DataConnector -from memgpt import Admin from memgpt.data_types import Document +from tests import TEST_MEMGPT_CONFIG from .constants import TIMEOUT @@ -38,17 +37,17 @@ def create_config(endpoint="openai"): def wipe_config(): - if MemGPTConfig.exists(): + if TEST_MEMGPT_CONFIG.exists(): # delete if os.getenv("MEMGPT_CONFIG_PATH"): config_path = os.getenv("MEMGPT_CONFIG_PATH") else: - config_path = MemGPTConfig.config_path + config_path = TEST_MEMGPT_CONFIG.config_path # TODO delete file config_path os.remove(config_path) - assert not MemGPTConfig.exists(), "Config should not exist after deletion" + assert not TEST_MEMGPT_CONFIG.exists(), "Config should not exist after deletion" else: - print("No config to wipe", MemGPTConfig.config_path) + print("No config to wipe", TEST_MEMGPT_CONFIG.config_path) def wipe_memgpt_home(): @@ -64,7 +63,7 @@ def wipe_memgpt_home(): os.system(f"mv ~/.memgpt {backup_dir}") # Setup the initial directory - MemGPTConfig.create_config_dir() + TEST_MEMGPT_CONFIG.create_config_dir() def configure_memgpt_localllm():