Skip to content

Commit

Permalink
feat: add functions to get context window overview (#1903)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Oct 18, 2024
1 parent fc3d4e1 commit 180bbfe
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 16 deletions.
28 changes: 26 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_messages
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.metadata import MetadataStore
from letta.persistence_manager import LocalStateManager
Expand All @@ -33,6 +33,9 @@
from letta.schemas.enums import MessageRole
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, UpdateMessage
from letta.schemas.openai.chat_completion_request import (
Tool as ChatCompletionRequestTool,
)
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage,
Expand Down Expand Up @@ -1458,6 +1461,24 @@ def get_context_window(self) -> ContextWindowOverview:
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)

# tokens taken up by function definitions
if self.functions:
available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in self.functions]
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=self.functions, model=self.model)
else:
available_functions_definitions = []
num_tokens_available_functions_definitions = 0

num_tokens_used_total = (
num_tokens_system # system prompt
+ num_tokens_available_functions_definitions # function definitions
+ num_tokens_core_memory # core memory
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
+ num_tokens_summary_memory # summary of ongoing conversation
+ num_tokens_messages # tokens taken by messages
)
assert isinstance(num_tokens_used_total, int)

return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(self._messages),
Expand All @@ -1466,7 +1487,7 @@ def get_context_window(self) -> ContextWindowOverview:
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
# top-level information
context_window_size_max=self.agent_state.llm_config.context_window,
context_window_size_current=num_tokens_system + num_tokens_core_memory + num_tokens_summary_memory + num_tokens_messages,
context_window_size_current=num_tokens_used_total,
# context window breakdown (in tokens)
num_tokens_system=num_tokens_system,
system_prompt=system_prompt,
Expand All @@ -1476,6 +1497,9 @@ def get_context_window(self) -> ContextWindowOverview:
summary_memory=summary_memory,
num_tokens_messages=num_tokens_messages,
messages=self._messages,
# related to functions
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
functions_definitions=available_functions_definitions,
)


