diff --git a/letta/agent.py b/letta/agent.py index ee5bc01921..c9b55dbef7 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -39,6 +39,7 @@ get_login_event, package_function_response, package_summarize_message, + package_user_message, ) from letta.utils import ( count_tokens, @@ -200,16 +201,7 @@ class BaseAgent(ABC): @abstractmethod def step( self, - messages: Union[Message, List[Message], str], # TODO deprecate str inputs - first_message: bool = False, - first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, - skip_verify: bool = False, - return_dicts: bool = True, # if True, return dicts, if False, return Message objects - recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field - stream: bool = False, # TODO move to config? - timestamp: Optional[datetime.datetime] = None, - inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT, - ms: Optional[MetadataStore] = None, + messages: Union[Message, List[Message]], ) -> AgentStepResponse: """ Top-level event message handler for the agent. @@ -730,14 +722,13 @@ def _handle_ai_response( def step( self, - user_message: Union[Message, None, str], # NOTE: should be json.dump(dict) + messages: Union[Message, List[Message]], first_message: bool = False, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, return_dicts: bool = True, - recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field + # recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field stream: bool = False, # TODO move to config? - timestamp: Optional[datetime.datetime] = None, inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT, ms: Optional[MetadataStore] = None, ) -> AgentStepResponse: @@ -760,50 +751,13 @@ def step( self.rebuild_memory(force=True, ms=ms) # Step 1: add user message - if user_message is not None: - if isinstance(user_message, Message): - assert user_message.text is not None - - # Validate JSON via save/load - user_message_text = validate_json(user_message.text) - cleaned_user_message_text, name = strip_name_field_from_user_message(user_message_text) - - if name is not None: - # Update Message object - user_message.text = cleaned_user_message_text - user_message.name = name + if isinstance(messages, Message): + messages = [messages] - # Recreate timestamp - if recreate_message_timestamp: - user_message.created_at = get_utc_time() + if not all(isinstance(m, Message) for m in messages): + raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}") - elif isinstance(user_message, str): - # Validate JSON via save/load - user_message = validate_json(user_message) - cleaned_user_message_text, name = strip_name_field_from_user_message(user_message) - - # If user_message['name'] is not None, it will be handled properly by dict_to_message - # So no need to run strip_name_field_from_user_message - - # Create the associated Message object (in the database) - user_message = Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.user_id, - model=self.model, - openai_message_dict={"role": "user", "content": cleaned_user_message_text, "name": name}, - created_at=timestamp, - ) - - else: - raise ValueError(f"Bad type for user_message: {type(user_message)}") - - self.interface.user_message(user_message.text, msg_obj=user_message) - - input_message_sequence = self._messages + [user_message] - - # Alternatively, the requestor can send an empty user message - else: - input_message_sequence = self._messages + input_message_sequence = self._messages + messages if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user": printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue") @@ -846,11 +800,8 @@ def step( ) # Step 6: extend the message history - if user_message is not None: - if isinstance(user_message, Message): - all_new_messages = [user_message] + all_response_messages - else: - raise ValueError(type(user_message)) + if len(messages) > 0: + all_new_messages = messages + all_response_messages else: all_new_messages = all_response_messages @@ -897,7 +848,7 @@ def step( ) except Exception as e: - printd(f"step() failed\nuser_message = {user_message}\nerror = {e}") + printd(f"step() failed\nmessages = {messages}\nerror = {e}") # If we got a context alert, try trimming the messages length, then try again if is_context_overflow_error(e): @@ -906,14 +857,14 @@ def step( # Try step again return self.step( - user_message, + messages=messages, first_message=first_message, first_message_retry_limit=first_message_retry_limit, skip_verify=skip_verify, return_dicts=return_dicts, - recreate_message_timestamp=recreate_message_timestamp, + # recreate_message_timestamp=recreate_message_timestamp, stream=stream, - timestamp=timestamp, + # timestamp=timestamp, inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option, ms=ms, ) @@ -922,6 +873,40 @@ def step( printd(f"step() failed with an unrecognized exception: '{str(e)}'") raise e + def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse: + """Takes a basic user message string, turns it into a stringified JSON with extra metadata, then sends it to the agent + + Example: + -> user_message_str = 'hi' + -> {'message': 'hi', 'type': 'user_message', ...} + -> json.dumps(...) + -> agent.step(messages=[Message(role='user', text=...)]) + """ + # Wrap with metadata, dumps to JSON + assert user_message_str and isinstance( + user_message_str, str + ), f"user_message_str should be a non-empty string, got {type(user_message_str)}" + user_message_json_str = package_user_message(user_message_str) + + # Validate JSON via save/load + user_message = validate_json(user_message_json_str) + cleaned_user_message_text, name = strip_name_field_from_user_message(user_message) + + # Turn into a dict + openai_message_dict = {"role": "user", "content": cleaned_user_message_text, "name": name} + + # Create the associated Message object (in the database) + assert self.agent_state.user_id is not None, "User ID is not set" + user_message = Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.user_id, + model=self.model, + openai_message_dict=openai_message_dict, + # created_at=timestamp, + ) + + return self.step(messages=[user_message], **kwargs) + def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True): assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})" @@ -1340,7 +1325,8 @@ def retry_message(self) -> List[Message]: self.pop_until_user() user_message = self.pop_message(count=1)[0] - step_response = self.step(user_message=user_message.text, return_dicts=False) + assert user_message.text is not None, "User message text is None" + step_response = self.step_user_message(user_message_str=user_message.text, return_dicts=False) messages = step_response.messages assert messages is not None diff --git a/letta/main.py b/letta/main.py index f663148de3..b084333caf 100644 --- a/letta/main.py +++ b/letta/main.py @@ -356,19 +356,29 @@ def run_agent_loop( else: # If message did not begin with command prefix, pass inputs to Letta # Handle user message and append to messages - user_message = system.package_user_message(user_input) + user_message = str(user_input) skip_next_user_input = False def process_agent_step(user_message, no_verify): - step_response = letta_agent.step( - user_message, - first_message=False, - skip_verify=no_verify, - stream=stream, - inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs, - ms=ms, - ) + if user_message is None: + step_response = letta_agent.step( + messages=[], + first_message=False, + skip_verify=no_verify, + stream=stream, + inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs, + ms=ms, + ) + else: + step_response = letta_agent.step_user_message( + user_message_str=user_message, + first_message=False, + skip_verify=no_verify, + stream=stream, + inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs, + ms=ms, + ) new_messages = step_response.messages heartbeat_request = step_response.heartbeat_request function_failed = step_response.function_failed diff --git a/letta/server/server.py b/letta/server/server.py index fcb00962ff..494f2cd51c 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -383,9 +383,22 @@ def _get_or_load_agent(self, agent_id: str) -> Agent: letta_agent = self._load_agent(user_id=user_id, agent_id=agent_id) return letta_agent - def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message], timestamp: Optional[datetime]) -> LettaUsageStatistics: + def _step( + self, + user_id: str, + agent_id: str, + input_messages: Union[Message, List[Message]], + # timestamp: Optional[datetime], + ) -> LettaUsageStatistics: """Send the input message through the agent""" - logger.debug(f"Got input message: {input_message}") + + # Input validation + if isinstance(input_messages, Message): + input_messages = [input_messages] + if not all(isinstance(m, Message) for m in input_messages): + raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}") + + logger.debug(f"Got input messages: {input_messages}") try: # Get the agent object (loaded in memory) @@ -398,18 +411,18 @@ def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message], logger.debug(f"Starting agent step") no_verify = True - next_input_message = input_message + next_input_message = input_messages counter = 0 total_usage = UsageStatistics() step_count = 0 while True: step_response = letta_agent.step( - next_input_message, + messages=next_input_message, first_message=False, skip_verify=no_verify, return_dicts=False, stream=token_streaming, - timestamp=timestamp, + # timestamp=timestamp, ms=self.ms, ) step_response.messages @@ -436,13 +449,40 @@ def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message], break # Chain handlers elif token_warning: - next_input_message = system.get_token_limit_warning() + assert letta_agent.agent_state.user_id is not None + next_input_message = Message.dict_to_message( + agent_id=letta_agent.agent_state.id, + user_id=letta_agent.agent_state.user_id, + model=letta_agent.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": system.get_token_limit_warning(), + }, + ) continue # always chain elif function_failed: - next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE) + assert letta_agent.agent_state.user_id is not None + next_input_message = Message.dict_to_message( + agent_id=letta_agent.agent_state.id, + user_id=letta_agent.agent_state.user_id, + model=letta_agent.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE), + }, + ) continue # always chain elif heartbeat_request: - next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) + assert letta_agent.agent_state.user_id is not None + next_input_message = Message.dict_to_message( + agent_id=letta_agent.agent_state.id, + user_id=letta_agent.agent_state.user_id, + model=letta_agent.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE), + }, + ) continue # always chain # Letta no-op / yield else: @@ -621,7 +661,7 @@ def user_message( ) # Run the agent state forward - usage = self._step(user_id=user_id, agent_id=agent_id, input_message=message, timestamp=timestamp) + usage = self._step(user_id=user_id, agent_id=agent_id, input_messages=message) return usage def system_message( @@ -669,7 +709,7 @@ def system_message( if isinstance(message, Message): # Can't have a null text field - if len(message.text) == 0 or message.text is None: + if message.text is None or len(message.text) == 0: raise ValueError(f"Invalid input: '{message.text}'") # If the input begins with a command prefix, reject elif message.text.startswith("/"): @@ -683,7 +723,7 @@ def system_message( message.created_at = timestamp # Run the agent state forward - return self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message, timestamp=timestamp) + return self._step(user_id=user_id, agent_id=agent_id, input_messages=message) # @LockingServer.agent_lock_decorator def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: