diff --git a/CHANGELOG.md b/CHANGELOG.md index f9b3f0c5..3d5e6098 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.7.42 +* fix: fix missing source after cleaning layout elements * Remove chipper model ## 0.7.41 diff --git a/test_unstructured_inference/test_elements.py b/test_unstructured_inference/test_elements.py index 6627e205..7b70d35d 100644 --- a/test_unstructured_inference/test_elements.py +++ b/test_unstructured_inference/test_elements.py @@ -61,6 +61,7 @@ def test_layoutelements(): element_coords=coords, element_class_ids=element_class_ids, element_class_id_map=class_map, + source="yolox", ) @@ -345,6 +346,7 @@ def test_clean_layoutelements(test_layoutelements): elements[1].bbox.x2, elements[1].bbox.x2, ) == (2, 2, 3, 3) + assert elements[0].source == elements[1].source == "yolox" @pytest.mark.parametrize( diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index 49482171..332df603 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -73,13 +73,15 @@ def slice(self, indices) -> LayoutElements: @classmethod def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements: """concatenate a sequence of LayoutElements in order as one LayoutElements""" - coords, texts, probs, class_ids = [], [], [], [] + coords, texts, probs, class_ids, sources = [], [], [], [], [] class_id_map = {} for group in groups: coords.append(group.element_coords) texts.append(group.texts) probs.append(group.element_probs) class_ids.append(group.element_class_ids) + if group.source: + sources.append(group.source) if group.element_class_id_map: class_id_map.update(group.element_class_id_map) return cls( @@ -88,7 +90,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements: element_probs=np.concatenate(probs), element_class_ids=np.concatenate(class_ids), element_class_id_map=class_id_map, - source=group.source, + source=sources[0] if sources else None, ) def as_list(self): @@ -439,7 +441,10 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float = final_coords = sorted_coords[mask] sorted_by_y1 = np.argsort(final_coords[:, 1]) - final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map} + final_attrs: dict[str, Any] = { + "element_class_id_map": elements.element_class_id_map, + "source": elements.source, + } for attr in ("element_class_ids", "element_probs", "texts"): if (original_attr := getattr(elements, attr)) is None: continue