diff --git a/db_state_validator.py b/db_state_validator.py new file mode 100644 index 00000000..48857a5a --- /dev/null +++ b/db_state_validator.py @@ -0,0 +1,175 @@ +import json +import logging +import sys +from typing import List, Mapping, Optional + +from pydantic import BaseModel, ValidationError + +# Setup logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class DocumentMetadata(BaseModel): + """Model for document metadata.""" + + family_topic: List[str] = [] + family_hazard: List[str] = [] + family_sector: List[str] = [] + family_keyword: List[str] = [] + family_framework: List[str] = [] + family_instrument: List[str] = [] + document_role: List[str] = [] + document_type: List[str] = [] + + +class Document(BaseModel): + """Model for a DocumentParserInput document.""" + + name: str + document_title: str + description: str + import_id: str + slug: str + family_import_id: str + family_slug: str + publication_ts: str + date: Optional[str] = None + source_url: Optional[str] = None + download_url: Optional[str] = None + corpus_import_id: str + corpus_type_name: str + collection_title: Optional[str] = None + collection_summary: Optional[str] = None + type: str + source: str + category: str + geography: str + geographies: List[str] = [] + languages: List[str] = [] + metadata: DocumentMetadata + + +class DBState(BaseModel): + """Model for the database state. + + :param documents: Mapping of documents where their import ID is key + :type documents: Mapping[str, Document] + """ + + documents: Mapping[str, Document] + + +def load_json(file_path: str) -> DBState: + """Load JSON from a file and validate it against the DBState model. + + :param file_path: Path to the JSON file + :type file_path: str + :raises SystemExit: If there's an error loading/validating the JSON + :return: The validated database state + :rtype: DBState + """ + try: + with open(file_path, "r") as file: + data = json.load(file) + return DBState(**data) + except (json.JSONDecodeError, ValidationError) as e: + logger.error(f"💥 Error loading JSON from {file_path}: {e}") + sys.exit(1) + + +def find_differing_doc_import_ids( + main_sorted: List[Document], branch_sorted: List[Document] +) -> bool: + main_set = {doc.import_id for doc in main_sorted} + branch_set = {doc.import_id for doc in branch_sorted} + missing_in_branch = main_set - branch_set + extra_in_branch = branch_set - main_set + + if missing_in_branch or extra_in_branch: + if missing_in_branch: + logger.info(f"🔍 Missing doc IDs in branch: {missing_in_branch}") + if extra_in_branch: + logger.info( + f"🔍 Extra doc IDs in branch compared with main: {extra_in_branch}" + ) + return True + return False + + +def find_document_differences( + main_sorted: List[Document], branch_sorted: List[Document] +) -> bool: + """Compare each document in two sorted lists and log differences. + + :param main_sorted: List of documents from the main database state + :type main_sorted: List[Document] + :param branch_sorted: List of documents from the branch database state + :type branch_sorted: List[Document] + :return: True if differences are found, False otherwise + :rtype: bool + """ + differences_found = False + + for main_doc, branch_doc in zip(main_sorted, branch_sorted): + if main_doc.import_id != branch_doc.import_id: + logger.info( + f"❌ Import ID difference found {main_doc.import_id} " + f"vs {branch_doc.import_id}" + ) + + if main_doc != branch_doc: + logger.info(f"❌ Difference(s) found in document {main_doc.import_id} ") + for field in main_doc.model_fields_set: + main_value = getattr(main_doc, field) + branch_value = getattr(branch_doc, field) + if main_value != branch_value: + logger.info( + f"🔍 Field '{field}' differs: main '{main_value}' " + f"vs branch '{branch_value}'" + ) + differences_found = True + + return differences_found + + +def compare_db_states(main_db: DBState, branch_db: DBState): + """Compare two DB state files and log differences. + + :param main_db: The main database state (source of truth) + :type main_db: DBState + :param branch_db: The branch database state under test + :type branch_db: DBState + :raises SystemExit: If there are differences in the document lengths + or contents + """ + # Sort documents by import_id for order-insensitive comparison + main_sorted = sorted(main_db.documents.values(), key=lambda doc: doc.import_id) + branch_sorted = sorted(branch_db.documents.values(), key=lambda doc: doc.import_id) + + if len(main_sorted) != len(branch_sorted): + logger.info( + f"🔍 Document list lengths differ: main {len(main_sorted)}, " + f"branch {len(branch_sorted)}" + ) + sys.exit(1) + + if find_differing_doc_import_ids(main_sorted, branch_sorted): + sys.exit(1) + + if find_document_differences(main_sorted, branch_sorted): + sys.exit(1) + + logger.info("🎉 DB states are equivalent!") + + +def main(): + """Main function to load and compare database states.""" + main_db_state = load_json("main_db_state.json") + branch_db_state = load_json("branch.json") + + compare_db_states(main_db_state, branch_db_state) + + +if __name__ == "__main__": + main()