Skip to content

Commit

Permalink
feat: Move Source to ORM model (#1979)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Nov 12, 2024
1 parent 781508f commit 5221bf6
Show file tree
Hide file tree
Showing 18 changed files with 507 additions and 234 deletions.
64 changes: 64 additions & 0 deletions alembic/versions/cda66b6cb0d6_move_sources_to_orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Move sources to orm
Revision ID: cda66b6cb0d6
Revises: b6d7ca024aa9
Create Date: 2024-11-07 13:29:57.186107
"""

from typing import Sequence, Union

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "cda66b6cb0d6"
down_revision: Union[str, None] = "b6d7ca024aa9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("sources", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("sources", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("sources", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("sources", sa.Column("_last_updated_by_id", sa.String(), nullable=True))

# Data migration step:
op.add_column("sources", sa.Column("organization_id", sa.String(), nullable=True))
# Populate `organization_id` based on `user_id`
# Use a raw SQL query to update the organization_id
op.execute(
"""
UPDATE sources
SET organization_id = users.organization_id
FROM users
WHERE sources.user_id = users.id
"""
)

# Set `organization_id` as non-nullable after population
op.alter_column("sources", "organization_id", nullable=False)

op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.drop_index("sources_idx_user", table_name="sources")
op.create_foreign_key(None, "sources", "organizations", ["organization_id"], ["id"])
op.drop_column("sources", "user_id")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("sources", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, "sources", type_="foreignkey")
op.create_index("sources_idx_user", "sources", ["user_id"], unique=False)
op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.drop_column("sources", "organization_id")
op.drop_column("sources", "_last_updated_by_id")
op.drop_column("sources", "_created_by_id")
op.drop_column("sources", "is_deleted")
op.drop_column("sources", "updated_at")
# ### end Alembic commands ###
6 changes: 4 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.source_manager import SourceManager
from letta.services.user_manager import UserManager
from letta.system import (
get_heartbeat,
get_initial_boot_messages,
Expand Down Expand Up @@ -1311,7 +1313,7 @@ def migrate_embedding(self, embedding_config: EmbeddingConfig):
def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory

user = UserManager().get_user_by_id(self.agent_state.user_id)
filters = {"user_id": self.agent_state.user_id, "source_id": source_id}
size = source_connector.size(filters)
page_size = 100
Expand Down Expand Up @@ -1339,7 +1341,7 @@ def attach_source(self, source_id: str, source_connector: StorageConnector, ms:
self.persistence_manager.archival_memory.storage.save()

# attach to agent
source = ms.get_source(source_id=source_id)
source = SourceManager().get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in metadata store"
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)

Expand Down
29 changes: 16 additions & 13 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def load_file_to_source(self, filename: str, source_id: str, blocking=True) -> J
def delete_file_from_source(self, source_id: str, file_id: str) -> None:
raise NotImplementedError

def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
raise NotImplementedError

def delete_source(self, source_id: str):
Expand Down Expand Up @@ -1188,7 +1188,7 @@ def delete_file_from_source(self, source_id: str, file_id: str) -> None:
if response.status_code not in [200, 204]:
raise ValueError(f"Failed to delete tool: {response.text}")

def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
"""
Create a source
Expand All @@ -1198,7 +1198,8 @@ def create_source(self, name: str) -> Source:
Returns:
source (Source): Created source
"""
payload = {"name": name}
source_create = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config)
payload = source_create.model_dump()
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers)
response_json = response.json()
return Source(**response_json)
Expand Down Expand Up @@ -1253,7 +1254,7 @@ def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
Returns:
source (Source): Updated source
"""
request = SourceUpdate(id=source_id, name=name)
request = SourceUpdate(name=name)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update source: {response.text}")
Expand Down Expand Up @@ -2453,7 +2454,7 @@ def list_jobs(self):
def list_active_jobs(self):
return self.server.list_active_jobs(user_id=self.user_id)

def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
"""
Create a source
Expand All @@ -2463,8 +2464,10 @@ def create_source(self, name: str) -> Source:
Returns:
source (Source): Created source
"""
request = SourceCreate(name=name)
return self.server.create_source(request=request, user_id=self.user_id)
source = Source(
name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id
)
return self.server.source_manager.create_source(source=source, actor=self.user)

def delete_source(self, source_id: str):
"""
Expand All @@ -2475,7 +2478,7 @@ def delete_source(self, source_id: str):
"""

# TODO: delete source data
self.server.delete_source(source_id=source_id, user_id=self.user_id)
self.server.delete_source(source_id=source_id, actor=self.user)

def get_source(self, source_id: str) -> Source:
"""
Expand All @@ -2487,7 +2490,7 @@ def get_source(self, source_id: str) -> Source:
Returns:
source (Source): Source
"""
return self.server.get_source(source_id=source_id, user_id=self.user_id)
return self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)

def get_source_id(self, source_name: str) -> str:
"""
Expand All @@ -2499,7 +2502,7 @@ def get_source_id(self, source_name: str) -> str:
Returns:
source_id (str): ID of the source
"""
return self.server.get_source_id(source_name=source_name, user_id=self.user_id)
return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id

def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
"""
Expand Down Expand Up @@ -2532,7 +2535,7 @@ def list_sources(self) -> List[Source]:
sources (List[Source]): List of sources
"""

return self.server.list_all_sources(user_id=self.user_id)
return self.server.list_all_sources(actor=self.user)

