Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor GET /documents to return minimal fields #21

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions selfie/api/documents.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional, List
from typing import List

from fastapi import APIRouter
from pydantic import BaseModel

from selfie.database import DataManager
from selfie.parsers.chat import ChatFileParser
from selfie.embeddings import DataIndex
from selfie.parsers.chat import ChatFileParser

router = APIRouter()

Expand All @@ -20,8 +20,8 @@ class IndexDocumentsRequest(BaseModel):


@router.get("/documents")
async def get_documents(source_id: Optional[int] = None):
return DataManager().get_documents(source_id)
async def get_documents():
return DataManager().get_documents()


@router.delete("/documents/{document_id}")
Expand Down Expand Up @@ -57,13 +57,12 @@ async def index_documents(request: IndexDocumentsRequest):
False,
document.id
).conversations,
#source=document.source.name,
# source=document.source.name,
source_document_id=document.id
) if is_chat else None)
for document_id in document_ids
]


# @app.delete("/documents/{document-id}")
# async def delete_data_source(document_id: int):
# DataSourceManager().remove_document(document_id)
Expand Down
9 changes: 5 additions & 4 deletions selfie/connectors/chatgpt/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def load_document(self, configuration: dict[str, Any]) -> List[DocumentDTO]:

return [
DocumentDTO(
content=data_uri_to_string(data_uri),
content_type="text/plain",
content=(content := data_uri_to_string(data_uri)),
content_type="application/json",
name="todo",
size=len(data_uri_to_string(data_uri).encode('utf-8'))
size=len(content.encode('utf-8'))
)
for data_uri in config.files
]
Expand All @@ -36,7 +36,8 @@ def validate_configuration(self, configuration: dict[str, Any]):
# TODO: check if file can be read from path
pass

