Skip to content

Commit

Permalink
feat: isolate test config from main config
Browse files Browse the repository at this point in the history
allow for directing test config to a different variable than the main
app config.

also set the stage for making MemGPTConfig a constant, rather than a
function argument.
  • Loading branch information
tombedor committed Feb 28, 2024
1 parent 397ddb8 commit 5a345f9
Show file tree
Hide file tree
Showing 16 changed files with 89 additions and 108 deletions.
11 changes: 3 additions & 8 deletions memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@
import uuid
from dataclasses import dataclass, field
import configparser
import typer
import questionary
from typing import Optional

import memgpt
import memgpt.utils as utils
from memgpt.utils import printd, get_schema_diff
from memgpt.functions.functions import load_all_function_sets

from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.data_types import AgentState, User, LLMConfig, EmbeddingConfig
from memgpt.constants import MEMGPT_DIR, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig


# helper functions for writing to configs
Expand All @@ -38,7 +33,7 @@ def set_field(config, section, field, value):

@dataclass
class MemGPTConfig:
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") if os.getenv("MEMGPT_CONFIG_PATH") else os.path.join(MEMGPT_DIR, "config")
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")
anon_clientid: str = None

# preset
Expand Down
4 changes: 4 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from tests.config import TestMGPTConfig


TEST_MEMGPT_CONFIG = TestMGPTConfig()
8 changes: 8 additions & 0 deletions tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import os
from memgpt.config import MemGPTConfig
from memgpt.constants import MEMGPT_DIR


class TestMGPTConfig(MemGPTConfig):
config_path: str = os.getenv("TEST_MEMGPT_CONFIG_PATH") or os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")

8 changes: 2 additions & 6 deletions tests/test_agent_function_update.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from collections import UserDict
import json
import os
import inspect
import uuid

from memgpt.config import MemGPTConfig
from memgpt import create_client
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from memgpt.functions.functions import USER_FUNCTIONS_DIR
from memgpt.utils import assistant_function_to_tool
from memgpt.models import chat_completion_response
from tests import TEST_MEMGPT_CONFIG

from tests.utils import wipe_config, create_config

Expand Down Expand Up @@ -39,10 +37,8 @@ def agent():
# create memgpt client
client = create_client()

config = MemGPTConfig.load()

# ensure user exists
user_id = uuid.UUID(config.anon_clientid)
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
if not client.server.get_user(user_id=user_id):
client.server.create_user({"id": user_id})

Expand Down
5 changes: 2 additions & 3 deletions tests/test_base_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import uuid

from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config, create_config


Expand All @@ -30,8 +30,7 @@ def create_test_agent():
)

global agent_obj
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)


Expand Down
7 changes: 0 additions & 7 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import uuid
import time
import os
import threading

from memgpt import Admin, create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
from memgpt.functions.functions import load_all_function_sets
from memgpt.prompts import gpt_system
from memgpt.constants import DEFAULT_PRESET

import pytest


from .utils import wipe_config
import uuid


Expand Down
8 changes: 3 additions & 5 deletions tests/test_different_embedding_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import os

from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage
from memgpt.data_types import EmbeddingConfig, Passage
from memgpt.embeddings import embedding_model
from memgpt.agent_store.storage import StorageConnector, TableType
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config, create_config
import uuid

Expand Down Expand Up @@ -86,8 +85,7 @@ def test_create_user():
hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)

# test passage dimentionality
config = MemGPTConfig.load()
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, client.user.id)
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id)
storage.filters = {} # clear filters to be able to get all passages
passages = storage.get_all()
for passage in passages:
Expand Down
32 changes: 15 additions & 17 deletions tests/test_load_archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from memgpt.cli.cli_load import load_directory

# from memgpt.data_sources.connectors import DirectoryConnector, load_data
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.metadata import MetadataStore
from memgpt.data_types import User, AgentState, EmbeddingConfig
from memgpt import create_client
from .utils import wipe_config, create_config
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config


@pytest.fixture(autouse=True)
Expand All @@ -40,13 +39,12 @@ def recreate_declarative_base():
def test_load_directory(metadata_storage_connector, passage_storage_connector, clear_dynamically_created_models, recreate_declarative_base):
wipe_config()
# setup config
config = MemGPTConfig()
if metadata_storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.metadata_storage_type = "postgres"
TEST_MEMGPT_CONFIG.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
TEST_MEMGPT_CONFIG.metadata_storage_type = "postgres"
elif metadata_storage_connector == "sqlite":
print("testing sqlite metadata")
# nothing to do (should be config defaults)
Expand All @@ -56,18 +54,18 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
elif passage_storage_connector == "chroma":
print("testing chroma passage storage")
# nothing to do (should be config defaults)
else:
raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented")
config.save()
TEST_MEMGPT_CONFIG.save()

# create metadata store
ms = MetadataStore(config)
user = User(id=uuid.UUID(config.anon_clientid))
ms = MetadataStore(TEST_MEMGPT_CONFIG)
user = User(id=uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid))

# embedding config
if os.getenv("OPENAI_API_KEY"):
Expand All @@ -93,10 +91,10 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
agent = AgentState(
user_id=user.id,
name="test_agent",
preset=config.preset,
persona=config.persona,
human=config.human,
llm_config=config.default_llm_config,
preset=TEST_MEMGPT_CONFIG.preset,
persona=TEST_MEMGPT_CONFIG.persona,
human=TEST_MEMGPT_CONFIG.human,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=embedding_config,
)
ms.delete_user(user.id)
Expand All @@ -109,7 +107,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
print("Creating storage connectors...")
user_id = user.id
print("User ID", user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)

# load data
name = "test_dataset"
Expand All @@ -121,7 +119,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
print("Resetting tables with delete_table...")
passages_conn.delete_table()
print("Re-creating tables...")
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"

# test: load directory
Expand Down
23 changes: 11 additions & 12 deletions tests/test_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,25 @@
import pytest

from memgpt.metadata import MetadataStore
from memgpt.config import MemGPTConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig
from tests import TEST_MEMGPT_CONFIG


# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
@pytest.mark.parametrize("storage_connector", ["sqlite"])
def test_storage(storage_connector):
config = MemGPTConfig()
if storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
if storage_connector == "sqlite":
config.recall_storage_type = "local"
TEST_MEMGPT_CONFIG.recall_storage_type = "local"

ms = MetadataStore(config)
ms = MetadataStore(TEST_MEMGPT_CONFIG)

# generate data
user_1 = User()
Expand All @@ -33,8 +32,8 @@ def test_storage(storage_connector):
preset=DEFAULT_PRESET,
persona=DEFAULT_PERSONA,
human=DEFAULT_HUMAN,
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)
source_1 = Source(user_id=user_1.id, name="source_1")

Expand All @@ -53,7 +52,7 @@ def test_storage(storage_connector):
# test: updating

# test: update JSON-stored LLMConfig class
print(agent_1.llm_config, config.default_llm_config)
print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config)
llm_config = ms.get_agent(agent_1.id).llm_config
assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}"
assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}"
Expand Down
4 changes: 1 addition & 3 deletions tests/test_migrate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.config import MemGPTConfig
from memgpt.migrate import migrate_all_agents
from .utils import wipe_config
from memgpt.server.server import SyncServer
import shutil
Expand Down
1 change: 0 additions & 1 deletion tests/test_openai_assistant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.server import app
from memgpt.constants import DEFAULT_PRESET
from memgpt.config import MemGPTConfig

# TODO: modify this to run against an actual running server
# def test_list_messages():
Expand Down
12 changes: 6 additions & 6 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import memgpt.utils as utils
from dotenv import load_dotenv

from tests.config import TestMGPTConfig

utils.DEBUG = True
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
from memgpt.embeddings import embedding_model
from memgpt.presets.presets import add_default_presets
from memgpt.data_types import EmbeddingConfig, LLMConfig
from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector


Expand All @@ -22,9 +21,10 @@ def server():

# Use os.getenv with a fallback to os.environ.get
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
assert db_url, "Missing PGVECTOR_TEST_DB_URL"

if os.getenv("OPENAI_API_KEY"):
config = MemGPTConfig(
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
Expand All @@ -48,7 +48,7 @@ def server():
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
config = MemGPTConfig(
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
Expand Down
Loading

0 comments on commit 5a345f9

Please sign in to comment.