Skip to content

Commit

Permalink
fix: fix bugs in data structure
Browse files Browse the repository at this point in the history
- fix bug when an empty list is passed into TextRegions.from_list
- fix bug when concatenating a list of `LayoutElements` the class id
  maps is not updated correctly
  • Loading branch information
badGarnet committed Jan 14, 2025
1 parent 4309e9e commit 41c6a7d
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 6 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
## 0.8.2

* fix: fix bug when an empty list is passed into `TextRegions.from_list` triggers `IndexError`
* fix: fix bug when concatenate a list of `LayoutElements` the class id mapping is no properly
updated

## 0.8.1

* fix: fix list index out of range error caused by calling LayoutElements.from_list() with empty list

## 0.8.0
Expand Down
28 changes: 28 additions & 0 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,31 @@ def test_layoutelements_to_list_and_back(test_layoutelements):
def test_layoutelements_from_list_no_elements():
back = LayoutElements.from_list(elements=[])
assert back.source is None
assert back.element_coords.size == 0


def test_textregions_from_list_no_elements():
back = TextRegions.from_list(regions=[])
assert back.source is None
assert back.element_coords.size == 0


def test_layoutelements_concatenate():
layout1 = LayoutElements(
element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
texts=np.array(["a", "two"]),
source=None,
element_class_ids=np.array([0, 1]),
element_class_id_map={0: "type0", 1: "type1"},
)
layout2 = LayoutElements(
element_coords=np.array([[10, 10, 2, 2], [20, 20, 1, 1]]),
texts=np.array(["three", "4"]),
source=None,
element_class_ids=np.array([0, 1]),
element_class_id_map={0: "type1", 1: "type2"},
)
joint = LayoutElements.concatenate([layout1, layout2])
assert joint.texts.tolist() == ["a", "two", "three", "4"]
assert joint.element_class_ids.tolist() == [0, 1, 1, 2]
assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"}
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.1" # pragma: no cover
__version__ = "0.8.2" # pragma: no cover
3 changes: 2 additions & 1 deletion unstructured_inference/inference/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def from_list(cls, regions: list):
for region in regions:
coords.append((region.bbox.x1, region.bbox.y1, region.bbox.x2, region.bbox.y2))
texts.append(region.text)
return cls(element_coords=np.array(coords), texts=np.array(texts), source=regions[0].source)
source = regions[0].source if regions else None
return cls(element_coords=np.array(coords), texts=np.array(texts), source=source)

def __len__(self):
return self.element_coords.shape[0]
Expand Down
17 changes: 13 additions & 4 deletions unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,31 @@ def slice(self, indices) -> LayoutElements:
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
coords, texts, probs, class_ids, sources = [], [], [], [], []
class_id_map = {}
class_id_reverse_map = {}
for group in groups:
coords.append(group.element_coords)
texts.append(group.texts)
probs.append(group.element_probs)
class_ids.append(group.element_class_ids)
if group.source:
sources.append(group.source)

idx = group.element_class_ids.copy()
if group.element_class_id_map:
class_id_map.update(group.element_class_id_map)
for class_id, class_name in group.element_class_id_map.items():
if class_name in class_id_reverse_map:
idx[group.element_class_ids == class_id] = class_id_reverse_map[class_name]
continue
new_id = len(class_id_reverse_map)
class_id_reverse_map[class_name] = new_id
idx[group.element_class_ids == class_id] = new_id
class_ids.append(idx)

return cls(
element_coords=np.concatenate(coords),
texts=np.concatenate(texts),
element_probs=np.concatenate(probs),
element_class_ids=np.concatenate(class_ids),
element_class_id_map=class_id_map,
element_class_id_map={v: k for k, v in class_id_reverse_map.items()},
source=sources[0] if sources else None,
)

Expand Down

0 comments on commit 41c6a7d

Please sign in to comment.