From db81e612267655b9af02e49714ba488f555cb577 Mon Sep 17 00:00:00 2001 From: Kahtaf Alam Date: Wed, 28 Feb 2024 21:41:15 -0500 Subject: [PATCH] feat: Refactor GET /documents to return minimal fields (#21) --- selfie/api/documents.py | 11 ++-- selfie/connectors/chatgpt/connector.py | 9 ++-- selfie/connectors/whatsapp/connector.py | 4 +- selfie/database/__init__.py | 67 ++++++++++++------------- 4 files changed, 44 insertions(+), 47 deletions(-) diff --git a/selfie/api/documents.py b/selfie/api/documents.py index 8703fe2..c7675ca 100644 --- a/selfie/api/documents.py +++ b/selfie/api/documents.py @@ -1,11 +1,11 @@ -from typing import Optional, List +from typing import List from fastapi import APIRouter from pydantic import BaseModel from selfie.database import DataManager -from selfie.parsers.chat import ChatFileParser from selfie.embeddings import DataIndex +from selfie.parsers.chat import ChatFileParser router = APIRouter() @@ -20,8 +20,8 @@ class IndexDocumentsRequest(BaseModel): @router.get("/documents") -async def get_documents(source_id: Optional[int] = None): - return DataManager().get_documents(source_id) +async def get_documents(): + return DataManager().get_documents() @router.delete("/documents/{document_id}") @@ -57,13 +57,12 @@ async def index_documents(request: IndexDocumentsRequest): False, document.id ).conversations, - #source=document.source.name, + # source=document.source.name, source_document_id=document.id ) if is_chat else None) for document_id in document_ids ] - # @app.delete("/documents/{document-id}") # async def delete_data_source(document_id: int): # DataSourceManager().remove_document(document_id) diff --git a/selfie/connectors/chatgpt/connector.py b/selfie/connectors/chatgpt/connector.py index b0fb1c6..d77c988 100644 --- a/selfie/connectors/chatgpt/connector.py +++ b/selfie/connectors/chatgpt/connector.py @@ -24,10 +24,10 @@ def load_document(self, configuration: dict[str, Any]) -> List[DocumentDTO]: return [ DocumentDTO( - content=data_uri_to_string(data_uri), - content_type="text/plain", + content=(content := data_uri_to_string(data_uri)), + content_type="application/json", name="todo", - size=len(data_uri_to_string(data_uri).encode('utf-8')) + size=len(content.encode('utf-8')) ) for data_uri in config.files ] @@ -36,7 +36,8 @@ def validate_configuration(self, configuration: dict[str, Any]): # TODO: check if file can be read from path pass - def transform_for_embedding(self, configuration: dict[str, Any], documents: List[DocumentDTO]) -> List[EmbeddingDocumentModel]: + def transform_for_embedding(self, configuration: dict[str, Any], documents: List[DocumentDTO]) -> List[ + EmbeddingDocumentModel]: return [ embeddingDocumentModel for document in documents diff --git a/selfie/connectors/whatsapp/connector.py b/selfie/connectors/whatsapp/connector.py index a1b3a5e..84f5246 100644 --- a/selfie/connectors/whatsapp/connector.py +++ b/selfie/connectors/whatsapp/connector.py @@ -24,10 +24,10 @@ def load_document(self, configuration: dict[str, Any]) -> List[DocumentDTO]: return [ DocumentDTO( - content=data_uri_to_string(data_uri), + content=(content := data_uri_to_string(data_uri)), content_type="text/plain", name="todo", - size=len(data_uri_to_string(data_uri).encode('utf-8')) + size=len(content.encode('utf-8')) ) for data_uri in config.files ] diff --git a/selfie/database/__init__.py b/selfie/database/__init__.py index 37b15d6..a2ed78e 100644 --- a/selfie/database/__init__.py +++ b/selfie/database/__init__.py @@ -1,5 +1,9 @@ +import importlib +import json +import logging import os from datetime import datetime +from typing import List, Dict, Any, Callable from llama_index.core.node_parser import SentenceSplitter from peewee import ( @@ -9,23 +13,15 @@ TextField, ForeignKeyField, AutoField, - DoesNotExist, Proxy, IntegerField, DateTimeField, ) -import json -import importlib -from typing import List, Dict, Any, Optional, Callable - from playhouse.shortcuts import model_to_dict from selfie.config import get_app_config from selfie.embeddings import DataIndex from selfie.embeddings.document_types import EmbeddingDocumentModel - -import logging - # TODO: This module should not be aware of DocumentDTO. Refactor its usage out of this module. from selfie.types.documents import DocumentDTO @@ -62,6 +58,7 @@ class DocumentConnectionModel(BaseModel): # name = CharField() connector_name = CharField() configuration = TextField() + # last_loaded_timestamp = CharField(null=True) class Meta: @@ -107,9 +104,9 @@ def __init__(self, storage_path: str = config.database_storage_root): self.db.create_tables([DocumentConnectionModel, DocumentModel]) def add_document_connection( - self, - connector_name: str, - configuration: Dict[str, Any], + self, + connector_name: str, + configuration: Dict[str, Any], ) -> int: return DocumentConnectionModel.create( connector_name=connector_name, @@ -131,12 +128,14 @@ async def remove_document(self, document_id: int, delete_indexed_data: bool = Tr document.delete_instance() - async def remove_document_connection(self, document_connection_id: int, delete_documents: bool = True, delete_indexed_data: bool = True): + async def remove_document_connection(self, document_connection_id: int, delete_documents: bool = True, + delete_indexed_data: bool = True): if self.get_document_connection(document_connection_id) is None: raise ValueError(f"No document connection found with ID {document_connection_id}") if delete_indexed_data: - source_document_ids = [doc.id for doc in DocumentModel.select().where(DocumentModel.document_connection == document_connection_id)] + source_document_ids = [doc.id for doc in DocumentModel.select().where( + DocumentModel.document_connection == document_connection_id)] await DataIndex("n/a").delete_documents_with_source_documents(source_document_ids) with self.db.atomic(): @@ -205,7 +204,8 @@ async def index_documents(self, document_connection: DocumentConnectionModel): documents = self._fetch_documents(json.loads(document_connection.configuration)) documents = [ - document for doc in documents for document in self._map_selfie_documents_to_index_documents(selfie_document=doc) + document for doc in documents for document in + self._map_selfie_documents_to_index_documents(selfie_document=doc) ] await DataIndex("n/a").index(documents, extract_importance=False) @@ -214,7 +214,8 @@ async def index_documents(self, document_connection: DocumentConnectionModel): return {"message": f"{len(documents)} documents indexed successfully"} - async def index_document(self, document: DocumentDTO, selfie_documents_to_index_documents: Callable[[DocumentDTO], List[EmbeddingDocumentModel]] = None): + async def index_document(self, document: DocumentDTO, selfie_documents_to_index_documents: Callable[ + [DocumentDTO], List[EmbeddingDocumentModel]] = None): print("Indexing document") if selfie_documents_to_index_documents is None: @@ -260,26 +261,22 @@ def get_document_connections(self): for source in DocumentConnectionModel.select() ] - def get_documents(self, document_connection_id: Optional[int] = None): - if document_connection_id: - documents = DocumentModel.select().where(DocumentModel.document_connection == document_connection_id) - doc_ids = [str(document.id) for document in documents] - else: - documents = DocumentModel.select() - doc_ids = None - - one_embedding_document_per_document = DataIndex("n/a").get_one_document_per_source_document(doc_ids) - indexed_documents = list(set([doc['source_document_id'] for doc in one_embedding_document_per_document])) - - return [ - { - **model_to_dict(doc), - "is_indexed": doc.id in indexed_documents, - # TODO: for some reason, initializing Embeddings in DataIndex with the SQLAlchemy driver returns indexed_documents as strings, not ints (requires str(doc.id)). - "num_index_documents": DataIndex("n/a").get_document_count([str(doc.id)]) - } - for doc in documents - ] + def get_documents(self): + documents = DocumentModel.select(DocumentModel.id, DocumentModel.name, DocumentModel.size, + DocumentModel.created_at, DocumentModel.updated_at, + DocumentModel.content_type, DocumentConnectionModel.connector_name).join( + DocumentConnectionModel) + + result = [] + for doc in documents: + doc_dict = model_to_dict(doc, backrefs=True, only=[ + DocumentModel.id, DocumentModel.name, DocumentModel.size, + DocumentModel.created_at, DocumentModel.updated_at, + DocumentModel.content_type, DocumentConnectionModel.connector_name + ]) + doc_dict['connector_name'] = doc.document_connection.connector_name + result.append(doc_dict) + return result def get_document(self, document_id: str): return DocumentModel.get_by_id(document_id)