Skip to content

Commit

Permalink
chore: Migrate FileMetadata to ORM (#2028)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Nov 13, 2024
1 parent 266df8b commit 151ce6f
Show file tree
Hide file tree
Showing 19 changed files with 247 additions and 113 deletions.
56 changes: 56 additions & 0 deletions alembic/versions/c85a3d07c028_move_files_to_orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Move files to orm
Revision ID: c85a3d07c028
Revises: cda66b6cb0d6
Create Date: 2024-11-12 13:58:57.221081
"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "c85a3d07c028"
down_revision: Union[str, None] = "cda66b6cb0d6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("files", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("files", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("files", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("files", sa.Column("organization_id", sa.String(), nullable=True))
# Populate `organization_id` based on `user_id`
# Use a raw SQL query to update the organization_id
op.execute(
"""
UPDATE files
SET organization_id = users.organization_id
FROM users
WHERE files.user_id = users.id
"""
)
op.alter_column("files", "organization_id", nullable=False)
op.create_foreign_key(None, "files", "organizations", ["organization_id"], ["id"])
op.create_foreign_key(None, "files", "sources", ["source_id"], ["id"])
op.drop_column("files", "user_id")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, "files", type_="foreignkey")
op.drop_constraint(None, "files", type_="foreignkey")
op.drop_column("files", "organization_id")
op.drop_column("files", "_last_updated_by_id")
op.drop_column("files", "_created_by_id")
op.drop_column("files", "is_deleted")
op.drop_column("files", "updated_at")
# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion letta/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from letta.agent_store.storage import StorageConnector, TableType
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
from letta.orm.base import Base
from letta.orm.file import FileMetadata as FileMetadataModel

# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.message import Message
Expand Down
4 changes: 2 additions & 2 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2440,7 +2440,7 @@ def load_file_to_source(self, filename: str, source_id: str, blocking=True):
return job

def delete_file_from_source(self, source_id: str, file_id: str):
self.server.delete_file_from_source(source_id, file_id, user_id=self.user_id)
self.server.source_manager.delete_file(file_id, actor=self.user)

def get_job(self, job_id: str):
return self.server.get_job(job_id=job_id)
Expand Down Expand Up @@ -2561,7 +2561,7 @@ def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Opti
Returns:
files (List[FileMetadata]): List of files
"""
return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
return self.server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=self.user)

def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
"""
Expand Down
11 changes: 3 additions & 8 deletions letta/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from letta.schemas.file import FileMetadata
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.services.source_manager import SourceManager
from letta.utils import create_uuid_from_string


Expand Down Expand Up @@ -41,12 +42,7 @@ def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Itera
"""


def load_data(
connector: DataConnector,
source: Source,
passage_store: StorageConnector,
file_metadata_store: StorageConnector,
):
def load_data(connector: DataConnector, source: Source, passage_store: StorageConnector, source_manager: SourceManager, actor: "User"):
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
embedding_config = source.embedding_config

Expand All @@ -60,7 +56,7 @@ def load_data(
file_count = 0
for file_metadata in connector.find_files(source):
file_count += 1
file_metadata_store.insert(file_metadata)
source_manager.create_file(file_metadata, actor)

# generate passages
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):
Expand Down Expand Up @@ -155,7 +151,6 @@ def find_files(self, source: Source) -> Iterator[FileMetadata]:

for metadata in extract_metadata_from_files(files):
yield FileMetadata(
user_id=source.created_by_id,
source_id=source.id,
file_name=metadata.get("file_name"),
file_path=metadata.get("file_path"),
Expand Down
73 changes: 0 additions & 73 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Column,
DateTime,
Index,
Integer,
String,
TypeDecorator,
)
Expand All @@ -24,7 +23,6 @@
from letta.schemas.block import Block, Human, Persona
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
Expand All @@ -40,41 +38,6 @@
from letta.utils import enforce_types, get_utc_time, printd


class FileMetadataModel(Base):
__tablename__ = "files"
__table_args__ = {"extend_existing": True}

id = Column(String, primary_key=True, nullable=False)
user_id = Column(String, nullable=False)
# TODO: Investigate why this breaks during table creation due to FK
# source_id = Column(String, ForeignKey("sources.id"), nullable=False)
source_id = Column(String, nullable=False)
file_name = Column(String, nullable=True)
file_path = Column(String, nullable=True)
file_type = Column(String, nullable=True)
file_size = Column(Integer, nullable=True)
file_creation_date = Column(String, nullable=True)
file_last_modified_date = Column(String, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())

def __repr__(self):
return f"<FileMetadata(id='{self.id}', source_id='{self.source_id}', file_name='{self.file_name}')>"

def to_record(self):
return FileMetadata(
id=self.id,
user_id=self.user_id,
source_id=self.source_id,
file_name=self.file_name,
file_path=self.file_path,
file_type=self.file_type,
file_size=self.file_size,
file_creation_date=self.file_creation_date,
file_last_modified_date=self.file_last_modified_date,
created_at=self.created_at,
)


class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON"""

Expand Down Expand Up @@ -510,21 +473,6 @@ def update_or_create_block(self, block: Block):
session.add(BlockModel(**vars(block)))
session.commit()