def list_attached_sources(self, agent_id: str) -> List[Source]:
"""
Expand Down Expand Up @@ -2572,8 +2575,8 @@ def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
source (Source): Updated source
"""
# TODO should the arg here just be "source_update: Source"?
request = SourceUpdate(id=source_id, name=name)
return self.server.update_source(request=request, user_id=self.user_id)
request = SourceUpdate(name=name)
return self.server.source_manager.update_source(source_id=source_id, source_update=request, actor=self.user)

# archival memory

Expand Down
6 changes: 3 additions & 3 deletions letta/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def load_data(
passage_store: StorageConnector,
file_metadata_store: StorageConnector,
):
"""Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id."""
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
embedding_config = source.embedding_config

# embedding model
Expand Down Expand Up @@ -88,7 +88,7 @@ def load_data(
file_id=file_metadata.id,
source_id=source.id,
metadata_=passage_metadata,
user_id=source.user_id,
user_id=source.created_by_id,
embedding_config=source.embedding_config,
embedding=embedding,
)
Expand Down Expand Up @@ -155,7 +155,7 @@ def find_files(self, source: Source) -> Iterator[FileMetadata]:

for metadata in extract_metadata_from_files(files):
yield FileMetadata(
user_id=source.user_id,
user_id=source.created_by_id,
source_id=source.id,
file_name=metadata.get("file_name"),
file_path=metadata.get("file_path"),
Expand Down
2 changes: 0 additions & 2 deletions letta/llm_api/google_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =

try:
response = requests.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")

# Grab the models out
model_list = response["models"]
Expand Down
95 changes: 3 additions & 92 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.source import Source
from letta.schemas.tool_rule import (
BaseToolRule,
InitToolRule,
Expand Down Expand Up @@ -292,40 +291,6 @@ def to_record(self) -> AgentState:
return agent_state


class SourceModel(Base):
"""Defines data model for storing Passages (consisting of text, embedding)"""

__tablename__ = "sources"
__table_args__ = {"extend_existing": True}

# Assuming passage_id is the primary key
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
embedding_config = Column(EmbeddingConfigColumn)
description = Column(String)
metadata_ = Column(JSON)
Index(__tablename__ + "_idx_user", user_id),

# TODO: add num passages

def __repr__(self) -> str:
return f"<Source(passage_id='{self.id}', name='{self.name}')>"

def to_record(self) -> Source:
return Source(
id=self.id,
user_id=self.user_id,
name=self.name,
created_at=self.created_at,
embedding_config=self.embedding_config,
description=self.description,
metadata_=self.metadata_,
)


class AgentSourceMappingModel(Base):
"""Stores mapping between agent -> source"""

Expand Down Expand Up @@ -497,14 +462,6 @@ def create_agent(self, agent: AgentState):
session.add(AgentModel(**fields))
session.commit()

@enforce_types
def create_source(self, source: Source):
with self.session_maker() as session:
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
session.add(SourceModel(**vars(source)))
session.commit()

@enforce_types
def create_block(self, block: Block):
with self.session_maker() as session:
Expand All @@ -522,6 +479,7 @@ def create_block(self, block: Block):
):

raise ValueError(f"Block with name {block.template_name} already exists")

session.add(BlockModel(**vars(block)))
session.commit()

Expand All @@ -536,12 +494,6 @@ def update_agent(self, agent: AgentState):
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
session.commit()

@enforce_types
def update_source(self, source: Source):
with self.session_maker() as session:
session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
session.commit()

@enforce_types
def update_block(self, block: Block):
with self.session_maker() as session:
Expand Down Expand Up @@ -591,29 +543,12 @@ def delete_agent(self, agent_id: str):

session.commit()

@enforce_types
def delete_source(self, source_id: str):
with self.session_maker() as session:
# delete from sources table
session.query(SourceModel).filter(SourceModel.id == source_id).delete()

# delete any mappings
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()

session.commit()

@enforce_types
def list_agents(self, user_id: str) -> List[AgentState]:
with self.session_maker() as session:
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
return [r.to_record() for r in results]

@enforce_types
def list_sources(self, user_id: str) -> List[Source]:
with self.session_maker() as session:
results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
return [r.to_record() for r in results]

@enforce_types
def get_agent(
self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None
Expand All @@ -630,21 +565,6 @@ def get_agent(
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
return results[0].to_record()

@enforce_types
def get_source(
self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None
) -> Optional[Source]:
with self.session_maker() as session:
if source_id:
results = session.query(SourceModel).filter(SourceModel.id == source_id).all()
else:
assert user_id is not None and source_name is not None
results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()

@enforce_types
def get_block(self, block_id: str) -> Optional[Block]:
with self.session_maker() as session:
Expand Down Expand Up @@ -699,19 +619,10 @@ def attach_source(self, user_id: str, agent_id: str, source_id: str):
session.commit()

@enforce_types
def list_attached_sources(self, agent_id: str) -> List[Source]:
def list_attached_source_ids(self, agent_id: str) -> List[str]:
with self.session_maker() as session:
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()

sources = []
# make sure source exists
for r in results:
source = self.get_source(source_id=r.source_id)
if source:
sources.append(source)
else:
printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.")
return sources
return [r.source_id for r in results]

@enforce_types
def list_attached_agents(self, source_id: str) -> List[str]:
Expand Down
Loading

0 comments on commit 5221bf6

Please sign in to comment.