Skip to content

Commit

Permalink
fix: decrease number of saves to MemGPTConfig
Browse files Browse the repository at this point in the history
MemGPTConfig.save is called many times by quickstart and configure,
resulting in confusing results. This collects changes and calls save once.
  • Loading branch information
tombedor committed Jan 30, 2024
1 parent 6edffe0 commit 6606a19
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 38 deletions.
80 changes: 45 additions & 35 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,15 @@ def str_to_quickstart_choice(choice_str: str) -> QuickstartChoice:
raise ValueError(f"{choice_str} is not a valid QuickstartChoice. Valid options are: {valid_options}")


def set_config_with_dict(new_config: dict) -> bool:
"""Set the base config using a dict"""
def set_config_with_dict(new_config: dict) -> (MemGPTConfig, bool):
"""_summary_
Args:
new_config (dict): Dict of new config values
Returns:
new_config MemGPTConfig, modified (bool): Returns the new config and a boolean indicating if the config was modified
"""
from memgpt.utils import printd

old_config = MemGPTConfig.load()
Expand Down Expand Up @@ -93,32 +100,7 @@ def set_config_with_dict(new_config: dict) -> bool:
else:
printd(f"Skipping new config {k}: {v} == {new_config[k]}")

if modified:
printd(f"Saving new config file.")
old_config.save()
typer.secho(f"📖 MemGPT configuration file updated!", fg=typer.colors.GREEN)
typer.secho(
"\n".join(
[
f"🧠 model\t-> {old_config.default_llm_config.model}",
f"🖥️ endpoint\t-> {old_config.default_llm_config.model_endpoint}",
]
),
fg=typer.colors.GREEN,
)
return True
else:
typer.secho(f"📖 MemGPT configuration file unchanged.", fg=typer.colors.WHITE)
typer.secho(
"\n".join(
[
f"🧠 model\t-> {old_config.default_llm_config.model}",
f"🖥️ endpoint\t-> {old_config.default_llm_config.model_endpoint}",
]
),
fg=typer.colors.WHITE,
)
return False
return (old_config, modified)


def quickstart(
Expand All @@ -127,7 +109,10 @@ def quickstart(
debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False,
terminal: bool = True,
):
"""Set the base config file with a single command"""
"""Set the base config file with a single command
This function and `configure` should be the ONLY places where MemGPTConfig.save() is called.
"""

# setup logger
utils.DEBUG = debug
Expand All @@ -154,7 +139,7 @@ def quickstart(
config = response.json()
# Output a success message and the first few items in the dictionary as a sample
printd("JSON config file downloaded successfully.")
config_was_modified = set_config_with_dict(config)
new_config, config_was_modified = set_config_with_dict(config)
else:
typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)

Expand All @@ -165,7 +150,7 @@ def quickstart(
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded backup config file successfully.")
config_was_modified = set_config_with_dict(backup_config)
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Backup config file not found at {backup_config_path}", fg=typer.colors.RED)
return
Expand All @@ -177,7 +162,7 @@ def quickstart(
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded config file successfully.")
config_was_modified = set_config_with_dict(backup_config)
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
return
Expand All @@ -203,7 +188,7 @@ def quickstart(
config = response.json()
# Output a success message and the first few items in the dictionary as a sample
print("JSON config file downloaded successfully.")
config_was_modified = set_config_with_dict(config)
new_config, config_was_modified = set_config_with_dict(config)
else:
typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)

Expand All @@ -214,7 +199,7 @@ def quickstart(
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded backup config file successfully.")
config_was_modified = set_config_with_dict(backup_config)
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Backup config file not found at {backup_config_path}", fg=typer.colors.RED)
return
Expand All @@ -226,14 +211,39 @@ def quickstart(
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded config file successfully.")
config_was_modified = set_config_with_dict(backup_config)
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
return

else:
raise NotImplementedError(backend)

if config_was_modified:
printd(f"Saving new config file.")
new_config.save()
typer.secho(f"📖 MemGPT configuration file updated!", fg=typer.colors.GREEN)
typer.secho(
"\n".join(
[
f"🧠 model\t-> {new_config.default_llm_config.model}",
f"🖥️ endpoint\t-> {new_config.default_llm_config.model_endpoint}",
]
),
fg=typer.colors.GREEN,
)
else:
typer.secho(f"📖 MemGPT configuration file unchanged.", fg=typer.colors.WHITE)
typer.secho(
"\n".join(
[
f"🧠 model\t-> {new_config.default_llm_config.model}",
f"🖥️ endpoint\t-> {new_config.default_llm_config.model_endpoint}",
]
),
fg=typer.colors.WHITE,
)

# 'terminal' = quickstart was run alone, in which case we should guide the user on the next command
if terminal:
if config_was_modified:
Expand Down
6 changes: 4 additions & 2 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
credentials.azure_key = azure_creds["azure_key"]
credentials.azure_endpoint = azure_creds["azure_endpoint"]
credentials.azure_version = azure_creds["azure_version"]
config.save()

model_endpoint_type = "azure"
model_endpoint = azure_creds["azure_endpoint"]
Expand Down Expand Up @@ -563,7 +562,10 @@ def configure_recall_storage(config: MemGPTConfig, credentials: MemGPTCredential

@app.command()
def configure():
"""Updates default MemGPT configurations"""
"""Updates default MemGPT configurations
This function and quickstart should be the ONLY place where MemGPTConfig.save() is called
"""

# check credentials
credentials = MemGPTCredentials.load()
Expand Down
1 change: 0 additions & 1 deletion memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def load(cls) -> "MemGPTConfig":
anon_clientid = MemGPTConfig.generate_uuid()
config = cls(anon_clientid=anon_clientid, config_path=config_path)
config.create_config_dir() # create dirs
config.save() # save updated config

return config

Expand Down

0 comments on commit 6606a19

Please sign in to comment.