@enforce_types
def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]):
with self.session_maker() as session:
file_metadata = (
session.query(FileMetadataModel)
.filter(FileMetadataModel.source_id == source_id, FileMetadataModel.id == file_id, FileMetadataModel.user_id == user_id)
.first()
)

if file_metadata:
session.delete(file_metadata)
session.commit()

return file_metadata

@enforce_types
def delete_block(self, block_id: str):
with self.session_maker() as session:
Expand Down Expand Up @@ -653,27 +601,6 @@ def create_job(self, job: Job):
session.add(JobModel(**vars(job)))
session.commit()

@enforce_types
def list_files_from_source(self, source_id: str, limit: int, cursor: Optional[str]):
with self.session_maker() as session:
# Start with the basic query filtered by source_id
query = session.query(FileMetadataModel).filter(FileMetadataModel.source_id == source_id)

if cursor:
# Assuming cursor is the ID of the last file in the previous page
query = query.filter(FileMetadataModel.id > cursor)

# Order by ID or other ordering criteria to ensure correct pagination
query = query.order_by(FileMetadataModel.id)

# Limit the number of results returned
results = query.limit(limit).all()

# Convert the results to the required FileMetadata objects
files = [r.to_record() for r in results]

return files

def delete_job(self, job_id: str):
with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job_id).delete()
Expand Down
1 change: 1 addition & 0 deletions letta/orm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from letta.orm.base import Base
from letta.orm.file import FileMetadata
from letta.orm.organization import Organization
from letta.orm.source import Source
from letta.orm.tool import Tool
Expand Down
29 changes: 29 additions & 0 deletions letta/orm/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import TYPE_CHECKING, Optional

from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.mixins import OrganizationMixin, SourceMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.file import FileMetadata as PydanticFileMetadata

if TYPE_CHECKING:
from letta.orm.organization import Organization


class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
"""Represents metadata for an uploaded file."""

__tablename__ = "files"
__pydantic_model__ = PydanticFileMetadata

file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The name of the file.")
file_path: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The file path on the system.")
file_type: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The type of the file.")
file_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="The size of the file in bytes.")
file_creation_date: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The creation date of the file.")
file_last_modified_date: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The last modified date of the file.")

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
8 changes: 8 additions & 0 deletions letta/orm/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ class UserMixin(Base):
__abstract__ = True

user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))


class SourceMixin(Base):
"""Mixin for models (e.g. file) that belong to a source."""

__abstract__ = True

source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id"))
4 changes: 2 additions & 2 deletions letta/orm/organization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, List

from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.file import FileMetadata
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.organization import Organization as PydanticOrganization

Expand All @@ -18,14 +18,14 @@ class Organization(SqlalchemyBase):
__tablename__ = "organizations"
__pydantic_model__ = PydanticOrganization

id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(doc="The display name of the organization.")

# relationships
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
# TODO: Map these relationships later when we actually make these models
# below is just a suggestion
# agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
Expand Down
3 changes: 2 additions & 1 deletion letta/orm/source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship
Expand Down Expand Up @@ -47,4 +47,5 @@ class Source(SqlalchemyBase, OrganizationMixin):

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
# agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
1 change: 0 additions & 1 deletion letta/orm/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class Tool(SqlalchemyBase, OrganizationMixin):
# An organization should not have multiple tools with the same name
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)

id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
Expand Down
2 changes: 0 additions & 2 deletions letta/orm/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING

from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.mixins import OrganizationMixin
Expand All @@ -17,7 +16,6 @@ class User(SqlalchemyBase, OrganizationMixin):
__tablename__ = "users"
__pydantic_model__ = PydanticUser

id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.")

# relationships
Expand Down
10 changes: 5 additions & 5 deletions letta/schemas/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pydantic import Field

from letta.schemas.letta_base import LettaBase
from letta.utils import get_utc_time


class FileMetadataBase(LettaBase):
Expand All @@ -17,15 +16,16 @@ class FileMetadata(FileMetadataBase):
"""Representation of a single FileMetadata"""

id: str = FileMetadataBase.generate_id_field()
user_id: str = Field(description="The unique identifier of the user associated with the document.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the document.")
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
file_name: Optional[str] = Field(None, description="The name of the file.")
file_path: Optional[str] = Field(None, description="The path to the file.")
file_type: Optional[str] = Field(None, description="The type of the file (MIME type).")
file_size: Optional[int] = Field(None, description="The size of the file in bytes.")
file_creation_date: Optional[str] = Field(None, description="The creation date of the file.")
file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.")
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.")

class Config:
extra = "allow"
# orm metadata, optional fields
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.")
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the file.")
is_deleted: bool = Field(False, description="Whether this file is deleted or not.")
6 changes: 4 additions & 2 deletions letta/server/rest_api/routers/v1/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,13 @@ def list_files_from_source(
limit: int = Query(1000, description="Number of files to return"),
cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List paginated files associated with a data source.
"""
return server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
actor = server.get_user_or_default(user_id=user_id)
return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor)


# it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action.
Expand All @@ -219,7 +221,7 @@ def delete_file_from_source(
"""
actor = server.get_user_or_default(user_id=user_id)

deleted_file = server.delete_file_from_source(source_id=source_id, file_id=file_id, user_id=actor.id)
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
if deleted_file is None:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")

Expand Down
Loading

0 comments on commit 151ce6f

Please sign in to comment.