Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ability to add tags to agents #1984

Merged
merged 16 commits into from
Nov 7, 2024
52 changes: 52 additions & 0 deletions alembic/versions/b6d7ca024aa9_add_agents_tags_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Add agents tags table

Revision ID: b6d7ca024aa9
Revises: d14ae606614c
Create Date: 2024-11-06 10:48:08.424108

"""

from typing import Sequence, Union

import sqlalchemy as sa

Check failure on line 11 in alembic/versions/b6d7ca024aa9_add_agents_tags_table.py

View workflow job for this annotation

GitHub Actions / style-checks (3.12)

Import "sqlalchemy" could not be resolved (reportMissingImports)

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "b6d7ca024aa9"
down_revision: Union[str, None] = "d14ae606614c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"agents_tags",
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("tag", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["agent_id"],
["agents.id"],
),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("agent_id", "id"),
sa.UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("agents_tags")
# ### end Alembic commands ###
21 changes: 13 additions & 8 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,12 @@ def __init__(
self._default_llm_config = default_llm_config
self._default_embedding_config = default_embedding_config

def list_agents(self) -> List[AgentState]:
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers)
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
params = {}
if tags:
params["tags"] = tags

response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
return [AgentState(**agent) for agent in response.json()]

def agent_exists(self, agent_id: str) -> bool:
Expand Down Expand Up @@ -480,6 +484,7 @@ def update_agent(
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
Expand Down Expand Up @@ -509,6 +514,7 @@ def update_agent(
name=name,
system=system,
tools=tools,
tags=tags,
description=description,
metadata_=metadata,
llm_config=llm_config,
Expand Down Expand Up @@ -1617,13 +1623,10 @@ def __init__(
self.organization = self.server.get_organization_or_default(self.org_id)

# agents
def list_agents(self) -> List[AgentState]:
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
self.interface.clear()

# TODO: fix the server function
# return self.server.list_agents(user_id=self.user_id)

return self.server.ms.list_agents(user_id=self.user_id)
return self.server.list_agents(user_id=self.user_id, tags=tags)

def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
"""
Expand Down Expand Up @@ -1757,6 +1760,7 @@ def update_agent(
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
Expand Down Expand Up @@ -1788,6 +1792,7 @@ def update_agent(
name=name,
system=system,
tools=tools,
tags=tags,
description=description,
metadata_=metadata,
llm_config=llm_config,
Expand Down Expand Up @@ -1872,7 +1877,7 @@ def get_agent_by_name(self, agent_name: str) -> AgentState:
agent_state (AgentState): State of the agent
"""
self.interface.clear()
return self.server.get_agent(agent_name=agent_name, user_id=self.user_id, agent_id=None)
return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None)

def get_agent(self, agent_id: str) -> AgentState:
"""
Expand Down
2 changes: 2 additions & 0 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def create_agent(self, agent: AgentState):
fields = vars(agent)
fields["memory"] = agent.memory.to_dict()
del fields["_internal_memory"]
del fields["tags"]
session.add(AgentModel(**fields))
session.commit()

Expand Down Expand Up @@ -531,6 +532,7 @@ def update_agent(self, agent: AgentState):
if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
fields["memory"] = agent.memory.to_dict()
del fields["_internal_memory"]
del fields["tags"]
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
session.commit()

Expand Down
28 changes: 28 additions & 0 deletions letta/orm/agents_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING

from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags

if TYPE_CHECKING:
from letta.orm.organization import Organization


class AgentsTags(SqlalchemyBase, OrganizationMixin):
"""Associates tags with agents, allowing agents to have multiple tags and supporting tag-based filtering."""

__tablename__ = "agents_tags"
__pydantic_model__ = PydanticAgentsTags
__table_args__ = (UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),)

# The agent associated with this tag
agent_id = mapped_column(String, ForeignKey("agents.id"), primary_key=True)

# The name of the tag
tag: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the tag associated with the agent.")

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents_tags")
1 change: 1 addition & 0 deletions letta/orm/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Organization(SqlalchemyBase):

users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")

# TODO: Map these relationships later when we actually make these models
# below is just a suggestion
Expand Down
16 changes: 16 additions & 0 deletions letta/orm/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type[
self.is_deleted = True
return self.update(db_session)

def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
"""Permanently removes the record from the database."""
if actor:
logger.info(f"User {actor.id} requested hard deletion of {self.__class__.__name__} with ID {self.id}")

with db_session as session:
try:
session.delete(self)
session.commit()
except Exception as e:
session.rollback()
logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}")
raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}")
else:
logger.info(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")

def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
if actor:
self._set_created_and_updated_by_fields(actor.id)
Expand Down
5 changes: 5 additions & 0 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class AgentState(BaseAgent, validate_assignment=True):
# tool rules
tool_rules: Optional[List[BaseToolRule]] = Field(default=None, description="The list of tool rules.")

# tags
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")

# system prompt
system: str = Field(..., description="The system prompt used by the agent.")

Expand Down Expand Up @@ -108,6 +111,7 @@ class CreateAgent(BaseAgent):
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.")
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
Expand Down Expand Up @@ -148,6 +152,7 @@ class UpdateAgentState(BaseAgent):
id: str = Field(..., description="The id of the agent.")
name: Optional[str] = Field(None, description="The name of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
Expand Down
33 changes: 33 additions & 0 deletions letta/schemas/agents_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from datetime import datetime
from typing import Optional

from pydantic import Field

from letta.schemas.letta_base import LettaBase


class AgentsTagsBase(LettaBase):
__id_prefix__ = "agents_tags"


class AgentsTags(AgentsTagsBase):
"""
Schema representing the relationship between tags and agents.

Parameters:
agent_id (str): The ID of the associated agent.
tag_id (str): The ID of the associated tag.
tag_name (str): The name of the tag.
created_at (datetime): The date this relationship was created.
"""

id: str = AgentsTagsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
tag: str = Field(..., description="The name of the tag.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the tag.")
is_deleted: bool = Field(False, description="Whether this tag is deleted or not.")


class AgentsTagsCreate(AgentsTagsBase):
tag: str = Field(..., description="The tag name.")
Loading
Loading