Skip to content

Commit

Permalink
feat: added ability to disable the initial message sequence during ag…
Browse files Browse the repository at this point in the history
…ent creation (#1978)
  • Loading branch information
cpacker authored Nov 5, 2024
1 parent edebfc1 commit a5e9f7d
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 11 deletions.
38 changes: 33 additions & 5 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def __init__(
# extras
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
initial_message_sequence: Optional[List[Message]] = None,
):
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
# Hold a copy of the state that was used to init the agent
Expand Down Expand Up @@ -294,6 +295,7 @@ def __init__(

else:
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
assert self.agent_state.id is not None and self.agent_state.user_id is not None

# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
Expand All @@ -306,14 +308,40 @@ def __init__(
include_initial_boot_message=True,
)

# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = []
for msg in init_messages:
init_messages_objs.append(
if initial_message_sequence is not None:
# We always need the system prompt up front
system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict=init_messages[0],
)
# Don't use anything else in the pregen sequence, instead use the provided sequence
init_messages = [system_message_obj] + initial_message_sequence

else:
# Basic "more human than human" initial message sequence
init_messages = initialize_message_sequence(
model=self.model,
system=self.system,
memory=self.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
# Cast to Message objects
init_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg
)
)
for msg in init_messages
]

# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = []
for msg in init_messages:
init_messages_objs.append(msg)
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)

# Put the messages inside the message buffer
Expand Down
14 changes: 13 additions & 1 deletion letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def create_agent(
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
) -> AgentState:
"""Create an agent
Expand Down Expand Up @@ -428,9 +429,18 @@ def create_agent(
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
)

# Use model_dump_json() instead of model_dump()
# If we use model_dump(), the datetime objects will not be serialized correctly
# response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents",
data=request.model_dump_json(), # Use model_dump_json() instead of json=model_dump()
headers={"Content-Type": "application/json", **self.headers},
)

response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}")
return AgentState(**response.json())
Expand Down Expand Up @@ -1648,6 +1658,7 @@ def create_agent(
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
) -> AgentState:
"""Create an agent
Expand Down Expand Up @@ -1702,6 +1713,7 @@ def create_agent(
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
),
actor=self.user,
)
Expand Down
8 changes: 6 additions & 2 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
Expand Down Expand Up @@ -105,14 +104,19 @@ class Config:
class CreateAgent(BaseAgent):
# all optional as server can generate defaults
name: Optional[str] = Field(None, description="The name of the agent.")
message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
# Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
# If the client wants to make this empty, then the client can set the arg to an empty list
initial_message_sequence: Optional[List[Message]] = Field(
None, description="The initial set of messages to put in the agent's in-context memory."
)

@field_validator("name")
@classmethod
Expand Down
2 changes: 2 additions & 0 deletions letta/schemas/letta_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class LettaBase(BaseModel):
from_attributes=True,
# throw errors if attributes are given that don't belong
extra="forbid",
# handle datetime serialization consistently across all models
# json_encoders={datetime: lambda dt: (dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt).isoformat()},
)

# def __id_prefix__(self):
Expand Down
9 changes: 7 additions & 2 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,15 +857,20 @@ def create_agent(
agent_state=agent_state,
tools=tool_objs,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
first_message_verify_mono=(
True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False
),
initial_message_sequence=request.initial_message_sequence,
)
elif request.agent_type == AgentType.o1_agent:
agent = O1Agent(
interface=interface,
agent_state=agent_state,
tools=tool_objs,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
first_message_verify_mono=(
True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False
),
)
# rebuilding agent memory on agent create in case shared memory blocks
# were specified in the new agent's memory config. we're doing this for two reasons:
Expand Down
76 changes: 75 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from dotenv import load_dotenv

from letta import create_client
from letta.agent import initialize_message_sequence
from letta.client.client import LocalClient, RESTClient
from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCallMessage,
Expand All @@ -28,6 +29,7 @@
from letta.schemas.usage import LettaUsageStatistics
from letta.services.tool_manager import ToolManager
from letta.settings import model_settings
from letta.utils import get_utc_time
from tests.helpers.client_helper import upload_file_using_client

# from tests.utils import create_config
Expand Down Expand Up @@ -598,3 +600,75 @@ def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)


@pytest.fixture
def cleanup_agents():
created_agents = []
yield created_agents
# Cleanup will run even if test fails
for agent_id in created_agents:
try:
client.delete_agent(agent_id)
except Exception as e:
print(f"Failed to delete agent {agent_id}: {e}")


def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]):
"""Test that we can set an initial message sequence
If we pass in None, we should get a "default" message sequence
If we pass in a non-empty list, we should get that sequence
If we pass in an empty list, we should get an empty sequence
"""

# The reference initial message sequence:
reference_init_messages = initialize_message_sequence(
model=agent.llm_config.model,
system=agent.system,
memory=agent.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)

# system, login message, send_message test, send_message receipt
assert len(reference_init_messages) > 0
assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}"

# Test with default sequence
default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None)
cleanup_agents.append(default_agent_state.id)
assert default_agent_state.message_ids is not None
assert len(default_agent_state.message_ids) > 0
assert len(default_agent_state.message_ids) == len(
reference_init_messages
), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}"

# Test with empty sequence
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
cleanup_agents.append(empty_agent_state.id)
assert empty_agent_state.message_ids is not None
assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}"

# Test with custom sequence
custom_sequence = [
Message(
role=MessageRole.user,
text="Hello, how are you?",
user_id=agent.user_id,
agent_id=agent.id,
model=agent.llm_config.model,
name=None,
tool_calls=None,
tool_call_id=None,
),
]
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
cleanup_agents.append(custom_agent_state.id)
assert custom_agent_state.message_ids is not None
assert (
len(custom_agent_state.message_ids) == len(custom_sequence) + 1
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]

0 comments on commit a5e9f7d

Please sign in to comment.