From 34e6371a2fa5655796b96f175d3b329617e8657c Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Thu, 9 Nov 2023 14:51:12 -0800 Subject: [PATCH] Remove AsyncAgent and async from cli (#400) * Remove AsyncAgent and async from cli Refactor agent.py memory.py Refactor interface.py Refactor main.py Refactor openai_tools.py Refactor cli/cli.py stray asyncs save make legacy embeddings not use async Refactor presets Remove deleted function from import * remove stray prints * typo * another stray print * patch test --------- Co-authored-by: cpacker --- memgpt/agent.py | 406 +-------------------------------- memgpt/agent_base.py | 7 - memgpt/autogen/memgpt_agent.py | 10 +- memgpt/cli/cli.py | 8 +- memgpt/cli/cli_config.py | 6 +- memgpt/config.py | 50 ++-- memgpt/interface.py | 38 +-- memgpt/main.py | 91 +++----- memgpt/memory.py | 116 +--------- memgpt/openai_tools.py | 84 ------- memgpt/presets.py | 40 +--- memgpt/utils.py | 41 ++-- tests/utils.py | 5 +- 13 files changed, 128 insertions(+), 774 deletions(-) delete mode 100644 memgpt/agent_base.py diff --git a/memgpt/agent.py b/memgpt/agent.py index 7e58741973..6de293c47c 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -1,4 +1,3 @@ -import asyncio import inspect import datetime import glob @@ -14,8 +13,8 @@ from memgpt.persistence_manager import LocalStateManager from memgpt.config import AgentConfig from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages -from .memory import CoreMemory as Memory, summarize_messages, a_summarize_messages -from .openai_tools import acompletions_with_backoff as acreate, completions_with_backoff as create +from .memory import CoreMemory as Memory, summarize_messages +from .openai_tools import completions_with_backoff as create from .utils import get_local_time, parse_json, united_diff, printd, count_tokens from .constants import ( MEMGPT_DIR, @@ -133,45 +132,6 @@ def get_ai_reply( raise e -async def get_ai_reply_async( - model, - message_sequence, - functions, - function_call="auto", -): - """Base call to GPT API w/ functions""" - - try: - response = await acreate( - model=model, - messages=message_sequence, - functions=functions, - function_call=function_call, - ) - - # special case for 'length' - if response.choices[0].finish_reason == "length": - raise Exception("Finish reason was length (maximum context length)") - - # catches for soft errors - if response.choices[0].finish_reason not in ["stop", "function_call"]: - raise Exception(f"API call finish with bad finish reason: {response}") - - # unpack with response.choices[0].message.content - return response - - except Exception as e: - raise e - - -# Assuming function_to_call is either sync or async -async def call_function(function_to_call, **function_args): - if inspect.iscoroutinefunction(function_to_call): - return await function_to_call(**function_args) - else: - return function_to_call(**function_args) - - class Agent(object): def __init__( self, @@ -207,7 +167,7 @@ def __init__( self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) # self.messages_total_init = self.messages_total self.messages_total_init = len(self._messages) - 1 - printd(f"AgentAsync initialized, self.messages_total={self.messages_total}") + printd(f"Agent initialized, self.messages_total={self.messages_total}") # Interface must implement: # - internal_monologue @@ -922,363 +882,3 @@ def heartbeat_is_paused(self): # Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60 - - -class AgentAsync(Agent): - """Core logic for an async MemGPT agent""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.init_avail_functions() - - async def handle_ai_response(self, response_message): - """Handles parsing and function execution""" - messages = [] # append these to the history when done - - # Step 2: check if LLM wanted to call a function - if response_message.get("function_call"): - # The content if then internal monologue, not chat - await self.interface.internal_monologue(response_message.content) - messages.append(response_message) # extend conversation with assistant's reply - - # Step 3: call the function - # Note: the JSON response may not always be valid; be sure to handle errors - - # Failure case 1: function name is wrong - function_name = response_message["function_call"]["name"] - try: - function_to_call = self.available_functions[function_name] - except KeyError as e: - error_msg = f"No function named {function_name}" - function_response = package_function_response(False, error_msg) - messages.append( - { - "role": "function", - "name": function_name, - "content": function_response, - } - ) # extend conversation with function response - await self.interface.function_message(f"Error: {error_msg}") - return messages, None, True # force a heartbeat to allow agent to handle error - - # Failure case 2: function name is OK, but function args are bad JSON - try: - raw_function_args = response_message["function_call"]["arguments"] - function_args = parse_json(raw_function_args) - except Exception as e: - error_msg = f"Error parsing JSON for function '{function_name}' arguments: {raw_function_args}" - function_response = package_function_response(False, error_msg) - messages.append( - { - "role": "function", - "name": function_name, - "content": function_response, - } - ) # extend conversation with function response - await self.interface.function_message(f"Error: {error_msg}") - return messages, None, True # force a heartbeat to allow agent to handle error - - # (Still parsing function args) - # Handle requests for immediate heartbeat - heartbeat_request = function_args.pop("request_heartbeat", None) - if not (isinstance(heartbeat_request, bool) or heartbeat_request is None): - printd( - f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}" - ) - heartbeat_request = None - - # Failure case 3: function failed during execution - await self.interface.function_message(f"Running {function_name}({function_args})") - try: - function_response_string = await call_function(function_to_call, **function_args) - function_response = package_function_response(True, function_response_string) - function_failed = False - except Exception as e: - error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}" - error_msg_user = f"{error_msg}\n{traceback.format_exc()}" - printd(error_msg_user) - function_response = package_function_response(False, error_msg) - messages.append( - { - "role": "function", - "name": function_name, - "content": function_response, - } - ) # extend conversation with function response - await self.interface.function_message(f"Error: {error_msg}") - return messages, None, True # force a heartbeat to allow agent to handle error - - # If no failures happened along the way: ... - # Step 4: send the info on the function call and function response to GPT - if function_response_string: - await self.interface.function_message(f"Success: {function_response_string}") - else: - await self.interface.function_message(f"Success") - messages.append( - { - "role": "function", - "name": function_name, - "content": function_response, - } - ) # extend conversation with function response - - else: - # Standard non-function reply - await self.interface.internal_monologue(response_message.content) - messages.append(response_message) # extend conversation with assistant's reply - heartbeat_request = None - function_failed = None - - return messages, heartbeat_request, function_failed - - async def step(self, user_message, first_message=False, first_message_retry_limit=FIRST_MESSAGE_ATTEMPTS, skip_verify=False): - """Top-level event message handler for the MemGPT agent""" - - try: - # Step 0: add user message - if user_message is not None: - await self.interface.user_message(user_message) - packed_user_message = {"role": "user", "content": user_message} - input_message_sequence = self.messages + [packed_user_message] - else: - input_message_sequence = self.messages - - if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user": - printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue") - from pprint import pprint - - pprint(input_message_sequence[-1]) - - # Step 1: send the conversation and available functions to GPT - if not skip_verify and (first_message or self.messages_total == self.messages_total_init): - printd(f"This is the first message. Running extra verifier on AI response.") - counter = 0 - while True: - response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions) - if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): - break - - counter += 1 - if counter > first_message_retry_limit: - raise Exception(f"Hit first message retry limit ({first_message_retry_limit})") - - else: - response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions) - - # Step 2: check if LLM wanted to call a function - # (if yes) Step 3: call the function - # (if yes) Step 4: send the info on the function call and function response to LLM - response_message = response.choices[0].message - response_message_copy = response_message.copy() - all_response_messages, heartbeat_request, function_failed = await self.handle_ai_response(response_message) - - # Add the extra metadata to the assistant response - # (e.g. enough metadata to enable recreating the API call) - assert "api_response" not in all_response_messages[0], f"api_response already in {all_response_messages[0]}" - all_response_messages[0]["api_response"] = response_message_copy - assert "api_args" not in all_response_messages[0], f"api_args already in {all_response_messages[0]}" - all_response_messages[0]["api_args"] = { - "model": self.model, - "messages": input_message_sequence, - "functions": self.functions, - } - - # Step 4: extend the message history - if user_message is not None: - all_new_messages = [packed_user_message] + all_response_messages - else: - all_new_messages = all_response_messages - - # Check the memory pressure and potentially issue a memory pressure warning - current_total_tokens = response["usage"]["total_tokens"] - active_memory_warning = False - if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS: - printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}") - # Only deliver the alert if we haven't already (this period) - if not self.agent_alerted_about_memory_pressure: - active_memory_warning = True - self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this - else: - printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_TOKENS}") - - self.append_to_messages(all_new_messages) - return all_new_messages, heartbeat_request, function_failed, active_memory_warning - - except Exception as e: - printd(f"step() failed\nuser_message = {user_message}\nerror = {e}") - print(f"step() failed\nuser_message = {user_message}\nerror = {e}") - - # If we got a context alert, try trimming the messages length, then try again - if "maximum context length" in str(e): - # A separate API call to run a summarizer - await self.summarize_messages_inplace() - - # Try step again - return await self.step(user_message, first_message=first_message) - else: - printd(f"step() failed with openai.InvalidRequestError, but didn't recognize the error message: '{str(e)}'") - print(e) - raise e - - async def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True): - assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})" - - # Start at index 1 (past the system message), - # and collect messages for summarization until we reach the desired truncation token fraction (eg 50%) - # Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling - token_counts = [count_tokens(str(msg)) for msg in self.messages] - message_buffer_token_count = sum(token_counts[1:]) # no system message - token_counts = token_counts[1:] - desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC) - candidate_messages_to_summarize = self.messages[1:] - if preserve_last_N_messages: - candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST] - token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST] - printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}") - printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}") - printd(f"token_counts={token_counts}") - printd(f"message_buffer_token_count={message_buffer_token_count}") - printd(f"desired_token_count_to_summarize={desired_token_count_to_summarize}") - printd(f"len(candidate_messages_to_summarize)={len(candidate_messages_to_summarize)}") - - # If at this point there's nothing to summarize, throw an error - if len(candidate_messages_to_summarize) == 0: - raise LLMError( - f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]" - ) - - # Walk down the message buffer (front-to-back) until we hit the target token count - tokens_so_far = 0 - cutoff = 0 - for i, msg in enumerate(candidate_messages_to_summarize): - cutoff = i - tokens_so_far += token_counts[i] - if tokens_so_far > desired_token_count_to_summarize: - break - # Account for system message - cutoff += 1 - - # Try to make an assistant message come after the cutoff - try: - printd(f"Selected cutoff {cutoff} was a 'user', shifting one...") - if self.messages[cutoff]["role"] == "user": - new_cutoff = cutoff + 1 - if self.messages[new_cutoff]["role"] == "user": - printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...") - cutoff = new_cutoff - except IndexError: - pass - - message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message - if len(message_sequence_to_summarize) == 0: - printd(f"message_sequence_to_summarize is len 0, skipping summarize") - raise LLMError( - f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, cutoff={cutoff}]" - ) - - printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}") - summary = await a_summarize_messages(self.model, message_sequence_to_summarize) - printd(f"Got summary: {summary}") - - # Metadata that's useful for the agent to see - all_time_message_count = self.messages_total - remaining_message_count = len(self.messages[cutoff:]) - hidden_message_count = all_time_message_count - remaining_message_count - summary_message_count = len(message_sequence_to_summarize) - summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count) - printd(f"Packaged into message: {summary_message}") - - prior_len = len(self.messages) - self.trim_messages(cutoff) - packed_summary_message = {"role": "user", "content": summary_message} - self.prepend_to_messages([packed_summary_message]) - - # reset alert - self.agent_alerted_about_memory_pressure = False - - printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}") - - async def free_step(self, user_message, limit=None): - """Allow agent to manage its own control flow (past a single LLM call). - Not currently used, instead this is handled in the CLI main.py logic - """ - - new_messages, heartbeat_request, function_failed = self.step(user_message) - step_count = 1 - - while limit is None or step_count < limit: - if function_failed: - user_message = get_heartbeat("Function call failed") - new_messages, heartbeat_request, function_failed = await self.step(user_message) - step_count += 1 - elif heartbeat_request: - user_message = get_heartbeat("AI requested") - new_messages, heartbeat_request, function_failed = await self.step(user_message) - step_count += 1 - else: - break - - return new_messages, heartbeat_request, function_failed - - ### Functions / tools the agent can use - # All functions should return a response string (or None) - # If the function fails, throw an exception - - async def send_ai_message(self, message): - """AI wanted to send a message""" - await self.interface.assistant_message(message) - return None - - async def recall_memory_search(self, query, count=5, page=0): - results, total = await self.persistence_manager.recall_memory.a_text_search(query, count=count, start=page * count) - num_pages = math.ceil(total / count) - 1 # 0 index - if len(results) == 0: - results_str = f"No results found." - else: - results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" - results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted)}" - return results_str - - async def recall_memory_search_date(self, start_date, end_date, count=5, page=0): - results, total = await self.persistence_manager.recall_memory.a_date_search(start_date, end_date, count=count, start=page * count) - num_pages = math.ceil(total / count) - 1 # 0 index - if len(results) == 0: - results_str = f"No results found." - else: - results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" - results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted)}" - return results_str - - async def archival_memory_insert(self, content): - await self.persistence_manager.archival_memory.a_insert(content) - return None - - async def archival_memory_search(self, query, count=5, page=0): - results, total = await self.persistence_manager.archival_memory.a_search(query, count=count, start=page * count) - num_pages = math.ceil(total / count) - 1 # 0 index - if len(results) == 0: - results_str = f"No results found." - else: - results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" - results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted)}" - return results_str - - async def message_chatgpt(self, message): - """Base call to GPT API w/ functions""" - - message_sequence = [ - {"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE}, - {"role": "user", "content": str(message)}, - ] - response = await acreate( - model=MESSAGE_CHATGPT_FUNCTION_MODEL, - messages=message_sequence, - # functions=functions, - # function_call=function_call, - ) - - reply = response.choices[0].message.content - return reply diff --git a/memgpt/agent_base.py b/memgpt/agent_base.py deleted file mode 100644 index 7f132e49c2..0000000000 --- a/memgpt/agent_base.py +++ /dev/null @@ -1,7 +0,0 @@ -from abc import ABC, abstractmethod - - -class AgentAsyncBase(ABC): - @abstractmethod - async def step(self, user_message): - pass diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index 27125db0a3..c0bf57cd0a 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -46,7 +46,7 @@ def create_memgpt_autogen_agent_from_config( autogen_memgpt_agent = create_autogen_memgpt_agent( name, - preset=presets.SYNC_CHAT, + preset=presets.DEFAULT_PRESET, model=model, persona_description=persona_desc, user_description=user_desc, @@ -57,7 +57,7 @@ def create_memgpt_autogen_agent_from_config( if human_input_mode != "ALWAYS": coop_agent1 = create_autogen_memgpt_agent( name, - preset=presets.SYNC_CHAT, + preset=presets.DEFAULT_PRESET, model=model, persona_description=persona_desc, user_description=user_desc, @@ -73,7 +73,7 @@ def create_memgpt_autogen_agent_from_config( else: coop_agent2 = create_autogen_memgpt_agent( name, - preset=presets.SYNC_CHAT, + preset=presets.DEFAULT_PRESET, model=model, persona_description=persona_desc, user_description=user_desc, @@ -95,7 +95,7 @@ def create_memgpt_autogen_agent_from_config( def create_autogen_memgpt_agent( autogen_name, - preset=presets.SYNC_CHAT, + preset=presets.DEFAULT_PRESET, model=constants.DEFAULT_MEMGPT_MODEL, persona_description=personas.DEFAULT, user_description=humans.DEFAULT, @@ -126,7 +126,7 @@ def create_autogen_memgpt_agent( persona=persona_description, human=user_description, model=model, - preset=presets.SYNC_CHAT, + preset=presets.DEFAULT_PRESET, ) interface = AutoGenInterface(**interface_kwargs) if interface is None else interface diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 0fe73a579d..9af142cde8 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -2,7 +2,6 @@ import sys import io import logging -import asyncio import os from prettytable import PrettyTable import questionary @@ -24,7 +23,7 @@ from memgpt.persistence_manager import LocalStateManager from memgpt.config import MemGPTConfig, AgentConfig from memgpt.constants import MEMGPT_DIR -from memgpt.agent import AgentAsync +from memgpt.agent import Agent from memgpt.embeddings import embedding_model from memgpt.openai_tools import ( configure_azure_support, @@ -121,7 +120,7 @@ def run( agent_config.save() # load existing agent - memgpt_agent = AgentAsync.load_agent(memgpt.interface, agent_config) + memgpt_agent = Agent.load_agent(memgpt.interface, agent_config) else: # create new agent # create new agent config: override defaults with args if provided typer.secho("Creating new agent...", fg=typer.colors.GREEN) @@ -162,8 +161,7 @@ def run( if config.model_endpoint == "azure": configure_azure_support() - loop = asyncio.get_event_loop() - loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify + run_agent_loop(memgpt_agent, first, no_verify, config) # TODO: add back no_verify def attach( diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index c68dbbb416..de889aa04f 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -30,7 +30,7 @@ def configure(): config = MemGPTConfig.load() # openai credentials - use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?", default=True).ask() + use_openai = questionary.confirm("Do you want to enable MemGPT with OpenAI?", default=True).ask() if use_openai: # search for key in enviornment openai_key = os.getenv("OPENAI_API_KEY") @@ -119,10 +119,10 @@ def configure(): # defaults personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()] - print(personas) + # print(personas) default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask() humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()] - print(humans) + # print(humans) default_human = questionary.select("Select default human:", humans, default=config.default_human).ask() # TODO: figure out if we should set a default agent or not diff --git a/memgpt/config.py b/memgpt/config.py index 4dfe1640aa..517b54b130 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -203,7 +203,7 @@ def save(self): # archival storage config.add_section("archival_storage") - print("archival storage", self.archival_storage_type) + # print("archival storage", self.archival_storage_type) config.set("archival_storage", "type", self.archival_storage_type) if self.archival_storage_path: config.set("archival_storage", "path", self.archival_storage_path) @@ -350,7 +350,7 @@ def __init__(self): self.preload_archival = False @classmethod - async def legacy_flags_init( + def legacy_flags_init( cls: Type["Config"], model: str, memgpt_persona: str, @@ -372,11 +372,11 @@ async def legacy_flags_init( if self.archival_storage_index: recompute_embeddings = False # TODO Legacy support -- can't recompute embeddings on a path that's not specified. if self.archival_storage_files: - await self.configure_archival_storage(recompute_embeddings) + self.configure_archival_storage(recompute_embeddings) return self @classmethod - async def config_init(cls: Type["Config"], config_file: str = None): + def config_init(cls: Type["Config"], config_file: str = None): self = cls() self.config_file = config_file if self.config_file is None: @@ -384,7 +384,7 @@ async def config_init(cls: Type["Config"], config_file: str = None): use_cfg = False if cfg: print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}") - use_cfg = await questionary.confirm(f"Use most recent config file '{cfg}'?").ask_async() + use_cfg = questionary.confirm(f"Use most recent config file '{cfg}'?").ask() if use_cfg: self.config_file = cfg @@ -393,74 +393,74 @@ async def config_init(cls: Type["Config"], config_file: str = None): recompute_embeddings = False if self.compute_embeddings: if self.archival_storage_index: - recompute_embeddings = await questionary.confirm( + recompute_embeddings = questionary.confirm( f"Would you like to recompute embeddings? Do this if your files have changed.\n Files: {self.archival_storage_files}", default=False, - ).ask_async() + ).ask() else: recompute_embeddings = True if self.load_type: - await self.configure_archival_storage(recompute_embeddings) + self.configure_archival_storage(recompute_embeddings) self.write_config() return self # print("No settings file found, configuring MemGPT...") print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}") - self.model = await questionary.select( + self.model = questionary.select( "Which model would you like to use?", model_choices, default=model_choices[0], - ).ask_async() + ).ask() - self.memgpt_persona = await questionary.select( + self.memgpt_persona = questionary.select( "Which persona would you like MemGPT to use?", Config.get_memgpt_personas(), - ).ask_async() + ).ask() print(self.memgpt_persona) - self.human_persona = await questionary.select( + self.human_persona = questionary.select( "Which user would you like to use?", Config.get_user_personas(), - ).ask_async() + ).ask() self.archival_storage_index = None - self.preload_archival = await questionary.confirm( + self.preload_archival = questionary.confirm( "Would you like to preload anything into MemGPT's archival memory?", default=False - ).ask_async() + ).ask() if self.preload_archival: - self.load_type = await questionary.select( + self.load_type = questionary.select( "What would you like to load?", choices=[ questionary.Choice("A folder or file", value="folder"), questionary.Choice("A SQL database", value="sql"), questionary.Choice("A glob pattern", value="glob"), ], - ).ask_async() + ).ask() if self.load_type == "folder" or self.load_type == "sql": - archival_storage_path = await questionary.path("Please enter the folder or file (tab for autocomplete):").ask_async() + archival_storage_path = questionary.path("Please enter the folder or file (tab for autocomplete):").ask() if os.path.isdir(archival_storage_path): self.archival_storage_files = os.path.join(archival_storage_path, "*") else: self.archival_storage_files = archival_storage_path else: - self.archival_storage_files = await questionary.path("Please enter the glob pattern (tab for autocomplete):").ask_async() - self.compute_embeddings = await questionary.confirm( + self.archival_storage_files = questionary.path("Please enter the glob pattern (tab for autocomplete):").ask() + self.compute_embeddings = questionary.confirm( "Would you like to compute embeddings over these files to enable embeddings search?" - ).ask_async() - await self.configure_archival_storage(self.compute_embeddings) + ).ask() + self.configure_archival_storage(self.compute_embeddings) self.write_config() return self - async def configure_archival_storage(self, recompute_embeddings): + def configure_archival_storage(self, recompute_embeddings): if recompute_embeddings: if self.host: interface.warning_message( "⛔️ Embeddings on a non-OpenAI endpoint are not yet supported, falling back to substring matching search." ) else: - self.archival_storage_index = await utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files) + self.archival_storage_index = utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files) if self.compute_embeddings and self.archival_storage_index: self.index, self.archival_database = utils.prepare_archival_index(self.archival_storage_index) else: diff --git a/memgpt/interface.py b/memgpt/interface.py index efaefe7fa6..c78f7d736e 100644 --- a/memgpt/interface.py +++ b/memgpt/interface.py @@ -28,7 +28,7 @@ def warning_message(msg): print(fstr.format(msg=msg)) -async def internal_monologue(msg): +def internal_monologue(msg): # ANSI escape code for italic is '\x1B[3m' fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {{msg}}{Style.RESET_ALL}" if STRIP_UI: @@ -36,28 +36,28 @@ async def internal_monologue(msg): print(fstr.format(msg=msg)) -async def assistant_message(msg): +def assistant_message(msg): fstr = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{{msg}}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" print(fstr.format(msg=msg)) -async def memory_message(msg): +def memory_message(msg): fstr = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{{msg}}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" print(fstr.format(msg=msg)) -async def system_message(msg): +def system_message(msg): fstr = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" print(fstr.format(msg=msg)) -async def user_message(msg, raw=False, dump=False, debug=DEBUG): +def user_message(msg, raw=False, dump=False, debug=DEBUG): def print_user_message(icon, msg, printf=print): if STRIP_UI: printf(f"{icon} {msg}") @@ -103,7 +103,7 @@ def printd_user_message(icon, msg): printd_user_message("🧑", msg_json) -async def function_message(msg, debug=DEBUG): +def function_message(msg, debug=DEBUG): def print_function_message(icon, msg, color=Fore.RED, printf=print): if STRIP_UI: printf(f"⚡{icon} [function] {msg}") @@ -171,7 +171,7 @@ def printd_function_message(icon, msg, color=Fore.RED): printd_function_message("", msg) -async def print_messages(message_sequence, dump=False): +def print_messages(message_sequence, dump=False): idx = len(message_sequence) for msg in message_sequence: if dump: @@ -181,42 +181,42 @@ async def print_messages(message_sequence, dump=False): content = msg["content"] if role == "system": - await system_message(content) + system_message(content) elif role == "assistant": # Differentiate between internal monologue, function calls, and messages if msg.get("function_call"): if content is not None: - await internal_monologue(content) + internal_monologue(content) # I think the next one is not up to date - # await function_message(msg["function_call"]) + # function_message(msg["function_call"]) args = json.loads(msg["function_call"].get("arguments")) - await assistant_message(args.get("message")) + assistant_message(args.get("message")) # assistant_message(content) else: - await internal_monologue(content) + internal_monologue(content) elif role == "user": - await user_message(content, dump=dump) + user_message(content, dump=dump) elif role == "function": - await function_message(content, debug=dump) + function_message(content, debug=dump) else: print(f"Unknown role: {content}") -async def print_messages_simple(message_sequence): +def print_messages_simple(message_sequence): for msg in message_sequence: role = msg["role"] content = msg["content"] if role == "system": - await system_message(content) + system_message(content) elif role == "assistant": - await assistant_message(content) + assistant_message(content) elif role == "user": - await user_message(content, raw=True) + user_message(content, raw=True) else: print(f"Unknown role: {content}") -async def print_messages_raw(message_sequence): +def print_messages_raw(message_sequence): for msg in message_sequence: print(msg) diff --git a/memgpt/main.py b/memgpt/main.py index a2a7691754..25e209e6ae 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -1,4 +1,3 @@ -import asyncio import shutil import configparser import uuid @@ -38,14 +37,13 @@ from memgpt.cli.cli_load import app as load_app from memgpt.config import Config, MemGPTConfig, AgentConfig from memgpt.constants import MEMGPT_DIR -from memgpt.agent import AgentAsync +from memgpt.agent import Agent from memgpt.openai_tools import ( configure_azure_support, check_azure_embeddings, get_set_azure_env_vars, ) from memgpt.connectors.storage import StorageConnector -import asyncio app = typer.Typer(pretty_exceptions_enable=False) app.command(name="run")(run) @@ -180,26 +178,23 @@ def legacy_run( if not questionary.confirm("Continue with legacy CLI?", default=False).ask(): return - loop = asyncio.get_event_loop() - loop.run_until_complete( - main( - persona, - human, - model, - first, - debug, - no_verify, - archival_storage_faiss_path, - archival_storage_files, - archival_storage_files_compute_embeddings, - archival_storage_sqldb, - use_azure_openai, - strip_ui, - ) + main( + persona, + human, + model, + first, + debug, + no_verify, + archival_storage_faiss_path, + archival_storage_files, + archival_storage_files_compute_embeddings, + archival_storage_sqldb, + use_azure_openai, + strip_ui, ) -async def main( +def main( persona, human, model, @@ -271,7 +266,7 @@ async def main( print(persona, model, memgpt_persona) if archival_storage_files: - cfg = await Config.legacy_flags_init( + cfg = Config.legacy_flags_init( model, memgpt_persona, human_persona, @@ -280,7 +275,7 @@ async def main( compute_embeddings=False, ) elif archival_storage_faiss_path: - cfg = await Config.legacy_flags_init( + cfg = Config.legacy_flags_init( model, memgpt_persona, human_persona, @@ -293,7 +288,7 @@ async def main( print(model) print(memgpt_persona) print(human_persona) - cfg = await Config.legacy_flags_init( + cfg = Config.legacy_flags_init( model, memgpt_persona, human_persona, @@ -302,7 +297,7 @@ async def main( compute_embeddings=True, ) elif archival_storage_sqldb: - cfg = await Config.legacy_flags_init( + cfg = Config.legacy_flags_init( model, memgpt_persona, human_persona, @@ -311,13 +306,13 @@ async def main( compute_embeddings=False, ) else: - cfg = await Config.legacy_flags_init( + cfg = Config.legacy_flags_init( model, memgpt_persona, human_persona, ) else: - cfg = await Config.config_init() + cfg = Config.config_init() memgpt.interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']") if cfg.model != constants.DEFAULT_MEMGPT_MODEL: @@ -352,7 +347,7 @@ async def main( persistence_manager, ) print_messages = memgpt.interface.print_messages - await print_messages(memgpt_agent.messages) + print_messages(memgpt_agent.messages) if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner if not os.path.exists(cfg.archival_storage_files): @@ -364,19 +359,19 @@ async def main( data_list = utils.read_database_as_list(cfg.archival_storage_files) user_message = f"Your archival memory has been loaded with a SQL database called {data_list[0]}, which contains schema {data_list[1]}. Remember to refer to this first while answering any user questions!" for row in data_list: - await memgpt_agent.persistence_manager.archival_memory.insert(row) + memgpt_agent.persistence_manager.archival_memory.insert(row) print(f"Database loaded into archival memory.") if cfg.agent_save_file: - load_save_file = await questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask_async() + load_save_file = questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask() if load_save_file: load(memgpt_agent, cfg.agent_save_file) # run agent loop - await run_agent_loop(memgpt_agent, first, no_verify, cfg, strip_ui, legacy=True) + run_agent_loop(memgpt_agent, first, no_verify, cfg, strip_ui, legacy=True) -async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=False, legacy=False): +def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=False, legacy=False): counter = 0 user_input = None skip_next_user_input = False @@ -392,11 +387,11 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u while True: if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): # Ask for user input - user_input = await questionary.text( + user_input = questionary.text( "Enter your message:", multiline=multiline_input, qmark=">", - ).ask_async() + ).ask() clear_line(strip_ui) # Gracefully exit on Ctrl-C/D @@ -462,7 +457,7 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u # TODO: check if agent already has it data_source_options = StorageConnector.list_loaded_data() - data_source = await questionary.select("Select data source", choices=data_source_options).ask_async() + data_source = questionary.select("Select data source", choices=data_source_options).ask() # attach new data attach(memgpt_agent.config.name, data_source) @@ -482,13 +477,13 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u command = user_input.strip().split() amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 if amount == 0: - await memgpt.interface.print_messages(memgpt_agent.messages, dump=True) + memgpt.interface.print_messages(memgpt_agent.messages, dump=True) else: - await memgpt.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) + memgpt.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) continue elif user_input.lower() == "/dumpraw": - await memgpt.interface.print_messages_raw(memgpt_agent.messages) + memgpt.interface.print_messages_raw(memgpt_agent.messages) continue elif user_input.lower() == "/memory": @@ -554,7 +549,7 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u # No skip options elif user_input.lower() == "/wipe": - memgpt_agent = agent.AgentAsync(memgpt.interface) + memgpt_agent = agent.Agent(memgpt.interface) user_message = None elif user_input.lower() == "/heartbeat": @@ -585,8 +580,8 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u skip_next_user_input = False - async def process_agent_step(user_message, no_verify): - new_messages, heartbeat_request, function_failed, token_warning = await memgpt_agent.step( + def process_agent_step(user_message, no_verify): + new_messages, heartbeat_request, function_failed, token_warning = memgpt_agent.step( user_message, first_message=False, skip_verify=no_verify ) @@ -606,16 +601,16 @@ async def process_agent_step(user_message, no_verify): while True: try: if strip_ui: - new_messages, user_message, skip_next_user_input = await process_agent_step(user_message, no_verify) + new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) break else: with console.status("[bold cyan]Thinking...") as status: - new_messages, user_message, skip_next_user_input = await process_agent_step(user_message, no_verify) + new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) break except Exception as e: print("An exception ocurred when running agent.step(): ") traceback.print_exc() - retry = await questionary.confirm("Retry agent.step()?").ask_async() + retry = questionary.confirm("Retry agent.step()?").ask() if not retry: break @@ -639,13 +634,3 @@ async def process_agent_step(user_message, no_verify): ("/memorywarning", "send a memory warning system message to the agent"), ("/attach", "attach data source to agent"), ] -# if __name__ == "__main__": -# -# app() -# #typer.run(run) -# -# #def run(argv): -# # loop = asyncio.get_event_loop() -# # loop.run_until_complete(main()) -# -# #app.run(run) diff --git a/memgpt/memory.py b/memgpt/memory.py index 83cf3a6e10..2cae46d847 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -9,8 +9,6 @@ from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from memgpt import utils from .openai_tools import ( - acompletions_with_backoff as acreate, - async_get_embedding_with_backoff, get_embedding_with_backoff, completions_with_backoff as create, ) @@ -148,36 +146,6 @@ def summarize_messages( return reply -async def a_summarize_messages( - model, - message_sequence_to_summarize, -): - """Summarize a message sequence using GPT""" - - summary_prompt = SUMMARY_PROMPT_SYSTEM - summary_input = str(message_sequence_to_summarize) - summary_input_tkns = count_tokens(summary_input) - if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS: - trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure... - cutoff = int(len(message_sequence_to_summarize) * trunc_ratio) - summary_input = str( - [await a_summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:] - ) - message_sequence = [ - {"role": "system", "content": summary_prompt}, - {"role": "user", "content": summary_input}, - ] - - response = await acreate( - model=model, - messages=message_sequence, - ) - - printd(f"summarize_messages gpt reply: {response.choices[0]}") - reply = response.choices[0].message.content - return reply - - class ArchivalMemory(ABC): @abstractmethod def insert(self, memory_string): @@ -238,9 +206,6 @@ def insert(self, memory_string): } ) - async def a_insert(self, memory_string): - return self.insert(memory_string) - def search(self, query_string, count=None, start=None): """Simple text-based search""" # in the dummy version, run an (inefficient) case-insensitive match search @@ -261,9 +226,6 @@ def search(self, query_string, count=None, start=None): else: return matches, len(matches) - async def a_search(self, query_string, count=None, start=None): - return self.search(query_string, count=None, start=None) - class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory): """Same as dummy in-memory archival memory, but with bare-bones embedding support""" @@ -293,13 +255,10 @@ def insert(self, memory_string): embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model) return self._insert(memory_string, embedding) - async def a_insert(self, memory_string): - embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) - return self._insert(memory_string, embedding) - - def _search(self, query_embedding, query_string, count, start): + def search(self, query_string, count, start): """Simple embedding-based search (inefficient, no caching)""" # see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb + query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model) # query_embedding = get_embedding(query_string, model=self.embedding_model) # our wrapped version supports backoff/rate-limits @@ -328,14 +287,6 @@ def _search(self, query_embedding, query_string, count, start): else: return matches, len(matches) - def search(self, query_string, count=None, start=None): - query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model) - return self._search(self, query_embedding, query_string, count, start) - - async def a_search(self, query_string, count=None, start=None): - query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model) - return await self._search(self, query_embedding, query_string, count, start) - class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): """Dummy in-memory version of an archival memory database, using a FAISS @@ -365,9 +316,12 @@ def __init__(self, index=None, archival_memory_database=None, embedding_model="t def __len__(self): return len(self._archive) - def _insert(self, memory_string, embedding): + def insert(self, memory_string): import numpy as np + # Get the embedding + embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model) + print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}") self._archive.append( @@ -380,17 +334,7 @@ def _insert(self, memory_string, embedding): embedding = np.array([embedding]).astype("float32") self.index.add(embedding) - def insert(self, memory_string): - # Get the embedding - embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model) - return self._insert(memory_string, embedding) - - async def a_insert(self, memory_string): - # Get the embedding - embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) - return self._insert(memory_string, embedding) - - def _search(self, query_embedding, query_string, count=None, start=None): + def search(self, query_string, count=None, start=None): """Simple embedding-based search (inefficient, no caching)""" # see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb @@ -401,6 +345,7 @@ def _search(self, query_embedding, query_string, count=None, start=None): if query_string in self.embeddings_dict: search_result = self.search_results[query_string] else: + query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model) _, indices = self.index.search(np.array([np.array(query_embedding, dtype=np.float32)]), self.k) search_result = [self._archive[idx] if idx < len(self._archive) else "" for idx in indices[0]] self.embeddings_dict[query_string] = query_embedding @@ -430,38 +375,16 @@ def _search(self, query_embedding, query_string, count=None, start=None): else: return matches, len(matches) - def search(self, query_string, count=None, start=None): - if query_string in self.embeddings_dict: - query_embedding = self.embeddings_dict[query_string] - else: - query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model) - return self._search(query_embedding, query_string, count, start) - - async def a_search(self, query_string, count=None, start=None): - if query_string in self.embeddings_dict: - query_embedding = self.embeddings_dict[query_string] - else: - query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model) - return self._search(query_embedding, query_string, count, start) - class RecallMemory(ABC): @abstractmethod def text_search(self, query_string, count=None, start=None): pass - @abstractmethod - async def a_text_search(self, query_string, count=None, start=None): - pass - @abstractmethod def date_search(self, query_string, count=None, start=None): pass - @abstractmethod - async def a_date_search(self, query_string, count=None, start=None): - pass - @abstractmethod def __repr__(self) -> str: pass @@ -513,7 +436,7 @@ def __repr__(self) -> str: ) return f"\n### RECALL MEMORY ###" + f"\n{memory_str}" - async def insert(self, message): + def insert(self, message): raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top") def text_search(self, query_string, count=None, start=None): @@ -538,9 +461,6 @@ def text_search(self, query_string, count=None, start=None): else: return matches, len(matches) - async def a_text_search(self, query_string, count=None, start=None): - return self.text_search(query_string, count, start) - def _validate_date_format(self, date_str): """Validate the given date string in the format 'YYYY-MM-DD'.""" try: @@ -583,9 +503,6 @@ def date_search(self, start_date, end_date, count=None, start=None): else: return matches, len(matches) - async def a_date_search(self, start_date, end_date, count=None, start=None): - return self.date_search(start_date, end_date, count, start) - class DummyRecallMemoryWithEmbeddings(DummyRecallMemory): """Lazily manage embeddings by keeping a string->embed dict""" @@ -641,9 +558,6 @@ def text_search(self, query_string, count, start): else: return matches, len(matches) - async def a_text_search(self, query_string, count=None, start=None): - return self.text_search(query_string, count, start) - class LocalArchivalMemory(ArchivalMemory): """Archival memory built on top of Llama Index""" @@ -707,9 +621,6 @@ def insert(self, memory_string): similarity_top_k=self.top_k, ) - async def a_insert(self, memory_string): - return self.insert(memory_string) - def search(self, query_string, count=None, start=None): print("searching with local") if self.retriever is None: @@ -729,9 +640,6 @@ def search(self, query_string, count=None, start=None): # pprint(results) return results, len(results) - async def a_search(self, query_string, count=None, start=None): - return self.search(query_string, count, start) - def __repr__(self) -> str: if isinstance(self.index, EmptyIndex): memory_str = "" @@ -809,12 +717,6 @@ def search(self, query_string, count=None, start=None): print("Archival search error", e) raise e - async def a_search(self, query_string, count=None, start=None): - return self.search(query_string, count, start) - - async def a_insert(self, memory_string): - return self.insert(memory_string) - def __repr__(self) -> str: limit = 10 passages = [] diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py index 29f789ac14..60eeb331a4 100644 --- a/memgpt/openai_tools.py +++ b/memgpt/openai_tools.py @@ -1,4 +1,3 @@ -import asyncio import random import os import time @@ -74,89 +73,6 @@ def completions_with_backoff(**kwargs): return openai.ChatCompletion.create(**kwargs) -def aretry_with_exponential_backoff( - func, - initial_delay: float = 1, - exponential_base: float = 2, - jitter: bool = True, - max_retries: int = 20, - errors: tuple = (openai.error.RateLimitError,), -): - """Retry a function with exponential backoff.""" - - async def wrapper(*args, **kwargs): - # Initialize variables - num_retries = 0 - delay = initial_delay - - # Loop until a successful response or max_retries is hit or an exception is raised - while True: - try: - return await func(*args, **kwargs) - - # Retry on specified errors - except errors as e: - print(f"acreate (backoff): caught error: {e}") - # Increment retries - num_retries += 1 - - # Check if max retries has been reached - if num_retries > max_retries: - raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") - - # Increment the delay - delay *= exponential_base * (1 + jitter * random.random()) - - # Sleep for the delay - await asyncio.sleep(delay) - - # Raise exceptions for any errors not specified - except Exception as e: - raise e - - return wrapper - - -@aretry_with_exponential_backoff -async def acompletions_with_backoff(**kwargs): - # Local model - if HOST_TYPE is not None: - return get_chat_completion(**kwargs) - - # OpenAI / Azure model - else: - if using_azure(): - azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") - if azure_openai_deployment is not None: - kwargs["deployment_id"] = azure_openai_deployment - else: - kwargs["engine"] = MODEL_TO_AZURE_ENGINE[kwargs["model"]] - kwargs.pop("model") - return await openai.ChatCompletion.acreate(**kwargs) - - -@aretry_with_exponential_backoff -async def acreate_embedding_with_backoff(**kwargs): - """Wrapper around Embedding.acreate w/ backoff""" - if using_azure(): - azure_openai_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT") - if azure_openai_deployment is not None: - kwargs["deployment_id"] = azure_openai_deployment - else: - kwargs["engine"] = kwargs["model"] - kwargs.pop("model") - return await openai.Embedding.acreate(**kwargs) - - -async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002"): - """To get text embeddings, import/call this function - It specifies defaults + handles rate-limiting + is async""" - text = text.replace("\n", " ") - response = await acreate_embedding_with_backoff(input=[text], model=model) - embedding = response["data"][0]["embedding"] - return embedding - - @retry_with_exponential_backoff def create_embedding_with_backoff(**kwargs): if using_azure(): diff --git a/memgpt/presets.py b/memgpt/presets.py index bbd96ee6f7..b8fb55b64c 100644 --- a/memgpt/presets.py +++ b/memgpt/presets.py @@ -4,13 +4,11 @@ DEFAULT_PRESET = "memgpt_chat" preset_options = [DEFAULT_PRESET] -SYNC_CHAT = "memgpt_chat_sync" # TODO: remove me after we move the CLI to AgentSync - def use_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): """Storing combinations of SYSTEM + FUNCTION prompts""" - from memgpt.agent import AgentAsync, Agent + from memgpt.agent import Agent from memgpt.utils import printd if preset_name == DEFAULT_PRESET: @@ -28,38 +26,6 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers printd(f"Available functions:\n", [x["name"] for x in available_functions]) assert len(functions) == len(available_functions) - if "gpt-3.5" in model: - # use a different system message for gpt-3.5 - preset_name = "memgpt_gpt35_extralong" - - return AgentAsync( - config=agent_config, - model=model, - system=gpt_system.get_system_text(preset_name), - functions=available_functions, - interface=interface, - persistence_manager=persistence_manager, - persona_notes=persona, - human_notes=human, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if "gpt-4" in model else False, - ) - - elif preset_name == "memgpt_chat_sync": # TODO: remove me after we move the CLI to AgentSync - functions = [ - "send_message", - "pause_heartbeats", - "core_memory_append", - "core_memory_replace", - "conversation_search", - "conversation_search_date", - "archival_memory_insert", - "archival_memory_search", - ] - available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions] - printd(f"Available functions:\n", [x["name"] for x in available_functions]) - assert len(functions) == len(available_functions) - if "gpt-3.5" in model: # use a different system message for gpt-3.5 preset_name = "memgpt_gpt35_extralong" @@ -67,7 +33,7 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers return Agent( config=agent_config, model=model, - system=gpt_system.get_system_text(DEFAULT_PRESET), + system=gpt_system.get_system_text(preset_name), functions=available_functions, interface=interface, persistence_manager=persistence_manager, @@ -101,7 +67,7 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers # use a different system message for gpt-3.5 preset_name = "memgpt_gpt35_extralong" - return AgentAsync( + return Agent( model=model, system=gpt_system.get_system_text("memgpt_chat"), functions=available_functions, diff --git a/memgpt/utils.py b/memgpt/utils.py index e96c15de1e..955d86b94e 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -1,5 +1,4 @@ from datetime import datetime -import asyncio import csv import difflib import demjson3 as demjson @@ -14,11 +13,13 @@ from tqdm import tqdm import typer import memgpt -from memgpt.openai_tools import async_get_embedding_with_backoff +from memgpt.openai_tools import get_embedding_with_backoff from memgpt.constants import MEMGPT_DIR from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext from llama_index.embeddings import OpenAIEmbedding +from concurrent.futures import ThreadPoolExecutor, as_completed + def count_tokens(s: str, model: str = "gpt-4") -> int: encoding = tiktoken.encoding_for_model(model) @@ -242,38 +243,28 @@ def chunk_files_for_jsonl(files, tkns_per_chunk=300, model="gpt-4"): return ret -async def process_chunk(i, chunk, model): +def process_chunk(i, chunk, model): try: - return i, await async_get_embedding_with_backoff(chunk["content"], model=model) + return i, get_embedding_with_backoff(chunk["content"], model=model) except Exception as e: print(chunk) raise e -async def process_concurrently(archival_database, model, concurrency=10): - # Create a semaphore to limit the number of concurrent tasks - semaphore = asyncio.Semaphore(concurrency) - - async def bounded_process_chunk(i, chunk): - async with semaphore: - return await process_chunk(i, chunk, model) - - # Create a list of tasks for chunks +def process_concurrently(archival_database, model, concurrency=10): embedding_data = [0 for _ in archival_database] - tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)] - - for future in tqdm( - asyncio.as_completed(tasks), - total=len(archival_database), - desc="Processing file chunks", - ): - i, result = await future - embedding_data[i] = result - + with ThreadPoolExecutor(max_workers=concurrency) as executor: + # Submit tasks to the executor + future_to_chunk = {executor.submit(process_chunk, i, chunk, model): i for i, chunk in enumerate(archival_database)} + + # As each task completes, process the results + for future in tqdm(as_completed(future_to_chunk), total=len(archival_database), desc="Processing file chunks"): + i, result = future.result() + embedding_data[i] = result return embedding_data -async def prepare_archival_index_from_files_compute_embeddings( +def prepare_archival_index_from_files_compute_embeddings( glob_pattern, tkns_per_chunk=300, model="gpt-4", @@ -293,7 +284,7 @@ async def prepare_archival_index_from_files_compute_embeddings( # chunk the files, make embeddings archival_database = chunk_files(files, tkns_per_chunk, model) - embedding_data = await process_concurrently(archival_database, embeddings_model) + embedding_data = process_concurrently(archival_database, embeddings_model) embeddings_file = os.path.join(save_dir, "embeddings.json") with open(embeddings_file, "w") as f: print(f"Saving embeddings to {embeddings_file}") diff --git a/tests/utils.py b/tests/utils.py index 47f92b12cd..271f70b54c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,7 @@ def configure_memgpt(enable_openai=True, enable_azure=False): child = pexpect.spawn("memgpt configure") - child.expect("Do you want to enable MemGPT with Open AI?", timeout=TIMEOUT) + child.expect("Do you want to enable MemGPT with OpenAI?", timeout=TIMEOUT) if enable_openai: child.sendline("y") else: @@ -27,6 +27,9 @@ def configure_memgpt(enable_openai=True, enable_azure=False): child.expect("Select default preset:", timeout=TIMEOUT) child.sendline() + child.expect("Select default model", timeout=TIMEOUT) + child.sendline() + child.expect("Select default persona:", timeout=TIMEOUT) child.sendline()