diff --git a/memgpt/agent.py b/memgpt/agent.py index 6de293c47c..8e5c17a8d9 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -1,15 +1,12 @@ import inspect import datetime import glob -import pickle import math import os import requests import json -import threading import traceback -import openai from memgpt.persistence_manager import LocalStateManager from memgpt.config import AgentConfig from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages @@ -17,12 +14,11 @@ from .openai_tools import completions_with_backoff as create from .utils import get_local_time, parse_json, united_diff, printd, count_tokens from .constants import ( - MEMGPT_DIR, FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, - MESSAGE_SUMMARY_WARNING_TOKENS, + MESSAGE_SUMMARY_WARNING_FRAC, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, CORE_MEMORY_HUMAN_CHAR_LIMIT, @@ -108,10 +104,12 @@ def get_ai_reply( message_sequence, functions, function_call="auto", + context_window=None, ): try: response = create( model=model, + context_window=context_window, messages=message_sequence, functions=functions, function_call=function_call, @@ -582,7 +580,12 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS printd(f"This is the first message. Running extra verifier on AI response.") counter = 0 while True: - response = get_ai_reply(model=self.model, message_sequence=input_message_sequence, functions=self.functions) + response = get_ai_reply( + model=self.model, + message_sequence=input_message_sequence, + functions=self.functions, + context_window=self.config.context_window, + ) if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break @@ -591,7 +594,12 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS raise Exception(f"Hit first message retry limit ({first_message_retry_limit})") else: - response = get_ai_reply(model=self.model, message_sequence=input_message_sequence, functions=self.functions) + response = get_ai_reply( + model=self.model, + message_sequence=input_message_sequence, + functions=self.functions, + context_window=self.config.context_window, + ) # Step 2: check if LLM wanted to call a function # (if yes) Step 3: call the function @@ -620,14 +628,16 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS # Check the memory pressure and potentially issue a memory pressure warning current_total_tokens = response["usage"]["total_tokens"] active_memory_warning = False - if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS: - printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}") + if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window: + printd( + f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window}" + ) # Only deliver the alert if we haven't already (this period) if not self.agent_alerted_about_memory_pressure: active_memory_warning = True self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this else: - printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_TOKENS}") + printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window}") self.append_to_messages(all_new_messages) return all_new_messages, heartbeat_request, function_failed, active_memory_warning @@ -698,7 +708,9 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True) message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}") - summary = summarize_messages(self.model, message_sequence_to_summarize) + summary = summarize_messages( + model=self.model, context_window=self.config.context_window, message_sequence_to_summarize=message_sequence_to_summarize + ) printd(f"Got summary: {summary}") # Metadata that's useful for the agent to see diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 9af142cde8..d512dc4f1b 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -42,6 +42,9 @@ def run( debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"), no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"), yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"), + context_window: int = typer.Option( + None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)" + ), ): """Start chatting with an MemGPT agent @@ -96,6 +99,11 @@ def run( set_global_service_context(service_context) sys.stdout = original_stdout + # overwrite the context_window if specified + if context_window is not None and int(context_window) != config.context_window: + typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW) + config.context_window = context_window + # create agent config if agent and AgentConfig.exists(agent): # use existing agent typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN) @@ -129,6 +137,7 @@ def run( persona=persona if persona else config.default_persona, human=human if human else config.default_human, model=model if model else config.model, + context_window=context_window if context_window else config.context_window, preset=preset if preset else config.preset, ) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index de889aa04f..b6175bcece 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -14,6 +14,7 @@ from memgpt.config import MemGPTConfig, AgentConfig from memgpt.constants import MEMGPT_DIR from memgpt.connectors.storage import StorageConnector +from memgpt.constants import LLM_MAX_TOKENS app = typer.Typer() @@ -76,7 +77,9 @@ def configure(): model_endpoint_options += ["openai"] if use_azure: model_endpoint_options += ["azure"] - assert len(model_endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE." + assert ( + len(model_endpoint_options) > 0 + ), "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE to point at the IP address of your LLM server." valid_default_model = config.model_endpoint in model_endpoint_options default_endpoint = questionary.select( "Select default inference endpoint:", @@ -85,16 +88,24 @@ def configure(): ).ask() # configure embedding provider - embedding_endpoint_options = ["local"] # cannot configure custom endpoint (too confusing) + embedding_endpoint_options = [] if use_azure: embedding_endpoint_options += ["azure"] if use_openai: embedding_endpoint_options += ["openai"] + embedding_endpoint_options += ["local"] valid_default_embedding = config.embedding_model in embedding_endpoint_options + # determine the default selection in a smart way + if "openai" in embedding_endpoint_options and default_endpoint == "openai": + # openai llm -> openai embeddings + default_embedding_endpoint_default = "openai" + elif default_endpoint not in ["openai", "azure"]: # is local + # local llm -> local embeddings + default_embedding_endpoint_default = "local" + else: + default_embedding_endpoint_default = config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1] default_embedding_endpoint = questionary.select( - "Select default embedding endpoint:", - embedding_endpoint_options, - default=config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1], + "Select default embedding endpoint:", embedding_endpoint_options, default=default_embedding_endpoint_default ).ask() # configure embedding dimentions @@ -117,6 +128,38 @@ def configure(): else: default_model = "local" # TODO: figure out if this is ok? this is for local endpoint + # get the max tokens (context window) for the model + if default_model == "local" or str(default_model) not in LLM_MAX_TOKENS: + # Ask the user to specify the context length + context_length_options = [ + str(2**12), # 4096 + str(2**13), # 8192 + str(2**14), # 16384 + str(2**15), # 32768 + str(2**18), # 262144 + "custom", # enter yourself + ] + default_model_context_window = questionary.select( + "Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):", + choices=context_length_options, + default=str(LLM_MAX_TOKENS["DEFAULT"]), + ).ask() + + # If custom, ask for input + if default_model_context_window == "custom": + while True: + default_model_context_window = questionary.text("Enter context window (e.g. 8192)").ask() + try: + default_model_context_window = int(default_model_context_window) + break + except ValueError: + print(f"Context window must be a valid integer") + else: + default_model_context_window = int(default_model_context_window) + else: + # Pull the context length from the models + default_model_context_window = LLM_MAX_TOKENS[default_model] + # defaults personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()] # print(personas) @@ -152,6 +195,7 @@ def configure(): config = MemGPTConfig( model=default_model, + context_window=default_model_context_window, preset=default_preset, model_endpoint=default_endpoint, embedding_model=default_embedding_endpoint, diff --git a/memgpt/config.py b/memgpt/config.py index 517b54b130..c982c90f8a 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -19,7 +19,7 @@ import memgpt.interface as interface from memgpt.personas.personas import get_persona_text from memgpt.humans.humans import get_human_text -from memgpt.constants import MEMGPT_DIR +from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS import memgpt.constants as constants import memgpt.personas.personas as personas import memgpt.humans.humans as humans @@ -51,6 +51,7 @@ class MemGPTConfig: # provider: str = "openai" # openai, azure, local (TODO) model_endpoint: str = "openai" model: str = "gpt-4" # gpt-4, gpt-3.5-turbo, local + context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"] # model parameters: openai openai_key: str = None @@ -106,6 +107,9 @@ def load(cls) -> "MemGPTConfig": # read config values model = config.get("defaults", "model") + context_window = ( + config.get("defaults", "context_window") if config.has_option("defaults", "context_window") else LLM_MAX_TOKENS["DEFAULT"] + ) preset = config.get("defaults", "preset") model_endpoint = config.get("defaults", "model_endpoint") default_persona = config.get("defaults", "persona") @@ -141,6 +145,7 @@ def load(cls) -> "MemGPTConfig": return cls( model=model, + context_window=context_window, preset=preset, model_endpoint=model_endpoint, default_persona=default_persona, @@ -254,6 +259,7 @@ def __init__( persona, human, model, + context_window, preset=DEFAULT_PRESET, name=None, data_sources=[], @@ -268,6 +274,7 @@ def __init__( self.persona = persona self.human = human self.model = model + self.context_window = context_window self.preset = preset self.data_sources = data_sources self.create_time = create_time if create_time is not None else utils.get_local_time() diff --git a/memgpt/constants.py b/memgpt/constants.py index 1326760e32..fc3454048f 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -17,9 +17,27 @@ # Constants to do with summarization / conversation length window # The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B) -LLM_MAX_TOKENS = 8000 # change this depending on your model +LLM_MAX_TOKENS = { + "DEFAULT": 8192, + ## OpenAI models: https://platform.openai.com/docs/models/overview + # gpt-4 + "gpt-4-1106-preview": 128000, + "gpt-4": 8192, + "gpt-4-32k": 32768, + "gpt-4-0613": 8192, + "gpt-4-32k-0613": 32768, + "gpt-4-0314": 8192, # legacy + "gpt-4-32k-0314": 32768, # legacy + # gpt-3.5 + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-0613": 4096, # legacy + "gpt-3.5-turbo-16k-0613": 16385, # legacy + "gpt-3.5-turbo-0301": 4096, # legacy +} # The amount of tokens before a sytem warning about upcoming truncation is sent to MemGPT -MESSAGE_SUMMARY_WARNING_TOKENS = int(0.75 * LLM_MAX_TOKENS) +MESSAGE_SUMMARY_WARNING_FRAC = 0.75 # The error message that MemGPT will receive MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed." # The fraction of tokens we truncate down to diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index 44b3c81fa5..9a59dce268 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -27,7 +27,9 @@ def get_chat_completion( messages, functions=None, function_call="auto", + context_window=None, ): + assert context_window is not None, "Local LLM calls need the context length to be explicitly set" global has_shown_warning grammar_name = None @@ -90,15 +92,15 @@ def get_chat_completion( try: if HOST_TYPE == "webui": - result = get_webui_completion(prompt, grammar=grammar_name) + result = get_webui_completion(prompt, context_window, grammar=grammar_name) elif HOST_TYPE == "lmstudio": - result = get_lmstudio_completion(prompt) + result = get_lmstudio_completion(prompt, context_window) elif HOST_TYPE == "llamacpp": - result = get_llamacpp_completion(prompt, grammar=grammar_name) + result = get_llamacpp_completion(prompt, context_window, grammar=grammar_name) elif HOST_TYPE == "koboldcpp": - result = get_koboldcpp_completion(prompt, grammar=grammar_name) + result = get_koboldcpp_completion(prompt, context_window, grammar=grammar_name) elif HOST_TYPE == "ollama": - result = get_ollama_completion(prompt) + result = get_ollama_completion(prompt, context_window) else: raise LocalLLMError( f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" diff --git a/memgpt/local_llm/koboldcpp/api.py b/memgpt/local_llm/koboldcpp/api.py index d345217928..41e5484d00 100644 --- a/memgpt/local_llm/koboldcpp/api.py +++ b/memgpt/local_llm/koboldcpp/api.py @@ -4,7 +4,6 @@ from .settings import SIMPLE from ..utils import load_grammar_file, count_tokens -from ...constants import LLM_MAX_TOKENS HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion @@ -13,11 +12,11 @@ DEBUG = True -def get_koboldcpp_completion(prompt, grammar=None, settings=SIMPLE): +def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMPLE): """See https://lite.koboldai.net/koboldcpp_api for API spec""" prompt_tokens = count_tokens(prompt) - if prompt_tokens > LLM_MAX_TOKENS: - raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)") + if prompt_tokens > context_window: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc request = settings diff --git a/memgpt/local_llm/llamacpp/api.py b/memgpt/local_llm/llamacpp/api.py index 5e2218e55e..ce91ad6179 100644 --- a/memgpt/local_llm/llamacpp/api.py +++ b/memgpt/local_llm/llamacpp/api.py @@ -4,7 +4,6 @@ from .settings import SIMPLE from ..utils import load_grammar_file, count_tokens -from ...constants import LLM_MAX_TOKENS HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion @@ -13,11 +12,11 @@ DEBUG = True -def get_llamacpp_completion(prompt, grammar=None, settings=SIMPLE): +def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPLE): """See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) - if prompt_tokens > LLM_MAX_TOKENS: - raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)") + if prompt_tokens > context_window: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc request = settings diff --git a/memgpt/local_llm/lmstudio/api.py b/memgpt/local_llm/lmstudio/api.py index b867774b2c..e5440799f9 100644 --- a/memgpt/local_llm/lmstudio/api.py +++ b/memgpt/local_llm/lmstudio/api.py @@ -3,6 +3,7 @@ import requests from .settings import SIMPLE +from ..utils import count_tokens HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion @@ -11,8 +12,11 @@ DEBUG = False -def get_lmstudio_completion(prompt, settings=SIMPLE, api="chat"): +def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat"): """Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client""" + prompt_tokens = count_tokens(prompt) + if prompt_tokens > context_window: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc request = settings diff --git a/memgpt/local_llm/ollama/api.py b/memgpt/local_llm/ollama/api.py index f13af5f383..934ba1bf38 100644 --- a/memgpt/local_llm/ollama/api.py +++ b/memgpt/local_llm/ollama/api.py @@ -4,7 +4,6 @@ from .settings import SIMPLE from ..utils import count_tokens -from ...constants import LLM_MAX_TOKENS from ...errors import LocalLLMError HOST = os.getenv("OPENAI_API_BASE") @@ -14,11 +13,11 @@ DEBUG = False -def get_ollama_completion(prompt, settings=SIMPLE, grammar=None): +def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None): """See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) - if prompt_tokens > LLM_MAX_TOKENS: - raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)") + if prompt_tokens > context_window: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") if MODEL_NAME is None: raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')") diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py index 163c403516..211100a376 100644 --- a/memgpt/local_llm/webui/api.py +++ b/memgpt/local_llm/webui/api.py @@ -4,7 +4,6 @@ from .settings import SIMPLE from ..utils import load_grammar_file, count_tokens -from ...constants import LLM_MAX_TOKENS HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion @@ -12,11 +11,11 @@ DEBUG = False -def get_webui_completion(prompt, settings=SIMPLE, grammar=None): +def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None): """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) - if prompt_tokens > LLM_MAX_TOKENS: - raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)") + if prompt_tokens > context_window: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc request = settings diff --git a/memgpt/memory.py b/memgpt/memory.py index 2cae46d847..7860c39c31 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -4,7 +4,7 @@ import re from typing import Optional, List, Tuple -from .constants import MESSAGE_SUMMARY_WARNING_TOKENS, MEMGPT_DIR +from .constants import MESSAGE_SUMMARY_WARNING_FRAC, MEMGPT_DIR from .utils import cosine_similarity, get_local_time, printd, count_tokens from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from memgpt import utils @@ -120,6 +120,7 @@ def edit_replace(self, field, old_content, new_content): def summarize_messages( model, + context_window, message_sequence_to_summarize, ): """Summarize a message sequence using GPT""" @@ -127,10 +128,12 @@ def summarize_messages( summary_prompt = SUMMARY_PROMPT_SYSTEM summary_input = str(message_sequence_to_summarize) summary_input_tkns = count_tokens(summary_input) - if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS: - trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure... + if summary_input_tkns > MESSAGE_SUMMARY_WARNING_FRAC * context_window: + trunc_ratio = (MESSAGE_SUMMARY_WARNING_FRAC * context_window / summary_input_tkns) * 0.8 # For good measure... cutoff = int(len(message_sequence_to_summarize) * trunc_ratio) - summary_input = str([summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:]) + summary_input = str( + [summarize_messages(model, context_window, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:] + ) message_sequence = [ {"role": "system", "content": summary_prompt}, {"role": "user", "content": summary_input},