Skip to content

Commit

Permalink
feat: add agent types (#1831)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vivek Verma authored Oct 8, 2024
1 parent 4a01ca3 commit 90a1e3b
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docker-compose-vllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ services:
ports:
- "8000:8000"
command: >
--model ${LETTA_LLM_MODEL} --max_model_len=8000
--model ${LETTA_LLM_MODEL} --max_model_len=8000
# Replace with your model
ipc: host
ipc: host
9 changes: 8 additions & 1 deletion letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
from letta.memory import get_memory_functions
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.schemas.block import (
Block,
CreateBlock,
Expand Down Expand Up @@ -68,6 +68,7 @@ def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str]
def create_agent(
self,
name: Optional[str] = None,
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
Expand Down Expand Up @@ -319,6 +320,8 @@ def agent_exists(self, agent_id: str) -> bool:
def create_agent(
self,
name: Optional[str] = None,
# agent config
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
# model configs
embedding_config: EmbeddingConfig = None,
llm_config: LLMConfig = None,
Expand Down Expand Up @@ -381,6 +384,7 @@ def create_agent(
memory=memory,
tools=tool_names,
system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
)
Expand Down Expand Up @@ -1462,6 +1466,8 @@ def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str]
def create_agent(
self,
name: Optional[str] = None,
# agent config
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
# model configs
embedding_config: EmbeddingConfig = None,
llm_config: LLMConfig = None,
Expand Down Expand Up @@ -1524,6 +1530,7 @@ def create_agent(
memory=memory,
tools=tool_names,
system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
),
Expand Down
2 changes: 2 additions & 0 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class AgentModel(Base):
tools = Column(JSON)

# configs
agent_type = Column(String)
llm_config = Column(LLMConfigColumn)
embedding_config = Column(EmbeddingConfigColumn)

Expand All @@ -243,6 +244,7 @@ def to_record(self) -> AgentState:
memory=Memory.load(self.memory), # load dictionary
system=self.system,
tools=self.tools,
agent_type=self.agent_type,
llm_config=self.llm_config,
embedding_config=self.embedding_config,
metadata_=self.metadata_,
Expand Down
14 changes: 14 additions & 0 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field, field_validator
Expand All @@ -21,6 +22,15 @@ class BaseAgent(LettaBase, validate_assignment=True):
user_id: Optional[str] = Field(None, description="The user id of the agent.")


class AgentType(str, Enum):
"""
Enum to represent the type of agent.
"""

memgpt_agent = "memgpt_agent"
split_thread_agent = "split_thread_agent"


class AgentState(BaseAgent):
"""
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
Expand Down Expand Up @@ -52,6 +62,9 @@ class AgentState(BaseAgent):
# system prompt
system: str = Field(..., description="The system prompt used by the agent.")

# agent configuration
agent_type: AgentType = Field(..., description="The type of agent.")

# llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
Expand All @@ -64,6 +77,7 @@ class CreateAgent(BaseAgent):
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")

Expand Down
8 changes: 6 additions & 2 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
OpenAIProvider,
VLLMProvider,
)
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.schemas.api_key import APIKey, APIKeyCreate
from letta.schemas.block import (
Block,
Expand Down Expand Up @@ -335,7 +335,10 @@ def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterfa
# Make sure the memory is a memory object
assert isinstance(agent_state.memory, Memory)

letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
if agent_state.agent_type == AgentType.memgpt_agent:
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
else:
raise NotImplementedError("Only base agents are supported as of right now!")

# Add the agent to the in-memory store and return its reference
logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
Expand Down Expand Up @@ -787,6 +790,7 @@ def create_agent(
name=request.name,
user_id=user_id,
tools=request.tools if request.tools else [],
agent_type=request.agent_type or AgentType.memgpt_agent,
llm_config=llm_config,
embedding_config=embedding_config,
system=request.system,
Expand Down

0 comments on commit 90a1e3b

Please sign in to comment.