Skip to content

Commit

Permalink
feat: isolate test config from main config (#1063)
Browse files Browse the repository at this point in the history
Co-authored-by: Charles Packer <[email protected]>
  • Loading branch information
tombedor and cpacker authored Mar 6, 2024
1 parent 23702f6 commit 503e812
Show file tree
Hide file tree
Showing 17 changed files with 196 additions and 134 deletions.
1 change: 0 additions & 1 deletion memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def __init__(
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
):

# An agent can be created from a Preset object
if preset is not None:
assert agent_state is None, "Can create an agent from a Preset or AgentState (but both were provided)"
Expand Down
62 changes: 53 additions & 9 deletions memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def set_field(config, section, field, value):

@dataclass
class MemGPTConfig:
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") if os.getenv("MEMGPT_CONFIG_PATH") else os.path.join(MEMGPT_DIR, "config")
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")
anon_clientid: str = str(uuid.UUID(int=0))

# preset
Expand Down Expand Up @@ -196,16 +196,51 @@ def save(self):
# model defaults
set_field(config, "model", "model", self.default_llm_config.model)
set_field(config, "model", "model_endpoint", self.default_llm_config.model_endpoint)
set_field(config, "model", "model_endpoint_type", self.default_llm_config.model_endpoint_type)
set_field(
config,
"model",
"model_endpoint_type",
self.default_llm_config.model_endpoint_type,
)
set_field(config, "model", "model_wrapper", self.default_llm_config.model_wrapper)
set_field(config, "model", "context_window", str(self.default_llm_config.context_window))
set_field(
config,
"model",
"context_window",
str(self.default_llm_config.context_window),
)

# embeddings
set_field(config, "embedding", "embedding_endpoint_type", self.default_embedding_config.embedding_endpoint_type)
set_field(config, "embedding", "embedding_endpoint", self.default_embedding_config.embedding_endpoint)
set_field(config, "embedding", "embedding_model", self.default_embedding_config.embedding_model)
set_field(config, "embedding", "embedding_dim", str(self.default_embedding_config.embedding_dim))
set_field(config, "embedding", "embedding_chunk_size", str(self.default_embedding_config.embedding_chunk_size))
set_field(
config,
"embedding",
"embedding_endpoint_type",
self.default_embedding_config.embedding_endpoint_type,
)
set_field(
config,
"embedding",
"embedding_endpoint",
self.default_embedding_config.embedding_endpoint,
)
set_field(
config,
"embedding",
"embedding_model",
self.default_embedding_config.embedding_model,
)
set_field(
config,
"embedding",
"embedding_dim",
str(self.default_embedding_config.embedding_dim),
)
set_field(
config,
"embedding",
"embedding_chunk_size",
str(self.default_embedding_config.embedding_chunk_size),
)

# archival storage
set_field(config, "archival_storage", "type", self.archival_storage_type)
Expand Down Expand Up @@ -253,7 +288,16 @@ def create_config_dir():
if not os.path.exists(MEMGPT_DIR):
os.makedirs(MEMGPT_DIR, exist_ok=True)

folders = ["personas", "humans", "archival", "agents", "functions", "system_prompts", "presets", "settings"]
folders = [
"personas",
"humans",
"archival",
"agents",
"functions",
"system_prompts",
"presets",
"settings",
]

for folder in folders:
if not os.path.exists(os.path.join(MEMGPT_DIR, folder)):
Expand Down
4 changes: 4 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from tests.config import TestMGPTConfig


TEST_MEMGPT_CONFIG = TestMGPTConfig()
7 changes: 7 additions & 0 deletions tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os
from memgpt.config import MemGPTConfig
from memgpt.constants import MEMGPT_DIR


class TestMGPTConfig(MemGPTConfig):
config_path: str = os.getenv("TEST_MEMGPT_CONFIG_PATH") or os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")
8 changes: 2 additions & 6 deletions tests/test_agent_function_update.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from collections import UserDict
import json
import os
import inspect
import uuid

from memgpt.config import MemGPTConfig
from memgpt import create_client
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from memgpt.functions.functions import USER_FUNCTIONS_DIR
from memgpt.utils import assistant_function_to_tool
from memgpt.models import chat_completion_response
from tests import TEST_MEMGPT_CONFIG

from tests.utils import wipe_config, create_config

Expand Down Expand Up @@ -39,10 +37,8 @@ def agent():
# create memgpt client
client = create_client()

config = MemGPTConfig.load()

# ensure user exists
user_id = uuid.UUID(config.anon_clientid)
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
if not client.server.get_user(user_id=user_id):
client.server.create_user({"id": user_id})

Expand Down
5 changes: 2 additions & 3 deletions tests/test_base_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import uuid

from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config, create_config


Expand All @@ -30,8 +30,7 @@ def create_test_agent():
)

global agent_obj
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)


Expand Down
13 changes: 0 additions & 13 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import uuid
import time
import os
import threading

from memgpt import Admin, create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
from memgpt.functions.functions import load_all_function_sets
from memgpt.prompts import gpt_system
from memgpt.constants import DEFAULT_PRESET

import pytest


from .utils import wipe_config
import uuid


Expand Down Expand Up @@ -116,9 +109,3 @@ def test_user_message(client):
# print(
# f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}"
# )


if __name__ == "__main__":
# test_create_preset()
test_create_agent()
test_user_message()
14 changes: 8 additions & 6 deletions tests/test_different_embedding_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import os

