Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: allow table model to accept optional OCR data #256

Merged
merged 11 commits into from
Oct 17, 2023
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.7.9

* Allow table model to accept optional OCR tokens

## 0.7.8

* Fix: include onnx as base dependency.
Expand Down
15 changes: 15 additions & 0 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,21 @@ def test_table_prediction_tesseract(table_transformer, example_image):
) in prediction


def test_table_prediction_tesseract_with_ocr_tokens(table_transformer, example_image):
ocr_tokens = [
{
# bounding box should match table structure
"bbox": [70.0, 245.0, 127.0, 266.0],
"block_num": 0,
"line_num": 0,
"span_num": 0,
"text": "Blind",
},
]
prediction = table_transformer.predict(example_image, ocr_tokens=ocr_tokens)
assert prediction == "<table><tr><td>Blind</td></tr></table>"


@pytest.mark.skipif(skip_outside_ci, reason="Skipping paddle test run outside of CI")
def test_table_prediction_paddle(monkeypatch, example_image):
monkeypatch.setenv("TABLE_OCR", "paddle")
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.8" # pragma: no cover
__version__ = "0.7.9" # pragma: no cover
32 changes: 26 additions & 6 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -33,10 +33,24 @@ class UnstructuredTableTransformerModel(UnstructuredModel):
def __init__(self):
pass

def predict(self, x: Image):
"""Predict table structure deferring to run_prediction"""
def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None):
"""Predict table structure deferring to run_prediction with ocr tokens

Note:
`ocr_tokens` is a list of dictionaries representing OCR tokens,
where each dictionary has the following format:
{
"bbox": [int, int, int, int], # Bounding box coordinates of the token
"block_num": int, # Block number
"line_num": int, # Line number
"span_num": int, # Span number
"text": str, # Text content of the token
}
The bounding box coordinates should match the table structure.
FIXME: refactor token data into a dataclass so we have clear expectations of the fields
"""
super().predict(x)
return self.run_prediction(x)
return self.run_prediction(x, ocr_tokens=ocr_tokens)

def initialize(
self,
Expand Down Expand Up @@ -161,12 +175,18 @@ def run_prediction(
self,
x: Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
ocr_tokens: Optional[List[Dict]] = None,
):
"""Predict table structure"""
outputs_structure = self.get_structure(x, pad_for_structure_detection)
tokens = self.get_tokens(x=x)
if ocr_tokens is None:
logger.warning(
"Table OCR from get_tokens method will be deprecated. "
"In the future the OCR tokens are expected to be passed in.",
)
ocr_tokens = self.get_tokens(x=x)

html = recognize(outputs_structure, x, tokens=tokens, out_html=True)["html"]
html = recognize(outputs_structure, x, tokens=ocr_tokens, out_html=True)["html"]
prediction = html[0] if html else ""
return prediction

Expand Down