Expand Down
21 changes: 16 additions & 5 deletions letta/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import (
ChatCompletionRequest,
FunctionCall as ToolFunctionChoiceFunctionCall,
)
from letta.schemas.openai.chat_completion_request import (
Tool,
ToolFunctionChoice,
cast_message_to_subtype,
)
from letta.schemas.openai.chat_completion_response import (
Expand Down Expand Up @@ -100,10 +105,10 @@ def openai_get_model_list(

def build_openai_chat_completions_request(
llm_config: LLMConfig,
messages: List[Message],
messages: List[_Message],
user_id: Optional[str],
functions: Optional[list],
function_call: str,
function_call: Optional[str],
use_tool_naming: bool,
max_tokens: Optional[int],
) -> ChatCompletionRequest:
Expand All @@ -124,11 +129,17 @@ def build_openai_chat_completions_request(
model = None

if use_tool_naming:
if function_call is None:
tool_choice = None
elif function_call not in ["none", "auto", "required"]:
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=function_call))
else:
tool_choice = function_call
data = ChatCompletionRequest(
model=model,
messages=openai_message_list,
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
tools=[Tool(type="function", function=f) for f in functions] if functions else None,
tool_choice=tool_choice,
user=str(user_id),
max_tokens=max_tokens,
)
Expand Down
28 changes: 22 additions & 6 deletions letta/local_llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import List
from typing import List, Union

import requests
import tiktoken
Expand All @@ -11,6 +11,7 @@
import letta.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
import letta.local_llm.llm_chat_completion_wrappers.llama3 as llama3
import letta.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
from letta.schemas.openai.chat_completion_request import Tool, ToolCall


def post_json_auth_request(uri, json_payload, auth_type, auth_key):
Expand Down Expand Up @@ -123,7 +124,7 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
return num_tokens


def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"):
def num_tokens_from_tool_calls(tool_calls: Union[List[dict], List[ToolCall]], model: str = "gpt-4"):
"""Based on above code (num_tokens_from_functions).
Example to encode:
Expand All @@ -144,10 +145,25 @@ def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"):

num_tokens = 0
for tool_call in tool_calls:
function_tokens = len(encoding.encode(tool_call["id"]))
function_tokens += 2 + len(encoding.encode(tool_call["type"]))
function_tokens += 2 + len(encoding.encode(tool_call["function"]["name"]))
function_tokens += 2 + len(encoding.encode(tool_call["function"]["arguments"]))
if isinstance(tool_call, dict):
tool_call_id = tool_call["id"]
tool_call_type = tool_call["type"]
tool_call_function = tool_call["function"]
tool_call_function_name = tool_call_function["name"]
tool_call_function_arguments = tool_call_function["arguments"]
elif isinstance(tool_call, Tool):
tool_call_id = tool_call.id
tool_call_type = tool_call.type
tool_call_function = tool_call.function
tool_call_function_name = tool_call_function.name
tool_call_function_arguments = tool_call_function.arguments
else:
raise ValueError(f"Unknown tool call type: {type(tool_call)}")

function_tokens = len(encoding.encode(tool_call_id))
function_tokens += 2 + len(encoding.encode(tool_call_type))
function_tokens += 2 + len(encoding.encode(tool_call_function_name))
function_tokens += 2 + len(encoding.encode(tool_call_function_arguments))

num_tokens += function_tokens

Expand Down
4 changes: 4 additions & 0 deletions letta/schemas/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from letta.schemas.block import Block
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import Tool


class ContextWindowOverview(BaseModel):
Expand Down Expand Up @@ -41,6 +42,9 @@ class ContextWindowOverview(BaseModel):
num_tokens_summary_memory: int = Field(..., description="The number of tokens in the summary memory.")
summary_memory: Optional[str] = Field(None, description="The content of the summary memory.")

num_tokens_functions_definitions: int = Field(..., description="The number of tokens in the functions definitions.")
functions_definitions: Optional[List[Tool]] = Field(..., description="The content of the functions definitions.")

num_tokens_messages: int = Field(..., description="The number of tokens in the messages list.")
# TODO make list of messages?
# messages: List[dict] = Field(..., description="The messages in the context window.")
Expand Down
4 changes: 2 additions & 2 deletions letta/schemas/openai/chat_completion_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ToolFunctionChoice(BaseModel):
function: FunctionCall


ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice]
ToolChoice = Union[Literal["none", "auto", "required"], ToolFunctionChoice]


## tools ##
Expand Down Expand Up @@ -117,7 +117,7 @@ class ChatCompletionRequest(BaseModel):

# function-calling related
tools: Optional[List[Tool]] = None
tool_choice: Optional[ToolChoice] = "none"
tool_choice: Optional[ToolChoice] = None # "none" means don't call a tool
# deprecated scheme
functions: Optional[List[FunctionSchema]] = None
function_call: Optional[FunctionCallChoice] = None
16 changes: 15 additions & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.memory import (
ArchivalMemorySummary,
ContextWindowOverview,
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.passage import Passage
Expand Down Expand Up @@ -2177,3 +2182,12 @@ def add_llm_model(self, request: LLMConfig) -> LLMConfig:

def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
"""Add a new embedding model"""

def get_agent_context_window(
self,
user_id: str,
agent_id: str,
) -> ContextWindowOverview:
# Get the current message
letta_agent = self._get_or_load_agent(agent_id=agent_id)
return letta_agent.get_context_window()
40 changes: 40 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,43 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
args_json = json.loads(last_agent_message.tool_calls[0].function.arguments)
print(args_json)
assert "message" in args_json and args_json["message"] is not None and args_json["message"] != new_text


def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
"""Test that the context window overview fetch works"""

overview = server.get_agent_context_window(user_id=user_id, agent_id=agent_id)
assert overview is not None

# Run some basic checks
assert overview.context_window_size_max is not None
assert overview.context_window_size_current is not None
assert overview.num_archival_memory is not None
assert overview.num_recall_memory is not None
assert overview.num_tokens_external_memory_summary is not None
assert overview.num_tokens_system is not None
assert overview.system_prompt is not None
assert overview.num_tokens_core_memory is not None
assert overview.core_memory is not None
assert overview.num_tokens_summary_memory is not None
if overview.num_tokens_summary_memory > 0:
assert overview.summary_memory is not None
else:
assert overview.summary_memory is None
assert overview.num_tokens_functions_definitions is not None
if overview.num_tokens_functions_definitions > 0:
assert overview.functions_definitions is not None
else:
assert overview.functions_definitions is None
assert overview.num_tokens_messages is not None
assert overview.messages is not None

assert overview.context_window_size_max >= overview.context_window_size_current
assert overview.context_window_size_current == (
overview.num_tokens_system
+ overview.num_tokens_core_memory
+ overview.num_tokens_summary_memory
+ overview.num_tokens_messages
+ overview.num_tokens_functions_definitions
+ overview.num_tokens_external_memory_summary
)

0 comments on commit 180bbfe

Please sign in to comment.