Skip to content

Commit

Permalink
fix missing source (#396)
Browse files Browse the repository at this point in the history
This PR fixes a bug that was found when working with `unstructured`:
https://github.com/Unstructured-IO/unstructured/actions/runs/11403980075/job/31732726752#step:6:1778

## test

- this PR should fix the failing test above in `unstructured` ci
- this PR expands the test on `clean_layoutelements` to test `source` is
kept

## note for release

Since we have not tagged a release for 0.7.42 this PR opt to not
increase version number and include itself as part of 0.7.42
  • Loading branch information
badGarnet authored Oct 18, 2024
1 parent 42eebd3 commit 7f82e52
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.7.42

* fix: fix missing source after cleaning layout elements
* Remove chipper model

## 0.7.41
Expand Down
2 changes: 2 additions & 0 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_layoutelements():
element_coords=coords,
element_class_ids=element_class_ids,
element_class_id_map=class_map,
source="yolox",
)


Expand Down Expand Up @@ -345,6 +346,7 @@ def test_clean_layoutelements(test_layoutelements):
elements[1].bbox.x2,
elements[1].bbox.x2,
) == (2, 2, 3, 3)
assert elements[0].source == elements[1].source == "yolox"


@pytest.mark.parametrize(
Expand Down
11 changes: 8 additions & 3 deletions unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,15 @@ def slice(self, indices) -> LayoutElements:
@classmethod
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
coords, texts, probs, class_ids = [], [], [], []
coords, texts, probs, class_ids, sources = [], [], [], [], []
class_id_map = {}
for group in groups:
coords.append(group.element_coords)
texts.append(group.texts)
probs.append(group.element_probs)
class_ids.append(group.element_class_ids)
if group.source:
sources.append(group.source)
if group.element_class_id_map:
class_id_map.update(group.element_class_id_map)
return cls(
Expand All @@ -88,7 +90,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
element_probs=np.concatenate(probs),
element_class_ids=np.concatenate(class_ids),
element_class_id_map=class_id_map,
source=group.source,
source=sources[0] if sources else None,
)

def as_list(self):
Expand Down Expand Up @@ -439,7 +441,10 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float =
final_coords = sorted_coords[mask]
sorted_by_y1 = np.argsort(final_coords[:, 1])

final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map}
final_attrs: dict[str, Any] = {
"element_class_id_map": elements.element_class_id_map,
"source": elements.source,
}
for attr in ("element_class_ids", "element_probs", "texts"):
if (original_attr := getattr(elements, attr)) is None:
continue
Expand Down

0 comments on commit 7f82e52

Please sign in to comment.