From abb0174181d31343cbc47dcef2bdaebb54c2af9e Mon Sep 17 00:00:00 2001 From: Dimitri Lozeve Date: Tue, 23 Apr 2024 23:11:39 +0200 Subject: [PATCH] Integration with the Google Cloud Vision API (#2902) This PR adds a third OCR provider, alongside Tesseract and Paddle: the [Google Cloud Vision API](https://cloud.google.com/vision). It can be used similarly to other OCR methods: set the `OCR_AGENT` environment variable to the path to the OCR module (`unstructured.partition.utils.ocr_models.google_vision_ocr.OCRAgentGoogleVision`). You also need to set the credentials to use Google APIs, for instance by setting the `GOOGLE_APPLICATION_CREDENTIALS` environment variable. --------- Co-authored-by: christinestraub --- CHANGELOG.md | 4 +- requirements/base.txt | 2 +- requirements/dev.txt | 2 +- requirements/extra-paddleocr.txt | 2 +- requirements/extra-pdf-image.in | 1 + requirements/extra-pdf-image.txt | 38 +++++++ requirements/ingest/astra.txt | 4 +- requirements/ingest/azure.txt | 4 +- requirements/ingest/box.txt | 4 +- requirements/ingest/chroma.txt | 4 +- requirements/ingest/delta-table.txt | 2 +- requirements/ingest/embed-aws-bedrock.txt | 2 + requirements/ingest/embed-huggingface.txt | 2 + requirements/ingest/embed-openai.txt | 2 + requirements/ingest/embed-vertexai.txt | 3 +- requirements/ingest/github.txt | 4 +- requirements/ingest/onedrive.txt | 4 +- requirements/ingest/outlook.txt | 4 +- requirements/ingest/qdrant.txt | 6 +- requirements/ingest/sharepoint.txt | 4 +- .../partition/pdf_image/test_ocr.py | 86 +++++++++++++++ unstructured/__version__.py | 2 +- unstructured/partition/utils/constants.py | 7 +- .../utils/ocr_models/google_vision_ocr.py | 104 ++++++++++++++++++ 24 files changed, 261 insertions(+), 36 deletions(-) create mode 100644 unstructured/partition/utils/ocr_models/google_vision_ocr.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c484e9eca1..0254d09d04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,11 @@ -## 0.13.4-dev0 +## 0.13.4-dev1 ### Enhancements ### Features +* **Add integration with the Google Cloud Vision API**. Adds a third OCR provider, alongside Tesseract and Paddle: the Google Cloud Vision API. + ### Fixes * **Remove ElementMetadata.section field.**. This field was unused, not populated by any partitioners. diff --git a/requirements/base.txt b/requirements/base.txt index 87ae0d05fa..c47a33314b 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -25,7 +25,7 @@ dataclasses-json==0.6.4 # via -r ./base.in dataclasses-json-speakeasy==0.5.11 # via unstructured-client -emoji==2.11.0 +emoji==2.11.1 # via -r ./base.in filetype==1.2.0 # via -r ./base.in diff --git a/requirements/dev.txt b/requirements/dev.txt index 477172e78f..54f006385e 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -185,7 +185,7 @@ jupyterlab==4.1.6 # via notebook jupyterlab-pygments==0.3.0 # via nbconvert -jupyterlab-server==2.26.0 +jupyterlab-server==2.27.0 # via # jupyterlab # notebook diff --git a/requirements/extra-paddleocr.txt b/requirements/extra-paddleocr.txt index 114316e73f..3557635dd7 100644 --- a/requirements/extra-paddleocr.txt +++ b/requirements/extra-paddleocr.txt @@ -53,7 +53,7 @@ idna==3.7 # via # -c ./base.txt # requests -imageio==2.34.0 +imageio==2.34.1 # via # imgaug # scikit-image diff --git a/requirements/extra-pdf-image.in b/requirements/extra-pdf-image.in index f6e003d1a3..b4e3f3b8ec 100644 --- a/requirements/extra-pdf-image.in +++ b/requirements/extra-pdf-image.in @@ -13,3 +13,4 @@ unstructured-inference==0.7.27 # unstructured fork of pytesseract that provides an interface to allow for multiple output formats # from one tesseract call unstructured.pytesseract>=0.3.12 +google-cloud-vision diff --git a/requirements/extra-pdf-image.txt b/requirements/extra-pdf-image.txt index 2d902f0194..2acc11122a 100644 --- a/requirements/extra-pdf-image.txt +++ b/requirements/extra-pdf-image.txt @@ -6,6 +6,8 @@ # antlr4-python3-runtime==4.9.3 # via omegaconf +cachetools==5.3.3 + # via google-auth certifi==2024.2.2 # via # -c ././deps/constraints.txt @@ -43,6 +45,24 @@ fsspec==2024.3.1 # via # huggingface-hub # torch +google-api-core[grpc]==2.18.0 + # via google-cloud-vision +google-auth==2.29.0 + # via + # google-api-core + # google-cloud-vision +google-cloud-vision==3.7.2 + # via -r ./extra-pdf-image.in +googleapis-common-protos==1.63.0 + # via + # google-api-core + # grpcio-status +grpcio==1.62.2 + # via + # google-api-core + # grpcio-status +grpcio-status==1.62.2 + # via google-api-core huggingface-hub==0.22.2 # via # timm @@ -147,11 +167,26 @@ pillow-heif==0.16.0 # via -r ./extra-pdf-image.in portalocker==2.8.2 # via iopath +proto-plus==1.23.0 + # via + # google-api-core + # google-cloud-vision protobuf==4.23.4 # via # -c ././deps/constraints.txt + # google-api-core + # google-cloud-vision + # googleapis-common-protos + # grpcio-status # onnx # onnxruntime + # proto-plus +pyasn1==0.6.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.0 + # via google-auth pycocotools==2.0.7 # via # -c ././deps/constraints.txt @@ -195,8 +230,11 @@ regex==2024.4.16 requests==2.31.0 # via # -c ./base.txt + # google-api-core # huggingface-hub # transformers +rsa==4.9 + # via google-auth safetensors==0.4.3 # via # timm diff --git a/requirements/ingest/astra.txt b/requirements/ingest/astra.txt index 0e8c50605b..8e7f91969c 100644 --- a/requirements/ingest/astra.txt +++ b/requirements/ingest/astra.txt @@ -46,9 +46,7 @@ hpack==4.0.0 httpcore==1.0.5 # via httpx httpx[http2]==0.27.0 - # via - # astrapy - # httpx + # via astrapy hyperframe==6.0.1 # via h2 idna==3.7 diff --git a/requirements/ingest/azure.txt b/requirements/ingest/azure.txt index 2c48b6950e..e60c4536c1 100644 --- a/requirements/ingest/azure.txt +++ b/requirements/ingest/azure.txt @@ -80,9 +80,7 @@ portalocker==2.8.2 pycparser==2.22 # via cffi pyjwt[crypto]==2.8.0 - # via - # msal - # pyjwt + # via msal requests==2.31.0 # via # -c ./ingest/../base.txt diff --git a/requirements/ingest/box.txt b/requirements/ingest/box.txt index 2f3c8980ab..80244e0885 100644 --- a/requirements/ingest/box.txt +++ b/requirements/ingest/box.txt @@ -9,9 +9,7 @@ attrs==23.2.0 boxfs==0.3.0 # via -r ./ingest/box.in boxsdk[jwt]==3.9.2 - # via - # boxfs - # boxsdk + # via boxfs certifi==2024.2.2 # via # -c ./ingest/../base.txt diff --git a/requirements/ingest/chroma.txt b/requirements/ingest/chroma.txt index d4acacdc3a..8cef371aff 100644 --- a/requirements/ingest/chroma.txt +++ b/requirements/ingest/chroma.txt @@ -214,9 +214,7 @@ urllib3==1.26.18 # kubernetes # requests uvicorn[standard]==0.29.0 - # via - # chromadb - # uvicorn + # via chromadb uvloop==0.19.0 # via uvicorn watchfiles==0.21.0 diff --git a/requirements/ingest/delta-table.txt b/requirements/ingest/delta-table.txt index 1053728f11..31d05c7105 100644 --- a/requirements/ingest/delta-table.txt +++ b/requirements/ingest/delta-table.txt @@ -4,7 +4,7 @@ # # pip-compile ./ingest/delta-table.in # -deltalake==0.16.4 +deltalake==0.17.0 # via -r ./ingest/delta-table.in fsspec==2024.3.1 # via -r ./ingest/delta-table.in diff --git a/requirements/ingest/embed-aws-bedrock.txt b/requirements/ingest/embed-aws-bedrock.txt index 31052577e9..53ee3b9b0d 100644 --- a/requirements/ingest/embed-aws-bedrock.txt +++ b/requirements/ingest/embed-aws-bedrock.txt @@ -38,6 +38,8 @@ frozenlist==1.4.1 # via # aiohttp # aiosignal +greenlet==3.0.3 + # via sqlalchemy idna==3.7 # via # -c ./ingest/../base.txt diff --git a/requirements/ingest/embed-huggingface.txt b/requirements/ingest/embed-huggingface.txt index ce967ba576..4360feb567 100644 --- a/requirements/ingest/embed-huggingface.txt +++ b/requirements/ingest/embed-huggingface.txt @@ -40,6 +40,8 @@ fsspec==2024.3.1 # via # huggingface-hub # torch +greenlet==3.0.3 + # via sqlalchemy huggingface==0.0.1 # via -r ./ingest/embed-huggingface.in huggingface-hub==0.22.2 diff --git a/requirements/ingest/embed-openai.txt b/requirements/ingest/embed-openai.txt index dae330e673..91a79c4411 100644 --- a/requirements/ingest/embed-openai.txt +++ b/requirements/ingest/embed-openai.txt @@ -42,6 +42,8 @@ frozenlist==1.4.1 # via # aiohttp # aiosignal +greenlet==3.0.3 + # via sqlalchemy h11==0.14.0 # via httpcore httpcore==1.0.5 diff --git a/requirements/ingest/embed-vertexai.txt b/requirements/ingest/embed-vertexai.txt index 39aad94d13..8ed678c259 100644 --- a/requirements/ingest/embed-vertexai.txt +++ b/requirements/ingest/embed-vertexai.txt @@ -42,7 +42,6 @@ frozenlist==1.4.1 # aiosignal google-api-core[grpc]==2.18.0 # via - # google-api-core # google-cloud-aiplatform # google-cloud-bigquery # google-cloud-core @@ -83,6 +82,8 @@ googleapis-common-protos[grpc]==1.63.0 # google-api-core # grpc-google-iam-v1 # grpcio-status +greenlet==3.0.3 + # via sqlalchemy grpc-google-iam-v1==0.13.0 # via google-cloud-resource-manager grpcio==1.62.2 diff --git a/requirements/ingest/github.txt b/requirements/ingest/github.txt index 18e29fc3c4..ed8ec5fdb1 100644 --- a/requirements/ingest/github.txt +++ b/requirements/ingest/github.txt @@ -30,9 +30,7 @@ pycparser==2.22 pygithub==2.3.0 # via -r ./ingest/github.in pyjwt[crypto]==2.8.0 - # via - # pygithub - # pyjwt + # via pygithub pynacl==1.5.0 # via pygithub requests==2.31.0 diff --git a/requirements/ingest/onedrive.txt b/requirements/ingest/onedrive.txt index 8922ec418c..ced2374b7d 100644 --- a/requirements/ingest/onedrive.txt +++ b/requirements/ingest/onedrive.txt @@ -40,9 +40,7 @@ office365-rest-python-client==2.4.2 pycparser==2.22 # via cffi pyjwt[crypto]==2.8.0 - # via - # msal - # pyjwt + # via msal pytz==2024.1 # via office365-rest-python-client requests==2.31.0 diff --git a/requirements/ingest/outlook.txt b/requirements/ingest/outlook.txt index 2129b31be5..9a6ecbe3e4 100644 --- a/requirements/ingest/outlook.txt +++ b/requirements/ingest/outlook.txt @@ -34,9 +34,7 @@ office365-rest-python-client==2.4.2 pycparser==2.22 # via cffi pyjwt[crypto]==2.8.0 - # via - # msal - # pyjwt + # via msal pytz==2024.1 # via office365-rest-python-client requests==2.31.0 diff --git a/requirements/ingest/qdrant.txt b/requirements/ingest/qdrant.txt index 41b0c25d26..da0967d862 100644 --- a/requirements/ingest/qdrant.txt +++ b/requirements/ingest/qdrant.txt @@ -33,9 +33,7 @@ hpack==4.0.0 httpcore==1.0.5 # via httpx httpx[http2]==0.27.0 - # via - # httpx - # qdrant-client + # via qdrant-client hyperframe==6.0.1 # via h2 idna==3.7 @@ -57,7 +55,7 @@ pydantic==2.7.0 # via qdrant-client pydantic-core==2.18.1 # via pydantic -qdrant-client==1.8.2 +qdrant-client==1.9.0 # via -r ./ingest/qdrant.in sniffio==1.3.1 # via diff --git a/requirements/ingest/sharepoint.txt b/requirements/ingest/sharepoint.txt index 9167a159ed..4eb8e6b15e 100644 --- a/requirements/ingest/sharepoint.txt +++ b/requirements/ingest/sharepoint.txt @@ -34,9 +34,7 @@ office365-rest-python-client==2.4.2 pycparser==2.22 # via cffi pyjwt[crypto]==2.8.0 - # via - # msal - # pyjwt + # via msal pytz==2024.1 # via office365-rest-python-client requests==2.31.0 diff --git a/test_unstructured/partition/pdf_image/test_ocr.py b/test_unstructured/partition/pdf_image/test_ocr.py index 1c8ec23fa9..175682156a 100644 --- a/test_unstructured/partition/pdf_image/test_ocr.py +++ b/test_unstructured/partition/pdf_image/test_ocr.py @@ -1,3 +1,4 @@ +from collections import namedtuple from unittest.mock import patch import numpy as np @@ -19,6 +20,7 @@ from unstructured.partition.utils.constants import ( Source, ) +from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle from unstructured.partition.utils.ocr_models.tesseract_ocr import ( OCRAgentTesseract, @@ -192,6 +194,90 @@ def test_get_ocr_text_from_image_paddle(monkeypatch): assert ocr_text == "Hello\n\nWorld\n\n!" +@pytest.fixture() +def google_vision_text_annotation(): + from google.cloud.vision import ( + Block, + BoundingPoly, + Page, + Paragraph, + Symbol, + TextAnnotation, + Vertex, + Word, + ) + + breaks = TextAnnotation.DetectedBreak.BreakType + symbols_hello = [Symbol(text=c) for c in "Hello"] + [ + Symbol( + property=TextAnnotation.TextProperty( + detected_break=TextAnnotation.DetectedBreak(type_=breaks.SPACE) + ) + ) + ] + symbols_world = [Symbol(text=c) for c in "World!"] + [ + Symbol( + property=TextAnnotation.TextProperty( + detected_break=TextAnnotation.DetectedBreak(type_=breaks.LINE_BREAK) + ) + ) + ] + words = [Word(symbols=symbols_hello), Word(symbols=symbols_world)] + bounding_box = BoundingPoly( + vertices=[Vertex(x=0, y=0), Vertex(x=0, y=10), Vertex(x=10, y=10), Vertex(x=10, y=0)] + ) + paragraphs = [Paragraph(words=words, bounding_box=bounding_box)] + blocks = [Block(paragraphs=paragraphs)] + pages = [Page(blocks=blocks)] + return TextAnnotation(text="Hello World!", pages=pages) + + +@pytest.fixture() +def google_vision_client(google_vision_text_annotation): + Response = namedtuple("Response", "full_text_annotation") + + class FakeGoogleVisionClient: + def document_text_detection(self, image): + return Response(full_text_annotation=google_vision_text_annotation) + + class OCRAgentFakeGoogleVision(OCRAgentGoogleVision): + def __init__(self): + self.client = FakeGoogleVisionClient() + + return OCRAgentFakeGoogleVision() + + +def test_get_ocr_from_image_google_vision(google_vision_client): + image = Image.new("RGB", (100, 100)) + + ocr_agent = google_vision_client + ocr_text = ocr_agent.get_text_from_image(image, ocr_languages="eng") + + assert ocr_text == "Hello World!" + + +def test_get_layout_from_image_google_vision(google_vision_client): + image = Image.new("RGB", (100, 100)) + + ocr_agent = google_vision_client + regions = ocr_agent.get_layout_from_image(image, ocr_languages="eng") + assert len(regions) == 1 + assert regions[0].text == "Hello World!" + assert regions[0].source == Source.OCR_GOOGLEVISION + assert regions[0].bbox.x1 == 0 + assert regions[0].bbox.y1 == 0 + assert regions[0].bbox.x2 == 10 + assert regions[0].bbox.y2 == 10 + + +def test_get_layout_elements_from_image_google_vision(google_vision_client): + image = Image.new("RGB", (100, 100)) + + ocr_agent = google_vision_client + layout_elements = ocr_agent.get_layout_elements_from_image(image, ocr_languages="eng") + assert len(layout_elements) == 1 + + @pytest.fixture() def mock_ocr_regions(): return [ diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 927d266d2c..8c8a0a7592 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.13.4-dev0" # pragma: no cover +__version__ = "0.13.4-dev1" # pragma: no cover diff --git a/unstructured/partition/utils/constants.py b/unstructured/partition/utils/constants.py index 6645dce079..c1864a9e56 100644 --- a/unstructured/partition/utils/constants.py +++ b/unstructured/partition/utils/constants.py @@ -6,6 +6,7 @@ class Source(Enum): PDFMINER = "pdfminer" OCR_TESSERACT = "ocr_tesseract" OCR_PADDLE = "ocr_paddle" + OCR_GOOGLEVISION = "ocr_googlevision" class OCRMode(Enum): @@ -29,11 +30,15 @@ class PartitionStrategy: OCR_AGENT_TESSERACT = "unstructured.partition.utils.ocr_models.tesseract_ocr.OCRAgentTesseract" OCR_AGENT_PADDLE = "unstructured.partition.utils.ocr_models.paddle_ocr.OCRAgentPaddle" +OCR_AGENT_GOOGLEVISION = ( + "unstructured.partition.utils.ocr_models.google_vision_ocr.OCRAgentGoogleVision" +) OCR_AGENT_MODULES_WHITELIST = os.getenv( "OCR_AGENT_MODULES_WHITELIST", "unstructured.partition.utils.ocr_models.tesseract_ocr," - "unstructured.partition.utils.ocr_models.paddle_ocr", + "unstructured.partition.utils.ocr_models.paddle_ocr," + "unstructured.partition.utils.ocr_models.google_vision_ocr", ).split(",") UNSTRUCTURED_INCLUDE_DEBUG_METADATA = os.getenv("UNSTRUCTURED_INCLUDE_DEBUG_METADATA", False) diff --git a/unstructured/partition/utils/ocr_models/google_vision_ocr.py b/unstructured/partition/utils/ocr_models/google_vision_ocr.py new file mode 100644 index 0000000000..231a10904c --- /dev/null +++ b/unstructured/partition/utils/ocr_models/google_vision_ocr.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from io import BytesIO +from typing import TYPE_CHECKING + +from google.cloud.vision import Image, ImageAnnotatorClient, Paragraph, TextAnnotation + +from unstructured.partition.utils.constants import Source +from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent + +if TYPE_CHECKING: + from PIL import Image as PILImage + from unstructured_inference.inference.elements import TextRegion + from unstructured_inference.inference.layoutelement import LayoutElement + + +class OCRAgentGoogleVision(OCRAgent): + """OCR service implementation for Google Vision API.""" + + def __init__(self) -> None: + self.client = ImageAnnotatorClient() + + def is_text_sorted(self) -> bool: + return True + + def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str: + with BytesIO() as buffer: + image.save(buffer, format="PNG") + response = self.client.document_text_detection(image=Image(content=buffer.getvalue())) + document = response.full_text_annotation + assert isinstance(document, TextAnnotation) + return document.text + + def get_layout_from_image( + self, image: PILImage.Image, ocr_languages: str = "eng" + ) -> list[TextRegion]: + with BytesIO() as buffer: + image.save(buffer, format="PNG") + response = self.client.document_text_detection(image=Image(content=buffer.getvalue())) + document = response.full_text_annotation + assert isinstance(document, TextAnnotation) + regions = self._parse_regions(document) + return regions + + def get_layout_elements_from_image( + self, image: PILImage.Image, ocr_languages: str = "eng" + ) -> list[LayoutElement]: + from unstructured.partition.pdf_image.inference_utils import ( + build_layout_elements_from_ocr_regions, + ) + + ocr_regions = self.get_layout_from_image( + image, + ocr_languages=ocr_languages, + ) + ocr_text = self.get_text_from_image( + image, + ocr_languages=ocr_languages, + ) + layout_elements = build_layout_elements_from_ocr_regions( + ocr_regions=ocr_regions, + ocr_text=ocr_text, + group_by_ocr_text=False, + ) + return layout_elements + + def _parse_regions(self, ocr_data: TextAnnotation) -> list[TextRegion]: + from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords + + text_regions: list[TextRegion] = [] + for page_idx, page in enumerate(ocr_data.pages): + for block in page.blocks: + for paragraph in block.paragraphs: + vertices = paragraph.bounding_box.vertices + x1, y1 = vertices[0].x, vertices[0].y + x2, y2 = vertices[2].x, vertices[2].y + text_region = build_text_region_from_coords( + x1, + y1, + x2, + y2, + text=self._get_text_from_paragraph(paragraph), + source=Source.OCR_GOOGLEVISION, + ) + text_regions.append(text_region) + return text_regions + + def _get_text_from_paragraph(self, paragraph: Paragraph) -> str: + breaks = TextAnnotation.DetectedBreak.BreakType + para = "" + line = "" + for word in paragraph.words: + for symbol in word.symbols: + line += symbol.text + if symbol.property.detected_break.type_ == breaks.SPACE: + line += " " + if symbol.property.detected_break.type_ == breaks.EOL_SURE_SPACE: + line += " " + para += line + line = "" + if symbol.property.detected_break.type_ == breaks.LINE_BREAK: + para += line + line = "" + return para