From 7efe3a497ec355f022923cbb91374fc7c2028ee4 Mon Sep 17 00:00:00 2001 From: Michal Martyniak Date: Thu, 28 Mar 2024 11:12:38 +0100 Subject: [PATCH] refactor ids recalculation by moving it to process_metadata decorator --- unstructured/documents/elements.py | 42 ++++++++++++++++++++++++++++- unstructured/file_utils/filetype.py | 5 ---- unstructured/partition/common.py | 10 ------- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/unstructured/documents/elements.py b/unstructured/documents/elements.py index d834cd7611..4af1a64aff 100644 --- a/unstructured/documents/elements.py +++ b/unstructured/documents/elements.py @@ -510,6 +510,44 @@ def field_consolidation_strategies(cls) -> dict[str, ConsolidationStrategy]: _P = ParamSpec("_P") +from typing import List + + +def calculate_hash(text: str, page_number: int, index_in_sequence: int) -> str: + """ + Calculate a deterministic hash for a given text, page number, and index in sequence. + + Args: + text: The text of the element. + page_number: The page number where the element is found. + index_in_sequence: The index of the element in the sequence of elements. + + Returns: + The first 32 characters of the SHA256 hash of the concatenated input parameters. + """ + data = f"{text}{page_number}{index_in_sequence}" + return hashlib.sha256(data.encode()).hexdigest()[:32] + + +def recalculate_ids(elements: List[Element]) -> List[Element]: + """Updates the `id` (and `parent_id`) attributes of each element + in the list of elements based on the element's attributes and its index in sequence + + Args: + elements: The list of elements whose IDs are to be recalculated. + + Returns: + The list of elements with updated IDs. + """ + old_to_new_id_mapping = { + e.id: calculate_hash(e.text, e.metadata.page_number, idx_in_seq) + for idx_in_seq, e in enumerate(elements) + } + for element in elements: + element.id = old_to_new_id_mapping[element.id] + element.metadata.parent_id = old_to_new_id_mapping.get(element.metadata.parent_id) + return elements + def process_metadata() -> Callable[[Callable[_P, list[Element]]], Callable[_P, list[Element]]]: """Post-process element-metadata for this document. @@ -559,6 +597,8 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> list[Element]: if unique_element_ids: for element in elements: element.id_to_uuid() + else: + elements = recalculate_ids(elements) return elements @@ -803,7 +843,7 @@ def _calculate_hash(self, index_in_sequence: int = 0) -> HashValue: Returns: HashValue - 128-bit hash value of the element. """ - data = f"{self.text}{index_in_sequence}" + data = f"{self.text}" return HashValue(hashlib.sha256(data.encode()).hexdigest()[:32]) def __eq__(self, other: object): diff --git a/unstructured/file_utils/filetype.py b/unstructured/file_utils/filetype.py index 44ac860807..589c9835e7 100644 --- a/unstructured/file_utils/filetype.py +++ b/unstructured/file_utils/filetype.py @@ -17,7 +17,6 @@ from unstructured.partition.common import ( add_element_metadata, exactly_one, - recalculate_ids, remove_element_metadata, set_element_hierarchy, ) @@ -595,10 +594,6 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> List[Element]: kwarg: params.get(kwarg) for kwarg in ("filename", "url", "text_as_html") } - # NOTE(mike) must recalculate before calling `set_element_hierarchy` - # otherwise `parent_id` won't be assigned correctly - elements = recalculate_ids(elements) - # NOTE (yao): do not use cast here as cast(None) still is None if not str(kwargs.get("model_name", "")).startswith("chipper"): # NOTE(alan): Skip hierarchy if using chipper, as it should take care of that diff --git a/unstructured/partition/common.py b/unstructured/partition/common.py index 63aa0426ea..698bbcd7d3 100644 --- a/unstructured/partition/common.py +++ b/unstructured/partition/common.py @@ -223,16 +223,6 @@ def layout_list_to_list_items( return list_items -def recalculate_ids(elements: List[Element]) -> List[Element]: - """Updates the id of each element in the list of elements - based on the element's attributes and its index in sequence - """ - for idx_in_seq, element in enumerate(elements): - if isinstance(element.id, (NoID, HashValue)): - element.id = str(element._calculate_hash(idx_in_seq)) - return elements - - def set_element_hierarchy( elements: List[Element], ruleset: dict[str, list[str]] = HIERARCHY_RULE_SET ) -> list[Element]: