Skip to content

Commit

Permalink
feat: isolate test config from main config
Browse files Browse the repository at this point in the history
allow for directing test config to a different variable than the main
app config.

also set the stage for making MemGPTConfig a constant, rather than a
function argument.
  • Loading branch information
tombedor committed Mar 2, 2024
1 parent acb73dc commit e1cecc0
Show file tree
Hide file tree
Showing 16 changed files with 187 additions and 125 deletions.
62 changes: 53 additions & 9 deletions memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def set_field(config, section, field, value):

@dataclass
class MemGPTConfig:
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") if os.getenv("MEMGPT_CONFIG_PATH") else os.path.join(MEMGPT_DIR, "config")
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")
anon_clientid: str = str(uuid.UUID(int=0))

# preset
Expand Down Expand Up @@ -196,16 +196,51 @@ def save(self):
# model defaults
set_field(config, "model", "model", self.default_llm_config.model)
set_field(config, "model", "model_endpoint", self.default_llm_config.model_endpoint)
set_field(config, "model", "model_endpoint_type", self.default_llm_config.model_endpoint_type)
set_field(
config,
"model",
"model_endpoint_type",
self.default_llm_config.model_endpoint_type,
)
set_field(config, "model", "model_wrapper", self.default_llm_config.model_wrapper)
set_field(config, "model", "context_window", str(self.default_llm_config.context_window))
set_field(
config,
"model",
"context_window",
str(self.default_llm_config.context_window),
)

# embeddings
set_field(config, "embedding", "embedding_endpoint_type", self.default_embedding_config.embedding_endpoint_type)
set_field(config, "embedding", "embedding_endpoint", self.default_embedding_config.embedding_endpoint)
set_field(config, "embedding", "embedding_model", self.default_embedding_config.embedding_model)
set_field(config, "embedding", "embedding_dim", str(self.default_embedding_config.embedding_dim))
set_field(config, "embedding", "embedding_chunk_size", str(self.default_embedding_config.embedding_chunk_size))
set_field(
config,
"embedding",
"embedding_endpoint_type",
self.default_embedding_config.embedding_endpoint_type,
)
set_field(
config,
"embedding",
"embedding_endpoint",
self.default_embedding_config.embedding_endpoint,
)
set_field(
config,
"embedding",
"embedding_model",
self.default_embedding_config.embedding_model,
)
set_field(
config,
"embedding",
"embedding_dim",
str(self.default_embedding_config.embedding_dim),
)
set_field(
config,
"embedding",
"embedding_chunk_size",
str(self.default_embedding_config.embedding_chunk_size),
)

# archival storage
set_field(config, "archival_storage", "type", self.archival_storage_type)
Expand Down Expand Up @@ -253,7 +288,16 @@ def create_config_dir():
if not os.path.exists(MEMGPT_DIR):
os.makedirs(MEMGPT_DIR, exist_ok=True)

folders = ["personas", "humans", "archival", "agents", "functions", "system_prompts", "presets", "settings"]
folders = [
"personas",
"humans",
"archival",
"agents",
"functions",
"system_prompts",
"presets",
"settings",
]

for folder in folders:
if not os.path.exists(os.path.join(MEMGPT_DIR, folder)):
Expand Down
4 changes: 4 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from tests.config import TestMGPTConfig


TEST_MEMGPT_CONFIG = TestMGPTConfig()
7 changes: 7 additions & 0 deletions tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os
from memgpt.config import MemGPTConfig
from memgpt.constants import MEMGPT_DIR


class TestMGPTConfig(MemGPTConfig):
config_path: str = os.getenv("TEST_MEMGPT_CONFIG_PATH") or os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")
8 changes: 2 additions & 6 deletions tests/test_agent_function_update.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from collections import UserDict
import json
import os
import inspect
import uuid

from memgpt.config import MemGPTConfig
from memgpt import create_client
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from memgpt.functions.functions import USER_FUNCTIONS_DIR
from memgpt.utils import assistant_function_to_tool
from memgpt.models import chat_completion_response
from tests import TEST_MEMGPT_CONFIG

from tests.utils import wipe_config, create_config

Expand Down Expand Up @@ -39,10 +37,8 @@ def agent():
# create memgpt client
client = create_client()

config = MemGPTConfig.load()

# ensure user exists
user_id = uuid.UUID(config.anon_clientid)
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
if not client.server.get_user(user_id=user_id):
client.server.create_user({"id": user_id})

Expand Down
5 changes: 2 additions & 3 deletions tests/test_base_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import uuid

from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config, create_config


Expand All @@ -30,8 +30,7 @@ def create_test_agent():
)

global agent_obj
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)


Expand Down
13 changes: 0 additions & 13 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import uuid
import time
import os
import threading

from memgpt import Admin, create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
from memgpt.functions.functions import load_all_function_sets
from memgpt.prompts import gpt_system
from memgpt.constants import DEFAULT_PRESET

import pytest


from .utils import wipe_config
import uuid


Expand Down Expand Up @@ -116,9 +109,3 @@ def test_user_message(client):
# print(
# f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}"
# )


if __name__ == "__main__":
# test_create_preset()
test_create_agent()
test_user_message()
14 changes: 8 additions & 6 deletions tests/test_different_embedding_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import os

from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage
from memgpt.data_types import EmbeddingConfig, Passage
from memgpt.embeddings import embedding_model
from memgpt.agent_store.storage import StorageConnector, TableType
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config, create_config
import uuid

Expand All @@ -21,7 +20,11 @@

def generate_passages(user, agent):
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
texts = [
"This is a test passage",
"This is another test passage",
"Cinderella wept",
]
embed_model = embedding_model(agent.embedding_config)
orig_embeddings = []
passages = []
Expand Down Expand Up @@ -86,8 +89,7 @@ def test_create_user():
hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)

# test passage dimentionality
config = MemGPTConfig.load()
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, client.user.id)
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id)
storage.filters = {} # clear filters to be able to get all passages
passages = storage.get_all()
for passage in passages:
Expand Down
39 changes: 21 additions & 18 deletions tests/test_load_archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from memgpt.cli.cli_load import load_directory

# from memgpt.data_sources.connectors import DirectoryConnector, load_data
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.metadata import MetadataStore
from memgpt.data_types import User, AgentState, EmbeddingConfig
from memgpt import create_client
from .utils import wipe_config, create_config
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config


@pytest.fixture(autouse=True)
Expand All @@ -37,16 +36,20 @@ def recreate_declarative_base():

@pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"])
@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"])
def test_load_directory(metadata_storage_connector, passage_storage_connector, clear_dynamically_created_models, recreate_declarative_base):
def test_load_directory(
metadata_storage_connector,
passage_storage_connector,
clear_dynamically_created_models,
recreate_declarative_base,
):
wipe_config()
# setup config
config = MemGPTConfig()
if metadata_storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.metadata_storage_type = "postgres"
TEST_MEMGPT_CONFIG.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
TEST_MEMGPT_CONFIG.metadata_storage_type = "postgres"
elif metadata_storage_connector == "sqlite":
print("testing sqlite metadata")
# nothing to do (should be config defaults)
Expand All @@ -56,18 +59,18 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
elif passage_storage_connector == "chroma":
print("testing chroma passage storage")
# nothing to do (should be config defaults)
else:
raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented")
config.save()
TEST_MEMGPT_CONFIG.save()

# create metadata store
ms = MetadataStore(config)
user = User(id=uuid.UUID(config.anon_clientid))
ms = MetadataStore(TEST_MEMGPT_CONFIG)
user = User(id=uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid))

# embedding config
if os.getenv("OPENAI_API_KEY"):
Expand All @@ -93,10 +96,10 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
agent = AgentState(
user_id=user.id,
name="test_agent",
preset=config.preset,
persona=config.persona,
human=config.human,
llm_config=config.default_llm_config,
preset=TEST_MEMGPT_CONFIG.preset,
persona=TEST_MEMGPT_CONFIG.persona,
human=TEST_MEMGPT_CONFIG.human,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=embedding_config,
)
ms.delete_user(user.id)
Expand All @@ -109,7 +112,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
print("Creating storage connectors...")
user_id = user.id
print("User ID", user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)

# load data
name = "test_dataset"
Expand All @@ -121,7 +124,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
print("Resetting tables with delete_table...")
passages_conn.delete_table()
print("Re-creating tables...")
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"

# test: load directory
Expand Down
23 changes: 11 additions & 12 deletions tests/test_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,26 @@

from memgpt.agent import Agent, save_agent
from memgpt.metadata import MetadataStore
from memgpt.config import MemGPTConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig
from memgpt.utils import get_human_text, get_persona_text
from tests import TEST_MEMGPT_CONFIG


# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
@pytest.mark.parametrize("storage_connector", ["sqlite"])
def test_storage(storage_connector):
config = MemGPTConfig()
if storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
if storage_connector == "sqlite":
config.recall_storage_type = "local"
TEST_MEMGPT_CONFIG.recall_storage_type = "local"

ms = MetadataStore(config)
ms = MetadataStore(TEST_MEMGPT_CONFIG)

# generate data
user_1 = User()
Expand All @@ -35,8 +34,8 @@ def test_storage(storage_connector):
preset=DEFAULT_PRESET,
persona=DEFAULT_PERSONA,
human=DEFAULT_HUMAN,
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)
source_1 = Source(user_id=user_1.id, name="source_1")

Expand Down Expand Up @@ -88,7 +87,7 @@ def test_storage(storage_connector):
# test: updating

# test: update JSON-stored LLMConfig class
print(agent_1.llm_config, config.default_llm_config)
print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config)
llm_config = ms.get_agent(agent_1.id).llm_config
assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}"
assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}"
Expand Down
11 changes: 7 additions & 4 deletions tests/test_migrate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.config import MemGPTConfig
from memgpt.migrate import migrate_all_agents
from .utils import wipe_config
from memgpt.server.server import SyncServer
import shutil
Expand Down Expand Up @@ -38,7 +36,12 @@ def test_migrate_0211():
assert len(message_ids) > 0

# assert recall memories exist
messages = server.get_agent_messages(user_id=agent_state.user_id, agent_id=agent_state.id, start=0, count=1000)
messages = server.get_agent_messages(
user_id=agent_state.user_id,
agent_id=agent_state.id,
start=0,
count=1000,
)
assert len(messages) > 0

# for source_name in source_res["migration_candidates"]:
Expand Down
Loading

0 comments on commit e1cecc0

Please sign in to comment.