Skip to content

Commit

Permalink
Fix max tokens constant (letta-ai#374)
Browse files Browse the repository at this point in the history
* stripped LLM_MAX_TOKENS constant, instead it's a dictionary, and context_window is set via the config (defaults to 8k)

* pass context window in the calls to local llm APIs

* safety check

* remove dead imports

* context_length -> context_window

* add default for agent.load

* in configure, ask for the model context window if not specified via dictionary

* fix default, also make message about OPENAI_API_BASE missing more informative

* make openai default embedding if openai is default llm

* make openai on top of list

* typo

* also make local the default for embeddings if you're using localllm instead of the locallm endpoint

* provide --context_window flag to memgpt run

* fix runtime error

* stray comments

* stray comment
cpacker authored Nov 10, 2023

Verified

This commit was signed with the committer’s verified signature.
sulix David Gow
1 parent 17c5e3a commit cb50308
Showing 12 changed files with 140 additions and 45 deletions.
34 changes: 23 additions & 11 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import inspect
import datetime
import glob
import pickle
import math
import os
import requests
import json
import threading
import traceback

import openai
from memgpt.persistence_manager import LocalStateManager
from memgpt.config import AgentConfig
from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from .memory import CoreMemory as Memory, summarize_messages
from .openai_tools import completions_with_backoff as create
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
from .constants import (
MEMGPT_DIR,
FIRST_MESSAGE_ATTEMPTS,
MAX_PAUSE_HEARTBEATS,
MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
MESSAGE_SUMMARY_WARNING_TOKENS,
MESSAGE_SUMMARY_WARNING_FRAC,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
@@ -108,10 +104,12 @@ def get_ai_reply(
message_sequence,
functions,
function_call="auto",
context_window=None,
):
try:
response = create(
model=model,
context_window=context_window,
messages=message_sequence,
functions=functions,
function_call=function_call,
@@ -582,7 +580,12 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
response = get_ai_reply(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
response = get_ai_reply(
model=self.model,
message_sequence=input_message_sequence,
functions=self.functions,
context_window=self.config.context_window,
)
if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break

@@ -591,7 +594,12 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")

else:
response = get_ai_reply(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
response = get_ai_reply(
model=self.model,
message_sequence=input_message_sequence,
functions=self.functions,
context_window=self.config.context_window,
)

# Step 2: check if LLM wanted to call a function
# (if yes) Step 3: call the function
@@ -620,14 +628,16 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS
# Check the memory pressure and potentially issue a memory pressure warning
current_total_tokens = response["usage"]["total_tokens"]
active_memory_warning = False
if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS:
printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}")
if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window:
printd(
f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window}"
)
# Only deliver the alert if we haven't already (this period)
if not self.agent_alerted_about_memory_pressure:
active_memory_warning = True
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
else:
printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_TOKENS}")
printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window}")

self.append_to_messages(all_new_messages)
return all_new_messages, heartbeat_request, function_failed, active_memory_warning
@@ -698,7 +708,9 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True)
message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")

summary = summarize_messages(self.model, message_sequence_to_summarize)
summary = summarize_messages(
model=self.model, context_window=self.config.context_window, message_sequence_to_summarize=message_sequence_to_summarize
)
printd(f"Got summary: {summary}")

