Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify Agent.step inputs to Message or List[Message] only #1879

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 51 additions & 65 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_login_event,
package_function_response,
package_summarize_message,
package_user_message,
)
from letta.utils import (
count_tokens,
Expand Down Expand Up @@ -200,16 +201,7 @@ class BaseAgent(ABC):
@abstractmethod
def step(
self,
messages: Union[Message, List[Message], str], # TODO deprecate str inputs
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
messages: Union[Message, List[Message]],
) -> AgentStepResponse:
"""
Top-level event message handler for the agent.
Expand Down Expand Up @@ -730,14 +722,13 @@ def _handle_ai_response(

def step(
self,
user_message: Union[Message, None, str], # NOTE: should be json.dump(dict)
messages: Union[Message, List[Message]],
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
return_dicts: bool = True,
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
# recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> AgentStepResponse:
Expand All @@ -760,50 +751,13 @@ def step(
self.rebuild_memory(force=True, ms=ms)

# Step 1: add user message
if user_message is not None:
if isinstance(user_message, Message):
assert user_message.text is not None

# Validate JSON via save/load
user_message_text = validate_json(user_message.text)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message_text)

if name is not None:
# Update Message object
user_message.text = cleaned_user_message_text
user_message.name = name
if isinstance(messages, Message):
messages = [messages]

# Recreate timestamp
if recreate_message_timestamp:
user_message.created_at = get_utc_time()
if not all(isinstance(m, Message) for m in messages):
raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}")

elif isinstance(user_message, str):
# Validate JSON via save/load
user_message = validate_json(user_message)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message)

# If user_message['name'] is not None, it will be handled properly by dict_to_message
# So no need to run strip_name_field_from_user_message

# Create the associated Message object (in the database)
user_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={"role": "user", "content": cleaned_user_message_text, "name": name},
created_at=timestamp,
)

else:
raise ValueError(f"Bad type for user_message: {type(user_message)}")

self.interface.user_message(user_message.text, msg_obj=user_message)

input_message_sequence = self._messages + [user_message]

# Alternatively, the requestor can send an empty user message
else:
input_message_sequence = self._messages
input_message_sequence = self._messages + messages

if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
Expand Down Expand Up @@ -846,11 +800,8 @@ def step(
)

# Step 6: extend the message history
if user_message is not None:
if isinstance(user_message, Message):
all_new_messages = [user_message] + all_response_messages
else:
raise ValueError(type(user_message))
if len(messages) > 0:
all_new_messages = messages + all_response_messages
else:
all_new_messages = all_response_messages

Expand Down Expand Up @@ -897,7 +848,7 @@ def step(
)

except Exception as e:
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
printd(f"step() failed\nmessages = {messages}\nerror = {e}")

# If we got a context alert, try trimming the messages length, then try again
if is_context_overflow_error(e):
Expand All @@ -906,14 +857,14 @@ def step(

# Try step again
return self.step(
user_message,
messages=messages,
first_message=first_message,
first_message_retry_limit=first_message_retry_limit,
skip_verify=skip_verify,
return_dicts=return_dicts,
recreate_message_timestamp=recreate_message_timestamp,
# recreate_message_timestamp=recreate_message_timestamp,
stream=stream,
timestamp=timestamp,
# timestamp=timestamp,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
ms=ms,
)
Expand All @@ -922,6 +873,40 @@ def step(
printd(f"step() failed with an unrecognized exception: '{str(e)}'")
raise e

def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:
"""Takes a basic user message string, turns it into a stringified JSON with extra metadata, then sends it to the agent

Example:
-> user_message_str = 'hi'
-> {'message': 'hi', 'type': 'user_message', ...}
-> json.dumps(...)
-> agent.step(messages=[Message(role='user', text=...)])
"""
# Wrap with metadata, dumps to JSON
assert user_message_str and isinstance(
user_message_str, str
), f"user_message_str should be a non-empty string, got {type(user_message_str)}"
user_message_json_str = package_user_message(user_message_str)

# Validate JSON via save/load
user_message = validate_json(user_message_json_str)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message)

# Turn into a dict
openai_message_dict = {"role": "user", "content": cleaned_user_message_text, "name": name}

# Create the associated Message object (in the database)
assert self.agent_state.user_id is not None, "User ID is not set"
user_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict=openai_message_dict,
# created_at=timestamp,
)

return self.step(messages=[user_message], **kwargs)

def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"

Expand Down Expand Up @@ -1340,7 +1325,8 @@ def retry_message(self) -> List[Message]:

self.pop_until_user()
user_message = self.pop_message(count=1)[0]
step_response = self.step(user_message=user_message.text, return_dicts=False)
assert user_message.text is not None, "User message text is None"
step_response = self.step_user_message(user_message_str=user_message.text, return_dicts=False)
messages = step_response.messages

assert messages is not None
Expand Down
28 changes: 19 additions & 9 deletions letta/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,19 +356,29 @@ def run_agent_loop(
else:
# If message did not begin with command prefix, pass inputs to Letta
# Handle user message and append to messages
user_message = system.package_user_message(user_input)
user_message = str(user_input)

skip_next_user_input = False

def process_agent_step(user_message, no_verify):
step_response = letta_agent.step(
user_message,
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs,
ms=ms,
)
if user_message is None:
step_response = letta_agent.step(
messages=[],
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs,
ms=ms,
)
else:
step_response = letta_agent.step_user_message(
user_message_str=user_message,
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs,
ms=ms,
)
new_messages = step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
Expand Down
62 changes: 51 additions & 11 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,22 @@ def _get_or_load_agent(self, agent_id: str) -> Agent:
letta_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
return letta_agent

def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message], timestamp: Optional[datetime]) -> LettaUsageStatistics:
def _step(
self,
user_id: str,
agent_id: str,
input_messages: Union[Message, List[Message]],
# timestamp: Optional[datetime],
) -> LettaUsageStatistics:
"""Send the input message through the agent"""
logger.debug(f"Got input message: {input_message}")

# Input validation
if isinstance(input_messages, Message):
input_messages = [input_messages]
if not all(isinstance(m, Message) for m in input_messages):
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")

logger.debug(f"Got input messages: {input_messages}")
try:

# Get the agent object (loaded in memory)
Expand All @@ -398,18 +411,18 @@ def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message],

logger.debug(f"Starting agent step")
no_verify = True
next_input_message = input_message
next_input_message = input_messages
counter = 0
total_usage = UsageStatistics()
step_count = 0
while True:
step_response = letta_agent.step(
next_input_message,
messages=next_input_message,
first_message=False,
skip_verify=no_verify,
return_dicts=False,
stream=token_streaming,
timestamp=timestamp,
# timestamp=timestamp,
ms=self.ms,
)
step_response.messages
Expand All @@ -436,13 +449,40 @@ def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message],
break
# Chain handlers
elif token_warning:
next_input_message = system.get_token_limit_warning()
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
Expand Down Expand Up @@ -621,7 +661,7 @@ def user_message(
)

# Run the agent state forward
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=message, timestamp=timestamp)
usage = self._step(user_id=user_id, agent_id=agent_id, input_messages=message)
return usage

def system_message(
Expand Down Expand Up @@ -669,7 +709,7 @@ def system_message(

if isinstance(message, Message):
# Can't have a null text field
if len(message.text) == 0 or message.text is None:
if message.text is None or len(message.text) == 0:
raise ValueError(f"Invalid input: '{message.text}'")
# If the input begins with a command prefix, reject
elif message.text.startswith("/"):
Expand All @@ -683,7 +723,7 @@ def system_message(
message.created_at = timestamp

# Run the agent state forward
return self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message, timestamp=timestamp)
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message)

# @LockingServer.agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
Expand Down
Loading