From 463c5a9908ae3d3f816743f9accc239005b7409b Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sat, 21 Oct 2023 15:58:54 -0700 Subject: [PATCH] autosave on /exit --- main.py | 112 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 60 insertions(+), 52 deletions(-) diff --git a/main.py b/main.py index 2c3bfa002b..40cfa22ee2 100644 --- a/main.py +++ b/main.py @@ -43,6 +43,62 @@ def clear_line(): sys.stdout.flush() +def save(memgpt_agent): + filename = utils.get_local_time().replace(' ', '_').replace(':', '_') + filename = f"{filename}.json" + filename = os.path.join('saved_state', filename) + try: + if not os.path.exists("saved_state"): + os.makedirs("saved_state") + memgpt_agent.save_to_json_file(filename) + print(f"Saved checkpoint to: {filename}") + except Exception as e: + print(f"Saving state to {filename} failed with: {e}") + + # save the persistence manager too + filename = filename.replace('.json', '.persistence.pickle') + try: + memgpt_agent.persistence_manager.save(filename) + print(f"Saved persistence manager to: {filename}") + except Exception as e: + print(f"Saving persistence manager to {filename} failed with: {e}") + + +def load(memgpt_agent, filename): + if filename is not None: + if filename[-5:] != '.json': + filename += '.json' + try: + memgpt_agent.load_from_json_file_inplace(filename) + print(f"Loaded checkpoint {filename}") + except Exception as e: + print(f"Loading {filename} failed with: {e}") + else: + # Load the latest file + print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead") + json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory. + + # Check if there are any json files. + if not json_files: + print(f"/load error: no .json checkpoint files found") + else: + # Sort files based on modified timestamp, with the latest file being the first. + filename = max(json_files, key=os.path.getmtime) + try: + memgpt_agent.load_from_json_file_inplace(filename) + print(f"Loaded checkpoint {filename}") + except Exception as e: + print(f"Loading {filename} failed with: {e}") + + # need to load persistence manager too + filename = filename.replace('.json', '.persistence.pickle') + try: + memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods + print(f"Loaded persistence manager from {filename}") + except Exception as e: + print(f"/load warning: loading persistence manager from {filename} failed with: {e}") + + async def main(): utils.DEBUG = FLAGS.debug logging.getLogger().setLevel(logging.CRITICAL) @@ -162,6 +218,8 @@ async def main(): user_message = system.package_user_message("\n".join(user_input_list)) elif user_input.lower() == "/exit": + # autosave + save(memgpt_agent=memgpt_agent) break elif user_input.lower() == "/savechat": @@ -178,63 +236,13 @@ async def main(): continue elif user_input.lower() == "/save": - filename = utils.get_local_time().replace(' ', '_').replace(':', '_') - filename = f"{filename}.json" - filename = os.path.join('saved_state', filename) - try: - if not os.path.exists("saved_state"): - os.makedirs("saved_state") - memgpt_agent.save_to_json_file(filename) - print(f"Saved checkpoint to: {filename}") - except Exception as e: - print(f"Saving state to {filename} failed with: {e}") - - # save the persistence manager too - filename = filename.replace('.json', '.persistence.pickle') - try: - memgpt_agent.persistence_manager.save(filename) - print(f"Saved persistence manager to: {filename}") - except Exception as e: - print(f"Saving persistence manager to {filename} failed with: {e}") - + save(memgpt_agent=memgpt_agent) continue elif user_input.lower() == "/load" or user_input.lower().startswith("/load "): command = user_input.strip().split() filename = command[1] if len(command) > 1 else None - if filename is not None: - if filename[-5:] != '.json': - filename += '.json' - try: - memgpt_agent.load_from_json_file_inplace(filename) - print(f"Loaded checkpoint {filename}") - except Exception as e: - print(f"Loading {filename} failed with: {e}") - else: - # Load the latest file - print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead") - json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory. - - # Check if there are any json files. - if not json_files: - print(f"/load error: no .json checkpoint files found") - else: - # Sort files based on modified timestamp, with the latest file being the first. - filename = max(json_files, key=os.path.getmtime) - try: - memgpt_agent.load_from_json_file_inplace(filename) - print(f"Loaded checkpoint {filename}") - except Exception as e: - print(f"Loading {filename} failed with: {e}") - - # need to load persistence manager too - filename = filename.replace('.json', '.persistence.pickle') - try: - memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods - print(f"Loaded persistence manager from {filename}") - except Exception as e: - print(f"/load warning: loading persistence manager from {filename} failed with: {e}") - + load(memgpt_agent=memgpt_agent, filename=filename) continue elif user_input.lower() == "/dump":