def transform_for_embedding(self, configuration: dict[str, Any], documents: List[DocumentDTO]) -> List[EmbeddingDocumentModel]:
def transform_for_embedding(self, configuration: dict[str, Any], documents: List[DocumentDTO]) -> List[
EmbeddingDocumentModel]:
return [
embeddingDocumentModel
for document in documents
Expand Down
4 changes: 2 additions & 2 deletions selfie/connectors/whatsapp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def load_document(self, configuration: dict[str, Any]) -> List[DocumentDTO]:

return [
DocumentDTO(
content=data_uri_to_string(data_uri),
content=(content := data_uri_to_string(data_uri)),
content_type="text/plain",
name="todo",
size=len(data_uri_to_string(data_uri).encode('utf-8'))
size=len(content.encode('utf-8'))
)
for data_uri in config.files
]
Expand Down
67 changes: 32 additions & 35 deletions selfie/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import importlib
import json
import logging
import os
from datetime import datetime
from typing import List, Dict, Any, Callable

from llama_index.core.node_parser import SentenceSplitter
from peewee import (
Expand All @@ -9,23 +13,15 @@
TextField,
ForeignKeyField,
AutoField,
DoesNotExist,
Proxy,
IntegerField,
DateTimeField,
)
import json
import importlib
from typing import List, Dict, Any, Optional, Callable

from playhouse.shortcuts import model_to_dict

from selfie.config import get_app_config
from selfie.embeddings import DataIndex
from selfie.embeddings.document_types import EmbeddingDocumentModel

import logging

# TODO: This module should not be aware of DocumentDTO. Refactor its usage out of this module.
from selfie.types.documents import DocumentDTO

Expand Down Expand Up @@ -62,6 +58,7 @@ class DocumentConnectionModel(BaseModel):
# name = CharField()
connector_name = CharField()
configuration = TextField()

# last_loaded_timestamp = CharField(null=True)

class Meta:
Expand Down Expand Up @@ -107,9 +104,9 @@ def __init__(self, storage_path: str = config.database_storage_root):
self.db.create_tables([DocumentConnectionModel, DocumentModel])

def add_document_connection(
self,
connector_name: str,
configuration: Dict[str, Any],
self,
connector_name: str,
configuration: Dict[str, Any],
) -> int:
return DocumentConnectionModel.create(
connector_name=connector_name,
Expand All @@ -131,12 +128,14 @@ async def remove_document(self, document_id: int, delete_indexed_data: bool = Tr

document.delete_instance()

async def remove_document_connection(self, document_connection_id: int, delete_documents: bool = True, delete_indexed_data: bool = True):
async def remove_document_connection(self, document_connection_id: int, delete_documents: bool = True,
delete_indexed_data: bool = True):
if self.get_document_connection(document_connection_id) is None:
raise ValueError(f"No document connection found with ID {document_connection_id}")

if delete_indexed_data:
source_document_ids = [doc.id for doc in DocumentModel.select().where(DocumentModel.document_connection == document_connection_id)]
source_document_ids = [doc.id for doc in DocumentModel.select().where(
DocumentModel.document_connection == document_connection_id)]
await DataIndex("n/a").delete_documents_with_source_documents(source_document_ids)

with self.db.atomic():
Expand Down Expand Up @@ -205,7 +204,8 @@ async def index_documents(self, document_connection: DocumentConnectionModel):

documents = self._fetch_documents(json.loads(document_connection.configuration))
documents = [
document for doc in documents for document in self._map_selfie_documents_to_index_documents(selfie_document=doc)
document for doc in documents for document in
self._map_selfie_documents_to_index_documents(selfie_document=doc)
]

await DataIndex("n/a").index(documents, extract_importance=False)
Expand All @@ -214,7 +214,8 @@ async def index_documents(self, document_connection: DocumentConnectionModel):

return {"message": f"{len(documents)} documents indexed successfully"}

async def index_document(self, document: DocumentDTO, selfie_documents_to_index_documents: Callable[[DocumentDTO], List[EmbeddingDocumentModel]] = None):
async def index_document(self, document: DocumentDTO, selfie_documents_to_index_documents: Callable[
[DocumentDTO], List[EmbeddingDocumentModel]] = None):
print("Indexing document")

if selfie_documents_to_index_documents is None:
Expand Down Expand Up @@ -260,26 +261,22 @@ def get_document_connections(self):
for source in DocumentConnectionModel.select()
]

def get_documents(self, document_connection_id: Optional[int] = None):
if document_connection_id:
documents = DocumentModel.select().where(DocumentModel.document_connection == document_connection_id)
doc_ids = [str(document.id) for document in documents]
else:
documents = DocumentModel.select()
doc_ids = None

one_embedding_document_per_document = DataIndex("n/a").get_one_document_per_source_document(doc_ids)
indexed_documents = list(set([doc['source_document_id'] for doc in one_embedding_document_per_document]))

return [
{
**model_to_dict(doc),
"is_indexed": doc.id in indexed_documents,
# TODO: for some reason, initializing Embeddings in DataIndex with the SQLAlchemy driver returns indexed_documents as strings, not ints (requires str(doc.id)).
"num_index_documents": DataIndex("n/a").get_document_count([str(doc.id)])
}
for doc in documents
]
def get_documents(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only changes made to this file are in get_documents, rest are formatting changes

documents = DocumentModel.select(DocumentModel.id, DocumentModel.name, DocumentModel.size,
DocumentModel.created_at, DocumentModel.updated_at,
DocumentModel.content_type, DocumentConnectionModel.connector_name).join(
DocumentConnectionModel)

result = []
for doc in documents:
doc_dict = model_to_dict(doc, backrefs=True, only=[
DocumentModel.id, DocumentModel.name, DocumentModel.size,
DocumentModel.created_at, DocumentModel.updated_at,
DocumentModel.content_type, DocumentConnectionModel.connector_name
])
doc_dict['connector_name'] = doc.document_connection.connector_name
result.append(doc_dict)
return result

def get_document(self, document_id: str):
return DocumentModel.get_by_id(document_id)
Expand Down