From 6606a19b69ec0a3c6df459f90706df8ed790655a Mon Sep 17 00:00:00 2001 From: Tom Bedor Date: Tue, 30 Jan 2024 14:12:50 -0800 Subject: [PATCH] fix: decrease number of saves to MemGPTConfig MemGPTConfig.save is called many times by quickstart and configure, resulting in confusing results. This collects changes and calls save once. --- memgpt/cli/cli.py | 80 ++++++++++++++++++++++------------------ memgpt/cli/cli_config.py | 6 ++- memgpt/config.py | 1 - 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index cce055e33b..727dbf6c89 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -55,8 +55,15 @@ def str_to_quickstart_choice(choice_str: str) -> QuickstartChoice: raise ValueError(f"{choice_str} is not a valid QuickstartChoice. Valid options are: {valid_options}") -def set_config_with_dict(new_config: dict) -> bool: - """Set the base config using a dict""" +def set_config_with_dict(new_config: dict) -> (MemGPTConfig, bool): + """_summary_ + + Args: + new_config (dict): Dict of new config values + + Returns: + new_config MemGPTConfig, modified (bool): Returns the new config and a boolean indicating if the config was modified + """ from memgpt.utils import printd old_config = MemGPTConfig.load() @@ -93,32 +100,7 @@ def set_config_with_dict(new_config: dict) -> bool: else: printd(f"Skipping new config {k}: {v} == {new_config[k]}") - if modified: - printd(f"Saving new config file.") - old_config.save() - typer.secho(f"📖 MemGPT configuration file updated!", fg=typer.colors.GREEN) - typer.secho( - "\n".join( - [ - f"🧠 model\t-> {old_config.default_llm_config.model}", - f"🖥️ endpoint\t-> {old_config.default_llm_config.model_endpoint}", - ] - ), - fg=typer.colors.GREEN, - ) - return True - else: - typer.secho(f"📖 MemGPT configuration file unchanged.", fg=typer.colors.WHITE) - typer.secho( - "\n".join( - [ - f"🧠 model\t-> {old_config.default_llm_config.model}", - f"🖥️ endpoint\t-> {old_config.default_llm_config.model_endpoint}", - ] - ), - fg=typer.colors.WHITE, - ) - return False + return (old_config, modified) def quickstart( @@ -127,7 +109,10 @@ def quickstart( debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False, terminal: bool = True, ): - """Set the base config file with a single command""" + """Set the base config file with a single command + + This function and `configure` should be the ONLY places where MemGPTConfig.save() is called. + """ # setup logger utils.DEBUG = debug @@ -154,7 +139,7 @@ def quickstart( config = response.json() # Output a success message and the first few items in the dictionary as a sample printd("JSON config file downloaded successfully.") - config_was_modified = set_config_with_dict(config) + new_config, config_was_modified = set_config_with_dict(config) else: typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED) @@ -165,7 +150,7 @@ def quickstart( with open(backup_config_path, "r", encoding="utf-8") as file: backup_config = json.load(file) printd("Loaded backup config file successfully.") - config_was_modified = set_config_with_dict(backup_config) + new_config, config_was_modified = set_config_with_dict(backup_config) except FileNotFoundError: typer.secho(f"Backup config file not found at {backup_config_path}", fg=typer.colors.RED) return @@ -177,7 +162,7 @@ def quickstart( with open(backup_config_path, "r", encoding="utf-8") as file: backup_config = json.load(file) printd("Loaded config file successfully.") - config_was_modified = set_config_with_dict(backup_config) + new_config, config_was_modified = set_config_with_dict(backup_config) except FileNotFoundError: typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED) return @@ -203,7 +188,7 @@ def quickstart( config = response.json() # Output a success message and the first few items in the dictionary as a sample print("JSON config file downloaded successfully.") - config_was_modified = set_config_with_dict(config) + new_config, config_was_modified = set_config_with_dict(config) else: typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED) @@ -214,7 +199,7 @@ def quickstart( with open(backup_config_path, "r", encoding="utf-8") as file: backup_config = json.load(file) printd("Loaded backup config file successfully.") - config_was_modified = set_config_with_dict(backup_config) + new_config, config_was_modified = set_config_with_dict(backup_config) except FileNotFoundError: typer.secho(f"Backup config file not found at {backup_config_path}", fg=typer.colors.RED) return @@ -226,7 +211,7 @@ def quickstart( with open(backup_config_path, "r", encoding="utf-8") as file: backup_config = json.load(file) printd("Loaded config file successfully.") - config_was_modified = set_config_with_dict(backup_config) + new_config, config_was_modified = set_config_with_dict(backup_config) except FileNotFoundError: typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED) return @@ -234,6 +219,31 @@ def quickstart( else: raise NotImplementedError(backend) + if config_was_modified: + printd(f"Saving new config file.") + new_config.save() + typer.secho(f"📖 MemGPT configuration file updated!", fg=typer.colors.GREEN) + typer.secho( + "\n".join( + [ + f"🧠 model\t-> {new_config.default_llm_config.model}", + f"🖥️ endpoint\t-> {new_config.default_llm_config.model_endpoint}", + ] + ), + fg=typer.colors.GREEN, + ) + else: + typer.secho(f"📖 MemGPT configuration file unchanged.", fg=typer.colors.WHITE) + typer.secho( + "\n".join( + [ + f"🧠 model\t-> {new_config.default_llm_config.model}", + f"🖥️ endpoint\t-> {new_config.default_llm_config.model_endpoint}", + ] + ), + fg=typer.colors.WHITE, + ) + # 'terminal' = quickstart was run alone, in which case we should guide the user on the next command if terminal: if config_was_modified: diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 22b3f81ed8..2965795a70 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -119,7 +119,6 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) credentials.azure_key = azure_creds["azure_key"] credentials.azure_endpoint = azure_creds["azure_endpoint"] credentials.azure_version = azure_creds["azure_version"] - config.save() model_endpoint_type = "azure" model_endpoint = azure_creds["azure_endpoint"] @@ -563,7 +562,10 @@ def configure_recall_storage(config: MemGPTConfig, credentials: MemGPTCredential @app.command() def configure(): - """Updates default MemGPT configurations""" + """Updates default MemGPT configurations + + This function and quickstart should be the ONLY place where MemGPTConfig.save() is called + """ # check credentials credentials = MemGPTCredentials.load() diff --git a/memgpt/config.py b/memgpt/config.py index f950fa6df3..c259fbb419 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -185,7 +185,6 @@ def load(cls) -> "MemGPTConfig": anon_clientid = MemGPTConfig.generate_uuid() config = cls(anon_clientid=anon_clientid, config_path=config_path) config.create_config_dir() # create dirs - config.save() # save updated config return config