From ffb1f0bcdcaae85e40da1146462d8de2881d57db Mon Sep 17 00:00:00 2001 From: Yuming Long <63475068+yuming-long@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:23:35 -0400 Subject: [PATCH] Refactor: Remove OCR related code for entire page OCR (#231) ## Summary One part of OCR refactor to move it from inference repo to unstructured repo. This PR removes all OCR related code for entire page OCR, which means all table related OCR still remain the same (will be moved after table refactor to accept preprocessed OCR data) ## Test Please see test description in https://github.com/Unstructured-IO/unstructured/pull/1579, since those two need to work together. ## Note The ingest test won't pass until we merge the unstructured refactor PR --------- Co-authored-by: christinestraub --- CHANGELOG.md | 4 + Dockerfile | 1 - examples/layout_analysis/visualization.py | 4 +- test_unstructured_inference/conftest.py | 25 -- .../inference/test_layout.py | 331 +----------------- .../inference/test_layout_element.py | 105 +----- .../models/test_model.py | 2 - .../models/test_tesseract.py | 26 -- test_unstructured_inference/test_elements.py | 15 - unstructured_inference/__version__.py | 2 +- unstructured_inference/constants.py | 11 +- unstructured_inference/inference/elements.py | 97 +---- unstructured_inference/inference/layout.py | 204 +---------- .../inference/layoutelement.py | 178 +--------- unstructured_inference/models/paddle_ocr.py | 1 + unstructured_inference/models/tables.py | 6 +- unstructured_inference/models/tesseract.py | 42 --- 17 files changed, 30 insertions(+), 1024 deletions(-) delete mode 100644 test_unstructured_inference/models/test_tesseract.py delete mode 100644 unstructured_inference/models/tesseract.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2277974d..4fe3117f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.0 + +* Remove all OCR related code expect the table OCR code + ## 0.6.6 * Stop passing ocr_languages parameter into paddle to avoid invalid paddle language code error, this will be fixed until diff --git a/Dockerfile b/Dockerfile index ebcf0da7..366cffc3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,7 +20,6 @@ RUN python3.8 -m pip install pip==${PIP_VERSION} && \ pip install --no-cache -r requirements/base.txt && \ pip install --no-cache -r requirements/test.txt && \ pip install --no-cache -r requirements/dev.txt && \ - pip install "unstructured.PaddleOCR" && \ dnf -y groupremove "Development Tools" && \ dnf clean all diff --git a/examples/layout_analysis/visualization.py b/examples/layout_analysis/visualization.py index fe4b497e..221d5301 100644 --- a/examples/layout_analysis/visualization.py +++ b/examples/layout_analysis/visualization.py @@ -14,7 +14,6 @@ def run(f_path, scope): "final": None, "extracted": {"layout": {"color": "green", "width": 2}}, "inferred": {"inferred_layout": {"color": "blue", "width": 2}}, - "ocr": {"ocr_layout": {"color": "yellow", "width": 2}}, } f_basename = os.path.splitext(os.path.basename(f_path))[0] @@ -47,8 +46,7 @@ def run(f_path, scope): write_image(img, output_f_path) print(f"page_num: {idx+1} - n_total_elements: {len(page.elements)} - n_extracted_elements: " - f"{len(page.layout)} - n_inferred_elements: {len(page.inferred_layout)} - " - f"n_ocr_elements: {len(page.ocr_layout)}") + f"{len(page.layout)} - n_inferred_elements: {len(page.inferred_layout)}") if __name__ == '__main__': diff --git a/test_unstructured_inference/conftest.py b/test_unstructured_inference/conftest.py index c20caece..097464fa 100644 --- a/test_unstructured_inference/conftest.py +++ b/test_unstructured_inference/conftest.py @@ -107,15 +107,6 @@ def mock_embedded_text_regions(): ] -@pytest.fixture() -def mock_ocr_regions(): - return [ - EmbeddedTextRegion(10, 10, 90, 90, text="0", source=None), - EmbeddedTextRegion(200, 200, 300, 300, text="1", source=None), - EmbeddedTextRegion(500, 320, 600, 350, text="3", source=None), - ] - - # TODO(alan): Make a better test layout @pytest.fixture() def mock_layout(mock_embedded_text_regions): @@ -130,19 +121,3 @@ def mock_layout(mock_embedded_text_regions): ) for r in mock_embedded_text_regions ] - - -@pytest.fixture() -def mock_inferred_layout(mock_embedded_text_regions): - return [ - LayoutElement( - r.x1, - r.y1, - r.x2, - r.y2, - text=None, - source=None, - type="Text", - ) - for r in mock_embedded_text_regions - ] diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index b2b665bd..17d29bde 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -2,7 +2,6 @@ import os.path import tempfile from functools import partial -from itertools import product from unittest.mock import mock_open, patch import numpy as np @@ -10,10 +9,9 @@ from PIL import Image import unstructured_inference.models.base as models -from unstructured_inference.constants import OCRMode, Source +from unstructured_inference.constants import Source from unstructured_inference.inference import elements, layout, layoutelement -from unstructured_inference.models import chipper, detectron2, tesseract -from unstructured_inference.models.base import get_model +from unstructured_inference.models import detectron2 from unstructured_inference.models.unstructuredmodel import ( UnstructuredElementExtractionModel, UnstructuredObjectDetectionModel, @@ -87,50 +85,6 @@ def verify_image_array(): verify_image_array() -def test_ocr(monkeypatch): - mock_text = "The parrot flies high in the air!" - - class MockOCRAgent: - def detect(self, *args): - return mock_text - - monkeypatch.setattr(tesseract, "ocr_agents", {"eng": MockOCRAgent}) - monkeypatch.setattr(tesseract, "is_pytesseract_available", lambda *args: True) - - image = Image.fromarray(np.random.randint(12, 24, (40, 40)), mode="RGB") - text_block = layout.TextRegion(1, 2, 3, 4, text=None) - - assert elements.ocr(text_block, image=image) == mock_text - - -def test_ocr_with_error(monkeypatch): - class MockOCRAgent: - def detect(self, *args): - # We sometimes get this error on very small images - raise tesseract.TesseractError(-8, "Estimating resolution as 1023") - - monkeypatch.setattr(tesseract, "ocr_agents", {"eng": MockOCRAgent}) - monkeypatch.setattr(tesseract, "is_pytesseract_available", lambda *args: True) - - image = Image.fromarray(np.random.randint(12, 24, (40, 40)), mode="RGB") - text_block = layout.TextRegion(1, 2, 3, 4, text=None) - - assert elements.ocr(text_block, image=image) == "" - - -def test_ocr_source(): - file = "sample-docs/loremipsum-flat.pdf" - model = get_model("yolox_tiny") - doc = layout.DocumentLayout.from_file( - file, - model, - ocr_mode=OCRMode.FULL_PAGE.value, - supplement_with_ocr_elements=True, - ocr_strategy="force", - ) - assert Source.OCR_TESSERACT in {e.source for e in doc.pages[0].elements} - - class MockLayoutModel: def __init__(self, layout): self.layout_return = layout @@ -160,26 +114,6 @@ def test_get_page_elements(monkeypatch, mock_final_layout): assert elements == page.elements -def test_get_page_elements_with_tesseract_error(monkeypatch, mock_final_layout): - def mock_image_to_data(*args, **kwargs): - raise tesseract.TesseractError(-2, "Estimating resolution as 1023") - - monkeypatch.setattr(layout.pytesseract, "image_to_data", mock_image_to_data) - - image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB") - page = layout.PageLayout( - number=0, - image=image, - layout=mock_final_layout, - detection_model=MockLayoutModel(mock_final_layout), - ) - - elements = page.get_elements_with_detection_model(inplace=False) - - assert str(elements[0]) == "A Catchy Title" - assert str(elements[1]).startswith("A very repetitive narrative.") - - class MockPool: def map(self, f, xs): return [f(x) for x in xs] @@ -191,102 +125,6 @@ def join(self): pass -@pytest.mark.skipif(skip_outside_ci, reason="Skipping paddle test run outside of CI") -def test_get_page_elements_with_paddle_ocr(monkeypatch): - monkeypatch.setenv("ENTIRE_PAGE_OCR", "paddle") - text_block = layout.TextRegion(2, 4, 6, 8, text=None) - image_block = layout.ImageTextRegion(8, 14, 16, 18) - doc_initial_layout = [text_block, image_block] - text_layoutelement = layoutelement.LayoutElement( - 2, - 4, - 6, - 8, - text=None, - type="UncategorizedText", - ) - image_layoutelement = layoutelement.LayoutElement(8, 14, 16, 18, text=None, type="Image") - doc_final_layout = [text_layoutelement, image_layoutelement] - - monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True) - monkeypatch.setattr(elements, "ocr", lambda *args, **kwargs: "An Even Catchier Title") - - image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB") - page = layout.PageLayout( - number=0, - image=image, - layout=doc_initial_layout, - detection_model=MockLayoutModel(doc_final_layout), - # Note(yuming): there are differnt language codes for same language - # between paddle and tesseract - ocr_languages="en", - ) - page.get_elements_with_detection_model() - - assert str(page) == "\n\nAn Even Catchier Title" - - -def test_get_page_elements_with_tesseract_ocr(monkeypatch): - monkeypatch.setenv("ENTIRE_PAGE_OCR", "tesseract") - text_block = layout.TextRegion(2, 4, 6, 8, text=None) - image_block = layout.ImageTextRegion(8, 14, 16, 18) - doc_initial_layout = [text_block, image_block] - text_layoutelement = layoutelement.LayoutElement( - 2, - 4, - 6, - 8, - text=None, - type="UncategorizedText", - ) - image_layoutelement = layoutelement.LayoutElement(8, 14, 16, 18, text=None, type="Image") - doc_final_layout = [text_layoutelement, image_layoutelement] - - monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True) - monkeypatch.setattr(elements, "ocr", lambda *args, **kwargs: "An Even Catchier Title") - - image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB") - page = layout.PageLayout( - number=0, - image=image, - layout=doc_initial_layout, - detection_model=MockLayoutModel(doc_final_layout), - ) - page.get_elements_with_detection_model() - - assert str(page) == "\n\nAn Even Catchier Title" - - -def test_get_page_elements_with_ocr_invalid_entrie_page_ocr(monkeypatch): - monkeypatch.setenv("ENTIRE_PAGE_OCR", "invalid_entire_page_ocr") - text_block = layout.TextRegion(2, 4, 6, 8, text=None) - image_block = layout.ImageTextRegion(8, 14, 16, 18) - doc_initial_layout = [text_block, image_block] - text_layoutelement = layoutelement.LayoutElement( - 2, - 4, - 6, - 8, - text=None, - type="UncategorizedText", - ) - image_layoutelement = layoutelement.LayoutElement(8, 14, 16, 18, text=None, type="Image") - doc_final_layout = [text_layoutelement, image_layoutelement] - - monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True) - monkeypatch.setattr(elements, "ocr", lambda *args, **kwargs: "An Even Catchier Title") - - image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB") - page = layout.PageLayout( - number=0, - image=image, - layout=doc_initial_layout, - detection_model=MockLayoutModel(doc_final_layout), - ) - with pytest.raises(ValueError): - page.get_elements_with_detection_model() - - def test_read_pdf(monkeypatch, mock_initial_layout, mock_final_layout, mock_image): with tempfile.TemporaryDirectory() as tmpdir: image_path1 = os.path.join(tmpdir, "mock1.jpg") @@ -373,10 +211,9 @@ def tolist(self): class MockEmbeddedTextRegion(layout.EmbeddedTextRegion): - def __init__(self, type=None, text=None, ocr_text=None): + def __init__(self, type=None, text=None): self.type = type self.text = text - self.ocr_text = ocr_text @property def points(self): @@ -390,21 +227,14 @@ def __init__( image=None, layout=None, model=None, - ocr_strategy="auto", - ocr_languages="eng", extract_tables=False, ): self.image = image self.layout = layout self.model = model - self.ocr_strategy = ocr_strategy - self.ocr_languages = ocr_languages self.extract_tables = extract_tables self.number = number - def ocr(self, text_block: MockEmbeddedTextRegion): - return text_block.ocr_text - @pytest.mark.parametrize( ("text", "expected"), @@ -442,31 +272,6 @@ def filter_by(self, *args, **kwargs): return MockLayout() -@pytest.mark.parametrize( - ("block_text", "layout_texts", "expected_text"), - [ - ("no ocr", ["pieced", "together", "group"], "no ocr"), - (None, ["pieced", "together", "group"], "pieced together group"), - ], -) -def test_get_element_from_block(block_text, layout_texts, mock_image, expected_text): - with patch("unstructured_inference.inference.elements.ocr", return_value="ocr"): - block = layout.TextRegion(0, 0, 10, 10, text=block_text) - captured_layout = [ - layout.TextRegion(i + 1, i + 1, i + 2, i + 2, text=text) - for i, text in enumerate(layout_texts) - ] - assert ( - layout.get_element_from_block(block, mock_image, captured_layout).text == expected_text - ) - - -def test_get_elements_from_block_raises(): - with pytest.raises(ValueError): - block = layout.TextRegion(0, 0, 10, 10, text=None) - layout.get_element_from_block(block, None, None) - - @pytest.mark.parametrize("filetype", ["png", "jpg", "tiff"]) def test_from_image_file(monkeypatch, mock_final_layout, filetype): def mock_get_elements(self, *args, **kwargs): @@ -574,11 +379,6 @@ def test_from_file_fixed_layout(fixed_layouts, called_method, not_called_method) getattr(layout.PageLayout, not_called_method).assert_not_called() -def test_invalid_ocr_strategy_raises(mock_image): - with pytest.raises(ValueError): - layout.PageLayout(0, mock_image, MockLayout(), ocr_strategy="fake_strategy") - - @pytest.mark.parametrize( ("text", "expected"), [("a\ts\x0cd\nfas\fd\rf\b", "asdfasdf"), ("\"'\\", "\"'\\")], @@ -602,93 +402,6 @@ def test_remove_control_characters(text, expected): unpopulated_text_region = layout.EmbeddedTextRegion(50, 50, 60, 60, text=None) -@pytest.mark.parametrize( - ("region", "objects", "ocr_strategy", "expected"), - [ - (no_text_region, [nonoverlapping_rect], "auto", False), - (no_text_region, [overlapping_rect], "auto", True), - (no_text_region, [], "auto", False), - (no_text_region, [populated_text_region, nonoverlapping_rect], "auto", False), - (no_text_region, [populated_text_region, overlapping_rect], "auto", False), - (no_text_region, [populated_text_region], "auto", False), - (no_text_region, [unpopulated_text_region, nonoverlapping_rect], "auto", False), - (no_text_region, [unpopulated_text_region, overlapping_rect], "auto", True), - (no_text_region, [unpopulated_text_region], "auto", False), - *list( - product( - [text_region], - [ - [], - [populated_text_region], - [unpopulated_text_region], - [nonoverlapping_rect], - [overlapping_rect], - [populated_text_region, nonoverlapping_rect], - [populated_text_region, overlapping_rect], - [unpopulated_text_region, nonoverlapping_rect], - [unpopulated_text_region, overlapping_rect], - ], - ["auto"], - [False], - ), - ), - *list( - product( - [cid_text_region], - [ - [], - [populated_text_region], - [unpopulated_text_region], - [overlapping_rect], - [populated_text_region, overlapping_rect], - [unpopulated_text_region, overlapping_rect], - ], - ["auto"], - [True], - ), - ), - *list( - product( - [no_text_region, text_region, cid_text_region], - [ - [], - [populated_text_region], - [unpopulated_text_region], - [nonoverlapping_rect], - [overlapping_rect], - [populated_text_region, nonoverlapping_rect], - [populated_text_region, overlapping_rect], - [unpopulated_text_region, nonoverlapping_rect], - [unpopulated_text_region, overlapping_rect], - ], - ["force"], - [True], - ), - ), - *list( - product( - [no_text_region, text_region, cid_text_region], - [ - [], - [populated_text_region], - [unpopulated_text_region], - [nonoverlapping_rect], - [overlapping_rect], - [populated_text_region, nonoverlapping_rect], - [populated_text_region, overlapping_rect], - [unpopulated_text_region, nonoverlapping_rect], - [unpopulated_text_region, overlapping_rect], - ], - ["never"], - [False], - ), - ), - ], -) -def test_ocr_image(region, objects, ocr_strategy, expected): - assert elements.needs_ocr(region, objects, ocr_strategy) is expected - - @pytest.mark.parametrize("filename", ["loremipsum.pdf", "IRS-form-1987.pdf"]) def test_load_pdf(filename): layouts, images = layout.load_pdf(f"sample-docs/{filename}") @@ -725,7 +438,7 @@ def test_load_pdf_raises_with_path_only_no_output_folder(): @pytest.mark.skip("Temporarily removed multicolumn to fix ordering") -def test_load_pdf_with_multicolumn_layout_and_ocr(filename="sample-docs/design-thinking.pdf"): +def test_load_pdf_with_multicolumn_layout(filename="sample-docs/design-thinking.pdf"): layouts, images = layout.load_pdf(filename) doc = layout.process_file_with_model(filename=filename, model_name=None) test_snippets = ["Key to design thinking", "Design thinking also", "But in recent years"] @@ -784,34 +497,12 @@ def check_annotated_image(): check_annotated_image() -def test_textregion_returns_empty_ocr_never(mock_image): - tr = elements.TextRegion(0, 0, 24, 24) - assert tr.extract_text(objects=None, image=mock_image, ocr_strategy="never") == "" - - @pytest.mark.parametrize(("text", "expected"), [("asdf", "asdf"), (None, "")]) def test_embedded_text_region(text, expected): etr = elements.EmbeddedTextRegion(0, 0, 24, 24, text=text) assert etr.extract_text(objects=None) == expected -@pytest.mark.parametrize( - ("text", "ocr_strategy", "expected"), - [ - (None, "never", ""), - (None, "always", "asdf"), - ("i have text", "never", "i have text"), - ("i have text", "always", "i have text"), - ], -) -def test_image_text_region(text, ocr_strategy, expected, mock_image): - itr = elements.ImageTextRegion(0, 0, 24, 24, text=text) - with patch.object(elements, "ocr", return_value="asdf"): - assert ( - itr.extract_text(objects=None, image=mock_image, ocr_strategy=ocr_strategy) == expected - ) - - class MockDetectionModel(layout.UnstructuredObjectDetectionModel): def initialize(self, *args, **kwargs): pass @@ -970,9 +661,6 @@ def test_process_file_with_model_routing(monkeypatch, model_type, is_detection_m "asdf", detection_model=detection_model, element_extraction_model=element_extraction_model, - ocr_strategy="auto", - ocr_languages="eng", - ocr_mode=OCRMode.FULL_PAGE.value, fixed_layouts=None, extract_tables=False, pdf_image_dpi=200, @@ -986,17 +674,6 @@ def test_exposed_pdf_image_dpi(pdf_image_dpi, expected, monkeypatch): assert mock_from_image.call_args[0][0].height == expected -def test_warning_if_chipper_and_low_dpi(caplog): - with patch.object(layout.DocumentLayout, "from_file") as mock_from_file, patch.object( - chipper.UnstructuredChipperModel, - "initialize", - ): - layout.process_file_with_model("asdf", model_name="chipper", pdf_image_dpi=299) - mock_from_file.assert_called_once() - assert caplog.records[0].levelname == "WARNING" - assert "DPI >= 300" in caplog.records[0].msg - - @pytest.mark.parametrize( ("filename", "img_num", "should_complete"), [("sample-docs/empty-document.pdf", 0, True), ("sample-docs/empty-document.pdf", 10, False)], diff --git a/test_unstructured_inference/inference/test_layout_element.py b/test_unstructured_inference/inference/test_layout_element.py index 0991a364..c037b4ad 100644 --- a/test_unstructured_inference/inference/test_layout_element.py +++ b/test_unstructured_inference/inference/test_layout_element.py @@ -2,115 +2,12 @@ from layoutparser.elements import TextBlock from layoutparser.elements.layout_elements import Rectangle as LPRectangle -from unstructured_inference.constants import SUBREGION_THRESHOLD_FOR_OCR, Source -from unstructured_inference.inference.elements import TextRegion +from unstructured_inference.constants import Source from unstructured_inference.inference.layoutelement import ( LayoutElement, - aggregate_ocr_text_by_block, - get_elements_from_ocr_regions, - merge_inferred_layout_with_ocr_layout, - merge_text_regions, - supplement_layout_with_ocr_elements, ) -def test_aggregate_ocr_text_by_block(): - expected = "A Unified Toolkit" - ocr_layout = [ - TextRegion(0, 0, 20, 20, source="OCR", text="A"), - TextRegion(50, 50, 150, 150, source="OCR", text="Unified"), - TextRegion(150, 150, 300, 250, source="OCR", text="Toolkit"), - TextRegion(200, 250, 300, 350, source="OCR", text="Deep"), - ] - region = TextRegion(0, 0, 250, 350, text="") - - text = aggregate_ocr_text_by_block(ocr_layout, region, 0.5) - assert text == expected - - -def test_merge_text_regions(mock_embedded_text_regions): - expected = TextRegion( - x1=437.83888888888885, - y1=317.319341111111, - x2=1256.334784222222, - y2=406.9837855555556, - text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image", - ) - - merged_text_region = merge_text_regions(mock_embedded_text_regions) - assert merged_text_region == expected - - -def test_get_elements_from_ocr_regions(mock_embedded_text_regions): - expected = [ - LayoutElement( - x1=437.83888888888885, - y1=317.319341111111, - x2=1256.334784222222, - y2=406.9837855555556, - text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image", - type="UncategorizedText", - ), - ] - - elements = get_elements_from_ocr_regions(mock_embedded_text_regions) - assert elements == expected - - -def test_supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions): - ocr_elements = [ - LayoutElement( - r.x1, - r.y1, - r.x2, - r.y2, - text=r.text, - source=None, - type="UncategorizedText", - ) - for r in mock_ocr_regions - ] - - final_layout = supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions) - - # Check if the final layout contains the original layout elements - for element in mock_layout: - assert element in final_layout - - # Check if the final layout contains the OCR-derived elements - assert any(ocr_element in final_layout for ocr_element in ocr_elements) - - # Check if the OCR-derived elements that are subregions of layout elements are removed - for element in mock_layout: - for ocr_element in ocr_elements: - if ocr_element.is_almost_subregion_of(element, SUBREGION_THRESHOLD_FOR_OCR): - assert ocr_element not in final_layout - - -def test_merge_inferred_layout_with_ocr_layout(mock_inferred_layout, mock_ocr_regions): - ocr_elements = [ - LayoutElement( - r.x1, - r.y1, - r.x2, - r.y2, - text=r.text, - source=None, - type="UncategorizedText", - ) - for r in mock_ocr_regions - ] - - final_layout = merge_inferred_layout_with_ocr_layout(mock_inferred_layout, mock_ocr_regions) - - # Check if the inferred layout's text attribute is updated with aggregated OCR text - assert final_layout[0].text == mock_ocr_regions[2].text - - # Check if the final layout contains both original elements and OCR-derived elements - assert all(element in final_layout for element in mock_inferred_layout) - assert any(element in final_layout for element in ocr_elements) - - @pytest.mark.parametrize("is_table", [False, True]) def test_layout_element_extract_text( mock_layout_element, diff --git a/test_unstructured_inference/models/test_model.py b/test_unstructured_inference/models/test_model.py index 4ae6c08a..e05fbf5b 100644 --- a/test_unstructured_inference/models/test_model.py +++ b/test_unstructured_inference/models/test_model.py @@ -87,8 +87,6 @@ def test_deduplicate_detected_elements(): doc = DocumentLayout.from_image_file( file, model, - ocr_strategy="never", - supplement_with_ocr_elements=False, ) known_elements = [e for e in doc.pages[0].elements if e.type != "UncategorizedText"] # Compute intersection matrix diff --git a/test_unstructured_inference/models/test_tesseract.py b/test_unstructured_inference/models/test_tesseract.py deleted file mode 100644 index 475cba08..00000000 --- a/test_unstructured_inference/models/test_tesseract.py +++ /dev/null @@ -1,26 +0,0 @@ -from unittest.mock import patch - -import pytest - -from unstructured_inference.models import tesseract - - -class MockTesseractAgent: - def __init__(self, languages): - pass - - -def test_load_agent(monkeypatch): - monkeypatch.setattr(tesseract, "TesseractAgent", MockTesseractAgent) - monkeypatch.setattr(tesseract, "ocr_agents", {}) - - with patch.object(tesseract, "is_pytesseract_available", return_value=True): - tesseract.load_agent(languages="eng+swe") - - assert isinstance(tesseract.ocr_agents["eng+swe"], MockTesseractAgent) - - -def test_load_agent_raises_when_not_available(): - with patch.object(tesseract, "is_pytesseract_available", return_value=False): - with pytest.raises(ImportError): - tesseract.load_agent() diff --git a/test_unstructured_inference/test_elements.py b/test_unstructured_inference/test_elements.py index 1c1be08c..1a68fa57 100644 --- a/test_unstructured_inference/test_elements.py +++ b/test_unstructured_inference/test_elements.py @@ -1,10 +1,8 @@ -import logging import os from random import randint from unittest.mock import PropertyMock, patch import pytest -from PIL import Image from unstructured_inference.inference import elements from unstructured_inference.inference.layoutelement import ( @@ -240,16 +238,3 @@ def test_separate(rect1, rect2): separate(rect1, rect2) # assert not rect1.intersects(rect2) #TODO: fix this test - - -@pytest.mark.skipif(skip_outside_ci, reason="Skipping paddle test run outside of CI") -def test_ocr_paddle(monkeypatch, caplog): - monkeypatch.setenv("ENTIRE_PAGE_OCR", "paddle") - image = Image.new("RGB", (100, 100), (255, 255, 255)) - text_block = elements.TextRegion(0, 0, 50, 50) - # Note(yuming): paddle result is currently non-deterministic on ci - # so don't check result like `assert result == ""` - # use logger info to confirm we are using paddle instead - with caplog.at_level(logging.INFO): - _ = elements.ocr(text_block, image, languages="en") - assert "paddle" in caplog.text diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 37b46218..8909b1e7 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.6.6" # pragma: no cover +__version__ = "0.7.0" # pragma: no cover diff --git a/unstructured_inference/constants.py b/unstructured_inference/constants.py index 78c46379..8fa622df 100644 --- a/unstructured_inference/constants.py +++ b/unstructured_inference/constants.py @@ -1,11 +1,6 @@ from enum import Enum -class OCRMode(Enum): - INDIVIDUAL_BLOCKS = "individual_blocks" - FULL_PAGE = "entire_page" - - class AnnotationResult(Enum): IMAGE = "image" PLOT = "plot" @@ -15,11 +10,11 @@ class Source(Enum): YOLOX = "yolox" DETECTRON2_ONNX = "detectron2_onnx" DETECTRON2_LP = "detectron2_lp" - OCR_TESSERACT = "OCR-tesseract" - OCR_PADDLE = "OCR-paddle" PDFMINER = "pdfminer" MERGED = "merged" -SUBREGION_THRESHOLD_FOR_OCR = 0.5 FULL_PAGE_REGION_THRESHOLD = 0.99 + +# this field is defined by pytesseract/unstructured.pytesseract +TESSERACT_TEXT_HEIGHT = "height" diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py index 67f78216..8ca22415 100644 --- a/unstructured_inference/inference/elements.py +++ b/unstructured_inference/inference/elements.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import re import unicodedata from copy import deepcopy @@ -13,9 +12,7 @@ from unstructured_inference.config import inference_config from unstructured_inference.constants import Source -from unstructured_inference.logger import logger from unstructured_inference.math import safe_division -from unstructured_inference.models import tesseract @dataclass @@ -208,23 +205,15 @@ def extract_text( objects: Optional[Collection[TextRegion]], image: Optional[Image.Image] = None, extract_tables: bool = False, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", ) -> str: """Extracts text contained in region.""" if self.text is not None: # If block text is already populated, we'll assume it's correct text = self.text elif objects is not None: - text = aggregate_by_block(self, image, objects, ocr_strategy) - elif image is not None: - # We don't have anything to go on but the image itself, so we use OCR - text = ocr(self, image, languages=ocr_languages) if ocr_strategy != "never" else "" + text = aggregate_by_block(self, objects) else: - raise ValueError( - "Got arguments image and layout as None, at least one must be populated to use for " - "text extraction.", - ) + text = "" return text @@ -234,8 +223,6 @@ def extract_text( objects: Optional[Collection[TextRegion]], image: Optional[Image.Image] = None, extract_tables: bool = False, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", ) -> str: """Extracts text contained in region.""" if self.text is None: @@ -250,96 +237,22 @@ def extract_text( objects: Optional[Collection[TextRegion]], image: Optional[Image.Image] = None, extract_tables: bool = False, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", ) -> str: """Extracts text contained in region.""" if self.text is None: - if ocr_strategy == "never" or image is None: - return "" - else: - return ocr(self, image, languages=ocr_languages) - else: - return super().extract_text(objects, image, extract_tables, ocr_strategy) - - -def ocr(text_block: TextRegion, image: Image.Image, languages: str = "eng") -> str: - """Runs a cropped text block image through and OCR agent.""" - logger.debug("Running OCR on text block ...") - tesseract.load_agent(languages=languages) - padded_block = text_block.pad(12) - cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2)) - entrie_page_ocr = os.getenv("ENTIRE_PAGE_OCR", "tesseract").lower() - if entrie_page_ocr == "paddle": - from unstructured_inference.models import paddle_ocr - - paddle_result = paddle_ocr.load_agent().ocr(np.array(cropped_image), cls=True) - recognized_text = "" - for idx in range(len(paddle_result)): - res = paddle_result[idx] - for line in res: - recognized_text += line[1][0] - return recognized_text - else: - agent = tesseract.ocr_agents.get(languages) - if agent is None: - raise RuntimeError("OCR agent is not loaded for {languages}.") - - try: - return agent.detect(cropped_image) - except tesseract.TesseractError: - logger.warning("TesseractError: Skipping region", exc_info=True) return "" - - -def needs_ocr( - region: TextRegion, - pdf_objects: Collection[TextRegion], - ocr_strategy: str, -) -> bool: - """Logic to determine whether ocr is needed to extract text from given region.""" - if ocr_strategy == "force": - return True - elif ocr_strategy == "auto": - image_objects = [obj for obj in pdf_objects if isinstance(obj, ImageTextRegion)] - word_objects = [obj for obj in pdf_objects if isinstance(obj, EmbeddedTextRegion)] - # If any image object overlaps with the region of interest, we have hope of getting some - # text from OCR. Otherwise, there's nothing there to find, no need to waste our time with - # OCR. - image_intersects = any(region.intersects(img_obj) for img_obj in image_objects) - if region.text is None: - # If the region has no text check if any images overlap with the region that might - # contain text. - if any(obj.is_in(region) and obj.text is not None for obj in word_objects): - # If there are word objects in the region, we defer to that rather than OCR - return False - else: - return image_intersects else: - # If the region has text, we should only have to OCR if too much of the text is - # uninterpretable. - return cid_ratio(region.text) > 0.5 - else: - return False + return super().extract_text(objects, extract_tables) def aggregate_by_block( text_region: TextRegion, - image: Optional[Image.Image], pdf_objects: Collection[TextRegion], - ocr_strategy: str = "auto", - ocr_languages: str = "eng", ) -> str: """Extracts the text aggregated from the elements of the given layout that lie within the given block.""" - if image is not None and needs_ocr(text_region, pdf_objects, ocr_strategy): - text = ocr(text_region, image, languages=ocr_languages) - else: - filtered_blocks = [obj for obj in pdf_objects if obj.is_in(text_region, error_margin=5)] - for little_block in filtered_blocks: - if image is not None and needs_ocr(little_block, pdf_objects, ocr_strategy): - little_block.text = ocr(little_block, image, languages=ocr_languages) - text = " ".join([x.text for x in filtered_blocks if x.text]) + filtered_blocks = [obj for obj in pdf_objects if obj.is_in(text_region, error_margin=5)] + text = " ".join([x.text for x in filtered_blocks if x.text]) text = remove_control_characters(text) return text diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index 447e7154..7e6811c6 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -7,13 +7,11 @@ import numpy as np import pdf2image -import pytesseract from pdfminer import psparser from pdfminer.high_level import extract_pages from PIL import Image, ImageSequence -from pytesseract import Output -from unstructured_inference.constants import OCRMode, Source +from unstructured_inference.constants import Source from unstructured_inference.inference.elements import ( EmbeddedTextRegion, ImageTextRegion, @@ -24,7 +22,6 @@ LayoutElement, LocationlessLayoutElement, merge_inferred_layout_with_extracted_layout, - merge_inferred_layout_with_ocr_layout, ) from unstructured_inference.inference.ordering import order_layout from unstructured_inference.inference.pdf import get_images_from_pdf_element @@ -47,12 +44,6 @@ import pdfplumber # noqa -VALID_OCR_STRATEGIES = ( - "auto", # Use OCR when it looks like other methods have failed - "force", # Always use OCR - "never", # Never use OCR -) - class DocumentLayout: """Class for handling documents that are saved as .pdf files. For .pdf files, a @@ -84,9 +75,6 @@ def from_file( detection_model: Optional[UnstructuredObjectDetectionModel] = None, element_extraction_model: Optional[UnstructuredElementExtractionModel] = None, fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", - ocr_mode: str = OCRMode.FULL_PAGE.value, extract_tables: bool = False, pdf_image_dpi: int = 200, **kwargs, @@ -124,9 +112,6 @@ def from_file( detection_model=detection_model, element_extraction_model=element_extraction_model, layout=layout, - ocr_strategy=ocr_strategy, - ocr_languages=ocr_languages, - ocr_mode=ocr_mode, fixed_layout=fixed_layout, extract_tables=extract_tables, **kwargs, @@ -140,9 +125,6 @@ def from_image_file( filename: str, detection_model: Optional[UnstructuredObjectDetectionModel] = None, element_extraction_model: Optional[UnstructuredElementExtractionModel] = None, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", - ocr_mode: str = OCRMode.FULL_PAGE.value, fixed_layout: Optional[List[TextRegion]] = None, extract_tables: bool = False, **kwargs, @@ -171,9 +153,6 @@ def from_image_file( detection_model=detection_model, element_extraction_model=element_extraction_model, layout=None, - ocr_strategy=ocr_strategy, - ocr_languages=ocr_languages, - ocr_mode=ocr_mode, fixed_layout=fixed_layout, extract_tables=extract_tables, **kwargs, @@ -195,12 +174,8 @@ def __init__( document_filename: Optional[Union[str, PurePath]] = None, detection_model: Optional[UnstructuredObjectDetectionModel] = None, element_extraction_model: Optional[UnstructuredElementExtractionModel] = None, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", - ocr_mode: str = OCRMode.FULL_PAGE.value, extract_tables: bool = False, analysis: bool = False, - supplement_with_ocr_elements: bool = True, ): if detection_model is not None and element_extraction_model is not None: raise ValueError("Only one of detection_model and extraction_model should be passed.") @@ -216,16 +191,9 @@ def __init__( self.detection_model = detection_model self.element_extraction_model = element_extraction_model self.elements: Collection[Union[LayoutElement, LocationlessLayoutElement]] = [] - if ocr_strategy not in VALID_OCR_STRATEGIES: - raise ValueError(f"ocr_strategy must be one of {VALID_OCR_STRATEGIES}.") - self.ocr_strategy = ocr_strategy - self.ocr_languages = ocr_languages - self.ocr_mode = ocr_mode self.extract_tables = extract_tables self.analysis = analysis self.inferred_layout: Optional[List[LayoutElement]] = None - self.ocr_layout: Optional[List[TextRegion]] = None - self.supplement_with_ocr_elements = supplement_with_ocr_elements def __str__(self) -> str: return "\n\n".join([str(element) for element in self.elements]) @@ -264,37 +232,6 @@ def get_elements_with_detection_model( inferred_layout = UnstructuredObjectDetectionModel.deduplicate_detected_elements( inferred_layout, ) - if self.ocr_mode == OCRMode.INDIVIDUAL_BLOCKS.value: - ocr_layout = None - elif self.ocr_mode == OCRMode.FULL_PAGE.value: - ocr_layout = None - entrie_page_ocr = os.getenv("ENTIRE_PAGE_OCR", "tesseract").lower() - if entrie_page_ocr not in ["paddle", "tesseract"]: - raise ValueError( - "Environment variable ENTIRE_PAGE_OCR must be set to 'tesseract' or 'paddle'.", - ) - - if entrie_page_ocr == "paddle": - logger.info("Processing entire page OCR with paddle...") - from unstructured_inference.models import paddle_ocr - - # TODO(yuming): pass ocr language to paddle when we have language mapping for paddle - ocr_data = paddle_ocr.load_agent().ocr( - np.array(self.image), - cls=True, - ) - ocr_layout = parse_ocr_data_paddle(ocr_data) - else: - logger.info("Processing entrie page OCR with tesseract...") - try: - ocr_data = pytesseract.image_to_data( - self.image, - lang=self.ocr_languages, - output_type=Output.DICT, - ) - ocr_layout = parse_ocr_data_tesseract(ocr_data) - except pytesseract.pytesseract.TesseractError: - logger.warning("TesseractError: Skipping page", exc_info=True) if self.layout is not None: threshold_kwargs = {} @@ -309,25 +246,9 @@ def get_elements_with_detection_model( inferred_layout=inferred_layout, extracted_layout=self.layout, page_image_size=self.image.size, - ocr_layout=ocr_layout, - supplement_with_ocr_elements=self.supplement_with_ocr_elements, - **threshold_kwargs, - ) - elif ocr_layout is not None: - threshold_kwargs = {} - # NOTE(Benjamin): With this the thresholds are only changed for detextron2_mask_rcnn - # In other case the default values for the functions are used - if ( - isinstance(self.detection_model, UnstructuredDetectronONNXModel) - and "R_50" not in self.detection_model.model_path - ): - threshold_kwargs = {"subregion_threshold": 0.3} - merged_layout = merge_inferred_layout_with_ocr_layout( - inferred_layout=inferred_layout, - ocr_layout=ocr_layout, - supplement_with_ocr_elements=self.supplement_with_ocr_elements, **threshold_kwargs, ) + else: merged_layout = inferred_layout @@ -335,7 +256,6 @@ def get_elements_with_detection_model( if self.analysis: self.inferred_layout = inferred_layout - self.ocr_layout = ocr_layout if inplace: self.elements = elements @@ -345,15 +265,13 @@ def get_elements_with_detection_model( def get_elements_from_layout(self, layout: List[TextRegion]) -> List[LayoutElement]: """Uses the given Layout to separate the page text into elements, either extracting the - text from the discovered layout blocks or from the image using OCR.""" + text from the discovered layout blocks.""" layout = order_layout(layout) elements = [ get_element_from_block( block=e, image=self.image, pdf_objects=self.layout, - ocr_strategy=self.ocr_strategy, - ocr_languages=self.ocr_languages, extract_tables=self.extract_tables, ) for e in layout @@ -481,12 +399,8 @@ def from_image( detection_model: Optional[UnstructuredObjectDetectionModel] = None, element_extraction_model: Optional[UnstructuredElementExtractionModel] = None, layout: Optional[List[TextRegion]] = None, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", - ocr_mode: str = OCRMode.FULL_PAGE.value, extract_tables: bool = False, fixed_layout: Optional[List[TextRegion]] = None, - supplement_with_ocr_elements: bool = True, extract_images_in_pdf: bool = False, image_output_dir_path: Optional[str] = None, analysis: bool = False, @@ -499,12 +413,8 @@ def from_image( layout=layout, detection_model=detection_model, element_extraction_model=element_extraction_model, - ocr_strategy=ocr_strategy, - ocr_languages=ocr_languages, - ocr_mode=ocr_mode, extract_tables=extract_tables, analysis=analysis, - supplement_with_ocr_elements=supplement_with_ocr_elements, ) if page.element_extraction_model is not None: page.get_elements_using_image_extraction() @@ -535,12 +445,9 @@ def process_data_with_model( data: BinaryIO, model_name: Optional[str], is_image: bool = False, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", - ocr_mode: str = OCRMode.FULL_PAGE.value, fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, extract_tables: bool = False, - pdf_image_dpi: Optional[int] = None, + pdf_image_dpi: int = 200, **kwargs, ) -> DocumentLayout: """Processes pdf file in the form of a file handler (supporting a read method) into a @@ -552,9 +459,6 @@ def process_data_with_model( tmp_file.name, model_name, is_image=is_image, - ocr_strategy=ocr_strategy, - ocr_languages=ocr_languages, - ocr_mode=ocr_mode, fixed_layouts=fixed_layouts, extract_tables=extract_tables, pdf_image_dpi=pdf_image_dpi, @@ -568,25 +472,14 @@ def process_file_with_model( filename: str, model_name: Optional[str], is_image: bool = False, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", - ocr_mode: str = OCRMode.FULL_PAGE.value, fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, extract_tables: bool = False, - pdf_image_dpi: Optional[int] = None, + pdf_image_dpi: int = 200, **kwargs, ) -> DocumentLayout: """Processes pdf file with name filename into a DocumentLayout by using a model identified by model_name.""" - if pdf_image_dpi is None: - pdf_image_dpi = 300 if model_name == "chipper" else 200 - if (pdf_image_dpi < 300) and (model_name == "chipper"): - logger.warning( - "The Chipper model performs better when images are rendered with DPI >= 300 " - f"(currently {pdf_image_dpi}).", - ) - model = get_model(model_name) if isinstance(model, UnstructuredObjectDetectionModel): detection_model = model @@ -601,9 +494,6 @@ def process_file_with_model( filename, detection_model=detection_model, element_extraction_model=element_extraction_model, - ocr_strategy=ocr_strategy, - ocr_languages=ocr_languages, - ocr_mode=ocr_mode, extract_tables=extract_tables, **kwargs, ) @@ -612,9 +502,6 @@ def process_file_with_model( filename, detection_model=detection_model, element_extraction_model=element_extraction_model, - ocr_strategy=ocr_strategy, - ocr_languages=ocr_languages, - ocr_mode=ocr_mode, fixed_layouts=fixed_layouts, extract_tables=extract_tables, pdf_image_dpi=pdf_image_dpi, @@ -628,8 +515,6 @@ def get_element_from_block( block: TextRegion, image: Optional[Image.Image] = None, pdf_objects: Optional[List[TextRegion]] = None, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", extract_tables: bool = False, ) -> LayoutElement: """Creates a LayoutElement from a given layout or image by finding all the text that lies within @@ -639,8 +524,6 @@ def get_element_from_block( objects=pdf_objects, image=image, extract_tables=extract_tables, - ocr_strategy=ocr_strategy, - ocr_languages=ocr_languages, ) return element @@ -707,80 +590,3 @@ def load_pdf( ) return layouts, images - - -def parse_ocr_data_tesseract(ocr_data: dict) -> List[TextRegion]: - """ - Parse the OCR result data to extract a list of TextRegion objects from - tesseract. - - The function processes the OCR result dictionary, looking for bounding - box information and associated text to create instances of the TextRegion - class, which are then appended to a list. - - Parameters: - - ocr_data (dict): A dictionary containing the OCR result data, expected - to have keys like "level", "left", "top", "width", - "height", and "text". - - Returns: - - List[TextRegion]: A list of TextRegion objects, each representing a - detected text region within the OCR-ed image. - - Note: - - An empty string or a None value for the 'text' key in the input - dictionary will result in its associated bounding box being ignored. - """ - - levels = ocr_data["level"] - text_regions = [] - for i, level in enumerate(levels): - (l, t, w, h) = ( - ocr_data["left"][i], - ocr_data["top"][i], - ocr_data["width"][i], - ocr_data["height"][i], - ) - (x1, y1, x2, y2) = l, t, l + w, t + h - text = ocr_data["text"][i] - if text: - text_region = TextRegion(x1, y1, x2, y2, text=text, source=Source.OCR_TESSERACT) - text_regions.append(text_region) - - return text_regions - - -def parse_ocr_data_paddle(ocr_data: list) -> List[TextRegion]: - """ - Parse the OCR result data to extract a list of TextRegion objects from - paddle. - - The function processes the OCR result dictionary, looking for bounding - box information and associated text to create instances of the TextRegion - class, which are then appended to a list. - - Parameters: - - ocr_data (list): A list containing the OCR result data - - Returns: - - List[TextRegion]: A list of TextRegion objects, each representing a - detected text region within the OCR-ed image. - - Note: - - An empty string or a None value for the 'text' key in the input - dictionary will result in its associated bounding box being ignored. - """ - text_regions = [] - for idx in range(len(ocr_data)): - res = ocr_data[idx] - for line in res: - x1 = min([i[0] for i in line[0]]) - y1 = min([i[1] for i in line[0]]) - x2 = max([i[0] for i in line[0]]) - y2 = max([i[1] for i in line[0]]) - text = line[1][0] - if text: - text_region = TextRegion(x1, y1, x2, y2, text, source=Source.OCR_PADDLE) - text_regions.append(text_region) - - return text_regions diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index f3f3343f..b1fbef3e 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Collection, List, Optional, Union, cast +from typing import Collection, List, Optional, Union import numpy as np from layoutparser.elements.layout import TextBlock @@ -11,7 +11,6 @@ from unstructured_inference.config import inference_config from unstructured_inference.constants import ( FULL_PAGE_REGION_THRESHOLD, - SUBREGION_THRESHOLD_FOR_OCR, Source, ) from unstructured_inference.inference.elements import ( @@ -19,7 +18,6 @@ Rectangle, TextRegion, grow_region_to_match_region, - partition_groups_from_regions, region_bounding_boxes_are_almost_the_same, ) @@ -35,16 +33,11 @@ def extract_text( objects: Optional[Collection[TextRegion]], image: Optional[Image.Image] = None, extract_tables: bool = False, - ocr_strategy: str = "auto", - ocr_languages: str = "eng", ): """Extracts text contained in region""" text = super().extract_text( objects=objects, - image=image, extract_tables=extract_tables, - ocr_strategy=ocr_strategy, - ocr_languages=ocr_languages, ) if extract_tables and self.type == "Table": self.text_as_html = interpret_table_block(self, image) @@ -97,8 +90,6 @@ def merge_inferred_layout_with_extracted_layout( inferred_layout: Collection[LayoutElement], extracted_layout: Collection[TextRegion], page_image_size: tuple, - ocr_layout: Optional[List[TextRegion]] = None, - supplement_with_ocr_elements: bool = True, same_region_threshold: float = inference_config.LAYOUT_SAME_REGION_THRESHOLD, subregion_threshold: float = inference_config.LAYOUT_SUBREGION_THRESHOLD, ) -> List[LayoutElement]: @@ -185,177 +176,12 @@ def merge_inferred_layout_with_extracted_layout( inferred_regions_to_add = [ region for region in inferred_layout if region not in inferred_regions_to_remove ] - inferred_regions_to_add_without_text = [ - region for region in inferred_regions_to_add if not region.text - ] - if ocr_layout is not None: - for inferred_region in inferred_regions_to_add_without_text: - inferred_region.text = aggregate_ocr_text_by_block( - ocr_layout, - inferred_region, - SUBREGION_THRESHOLD_FOR_OCR, - ) - out_layout = categorized_extracted_elements_to_add + inferred_regions_to_add - final_layout = ( - supplement_layout_with_ocr_elements(out_layout, ocr_layout) - if supplement_with_ocr_elements - else out_layout - ) - else: - final_layout = categorized_extracted_elements_to_add + inferred_regions_to_add - - return final_layout - - -def merge_inferred_layout_with_ocr_layout( - inferred_layout: List[LayoutElement], - ocr_layout: List[TextRegion], - supplement_with_ocr_elements: bool = True, -) -> List[LayoutElement]: - """ - Merge the inferred layout with the OCR-detected text regions. - - This function iterates over each inferred layout element and aggregates the - associated text from the OCR layout using the specified threshold. The inferred - layout's text attribute is then updated with this aggregated text. - """ - - for inferred_region in inferred_layout: - inferred_region.text = aggregate_ocr_text_by_block( - ocr_layout, - inferred_region, - SUBREGION_THRESHOLD_FOR_OCR, - ) - final_layout = ( - supplement_layout_with_ocr_elements(inferred_layout, ocr_layout) - if supplement_with_ocr_elements - else inferred_layout - ) + final_layout = categorized_extracted_elements_to_add + inferred_regions_to_add return final_layout -def aggregate_ocr_text_by_block( - ocr_layout: List[TextRegion], - region: TextRegion, - subregion_threshold: float, -) -> Optional[str]: - """Extracts the text aggregated from the regions of the ocr layout that lie within the given - block.""" - - extracted_texts = [] - - for ocr_region in ocr_layout: - ocr_region_is_subregion_of_given_region = ocr_region.is_almost_subregion_of( - region, - subregion_threshold=subregion_threshold, - ) - if ocr_region_is_subregion_of_given_region and ocr_region.text: - extracted_texts.append(ocr_region.text) - - return " ".join(extracted_texts) if extracted_texts else None - - -def supplement_layout_with_ocr_elements( - layout: List[LayoutElement], - ocr_layout: List[TextRegion], -) -> List[LayoutElement]: - """ - Supplement the existing layout with additional OCR-derived elements. - - This function takes two lists: one list of pre-existing layout elements (`layout`) - and another list of OCR-detected text regions (`ocr_layout`). It identifies OCR regions - that are subregions of the elements in the existing layout and removes them from the - OCR-derived list. Then, it appends the remaining OCR-derived regions to the existing layout. - - Parameters: - - layout (List[LayoutElement]): A list of existing layout elements, each of which is - an instance of `LayoutElement`. - - ocr_layout (List[TextRegion]): A list of OCR-derived text regions, each of which is - an instance of `TextRegion`. - - Returns: - - List[LayoutElement]: The final combined layout consisting of both the original layout - elements and the new OCR-derived elements. - - Note: - - The function relies on `is_almost_subregion_of()` method to determine if an OCR region - is a subregion of an existing layout element. - - It also relies on `get_elements_from_ocr_regions()` to convert OCR regions to layout elements. - - The `SUBREGION_THRESHOLD_FOR_OCR` constant is used to specify the subregion matching - threshold. - """ - - ocr_regions_to_remove = [] - for ocr_region in ocr_layout: - for el in layout: - ocr_region_is_subregion_of_out_el = ocr_region.is_almost_subregion_of( - cast(Rectangle, el), - SUBREGION_THRESHOLD_FOR_OCR, - ) - if ocr_region_is_subregion_of_out_el: - ocr_regions_to_remove.append(ocr_region) - break - - ocr_regions_to_add = [region for region in ocr_layout if region not in ocr_regions_to_remove] - if ocr_regions_to_add: - ocr_elements_to_add = get_elements_from_ocr_regions(ocr_regions_to_add) - final_layout = layout + ocr_elements_to_add - else: - final_layout = layout - - return final_layout - - -def merge_text_regions(regions: List[TextRegion]) -> TextRegion: - """ - Merge a list of TextRegion objects into a single TextRegion. - - Parameters: - - group (List[TextRegion]): A list of TextRegion objects to be merged. - - Returns: - - TextRegion: A single merged TextRegion object. - """ - - min_x1 = min([tr.x1 for tr in regions]) - min_y1 = min([tr.y1 for tr in regions]) - max_x2 = max([tr.x2 for tr in regions]) - max_y2 = max([tr.y2 for tr in regions]) - - merged_text = " ".join([tr.text for tr in regions if tr.text]) - sources = [*{tr.source for tr in regions}] - source = sources.pop() if len(sources) == 1 else Source.MERGED - element = TextRegion(min_x1, min_y1, max_x2, max_y2, source=source, text=merged_text) - setattr(element, "merged_sources", sources) - return element - - -def get_elements_from_ocr_regions(ocr_regions: List[TextRegion]) -> List[LayoutElement]: - """ - Get layout elements from OCR regions - """ - - grouped_regions = cast( - List[List[TextRegion]], - partition_groups_from_regions(ocr_regions), - ) - merged_regions = [merge_text_regions(group) for group in grouped_regions] - return [ - LayoutElement( - r.x1, - r.y1, - r.x2, - r.y2, - text=r.text, - source=r.source, - type="UncategorizedText", - ) - for r in merged_regions - ] - - def separate(region_a: Union[LayoutElement, Rectangle], region_b: Union[LayoutElement, Rectangle]): """Reduce leftmost rectangle to don't overlap with the other""" diff --git a/unstructured_inference/models/paddle_ocr.py b/unstructured_inference/models/paddle_ocr.py index b4d6d38c..03d2d5cd 100644 --- a/unstructured_inference/models/paddle_ocr.py +++ b/unstructured_inference/models/paddle_ocr.py @@ -1,3 +1,4 @@ +"""This OCR module is used in table models only and will be removed after table OCR refactoring""" import functools import paddle diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index 4a68e3d2..6b29fe78 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -16,11 +16,11 @@ from transformers import DetrImageProcessor, TableTransformerForObjectDetection from unstructured_inference.config import inference_config -from unstructured_inference.logger import logger -from unstructured_inference.models.table_postprocess import Rect -from unstructured_inference.models.tesseract import ( +from unstructured_inference.constants import ( TESSERACT_TEXT_HEIGHT, ) +from unstructured_inference.logger import logger +from unstructured_inference.models.table_postprocess import Rect from unstructured_inference.models.unstructuredmodel import UnstructuredModel from unstructured_inference.utils import pad_image_with_background_color diff --git a/unstructured_inference/models/tesseract.py b/unstructured_inference/models/tesseract.py deleted file mode 100644 index e6f599cc..00000000 --- a/unstructured_inference/models/tesseract.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -from typing import Dict - -import pytesseract -from layoutparser.ocr.tesseract_agent import TesseractAgent, is_pytesseract_available - -from unstructured_inference.logger import logger - -ocr_agents: Dict[str, TesseractAgent] = {} - -TesseractError = pytesseract.pytesseract.TesseractError - -# Force tesseract to be single threaded, -# otherwise we see major performance problems -if "OMP_THREAD_LIMIT" not in os.environ: - os.environ["OMP_THREAD_LIMIT"] = "1" - - -# this field is defined by pytesseract/unstructured.pytesseract -TESSERACT_TEXT_HEIGHT = "height" - - -def load_agent(languages: str = "eng"): - """Loads the Tesseract OCR agent as a global variable to ensure that we only load it once. - - Parameters - ---------- - languages - The languages to use for the Tesseract agent. To use a langauge, you'll first need - to isntall the appropriate Tesseract language pack. - """ - global ocr_agents - - if not is_pytesseract_available(): - raise ImportError( - "Failed to load Tesseract. Ensure that Tesseract is installed. Example command: \n" - " >>> sudo apt install -y tesseract-ocr", - ) - - if languages not in ocr_agents: - logger.info(f"Loading the Tesseract OCR agent for {languages} ...") - ocr_agents[languages] = TesseractAgent(languages=languages)