# Metadata that's useful for the agent to see
9 changes: 9 additions & 0 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,9 @@ def run(
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
):
"""Start chatting with an MemGPT agent
@@ -96,6 +99,11 @@ def run(
set_global_service_context(service_context)
sys.stdout = original_stdout

# overwrite the context_window if specified
if context_window is not None and int(context_window) != config.context_window:
typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW)
config.context_window = context_window

# create agent config
if agent and AgentConfig.exists(agent): # use existing agent
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
@@ -129,6 +137,7 @@ def run(
persona=persona if persona else config.default_persona,
human=human if human else config.default_human,
model=model if model else config.model,
context_window=context_window if context_window else config.context_window,
preset=preset if preset else config.preset,
)

54 changes: 49 additions & 5 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.connectors.storage import StorageConnector
from memgpt.constants import LLM_MAX_TOKENS

app = typer.Typer()

@@ -76,7 +77,9 @@ def configure():
model_endpoint_options += ["openai"]
if use_azure:
model_endpoint_options += ["azure"]
assert len(model_endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE."
assert (
len(model_endpoint_options) > 0
), "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE to point at the IP address of your LLM server."
valid_default_model = config.model_endpoint in model_endpoint_options
default_endpoint = questionary.select(
"Select default inference endpoint:",
@@ -85,16 +88,24 @@ def configure():
).ask()

# configure embedding provider
embedding_endpoint_options = ["local"] # cannot configure custom endpoint (too confusing)
embedding_endpoint_options = []
if use_azure:
embedding_endpoint_options += ["azure"]
if use_openai:
embedding_endpoint_options += ["openai"]
embedding_endpoint_options += ["local"]
valid_default_embedding = config.embedding_model in embedding_endpoint_options
# determine the default selection in a smart way
if "openai" in embedding_endpoint_options and default_endpoint == "openai":
# openai llm -> openai embeddings
default_embedding_endpoint_default = "openai"
elif default_endpoint not in ["openai", "azure"]: # is local
# local llm -> local embeddings
default_embedding_endpoint_default = "local"
else:
default_embedding_endpoint_default = config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1]
default_embedding_endpoint = questionary.select(
"Select default embedding endpoint:",
embedding_endpoint_options,
default=config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1],
"Select default embedding endpoint:", embedding_endpoint_options, default=default_embedding_endpoint_default
).ask()

# configure embedding dimentions
@@ -117,6 +128,38 @@ def configure():
else:
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint

# get the max tokens (context window) for the model
if default_model == "local" or str(default_model) not in LLM_MAX_TOKENS:
# Ask the user to specify the context length
context_length_options = [
str(2**12), # 4096
str(2**13), # 8192
str(2**14), # 16384
str(2**15), # 32768
str(2**18), # 262144
"custom", # enter yourself
]
default_model_context_window = questionary.select(
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
choices=context_length_options,
default=str(LLM_MAX_TOKENS["DEFAULT"]),
).ask()

# If custom, ask for input
if default_model_context_window == "custom":
while True:
default_model_context_window = questionary.text("Enter context window (e.g. 8192)").ask()
try:
default_model_context_window = int(default_model_context_window)
break
except ValueError:
print(f"Context window must be a valid integer")
else:
default_model_context_window = int(default_model_context_window)
else:
# Pull the context length from the models
default_model_context_window = LLM_MAX_TOKENS[default_model]

# defaults
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
# print(personas)
@@ -152,6 +195,7 @@ def configure():

config = MemGPTConfig(
model=default_model,
context_window=default_model_context_window,
preset=default_preset,
model_endpoint=default_endpoint,
embedding_model=default_embedding_endpoint,
9 changes: 8 additions & 1 deletion memgpt/config.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
import memgpt.interface as interface
from memgpt.personas.personas import get_persona_text
from memgpt.humans.humans import get_human_text
from memgpt.constants import MEMGPT_DIR
from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
@@ -51,6 +51,7 @@ class MemGPTConfig:
# provider: str = "openai" # openai, azure, local (TODO)
model_endpoint: str = "openai"
model: str = "gpt-4" # gpt-4, gpt-3.5-turbo, local
context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]

# model parameters: openai
openai_key: str = None
@@ -106,6 +107,9 @@ def load(cls) -> "MemGPTConfig":

# read config values
model = config.get("defaults", "model")
context_window = (
config.get("defaults", "context_window") if config.has_option("defaults", "context_window") else LLM_MAX_TOKENS["DEFAULT"]
)
preset = config.get("defaults", "preset")
model_endpoint = config.get("defaults", "model_endpoint")
default_persona = config.get("defaults", "persona")
@@ -141,6 +145,7 @@ def load(cls) -> "MemGPTConfig":

return cls(
model=model,
context_window=context_window,
preset=preset,
model_endpoint=model_endpoint,
default_persona=default_persona,
@@ -254,6 +259,7 @@ def __init__(
persona,
human,
model,
context_window,
preset=DEFAULT_PRESET,
name=None,
data_sources=[],
@@ -268,6 +274,7 @@ def __init__(
self.persona = persona
self.human = human
self.model = model
self.context_window = context_window
self.preset = preset
self.data_sources = data_sources
self.create_time = create_time if create_time is not None else utils.get_local_time()
22 changes: 20 additions & 2 deletions memgpt/constants.py
Original file line number Diff line number Diff line change
@@ -17,9 +17,27 @@

# Constants to do with summarization / conversation length window
# The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B)
LLM_MAX_TOKENS = 8000 # change this depending on your model
LLM_MAX_TOKENS = {
"DEFAULT": 8192,
## OpenAI models: https://platform.openai.com/docs/models/overview
# gpt-4
"gpt-4-1106-preview": 128000,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-0613": 8192,
"gpt-4-32k-0613": 32768,
"gpt-4-0314": 8192, # legacy
"gpt-4-32k-0314": 32768, # legacy
# gpt-3.5
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-0613": 4096, # legacy
"gpt-3.5-turbo-16k-0613": 16385, # legacy
"gpt-3.5-turbo-0301": 4096, # legacy
}
# The amount of tokens before a sytem warning about upcoming truncation is sent to MemGPT
MESSAGE_SUMMARY_WARNING_TOKENS = int(0.75 * LLM_MAX_TOKENS)
MESSAGE_SUMMARY_WARNING_FRAC = 0.75
# The error message that MemGPT will receive
MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
# The fraction of tokens we truncate down to
12 changes: 7 additions & 5 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,9 @@ def get_chat_completion(
messages,
functions=None,
function_call="auto",
context_window=None,
):
assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
global has_shown_warning
grammar_name = None

@@ -90,15 +92,15 @@ def get_chat_completion(

try:
if HOST_TYPE == "webui":
result = get_webui_completion(prompt, grammar=grammar_name)
result = get_webui_completion(prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "lmstudio":
result = get_lmstudio_completion(prompt)
result = get_lmstudio_completion(prompt, context_window)
elif HOST_TYPE == "llamacpp":
result = get_llamacpp_completion(prompt, grammar=grammar_name)
result = get_llamacpp_completion(prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "koboldcpp":
result = get_koboldcpp_completion(prompt, grammar=grammar_name)
result = get_koboldcpp_completion(prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "ollama":
result = get_ollama_completion(prompt)
result = get_ollama_completion(prompt, context_window)
else:
raise LocalLLMError(
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
7 changes: 3 additions & 4 deletions memgpt/local_llm/koboldcpp/api.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@

from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
from ...constants import LLM_MAX_TOKENS

HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
@@ -13,11 +12,11 @@
DEBUG = True


def get_koboldcpp_completion(prompt, grammar=None, settings=SIMPLE):
def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > LLM_MAX_TOKENS:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)")
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
7 changes: 3 additions & 4 deletions memgpt/local_llm/llamacpp/api.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@

from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
from ...constants import LLM_MAX_TOKENS

HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
@@ -13,11 +12,11 @@
DEBUG = True


def get_llamacpp_completion(prompt, grammar=None, settings=SIMPLE):
def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > LLM_MAX_TOKENS:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)")
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
Loading

0 comments on commit cb50308

Please sign in to comment.