Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added ability to disable the initial message sequence during agent creation #1978

Merged
merged 5 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def __init__(
# extras
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
initial_message_sequence: Optional[List[Message]] = None,
):
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
# Hold a copy of the state that was used to init the agent
Expand Down Expand Up @@ -294,6 +295,7 @@ def __init__(

else:
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
assert self.agent_state.id is not None and self.agent_state.user_id is not None

# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
Expand All @@ -306,14 +308,40 @@ def __init__(
include_initial_boot_message=True,
)

# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = []
for msg in init_messages:
init_messages_objs.append(
if initial_message_sequence is not None:
# We always need the system prompt up front
system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict=init_messages[0],
)
# Don't use anything else in the pregen sequence, instead use the provided sequence
init_messages = [system_message_obj] + initial_message_sequence

else:
# Basic "more human than human" initial message sequence
init_messages = initialize_message_sequence(
model=self.model,
system=self.system,
memory=self.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
# Cast to Message objects
init_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg
)
)
for msg in init_messages
]

# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = []
for msg in init_messages:
init_messages_objs.append(msg)
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)

# Put the messages inside the message buffer
Expand Down
14 changes: 13 additions & 1 deletion letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def create_agent(
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
) -> AgentState:
"""Create an agent

Expand Down Expand Up @@ -428,9 +429,18 @@ def create_agent(
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
)

# Use model_dump_json() instead of model_dump()
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
# If we use model_dump(), the datetime objects will not be serialized correctly
# response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents",
data=request.model_dump_json(), # Use model_dump_json() instead of json=model_dump()
headers={"Content-Type": "application/json", **self.headers},
)

response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}")
return AgentState(**response.json())
Expand Down Expand Up @@ -1648,6 +1658,7 @@ def create_agent(
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
) -> AgentState:
"""Create an agent

Expand Down Expand Up @@ -1702,6 +1713,7 @@ def create_agent(
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
),
actor=self.user,
)
Expand Down
8 changes: 6 additions & 2 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
Expand Down Expand Up @@ -105,14 +104,19 @@ class Config:
class CreateAgent(BaseAgent):
# all optional as server can generate defaults
name: Optional[str] = Field(None, description="The name of the agent.")
message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
# Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
# If the client wants to make this empty, then the client can set the arg to an empty list
initial_message_sequence: Optional[List[Message]] = Field(
None, description="The initial set of messages to put in the agent's in-context memory."
)

@field_validator("name")
@classmethod
Expand Down
3 changes: 3 additions & 0 deletions letta/schemas/letta_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from datetime import datetime, timezone
from logging import getLogger
from typing import Optional
from uuid import UUID
Expand All @@ -21,6 +22,8 @@ class LettaBase(BaseModel):
from_attributes=True,
# throw errors if attributes are given that don't belong
extra="forbid",
# handle datetime serialization consistently across all models
json_encoders={datetime: lambda dt: (dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt).isoformat()},
)

# def __id_prefix__(self):
Expand Down
9 changes: 7 additions & 2 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,15 +857,20 @@ def create_agent(
agent_state=agent_state,
tools=tool_objs,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
first_message_verify_mono=(
True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False
),
initial_message_sequence=request.initial_message_sequence,
)
elif request.agent_type == AgentType.o1_agent:
agent = O1Agent(
interface=interface,
agent_state=agent_state,
tools=tool_objs,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
first_message_verify_mono=(
True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False
),
)
# rebuilding agent memory on agent create in case shared memory blocks
# were specified in the new agent's memory config. we're doing this for two reasons:
Expand Down
76 changes: 75 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from dotenv import load_dotenv

from letta import create_client
from letta.agent import initialize_message_sequence
from letta.client.client import LocalClient, RESTClient
from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCallMessage,
Expand All @@ -28,6 +29,7 @@
from letta.schemas.usage import LettaUsageStatistics
from letta.services.tool_manager import ToolManager
from letta.settings import model_settings
from letta.utils import get_utc_time
from tests.helpers.client_helper import upload_file_using_client

# from tests.utils import create_config
Expand Down Expand Up @@ -598,3 +600,75 @@ def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)


@pytest.fixture
def cleanup_agents():
created_agents = []
yield created_agents
# Cleanup will run even if test fails
for agent_id in created_agents:
try:
client.delete_agent(agent_id)
except Exception as e:
print(f"Failed to delete agent {agent_id}: {e}")


def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]):
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
"""Test that we can set an initial message sequence

If we pass in None, we should get a "default" message sequence
If we pass in a non-empty list, we should get that sequence
If we pass in an empty list, we should get an empty sequence
"""

# The reference initial message sequence:
reference_init_messages = initialize_message_sequence(
model=agent.llm_config.model,
system=agent.system,
memory=agent.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)

# system, login message, send_message test, send_message receipt
assert len(reference_init_messages) > 0
assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}"

# Test with default sequence
default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None)
cleanup_agents.append(default_agent_state.id)
assert default_agent_state.message_ids is not None
assert len(default_agent_state.message_ids) > 0
assert len(default_agent_state.message_ids) == len(
reference_init_messages
), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}"

# Test with empty sequence
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
cleanup_agents.append(empty_agent_state.id)
assert empty_agent_state.message_ids is not None
assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}"

# Test with custom sequence
custom_sequence = [
Message(
role=MessageRole.user,
text="Hello, how are you?",
user_id=agent.user_id,
agent_id=agent.id,
model=agent.llm_config.model,
name=None,
tool_calls=None,
tool_call_id=None,
),
]
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
cleanup_agents.append(custom_agent_state.id)
assert custom_agent_state.message_ids is not None
assert (
len(custom_agent_state.message_ids) == len(custom_sequence) + 1
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
Loading