from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage
from memgpt.data_types import EmbeddingConfig, Passage
from memgpt.embeddings import embedding_model
from memgpt.agent_store.storage import StorageConnector, TableType
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config, create_config
import uuid

Expand All @@ -21,7 +20,11 @@

def generate_passages(user, agent):
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
texts = [
"This is a test passage",
"This is another test passage",
"Cinderella wept",
]
embed_model = embedding_model(agent.embedding_config)
orig_embeddings = []
passages = []
Expand Down Expand Up @@ -86,8 +89,7 @@ def test_create_user():
hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)

# test passage dimentionality
config = MemGPTConfig.load()
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, client.user.id)
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id)
storage.filters = {} # clear filters to be able to get all passages
passages = storage.get_all()
for passage in passages:
Expand Down
47 changes: 26 additions & 21 deletions tests/test_load_archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from memgpt.cli.cli_load import load_directory

# from memgpt.data_sources.connectors import DirectoryConnector, load_data
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.metadata import MetadataStore
from memgpt.data_types import User, AgentState, EmbeddingConfig
from memgpt import create_client
from .utils import wipe_config, create_config
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config


@pytest.fixture(autouse=True)
Expand All @@ -37,16 +36,20 @@ def recreate_declarative_base():

@pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"])
@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"])
def test_load_directory(metadata_storage_connector, passage_storage_connector, clear_dynamically_created_models, recreate_declarative_base):
def test_load_directory(
metadata_storage_connector,
passage_storage_connector,
clear_dynamically_created_models,
recreate_declarative_base,
):
wipe_config()
# setup config
config = MemGPTConfig()
if metadata_storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.metadata_storage_type = "postgres"
TEST_MEMGPT_CONFIG.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
TEST_MEMGPT_CONFIG.metadata_storage_type = "postgres"
elif metadata_storage_connector == "sqlite":
print("testing sqlite metadata")
# nothing to do (should be config defaults)
Expand All @@ -56,18 +59,18 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
elif passage_storage_connector == "chroma":
print("testing chroma passage storage")
# nothing to do (should be config defaults)
else:
raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented")
config.save()
TEST_MEMGPT_CONFIG.save()

# create metadata store
ms = MetadataStore(config)
user = User(id=uuid.UUID(config.anon_clientid))
ms = MetadataStore(TEST_MEMGPT_CONFIG)
user = User(id=uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid))

# embedding config
if os.getenv("OPENAI_API_KEY"):
Expand Down Expand Up @@ -100,18 +103,20 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
)

# write out the config so that the 'load' command will use it (CLI commands pull from config)
config.default_embedding_config = embedding_config
config.save()
TEST_MEMGPT_CONFIG.default_embedding_config = embedding_config
TEST_MEMGPT_CONFIG.save()
# config.default_embedding_config = embedding_config
# config.save()

# create user and agent
agent = AgentState(
user_id=user.id,
name="test_agent",
preset=config.preset,
persona=config.persona,
human=config.human,
llm_config=config.default_llm_config,
embedding_config=embedding_config,
preset=TEST_MEMGPT_CONFIG.preset,
persona=TEST_MEMGPT_CONFIG.persona,
human=TEST_MEMGPT_CONFIG.human,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)
ms.delete_user(user.id)
ms.create_user(user)
Expand All @@ -123,7 +128,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
print("Creating storage connectors...")
user_id = user.id
print("User ID", user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)

# load data
name = "test_dataset"
Expand All @@ -135,7 +140,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
print("Resetting tables with delete_table...")
passages_conn.delete_table()
print("Re-creating tables...")
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"

# test: load directory
Expand Down
28 changes: 11 additions & 17 deletions tests/test_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,29 @@

from memgpt.agent import Agent, save_agent
from memgpt.metadata import MetadataStore
from memgpt.config import MemGPTConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig
from memgpt.utils import get_human_text, get_persona_text
from tests import TEST_MEMGPT_CONFIG
from memgpt.presets.presets import add_default_presets, add_default_humans_and_personas

from memgpt.models.pydantic_models import HumanModel, PersonaModel

from memgpt.models.pydantic_models import HumanModel, PersonaModel


# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
@pytest.mark.parametrize("storage_connector", ["sqlite"])
def test_storage(storage_connector):

from memgpt.presets.presets import add_default_presets

config = MemGPTConfig()
if storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
if storage_connector == "sqlite":
config.recall_storage_type = "local"
TEST_MEMGPT_CONFIG.recall_storage_type = "local"

ms = MetadataStore(config)
ms = MetadataStore(TEST_MEMGPT_CONFIG)

# users
user_1 = User()
Expand All @@ -57,8 +51,8 @@ def test_storage(storage_connector):
preset=DEFAULT_PRESET,
persona=DEFAULT_PERSONA,
human=DEFAULT_HUMAN,
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)
source_1 = Source(user_id=user_1.id, name="source_1")

Expand Down Expand Up @@ -108,7 +102,7 @@ def test_storage(storage_connector):
# test: updating

# test: update JSON-stored LLMConfig class
print(agent_1.llm_config, config.default_llm_config)
print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config)
llm_config = ms.get_agent(agent_1.id).llm_config
assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}"
assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}"
Expand Down
Loading

0 comments on commit 503e812

Please sign in to comment.