-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e2dde3c
commit a03f693
Showing
1 changed file
with
175 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import json | ||
import logging | ||
import sys | ||
from typing import List, Mapping, Optional | ||
|
||
from pydantic import BaseModel, ValidationError | ||
|
||
# Setup logger | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DocumentMetadata(BaseModel): | ||
"""Model for document metadata.""" | ||
|
||
family_topic: List[str] = [] | ||
family_hazard: List[str] = [] | ||
family_sector: List[str] = [] | ||
family_keyword: List[str] = [] | ||
family_framework: List[str] = [] | ||
family_instrument: List[str] = [] | ||
document_role: List[str] = [] | ||
document_type: List[str] = [] | ||
|
||
|
||
class Document(BaseModel): | ||
"""Model for a DocumentParserInput document.""" | ||
|
||
name: str | ||
document_title: str | ||
description: str | ||
import_id: str | ||
slug: str | ||
family_import_id: str | ||
family_slug: str | ||
publication_ts: str | ||
date: Optional[str] = None | ||
source_url: Optional[str] = None | ||
download_url: Optional[str] = None | ||
corpus_import_id: str | ||
corpus_type_name: str | ||
collection_title: Optional[str] = None | ||
collection_summary: Optional[str] = None | ||
type: str | ||
source: str | ||
category: str | ||
geography: str | ||
geographies: List[str] = [] | ||
languages: List[str] = [] | ||
metadata: DocumentMetadata | ||
|
||
|
||
class DBState(BaseModel): | ||
"""Model for the database state. | ||
:param documents: Mapping of documents where their import ID is key | ||
:type documents: Mapping[str, Document] | ||
""" | ||
|
||
documents: Mapping[str, Document] | ||
|
||
|
||
def load_json(file_path: str) -> DBState: | ||
"""Load JSON from a file and validate it against the DBState model. | ||
:param file_path: Path to the JSON file | ||
:type file_path: str | ||
:raises SystemExit: If there's an error loading/validating the JSON | ||
:return: The validated database state | ||
:rtype: DBState | ||
""" | ||
try: | ||
with open(file_path, "r") as file: | ||
data = json.load(file) | ||
return DBState(**data) | ||
except (json.JSONDecodeError, ValidationError) as e: | ||
logger.error(f"π₯ Error loading JSON from {file_path}: {e}") | ||
sys.exit(1) | ||
|
||
|
||
def find_differing_doc_import_ids( | ||
main_sorted: List[Document], branch_sorted: List[Document] | ||
) -> bool: | ||
main_set = {doc.import_id for doc in main_sorted} | ||
branch_set = {doc.import_id for doc in branch_sorted} | ||
missing_in_branch = main_set - branch_set | ||
extra_in_branch = branch_set - main_set | ||
|
||
if missing_in_branch or extra_in_branch: | ||
if missing_in_branch: | ||
logger.info(f"π Missing doc IDs in branch: {missing_in_branch}") | ||
if extra_in_branch: | ||
logger.info( | ||
f"π Extra doc IDs in branch compared with main: {extra_in_branch}" | ||
) | ||
return True | ||
return False | ||
|
||
|
||
def find_document_differences( | ||
main_sorted: List[Document], branch_sorted: List[Document] | ||
) -> bool: | ||
"""Compare each document in two sorted lists and log differences. | ||
:param main_sorted: List of documents from the main database state | ||
:type main_sorted: List[Document] | ||
:param branch_sorted: List of documents from the branch database state | ||
:type branch_sorted: List[Document] | ||
:return: True if differences are found, False otherwise | ||
:rtype: bool | ||
""" | ||
differences_found = False | ||
|
||
for main_doc, branch_doc in zip(main_sorted, branch_sorted): | ||
if main_doc.import_id != branch_doc.import_id: | ||
logger.info( | ||
f"β Import ID difference found {main_doc.import_id} " | ||
f"vs {branch_doc.import_id}" | ||
) | ||
|
||
if main_doc != branch_doc: | ||
logger.info(f"β Difference(s) found in document {main_doc.import_id} ") | ||
for field in main_doc.model_fields_set: | ||
main_value = getattr(main_doc, field) | ||
branch_value = getattr(branch_doc, field) | ||
if main_value != branch_value: | ||
logger.info( | ||
f"π Field '{field}' differs: main '{main_value}' " | ||
f"vs branch '{branch_value}'" | ||
) | ||
differences_found = True | ||
|
||
return differences_found | ||
|
||
|
||
def compare_db_states(main_db: DBState, branch_db: DBState): | ||
"""Compare two DB state files and log differences. | ||
:param main_db: The main database state (source of truth) | ||
:type main_db: DBState | ||
:param branch_db: The branch database state under test | ||
:type branch_db: DBState | ||
:raises SystemExit: If there are differences in the document lengths | ||
or contents | ||
""" | ||
# Sort documents by import_id for order-insensitive comparison | ||
main_sorted = sorted(main_db.documents.values(), key=lambda doc: doc.import_id) | ||
branch_sorted = sorted(branch_db.documents.values(), key=lambda doc: doc.import_id) | ||
|
||
if len(main_sorted) != len(branch_sorted): | ||
logger.info( | ||
f"π Document list lengths differ: main {len(main_sorted)}, " | ||
f"branch {len(branch_sorted)}" | ||
) | ||
sys.exit(1) | ||
|
||
if find_differing_doc_import_ids(main_sorted, branch_sorted): | ||
sys.exit(1) | ||
|
||
if find_document_differences(main_sorted, branch_sorted): | ||
sys.exit(1) | ||
|
||
logger.info("π DB states are equivalent!") | ||
|
||
|
||
def main(): | ||
"""Main function to load and compare database states.""" | ||
main_db_state = load_json("main_db_state.json") | ||
branch_db_state = load_json("branch.json") | ||
|
||
compare_db_states(main_db_state, branch_db_state) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |