Skip to content

Commit

Permalink
feat: Add ability to add tags to agents (#1984)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Nov 7, 2024
1 parent 204a570 commit dca47fc
Show file tree
Hide file tree
Showing 14 changed files with 465 additions and 197 deletions.
52 changes: 52 additions & 0 deletions alembic/versions/b6d7ca024aa9_add_agents_tags_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Add agents tags table
Revision ID: b6d7ca024aa9
Revises: d14ae606614c
Create Date: 2024-11-06 10:48:08.424108
"""

from typing import Sequence, Union

import sqlalchemy as sa

Check failure on line 11 in alembic/versions/b6d7ca024aa9_add_agents_tags_table.py

View workflow job for this annotation

GitHub Actions / style-checks (3.12)

Import "sqlalchemy" could not be resolved (reportMissingImports)

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "b6d7ca024aa9"
down_revision: Union[str, None] = "d14ae606614c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"agents_tags",
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("tag", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["agent_id"],
["agents.id"],
),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("agent_id", "id"),
sa.UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("agents_tags")
# ### end Alembic commands ###
21 changes: 13 additions & 8 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,12 @@ def __init__(
self._default_llm_config = default_llm_config
self._default_embedding_config = default_embedding_config

def list_agents(self) -> List[AgentState]:
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers)
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
params = {}
if tags:
params["tags"] = tags

response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
return [AgentState(**agent) for agent in response.json()]

def agent_exists(self, agent_id: str) -> bool:
Expand Down Expand Up @@ -480,6 +484,7 @@ def update_agent(
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
Expand Down Expand Up @@ -509,6 +514,7 @@ def update_agent(
name=name,
system=system,
tools=tools,
tags=tags,
description=description,
metadata_=metadata,
llm_config=llm_config,
Expand Down Expand Up @@ -1617,13 +1623,10 @@ def __init__(
self.organization = self.server.get_organization_or_default(self.org_id)

# agents
def list_agents(self) -> List[AgentState]:
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
self.interface.clear()

# TODO: fix the server function
# return self.server.list_agents(user_id=self.user_id)

return self.server.ms.list_agents(user_id=self.user_id)
return self.server.list_agents(user_id=self.user_id, tags=tags)

def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
"""
Expand Down Expand Up @@ -1757,6 +1760,7 @@ def update_agent(
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
Expand Down Expand Up @@ -1788,6 +1792,7 @@ def update_agent(
name=name,
system=system,
tools=tools,
tags=tags,
description=description,
metadata_=metadata,
llm_config=llm_config,
Expand Down Expand Up @@ -1872,7 +1877,7 @@ def get_agent_by_name(self, agent_name: str) -> AgentState:
agent_state (AgentState): State of the agent
"""
self.interface.clear()
return self.server.get_agent(agent_name=agent_name, user_id=self.user_id, agent_id=None)
return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None)

def get_agent(self, agent_id: str) -> AgentState:
"""
Expand Down
2 changes: 2 additions & 0 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def create_agent(self, agent: AgentState):
fields = vars(agent)
fields["memory"] = agent.memory.to_dict()
del fields["_internal_memory"]
del fields["tags"]
session.add(AgentModel(**fields))
session.commit()

Expand Down Expand Up @@ -531,6 +532,7 @@ def update_agent(self, agent: AgentState):
if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
fields["memory"] = agent.memory.to_dict()
del fields["_internal_memory"]
del fields["tags"]
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
session.commit()

Expand Down
28 changes: 28 additions & 0 deletions letta/orm/agents_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING

from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags

if TYPE_CHECKING:
from letta.orm.organization import Organization


class AgentsTags(SqlalchemyBase, OrganizationMixin):
"""Associates tags with agents, allowing agents to have multiple tags and supporting tag-based filtering."""

__tablename__ = "agents_tags"
__pydantic_model__ = PydanticAgentsTags
__table_args__ = (UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),)

# The agent associated with this tag
agent_id = mapped_column(String, ForeignKey("agents.id"), primary_key=True)

# The name of the tag
tag: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the tag associated with the agent.")

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents_tags")
1 change: 1 addition & 0 deletions letta/orm/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Organization(SqlalchemyBase):

users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")

# TODO: Map these relationships later when we actually make these models
# below is just a suggestion
Expand Down
16 changes: 16 additions & 0 deletions letta/orm/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type[
self.is_deleted = True
return self.update(db_session)

def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
"""Permanently removes the record from the database."""
if actor:
logger.info(f"User {actor.id} requested hard deletion of {self.__class__.__name__} with ID {self.id}")

with db_session as session:
try:
session.delete(self)
session.commit()
except Exception as e:
session.rollback()
logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}")
raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}")
else:
logger.info(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")

def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
if actor:
self._set_created_and_updated_by_fields(actor.id)
Expand Down
5 changes: 5 additions & 0 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class AgentState(BaseAgent, validate_assignment=True):
# tool rules
tool_rules: Optional[List[BaseToolRule]] = Field(default=None, description="The list of tool rules.")

# tags
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")

# system prompt
system: str = Field(..., description="The system prompt used by the agent.")

Expand Down Expand Up @@ -108,6 +111,7 @@ class CreateAgent(BaseAgent):
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.")
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
Expand Down Expand Up @@ -148,6 +152,7 @@ class UpdateAgentState(BaseAgent):
id: str = Field(..., description="The id of the agent.")
name: Optional[str] = Field(None, description="The name of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
Expand Down
33 changes: 33 additions & 0 deletions letta/schemas/agents_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from datetime import datetime
from typing import Optional

from pydantic import Field

from letta.schemas.letta_base import LettaBase


class AgentsTagsBase(LettaBase):
__id_prefix__ = "agents_tags"


class AgentsTags(AgentsTagsBase):
"""
Schema representing the relationship between tags and agents.
Parameters:
agent_id (str): The ID of the associated agent.
tag_id (str): The ID of the associated tag.
tag_name (str): The name of the tag.
created_at (datetime): The date this relationship was created.
"""

id: str = AgentsTagsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
tag: str = Field(..., description="The name of the tag.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the tag.")
is_deleted: bool = Field(False, description="Whether this tag is deleted or not.")


class AgentsTagsCreate(AgentsTagsBase):
tag: str = Field(..., description="The tag name.")
Loading

0 comments on commit dca47fc

Please sign in to comment.