From 151ce6f655cbb1e7a629667f6363703956c36d9e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 13 Nov 2024 10:59:46 -0800 Subject: [PATCH] chore: Migrate FileMetadata to ORM (#2028) --- .../c85a3d07c028_move_files_to_orm.py | 56 ++++++++++++ letta/agent_store/db.py | 3 +- letta/client/client.py | 4 +- letta/data_sources/connectors.py | 11 +-- letta/metadata.py | 73 ---------------- letta/orm/__init__.py | 1 + letta/orm/file.py | 29 +++++++ letta/orm/mixins.py | 8 ++ letta/orm/organization.py | 4 +- letta/orm/source.py | 3 +- letta/orm/tool.py | 1 - letta/orm/user.py | 2 - letta/schemas/file.py | 10 +-- letta/server/rest_api/routers/v1/sources.py | 6 +- letta/server/server.py | 12 +-- letta/services/source_manager.py | 45 ++++++++++ tests/test_client.py | 3 +- tests/test_managers.py | 86 ++++++++++++++++++- tests/utils.py | 3 +- 19 files changed, 247 insertions(+), 113 deletions(-) create mode 100644 alembic/versions/c85a3d07c028_move_files_to_orm.py create mode 100644 letta/orm/file.py diff --git a/alembic/versions/c85a3d07c028_move_files_to_orm.py b/alembic/versions/c85a3d07c028_move_files_to_orm.py new file mode 100644 index 0000000000..b05d793031 --- /dev/null +++ b/alembic/versions/c85a3d07c028_move_files_to_orm.py @@ -0,0 +1,56 @@ +"""Move files to orm + +Revision ID: c85a3d07c028 +Revises: cda66b6cb0d6 +Create Date: 2024-11-12 13:58:57.221081 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c85a3d07c028" +down_revision: Union[str, None] = "cda66b6cb0d6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("files", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("files", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("files", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("files", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) + op.add_column("files", sa.Column("organization_id", sa.String(), nullable=True)) + # Populate `organization_id` based on `user_id` + # Use a raw SQL query to update the organization_id + op.execute( + """ + UPDATE files + SET organization_id = users.organization_id + FROM users + WHERE files.user_id = users.id + """ + ) + op.alter_column("files", "organization_id", nullable=False) + op.create_foreign_key(None, "files", "organizations", ["organization_id"], ["id"]) + op.create_foreign_key(None, "files", "sources", ["source_id"], ["id"]) + op.drop_column("files", "user_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("files", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.drop_constraint(None, "files", type_="foreignkey") + op.drop_constraint(None, "files", type_="foreignkey") + op.drop_column("files", "organization_id") + op.drop_column("files", "_last_updated_by_id") + op.drop_column("files", "_created_by_id") + op.drop_column("files", "is_deleted") + op.drop_column("files", "updated_at") + # ### end Alembic commands ### diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index 531641d64c..b9f3b40e4e 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -27,8 +27,9 @@ from letta.agent_store.storage import StorageConnector, TableType from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM -from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn +from letta.metadata import EmbeddingConfigColumn, ToolCallColumn from letta.orm.base import Base +from letta.orm.file import FileMetadata as FileMetadataModel # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall from letta.schemas.message import Message diff --git a/letta/client/client.py b/letta/client/client.py index 519902ba47..f2d2450a88 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2440,7 +2440,7 @@ def load_file_to_source(self, filename: str, source_id: str, blocking=True): return job def delete_file_from_source(self, source_id: str, file_id: str): - self.server.delete_file_from_source(source_id, file_id, user_id=self.user_id) + self.server.source_manager.delete_file(file_id, actor=self.user) def get_job(self, job_id: str): return self.server.get_job(job_id=job_id) @@ -2561,7 +2561,7 @@ def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Opti Returns: files (List[FileMetadata]): List of files """ - return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor) + return self.server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=self.user) def update_source(self, source_id: str, name: Optional[str] = None) -> Source: """ diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index f9fb3d2af1..2b88ac12f7 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -12,6 +12,7 @@ from letta.schemas.file import FileMetadata from letta.schemas.passage import Passage from letta.schemas.source import Source +from letta.services.source_manager import SourceManager from letta.utils import create_uuid_from_string @@ -41,12 +42,7 @@ def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Itera """ -def load_data( - connector: DataConnector, - source: Source, - passage_store: StorageConnector, - file_metadata_store: StorageConnector, -): +def load_data(connector: DataConnector, source: Source, passage_store: StorageConnector, source_manager: SourceManager, actor: "User"): """Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id.""" embedding_config = source.embedding_config @@ -60,7 +56,7 @@ def load_data( file_count = 0 for file_metadata in connector.find_files(source): file_count += 1 - file_metadata_store.insert(file_metadata) + source_manager.create_file(file_metadata, actor) # generate passages for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size): @@ -155,7 +151,6 @@ def find_files(self, source: Source) -> Iterator[FileMetadata]: for metadata in extract_metadata_from_files(files): yield FileMetadata( - user_id=source.created_by_id, source_id=source.id, file_name=metadata.get("file_name"), file_path=metadata.get("file_path"), diff --git a/letta/metadata.py b/letta/metadata.py index 9ffc81c6ed..dc87d03206 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -11,7 +11,6 @@ Column, DateTime, Index, - Integer, String, TypeDecorator, ) @@ -24,7 +23,6 @@ from letta.schemas.block import Block, Human, Persona from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus -from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory @@ -40,41 +38,6 @@ from letta.utils import enforce_types, get_utc_time, printd -class FileMetadataModel(Base): - __tablename__ = "files" - __table_args__ = {"extend_existing": True} - - id = Column(String, primary_key=True, nullable=False) - user_id = Column(String, nullable=False) - # TODO: Investigate why this breaks during table creation due to FK - # source_id = Column(String, ForeignKey("sources.id"), nullable=False) - source_id = Column(String, nullable=False) - file_name = Column(String, nullable=True) - file_path = Column(String, nullable=True) - file_type = Column(String, nullable=True) - file_size = Column(Integer, nullable=True) - file_creation_date = Column(String, nullable=True) - file_last_modified_date = Column(String, nullable=True) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - - def __repr__(self): - return f"" - - def to_record(self): - return FileMetadata( - id=self.id, - user_id=self.user_id, - source_id=self.source_id, - file_name=self.file_name, - file_path=self.file_path, - file_type=self.file_type, - file_size=self.file_size, - file_creation_date=self.file_creation_date, - file_last_modified_date=self.file_last_modified_date, - created_at=self.created_at, - ) - - class LLMConfigColumn(TypeDecorator): """Custom type for storing LLMConfig as JSON""" @@ -510,21 +473,6 @@ def update_or_create_block(self, block: Block): session.add(BlockModel(**vars(block))) session.commit() - @enforce_types - def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]): - with self.session_maker() as session: - file_metadata = ( - session.query(FileMetadataModel) - .filter(FileMetadataModel.source_id == source_id, FileMetadataModel.id == file_id, FileMetadataModel.user_id == user_id) - .first() - ) - - if file_metadata: - session.delete(file_metadata) - session.commit() - - return file_metadata - @enforce_types def delete_block(self, block_id: str): with self.session_maker() as session: @@ -653,27 +601,6 @@ def create_job(self, job: Job): session.add(JobModel(**vars(job))) session.commit() - @enforce_types - def list_files_from_source(self, source_id: str, limit: int, cursor: Optional[str]): - with self.session_maker() as session: - # Start with the basic query filtered by source_id - query = session.query(FileMetadataModel).filter(FileMetadataModel.source_id == source_id) - - if cursor: - # Assuming cursor is the ID of the last file in the previous page - query = query.filter(FileMetadataModel.id > cursor) - - # Order by ID or other ordering criteria to ensure correct pagination - query = query.order_by(FileMetadataModel.id) - - # Limit the number of results returned - results = query.limit(limit).all() - - # Convert the results to the required FileMetadata objects - files = [r.to_record() for r in results] - - return files - def delete_job(self, job_id: str): with self.session_maker() as session: session.query(JobModel).filter(JobModel.id == job_id).delete() diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index b69737ac65..733ce81613 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -1,4 +1,5 @@ from letta.orm.base import Base +from letta.orm.file import FileMetadata from letta.orm.organization import Organization from letta.orm.source import Source from letta.orm.tool import Tool diff --git a/letta/orm/file.py b/letta/orm/file.py new file mode 100644 index 0000000000..aec468819e --- /dev/null +++ b/letta/orm/file.py @@ -0,0 +1,29 @@ +from typing import TYPE_CHECKING, Optional + +from sqlalchemy import Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin, SourceMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.file import FileMetadata as PydanticFileMetadata + +if TYPE_CHECKING: + from letta.orm.organization import Organization + + +class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): + """Represents metadata for an uploaded file.""" + + __tablename__ = "files" + __pydantic_model__ = PydanticFileMetadata + + file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The name of the file.") + file_path: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The file path on the system.") + file_type: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The type of the file.") + file_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="The size of the file in bytes.") + file_creation_date: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The creation date of the file.") + file_last_modified_date: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The last modified date of the file.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin") + source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin") diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index 57145475ff..d49b868b19 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -29,3 +29,11 @@ class UserMixin(Base): __abstract__ = True user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id")) + + +class SourceMixin(Base): + """Mixin for models (e.g. file) that belong to a source.""" + + __abstract__ = True + + source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id")) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index a6b05ee6f9..9cfdfb9212 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING, List -from sqlalchemy import String from sqlalchemy.orm import Mapped, mapped_column, relationship +from letta.orm.file import FileMetadata from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.organization import Organization as PydanticOrganization @@ -18,7 +18,6 @@ class Organization(SqlalchemyBase): __tablename__ = "organizations" __pydantic_model__ = PydanticOrganization - id: Mapped[str] = mapped_column(String, primary_key=True) name: Mapped[str] = mapped_column(doc="The display name of the organization.") # relationships @@ -26,6 +25,7 @@ class Organization(SqlalchemyBase): tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan") + files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan") # TODO: Map these relationships later when we actually make these models # below is just a suggestion # agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/orm/source.py b/letta/orm/source.py index e8a7ed47a7..4b2262f09b 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional from sqlalchemy import JSON, TypeDecorator from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -47,4 +47,5 @@ class Source(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") + files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan") # agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") diff --git a/letta/orm/tool.py b/letta/orm/tool.py index 5e0ec0d047..d86fffa2ee 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -28,7 +28,6 @@ class Tool(SqlalchemyBase, OrganizationMixin): # An organization should not have multiple tools with the same name __table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),) - id: Mapped[str] = mapped_column(String, primary_key=True) name: Mapped[str] = mapped_column(doc="The display name of the tool.") description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.") tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.") diff --git a/letta/orm/user.py b/letta/orm/user.py index 05f6910254..6e41456254 100644 --- a/letta/orm/user.py +++ b/letta/orm/user.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING -from sqlalchemy import String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin @@ -17,7 +16,6 @@ class User(SqlalchemyBase, OrganizationMixin): __tablename__ = "users" __pydantic_model__ = PydanticUser - id: Mapped[str] = mapped_column(String, primary_key=True) name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.") # relationships diff --git a/letta/schemas/file.py b/letta/schemas/file.py index 37e392545f..b43eb64c38 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -4,7 +4,6 @@ from pydantic import Field from letta.schemas.letta_base import LettaBase -from letta.utils import get_utc_time class FileMetadataBase(LettaBase): @@ -17,7 +16,7 @@ class FileMetadata(FileMetadataBase): """Representation of a single FileMetadata""" id: str = FileMetadataBase.generate_id_field() - user_id: str = Field(description="The unique identifier of the user associated with the document.") + organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the document.") source_id: str = Field(..., description="The unique identifier of the source associated with the document.") file_name: Optional[str] = Field(None, description="The name of the file.") file_path: Optional[str] = Field(None, description="The path to the file.") @@ -25,7 +24,8 @@ class FileMetadata(FileMetadataBase): file_size: Optional[int] = Field(None, description="The size of the file in bytes.") file_creation_date: Optional[str] = Field(None, description="The creation date of the file.") file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.") - created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.") - class Config: - extra = "allow" + # orm metadata, optional fields + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.") + updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the file.") + is_deleted: bool = Field(False, description="Whether this file is deleted or not.") diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 58047f1296..c0558b0ca5 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -198,11 +198,13 @@ def list_files_from_source( limit: int = Query(1000, description="Number of files to return"), cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List paginated files associated with a data source. """ - return server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor) + actor = server.get_user_or_default(user_id=user_id) + return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor) # it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action. @@ -219,7 +221,7 @@ def delete_file_from_source( """ actor = server.get_user_or_default(user_id=user_id) - deleted_file = server.delete_file_from_source(source_id=source_id, file_id=file_id, user_id=actor.id) + deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor) if deleted_file is None: raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.") diff --git a/letta/server/server.py b/letta/server/server.py index 023fffadf8..25305f62f1 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -65,7 +65,6 @@ # openai schemas from letta.schemas.enums import JobStatus -from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig @@ -1632,9 +1631,6 @@ def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Jo return job - def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]) -> Optional[FileMetadata]: - return self.ms.delete_file_from_source(source_id=source_id, file_id=file_id, user_id=user_id) - def load_data( self, user_id: str, @@ -1652,10 +1648,9 @@ def load_data( # get the data connectors passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) - file_store = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id) # load data into the document store - passage_count, document_count = load_data(connector, source, passage_store, file_store) + passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user) return passage_count, document_count def attach_source_to_agent( @@ -1674,7 +1669,6 @@ def attach_source_to_agent( data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) else: raise ValueError(f"Need to provide at least source_id or source_name to find the source.") - # get connection to data source storage source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) @@ -1720,10 +1714,6 @@ def list_attached_sources(self, agent_id: str) -> List[Source]: return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids] - def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]: - # list all attached sources to an agent - return self.ms.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor) - def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index e09bddd959..03684d2367 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -1,7 +1,9 @@ from typing import List, Optional from letta.orm.errors import NoResultFound +from letta.orm.file import FileMetadata as FileMetadataModel from letta.orm.source import Source as SourceModel +from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate from letta.schemas.user import User as PydanticUser @@ -98,3 +100,46 @@ def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[ return None else: return sources[0].to_pydantic() + + @enforce_types + def create_file(self, file_metadata: PydanticFileMetadata, actor: PydanticUser) -> PydanticFileMetadata: + """Create a new file based on the PydanticFileMetadata schema.""" + db_file = self.get_file_by_id(file_metadata.id, actor=actor) + if db_file: + return db_file + else: + with self.session_maker() as session: + file_metadata.organization_id = actor.organization_id + file_metadata = FileMetadataModel(**file_metadata.model_dump(exclude_none=True)) + file_metadata.create(session, actor=actor) + return file_metadata.to_pydantic() + + # TODO: We make actor optional for now, but should most likely be enforced due to security reasons + @enforce_types + def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]: + """Retrieve a file by its ID.""" + with self.session_maker() as session: + try: + file = FileMetadataModel.read(db_session=session, identifier=file_id, actor=actor) + return file.to_pydantic() + except NoResultFound: + return None + + @enforce_types + def list_files( + self, source_id: str, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50 + ) -> List[PydanticFileMetadata]: + """List all files with optional pagination.""" + with self.session_maker() as session: + files = FileMetadataModel.list( + db_session=session, cursor=cursor, limit=limit, organization_id=actor.organization_id, source_id=source_id + ) + return [file.to_pydantic() for file in files] + + @enforce_types + def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata: + """Delete a file by its ID.""" + with self.session_maker() as session: + file = FileMetadataModel.read(db_session=session, identifier=file_id) + file.delete(db_session=session, actor=actor) + return file.to_pydantic() diff --git a/tests/test_client.py b/tests/test_client.py index ffdd27bf4b..f7cfaaedbb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,7 +12,7 @@ from letta.agent import initialize_message_sequence from letta.client.client import LocalClient, RESTClient from letta.constants import DEFAULT_PRESET -from letta.orm import Source +from letta.orm import FileMetadata, Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, MessageStreamStatus @@ -91,6 +91,7 @@ def clear_tables(): from letta.server.server import db_context with db_context() as session: + session.execute(delete(FileMetadata)) session.execute(delete(Source)) session.commit() diff --git a/tests/test_managers.py b/tests/test_managers.py index 8436d7ba6d..a5d8528a0c 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3,9 +3,10 @@ import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.orm import Organization, Source, Tool, User +from letta.orm import FileMetadata, Organization, Source, Tool, User from letta.schemas.agent import CreateAgent from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization as PydanticOrganization @@ -37,6 +38,7 @@ def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: + session.execute(delete(FileMetadata)) session.execute(delete(Source)) session.execute(delete(Tool)) # Clear all records from the Tool table session.execute(delete(User)) # Clear all records from the user table @@ -65,6 +67,18 @@ def other_user(server: SyncServer, default_organization): yield user +@pytest.fixture +def default_source(server: SyncServer, default_user): + source_pydantic = PydanticSource( + name="Test Source", + description="This is a test source.", + metadata_={"type": "test"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + yield source + + @pytest.fixture def sarah_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" @@ -420,7 +434,7 @@ def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user): # ====================================================================================================================== -# Source Manager Tests +# Source Manager Tests - Sources # ====================================================================================================================== @@ -562,6 +576,74 @@ def test_update_source_no_changes(server: SyncServer, default_user): assert updated_source.description == source.description +# ====================================================================================================================== +# Source Manager Tests - Files +# ====================================================================================================================== +def test_get_file_by_id(server: SyncServer, default_user, default_source): + """Test retrieving a file by ID.""" + file_metadata = PydanticFileMetadata( + file_name="Retrieve File", + file_path="/path/to/retrieve_file.txt", + file_type="text/plain", + file_size=2048, + source_id=default_source.id, + ) + created_file = server.source_manager.create_file(file_metadata=file_metadata, actor=default_user) + + # Retrieve the file by ID + retrieved_file = server.source_manager.get_file_by_id(file_id=created_file.id, actor=default_user) + + # Assertions to verify the retrieved file matches the created one + assert retrieved_file.id == created_file.id + assert retrieved_file.file_name == created_file.file_name + assert retrieved_file.file_path == created_file.file_path + assert retrieved_file.file_type == created_file.file_type + + +def test_list_files(server: SyncServer, default_user, default_source): + """Test listing files with pagination.""" + # Create multiple files + server.source_manager.create_file( + PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id), + actor=default_user, + ) + server.source_manager.create_file( + PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id), + actor=default_user, + ) + + # List files without pagination + files = server.source_manager.list_files(source_id=default_source.id, actor=default_user) + assert len(files) == 2 + + # List files with pagination + paginated_files = server.source_manager.list_files(source_id=default_source.id, actor=default_user, limit=1) + assert len(paginated_files) == 1 + + # Ensure cursor-based pagination works + next_page = server.source_manager.list_files(source_id=default_source.id, actor=default_user, cursor=paginated_files[-1].id, limit=1) + assert len(next_page) == 1 + assert next_page[0].file_name != paginated_files[0].file_name + + +def test_delete_file(server: SyncServer, default_user, default_source): + """Test deleting a file.""" + file_metadata = PydanticFileMetadata( + file_name="Delete File", file_path="/path/to/delete_file.txt", file_type="text/plain", source_id=default_source.id + ) + created_file = server.source_manager.create_file(file_metadata=file_metadata, actor=default_user) + + # Delete the file + deleted_file = server.source_manager.delete_file(file_id=created_file.id, actor=default_user) + + # Assertions to verify deletion + assert deleted_file.id == created_file.id + + # Verify that the file no longer appears in list_files + files = server.source_manager.list_files(source_id=default_source.id, actor=default_user) + assert len(files) == 0 + + # ====================================================================================================================== # AgentsTagsManager Tests # ====================================================================================================================== diff --git a/tests/utils.py b/tests/utils.py index 2168e2e387..19a05a090a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,8 +24,7 @@ def __init__(self, texts: List[str]): def find_files(self, source) -> Iterator[FileMetadata]: for text in self.texts: file_metadata = FileMetadata( - user_id="", - source_id="", + source_id=source.id, file_name="", file_path="", file_type="",