Skip to content

Commit

Permalink
fix: refactor refresh chroma
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesponti committed Sep 25, 2024
1 parent d69911c commit 41896c0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 36 deletions.
68 changes: 33 additions & 35 deletions src/crons/focus_crons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from langchain_core.documents import Document

from src.data.db import SessionLocal
from src.data.focus_repository import get_focus_item_by_id
from src.data.models.focus import Focus
from src.services import chroma_service
from src.services.keywords.keywords_service import get_query_keywords
Expand Down Expand Up @@ -37,46 +36,45 @@ def refresh_focus_from_chroma():
try:
focus_items = session.query(Focus).filter(Focus.in_vector_store.is_(False)).all()
if not focus_items:
logger.info("No new focus items to add to Chroma")
return

logger.info(f"Adding {len(focus_items)} focus items to Chroma")
docs = chroma_service.vector_store.get(ids=[str(focus_item.id) for focus_item in focus_items])
if len(docs["documents"]) == 0:
chroma_service.clear_focus_items_from_vector_store()
for focus_item in focus_items:
keyword_str = ",".join(get_query_keywords(focus_item.text))
chroma_service.vector_store.add_documents(
documents=[
Document(
page_content=f"{focus_item.text} \n\n {keyword_str}",
metadata={"keywords": keyword_str},
)
],
ids=[str(focus_item.id)],
)
logger.info(f"Added {focus_item.id} to Chroma")
session.query(Focus).filter(Focus.id.in_([focus_item.id for focus_item in focus_items])).update(
{"in_vector_store": True}
)
session.commit()
return
logger.info(f"Processing {len(focus_items)} focus items for Chroma")
focus_ids = [str(focus_item.id) for focus_item in focus_items]
existing_docs = chroma_service.vector_store.get(ids=focus_ids)
existing_ids = set(existing_docs["ids"])

logger.info(f"Docs: {len(docs['documents'])}")
completed_ids = []
for doc, doc_id, metadata in zip(docs["documents"], docs["ids"], docs["metadatas"]):
keywords = get_query_keywords(doc)
keyword_str = ",".join(keywords)
focus_item = get_focus_item_by_id(focus_items=focus_items, id=doc_id)
page_content = focus_item.text
chroma_service.vector_store.update_document(
document_id=doc_id,
document=Document(metadata=metadata, page_content=f"{page_content} \n\n {keyword_str}"),
for focus_item in focus_items:
focus_id = str(focus_item.id)
keyword_str = ",".join(get_query_keywords(focus_item.text))
page_content = f"{focus_item.text} \n\n {keyword_str}"
document = Document(
page_content=page_content,
metadata={"keywords": keyword_str},
)
completed_ids.append(doc_id)
logger.info(f"Updated {doc_id} with keywords: {keyword_str}")

session.query(Focus).filter(Focus.id.in_(completed_ids)).update({"in_vector_store": True})
if focus_id in existing_ids:
# Update existing document
chroma_service.vector_store.update_document(
document_id=focus_id,
document=document,
)
logger.info(f"Updated focus item {focus_id} in Chroma")
else:
# Add new document
chroma_service.vector_store.add_documents(
documents=[document],
ids=[focus_id],
)
logger.info(f"Added new focus item {focus_id} to Chroma")

# Mark all processed items as in_vector_store
session.query(Focus).filter(Focus.id.in_([focus_item.id for focus_item in focus_items])).update(
{"in_vector_store": True}
)
session.commit()
logger.info(f"Marked {len(focus_items)} focus items as in_vector_store")

except Exception as e:
traceback.print_exc()
logger.error(f"Error refreshing focus items from vector store: {e}")
Expand Down
3 changes: 2 additions & 1 deletion src/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import inspect
import json
import logging
import sys
from typing import Any, Dict, Optional


class StructuredLogger:
def __init__(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(asctime)s - %(_caller_filename)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
self.logger.addHandler(handler)
Expand Down

0 comments on commit 41896c0

Please sign in to comment.