Skip to content

Commit

Permalink
Chore (refactor): support table extraction with pre-computed ocr data (
Browse files Browse the repository at this point in the history
…#1801)

### Summary

Table OCR refactor, move the OCR part for table model in inference repo
to unst repo.
* Before this PR, table model extracts OCR tokens with texts and
bounding box and fills the tokens to the table structure in inference
repo. This means we need to do an additional OCR for tables.
* After this PR, we use the OCR data from entire page OCR and pass the
OCR tokens to inference repo, which means we only do one OCR for the
entire document.

**Tech details:**
* Combined env `ENTIRE_PAGE_OCR` and `TABLE_OCR` to `OCR_AGENT`, this
means we use the same OCR agent for entire page and tables since we only
do one OCR.
* Bump inference repo to `0.7.9`, which allow table model in inference
to use pre-computed OCR data from unst repo. Please check in
[PR](Unstructured-IO/unstructured-inference#256).
* All notebooks lint are made by `make tidy`
* This PR also fixes
[issue](#1564),
I've added test for the issue in
`test_pdf.py::test_partition_pdf_hi_table_extraction_with_languages`
* Add same scaling logic to image [similar to previous Table
OCR](https://github.com/Unstructured-IO/unstructured-inference/blob/main/unstructured_inference/models/tables.py#L109C1-L113),
but now scaling is applied to entire image

### Test
* Not much to manually testing expect table extraction still works
* But due to change on scaling and use pre-computed OCR data from entire
page, there are some slight (better) changes on table output, here is an
comparison on test outputs i found from the same test
`test_partition_image_with_table_extraction`:

screen shot for table in `layout-parser-paper-with-table.jpg`:
<img width="343" alt="expected"
src="https://github.com/Unstructured-IO/unstructured/assets/63475068/278d7665-d212-433d-9a05-872c4502725c">
before refactor:
<img width="709" alt="before"
src="https://github.com/Unstructured-IO/unstructured/assets/63475068/347fbc3b-f52b-45b5-97e9-6f633eaa0d5e">
after refactor:
<img width="705" alt="after"
src="https://github.com/Unstructured-IO/unstructured/assets/63475068/b3cbd809-cf67-4e75-945a-5cbd06b33b2d">

### TODO
(added as a ticket) Still have some clean up to do in inference repo
since now unst repo have duplicate logic, but can keep them as a fall
back plan. If we want to remove anything OCR related in inference, here
are items that is deprecated and can be removed:
*
[`get_tokens`](https://github.com/Unstructured-IO/unstructured-inference/blob/main/unstructured_inference/models/tables.py#L77)
(already noted in code)
* parameter `extract_tables` in inference
*
[`interpret_table_block`](https://github.com/Unstructured-IO/unstructured-inference/blob/main/unstructured_inference/inference/layoutelement.py#L88)
*
[`load_agent`](https://github.com/Unstructured-IO/unstructured-inference/blob/main/unstructured_inference/models/tables.py#L197)
* env `TABLE_OCR` 

### Note
if we want to fallback for an additional table OCR (may need this for
using paddle for table), we need to:
* pass `infer_table_structure` to inference with `extract_tables`
parameter
* stop passing `infer_table_structure` to `ocr.py`

---------

Co-authored-by: Yao You <[email protected]>
  • Loading branch information
yuming-long and badGarnet authored Oct 21, 2023
1 parent 3437a23 commit ce40cdc
Show file tree
Hide file tree
Showing 36 changed files with 843 additions and 399 deletions.
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[run]
omit =
unstructured/ingest/*
# TODO(yuming): please remove this line after adding tests for paddle (CORE-1886)
# TODO(yuming): please remove this line after adding tests for paddle
unstructured/partition/utils/ocr_models/paddle_ocr.py
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ jobs:
AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TABLE_OCR: "tesseract"
ENTIRE_PAGE_OCR: "tesseract"
OCR_AGENT: "tesseract"
CI: "true"
run: |
source .venv/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ingest-test-fixtures-update-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TABLE_OCR: "tesseract"
ENTIRE_PAGE_OCR: "tesseract"
OCR_AGENT: "tesseract"
OVERWRITE_FIXTURES: "true"
CI: "true"
run: |
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
## 0.10.25-dev8
## 0.10.25-dev9

### Enhancements

* **Duplicate CLI param check** Given that many of the options associated with the `Click` based cli ingest commands are added dynamically from a number of configs, a check was incorporated to make sure there were no duplicate entries to prevent new configs from overwriting already added options.

### Features

* **Table OCR refactor** support Table OCR with pre-computed OCR data to ensure we only do one OCR for entrie document. User can specify
ocr agent tesseract/paddle in environment variable `OCR_AGENT` for OCRing the entire document.
* **Adds accuracy function** The accuracy scoring was originally an option under `calculate_edit_distance`. For easy function call, it is now a wrapper around the original function that calls edit_distance and return as "score".
* **Adds HuggingFaceEmbeddingEncoder** The HuggingFace Embedding Encoder uses a local embedding model as opposed to using an API.
* **Add AWS bedrock embedding connector** `unstructured.embed.bedrock` now provides a connector to use AWS bedrock's `titan-embed-text` model to generate embeddings for elements. This features requires valid AWS bedrock setup and an internet connectionto run.
Expand Down
Binary file added example-docs/korean-text-with-tables.pdf
Binary file not shown.
42 changes: 25 additions & 17 deletions examples/argilla-summarization/isw-summarization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"source": [
"from IPython.display import Image\n",
"\n",
"Image(filename=\"img/isw.png\", width=800) "
"Image(filename=\"img/isw.png\", width=800)"
]
},
{
Expand Down Expand Up @@ -94,6 +94,7 @@
"source": [
"ISW_BASE_URL = \"https://www.understandingwar.org/backgrounder/russian-offensive-campaign-assessment\"\n",
"\n",
"\n",
"def datetime_to_url(dt):\n",
" month = dt.strftime(\"%B\").lower()\n",
" return f\"{ISW_BASE_URL}-{month}-{dt.day}\""
Expand Down Expand Up @@ -134,8 +135,8 @@
" r = requests.get(url)\n",
" if r.status_code != 200:\n",
" return None\n",
" \n",
" elements = partition_html(text=r.text) \n",
"\n",
" elements = partition_html(text=r.text)\n",
" return elements"
]
},
Expand Down Expand Up @@ -170,7 +171,7 @@
}
],
"source": [
"Image(filename=\"img/isw-key-takeaways.png\", width=500) "
"Image(filename=\"img/isw-key-takeaways.png\", width=500)"
]
},
{
Expand All @@ -185,13 +186,14 @@
" if element.text == \"Key Takeaways\":\n",
" return idx\n",
"\n",
"\n",
"def get_key_takeaways(elements):\n",
" key_takeaways_idx = _find_key_takeaways_idx(elements)\n",
" if not key_takeaways_idx:\n",
" return None\n",
" \n",
"\n",
" takeaways = []\n",
" for element in elements[key_takeaways_idx + 1:]:\n",
" for element in elements[key_takeaways_idx + 1 :]:\n",
" if not isinstance(element, ListItem):\n",
" break\n",
" takeaways.append(element)\n",
Expand Down Expand Up @@ -245,12 +247,12 @@
"source": [
"def get_narrative(elements):\n",
" narrative_text = \"\"\n",
" for element in elements: \n",
" for element in elements:\n",
" if isinstance(element, NarrativeText) and len(element.text) > 500:\n",
" # NOTE: Removes citations like [3] from the text\n",
" element_text = re.sub(\"\\[\\d{1,3}\\]\", \"\", element.text)\n",
" narrative_text += f\"\\n\\n{element_text}\"\n",
" \n",
"\n",
" return NarrativeText(text=narrative_text.strip())"
]
},
Expand Down Expand Up @@ -337,10 +339,10 @@
" elements = url_to_elements(url)\n",
" if url is None or not elements:\n",
" continue\n",
" \n",
"\n",
" text = get_narrative(elements)\n",
" annotation = get_key_takeaways(elements)\n",
" \n",
"\n",
" if text and annotation:\n",
" inputs.append(text)\n",
" annotations.append(annotation.text)\n",
Expand Down Expand Up @@ -600,7 +602,7 @@
}
],
"source": [
"Image(filename=\"img/argilla-dataset.png\", width=800) "
"Image(filename=\"img/argilla-dataset.png\", width=800)"
]
},
{
Expand Down Expand Up @@ -634,7 +636,7 @@
}
],
"source": [
"Image(filename=\"img/argilla-annotation.png\", width=800) "
"Image(filename=\"img/argilla-annotation.png\", width=800)"
]
},
{
Expand Down Expand Up @@ -688,7 +690,7 @@
],
"source": [
"from transformers import AutoTokenizer\n",
" \n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
]
},
Expand All @@ -702,6 +704,7 @@
"max_input_length = 1024\n",
"max_target_length = 128\n",
"\n",
"\n",
"def preprocess_function(examples):\n",
" inputs = [doc for doc in examples[\"text\"]]\n",
" model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n",
Expand Down Expand Up @@ -754,7 +757,12 @@
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
"from transformers import (\n",
" AutoModelForSeq2SeqLM,\n",
" DataCollatorForSeq2Seq,\n",
" Seq2SeqTrainingArguments,\n",
" Seq2SeqTrainer,\n",
")\n",
"\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
]
Expand All @@ -770,7 +778,7 @@
"model_name = model_checkpoint.split(\"/\")[-1]\n",
"args = Seq2SeqTrainingArguments(\n",
" \"t5-small-isw-summaries\",\n",
" evaluation_strategy = \"epoch\",\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=batch_size,\n",
" per_device_eval_batch_size=batch_size,\n",
Expand Down Expand Up @@ -1068,8 +1076,8 @@
],
"source": [
"summarization_model = pipeline(\n",
"task=\"summarization\",\n",
"model=\"./t5-small-isw-summaries\",\n",
" task=\"summarization\",\n",
" model=\"./t5-small-isw-summaries\",\n",
")"
]
},
Expand Down
118 changes: 71 additions & 47 deletions examples/arxiv-topic-modelling/topic_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
"metadata": {},
"outputs": [],
"source": [
"import arxiv # Interact with arXiv api to scrape papers\n",
"from sentence_transformers import SentenceTransformer # Use Hugging Face Embedding for Topic Modelling\n",
"from bertopic import BERTopic # Package for Topic Modelling\n",
"from tqdm import tqdm #Progress Bar When Iterating\n",
"import glob #Identify Files in Directory\n",
"import os #Delete Files in Directory\n",
"import pandas as pd #Dataframe Manipulation"
"import arxiv # Interact with arXiv api to scrape papers\n",
"from sentence_transformers import (\n",
" SentenceTransformer,\n",
") # Use Hugging Face Embedding for Topic Modelling\n",
"from bertopic import BERTopic # Package for Topic Modelling\n",
"from tqdm import tqdm # Progress Bar When Iterating\n",
"import glob # Identify Files in Directory\n",
"import os # Delete Files in Directory\n",
"import pandas as pd # Dataframe Manipulation"
]
},
{
Expand All @@ -42,13 +44,19 @@
"metadata": {},
"outputs": [],
"source": [
"from unstructured.partition.auto import partition #Base Function to Partition PDF\n",
"from unstructured.staging.base import convert_to_dict #Convert List Unstructured Elements Into List of Dicts for Easy Parsing\n",
"from unstructured.cleaners.core import clean, remove_punctuation, clean_non_ascii_chars #Cleaning Bricks\n",
"import re #Create Custom Cleaning Brick\n",
"import nltk #Toolkit for more advanced pre-processing\n",
"from nltk.corpus import stopwords #list of stopwords to remove\n",
"from typing import List #Type Hinting"
"from unstructured.partition.auto import partition # Base Function to Partition PDF\n",
"from unstructured.staging.base import (\n",
" convert_to_dict,\n",
") # Convert List Unstructured Elements Into List of Dicts for Easy Parsing\n",
"from unstructured.cleaners.core import (\n",
" clean,\n",
" remove_punctuation,\n",
" clean_non_ascii_chars,\n",
") # Cleaning Bricks\n",
"import re # Create Custom Cleaning Brick\n",
"import nltk # Toolkit for more advanced pre-processing\n",
"from nltk.corpus import stopwords # list of stopwords to remove\n",
"from typing import List # Type Hinting"
]
},
{
Expand Down Expand Up @@ -84,7 +92,7 @@
}
],
"source": [
"nltk.download('stopwords')"
"nltk.download(\"stopwords\")"
]
},
{
Expand All @@ -110,28 +118,29 @@
" Returns:\n",
" paper_texts (list[str]): Return list of narrative texts for each paper\n",
" \"\"\"\n",
" #Get List of Arxiv Papers Matching Our Query\n",
" # Get List of Arxiv Papers Matching Our Query\n",
" arxiv_papers = list(\n",
" arxiv.Search(\n",
" query = query,\n",
" max_results = max_results,\n",
" sort_by = arxiv.SortCriterion.Relevance,\n",
" sort_order = arxiv.SortOrder.Descending\n",
" )\n",
" .results()\n",
" query=query,\n",
" max_results=max_results,\n",
" sort_by=arxiv.SortCriterion.Relevance,\n",
" sort_order=arxiv.SortOrder.Descending,\n",
" ).results()\n",
" )\n",
"\n",
" #Loop Through PDFs, Download and Pre-Process and Then Delete\n",
" # Loop Through PDFs, Download and Pre-Process and Then Delete\n",
" paper_texts = []\n",
" for paper in tqdm(arxiv_papers):\n",
" paper.download_pdf()\n",
" pdf_file = glob.glob('*.pdf')[0]\n",
" elements = partition(pdf_file) #Partition PDF Using Unstructured\n",
" isd = convert_to_dict(elements) #Convert List of Elements to List of Dictionaries\n",
" narrative_texts = [element['text'] for element in isd if element['type'] == 'NarrativeText'] #Only Keep Narrative Text and Combine Into One String\n",
" os.remove(pdf_file) #Delete PDF\n",
" pdf_file = glob.glob(\"*.pdf\")[0]\n",
" elements = partition(pdf_file) # Partition PDF Using Unstructured\n",
" isd = convert_to_dict(elements) # Convert List of Elements to List of Dictionaries\n",
" narrative_texts = [\n",
" element[\"text\"] for element in isd if element[\"type\"] == \"NarrativeText\"\n",
" ] # Only Keep Narrative Text and Combine Into One String\n",
" os.remove(pdf_file) # Delete PDF\n",
" paper_texts += narrative_texts\n",
" return paper_texts\n"
" return paper_texts"
]
},
{
Expand All @@ -155,7 +164,7 @@
}
],
"source": [
"paper_texts = get_arxiv_paper_texts(query='natural language processing', max_results=10)"
"paper_texts = get_arxiv_paper_texts(query=\"natural language processing\", max_results=10)"
]
},
{
Expand All @@ -179,10 +188,11 @@
}
],
"source": [
"#Stopwords to Remove\n",
"stop_words = set(stopwords.words('english'))\n",
"# Stopwords to Remove\n",
"stop_words = set(stopwords.words(\"english\"))\n",
"\n",
"#Function to Apply Whatever Cleaning Brick Functionality to Each Narrative Text Element\n",
"\n",
"# Function to Apply Whatever Cleaning Brick Functionality to Each Narrative Text Element\n",
"def custom_clean_brick(narrative_text: str) -> str:\n",
" \"\"\"Apply Mix of Unstructured Cleaning Bricks With Some Custom Functionality to Pre-Process Narrative Text\n",
"\n",
Expand All @@ -192,18 +202,32 @@
" Returns:\n",
" cleaned_text (str): Text after going through all the cleaning procedures\n",
" \"\"\"\n",
" remove_numbers = lambda text: re.sub(r'\\d+', \"\", text) #lambda function to remove all punctuation\n",
" cleaned_text = remove_numbers(narrative_text) #Apply Custom Lambda\n",
" cleaned_text = clean(cleaned_text, extra_whitespace=True, dashes=True, bullets=True, trailing_punctuation=True, lowercase=True) #Apply Basic Clean Brick With All the Options\n",
" cleaned_text = remove_punctuation(cleaned_text) #Remove all punctuation\n",
" cleaned_text = ' '.join([word for word in cleaned_text.split() if word not in stop_words]) #remove stop words\n",
" remove_numbers = lambda text: re.sub(\n",
" r\"\\d+\", \"\", text\n",
" ) # lambda function to remove all punctuation\n",
" cleaned_text = remove_numbers(narrative_text) # Apply Custom Lambda\n",
" cleaned_text = clean(\n",
" cleaned_text,\n",
" extra_whitespace=True,\n",
" dashes=True,\n",
" bullets=True,\n",
" trailing_punctuation=True,\n",
" lowercase=True,\n",
" ) # Apply Basic Clean Brick With All the Options\n",
" cleaned_text = remove_punctuation(cleaned_text) # Remove all punctuation\n",
" cleaned_text = \" \".join(\n",
" [word for word in cleaned_text.split() if word not in stop_words]\n",
" ) # remove stop words\n",
" return cleaned_text\n",
"\n",
"#Apply Function to Paper Texts\n",
"\n",
"# Apply Function to Paper Texts\n",
"cleaned_paper_texts = [custom_clean_brick(text) for text in paper_texts]\n",
"\n",
"#Count Narratve Texts\n",
"print(\"Number of Narrative Texts to Run Through Topic Modelling: {}\".format(len(cleaned_paper_texts)))"
"# Count Narratve Texts\n",
"print(\n",
" \"Number of Narrative Texts to Run Through Topic Modelling: {}\".format(len(cleaned_paper_texts))\n",
")"
]
},
{
Expand All @@ -219,10 +243,10 @@
"metadata": {},
"outputs": [],
"source": [
"#Choose Which Hugging Face Model You Want to Use\n",
"# Choose Which Hugging Face Model You Want to Use\n",
"sentence_model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
"\n",
"#Initialize Model\n",
"# Initialize Model\n",
"topic_model = BERTopic(embedding_model=sentence_model, top_n_words=10, nr_topics=10, verbose=True)"
]
},
Expand Down Expand Up @@ -264,16 +288,16 @@
}
],
"source": [
"#Fit Topic Model and Transform List of Paper Narrative Texts Into Topic and Probabilities\n",
"# Fit Topic Model and Transform List of Paper Narrative Texts Into Topic and Probabilities\n",
"topic_model.fit(cleaned_paper_texts)\n",
"\n",
"#Store Document-Topic Info\n",
"# Store Document-Topic Info\n",
"doc_topic_info = topic_model.get_document_info(cleaned_paper_texts)\n",
"\n",
"#Store Topic Info\n",
"# Store Topic Info\n",
"topic_info = pd.DataFrame(topic_model.get_topics())\n",
"topic_info = topic_info.applymap(lambda x: x[0])\n",
"topic_info.columns = ['topic_{}'.format(col+1) for col in topic_info.columns]"
"topic_info.columns = [\"topic_{}\".format(col + 1) for col in topic_info.columns]"
]
},
{
Expand Down
Loading

0 comments on commit ce40cdc

Please sign in to comment.