Skip to content

Commit

Permalink
feat: return source metadata with list sources route (#1164)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Mar 19, 2024
1 parent ebfe949 commit 1c25a80
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 16 deletions.
6 changes: 4 additions & 2 deletions memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict, Literal
from typing import List, Optional, Dict, Literal, Type
from pydantic import BaseModel, Field, Json, ConfigDict
import uuid
import base64
Expand Down Expand Up @@ -96,7 +96,7 @@ class PersonaModel(SQLModel, table=True):

class SourceModel(SQLModel, table=True):
name: str = Field(..., description="The name of the source.")
description: str = Field(None, description="The description of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the source was created.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
Expand All @@ -105,6 +105,8 @@ class SourceModel(SQLModel, table=True):
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)
# NOTE: .metadata is a reserved attribute on SQLModel
metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.")


class PassageModel(BaseModel):
Expand Down
43 changes: 29 additions & 14 deletions memgpt/server/rest_api/sources/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


class ListSourcesResponse(BaseModel):
sources: List[SourceModel] = Field(..., description="List of available sources")
sources: List[SourceModel] = Field(..., description="List of available sources.")


class CreateSourceRequest(BaseModel):
Expand Down Expand Up @@ -65,34 +65,42 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
async def list_source(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all data sources created by a user."""
# Clear the interface
interface.clear()

sources = server.ms.list_sources(user_id=user_id)
sources = server.list_all_sources(user_id=user_id)
return ListSourcesResponse(sources=sources)

@router.post("/sources", tags=["sources"], response_model=SourceModel)
async def create_source(
request: CreateSourceRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Create a new data source."""
interface.clear()
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.create_source(name=request.name, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
try:
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.create_source(name=request.name, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

@router.delete("/sources/{source_id}", tags=["sources"])
async def delete_source(
source_id,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a data source."""
interface.clear()
try:
server.delete_source(source_id=uuid.UUID(source_id), user_id=user_id)
Expand All @@ -108,6 +116,7 @@ async def attach_source_to_agent(
source_name: str = Query(..., description="The name of the source to attach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Attach a data source to an existing agent."""
interface.clear()
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
Expand All @@ -127,6 +136,7 @@ async def detach_source_from_agent(
source_name: str = Query(..., description="The name of the source to detach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Detach a data source from an existing agent."""
server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)

@router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
Expand All @@ -136,6 +146,7 @@ async def upload_file_to_source(
source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Upload a file to a data source."""
interface.clear()
source = server.ms.get_source(source_id=source_id, user_id=user_id)

Expand All @@ -153,13 +164,17 @@ async def list_passages(
source_id: uuid.UUID = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
raise NotImplementedError
"""List all passages associated with a data source."""
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
return GetSourcePassagesResponse(passages=passages)

@router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
async def list_documents(
source_id: uuid.UUID = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
raise NotImplementedError
"""List all documents associated with a data source."""
documents = server.list_data_source_documents(user_id=user_id, source_id=source_id)
return GetSourceDocumentsResponse(documents=documents)

return router
46 changes: 46 additions & 0 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import wraps
from threading import Lock
from typing import Union, Callable, Optional, List
import warnings

from fastapi import HTTPException
import uvicorn
Expand Down Expand Up @@ -35,6 +36,8 @@
Token,
Preset,
)

from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel
from memgpt.interface import AgentInterface # abstract

# TODO use custom interface
Expand Down Expand Up @@ -1292,3 +1295,46 @@ def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, sour
def list_attached_sources(self, agent_id: uuid.UUID):
# list all attached sources to an agent
return self.ms.list_attached_sources(agent_id)

def list_data_source_passages(self, user_id: uuid.UUID, source_id: uuid.UUID) -> List[PassageModel]:
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
return []

def list_data_source_documents(self, user_id: uuid.UUID, source_id: uuid.UUID) -> List[DocumentModel]:
warnings.warn("list_data_source_documents is not yet implemented, returning empty list.", category=UserWarning)
return []

def list_all_sources(self, user_id: uuid.UUID) -> List[SourceModel]:
"""List all sources (w/ extra metadata) belonging to a user"""

sources = self.ms.list_sources(user_id=user_id)

# TODO don't unpack here, instead list_sources should return a SourceModel
sources = [
SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=self.server_embedding_config,
created_at=source.created_at,
)
for source in sources
]

# Add extra metadata to the sources
sources_with_metadata = []
for source in sources:

passages = self.list_data_source_passages(user_id=user_id, source_id=source.id)
documents = self.list_data_source_documents(user_id=user_id, source_id=source.id)

# Overwrite metadata field, should be empty anyways
source.metadata_ = dict(
num_documents=len(passages),
num_passages=len(documents),
)

sources_with_metadata.append(source)

return sources_with_metadata

0 comments on commit 1c25a80

Please sign in to comment.