Skip to content

Commit

Permalink
feat: Add DocAI Toolbox Batch Entity Extraction Notebook (#623)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtskinner authored Sep 12, 2023
1 parent 39977a4 commit b1dbb4e
Show file tree
Hide file tree
Showing 6 changed files with 394 additions and 65 deletions.
44 changes: 21 additions & 23 deletions document_ai_warehouse/common/src/common/utils/document_ai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import time
from typing import Any, Dict, List, Optional

from common.utils.helper import split_uri_2_bucket_prefix
from common.utils.logging_handler import Logger
from common.utils.storage_utils import read_binary_object
from common.utils.helper import split_uri_2_bucket_prefix
from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import InternalServerError
from google.api_core.exceptions import RetryError
Expand Down Expand Up @@ -54,11 +54,11 @@ def get_processor(self, processor_id: str):
return client.get_processor(request=request)

def process_file_from_gcs(
self,
processor_id: str,
bucket_name: str,
file_path: str,
mime_type: str = "application/pdf",
self,
processor_id: str,
bucket_name: str,
file_path: str,
mime_type: str = "application/pdf",
) -> documentai.Document:
client = self.get_docai_client()
parent = self.get_parent()
Expand All @@ -67,12 +67,8 @@ def process_file_from_gcs(

document_content = read_binary_object(bucket_name, file_path)

document = documentai.RawDocument(
content=document_content, mime_type=mime_type
)
request = documentai.ProcessRequest(
raw_document=document, name=processor_name
)
document = documentai.RawDocument(content=document_content, mime_type=mime_type)
request = documentai.ProcessRequest(raw_document=document, name=processor_name)

response = client.process_document(request)

Expand Down Expand Up @@ -103,11 +99,11 @@ def get_entity_key_value_pairs(docai_document):
return fields

def batch_extraction(
self,
processor_id: str,
input_uris: List[str],
gcs_output_bucket: str,
timeout=600,
self,
processor_id: str,
input_uris: List[str],
gcs_output_bucket: str,
timeout=600,
):
if len(input_uris) == 0:
return []
Expand Down Expand Up @@ -176,7 +172,9 @@ def batch_extraction(
f"batch_extraction - Batch Process Failed: {metadata.state_message}"
)

documents: Dict[str, Any] = {} # Contains per processed document, keys are path to original document
documents: Dict[
str, Any
] = {} # Contains per processed document, keys are path to original document

# One process per Input Document
for process in metadata.individual_process_statuses:
Expand Down Expand Up @@ -258,9 +256,9 @@ def merge_json_files(files):

# Handling Nested labels for CDE processor
def get_key_values_dic(
entity: documentai.Document.Entity,
document_entities: Dict[str, List[Any]],
parent_key: Optional[str] = None,
entity: documentai.Document.Entity,
document_entities: Dict[str, List[Any]],
parent_key: Optional[str] = None,
) -> None:
# Fields detected. For a full list of fields for each processor see
# the processor documentation:
Expand All @@ -272,8 +270,8 @@ def get_key_values_dic(

if normalized_value:
if (
isinstance(normalized_value, dict)
and "booleanValue" in normalized_value.keys()
isinstance(normalized_value, dict)
and "booleanValue" in normalized_value.keys()
):
normalized_value = normalized_value.get("booleanValue")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,10 @@ def set_raw_document_file_type_from_mimetype(

mime_to_dw_mime_enum = {
"application/pdf": document.raw_document_file_type.RAW_DOCUMENT_FILE_TYPE_PDF,
"application/vnd.openxmlformats-officedocument.wordprocessingml.document":
document.raw_document_file_type.RAW_DOCUMENT_FILE_TYPE_DOCX,
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": document.raw_document_file_type.RAW_DOCUMENT_FILE_TYPE_DOCX,
"text/plain": document.raw_document_file_type.RAW_DOCUMENT_FILE_TYPE_TEXT,
"application/vnd.openxmlformats-officedocument.presentationml.presentation":
document.raw_document_file_type.RAW_DOCUMENT_FILE_TYPE_PPTX,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
document.raw_document_file_type.RAW_DOCUMENT_FILE_TYPE_XLSX,
"application/vnd.openxmlformats-officedocument.presentationml.presentation": document.raw_document_file_type.RAW_DOCUMENT_FILE_TYPE_PPTX,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": document.raw_document_file_type.RAW_DOCUMENT_FILE_TYPE_XLSX,
}
if mime_type.lower() in mime_to_dw_mime_enum:
document.raw_document_file_type = mime_to_dw_mime_enum[mime_type.lower()]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import logging
import os

import google.cloud.logging_v2

"""class and methods for logs handling."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def file_exists(bucket_name: str, file_name: str):
return stats


def write_gcs_blob(bucket_name: str, file_name: str, content_as_str: str, content_type: str = "text/plain"):
def write_gcs_blob(
bucket_name: str,
file_name: str,
content_as_str: str,
content_type: str = "text/plain",
):
bucket = storage_client.get_bucket(bucket_name)
gcs_file = bucket.blob(file_name)
gcs_file.upload_from_string(content_as_str, content_type=content_type)
80 changes: 46 additions & 34 deletions document_ai_warehouse/document_ai_warehouse_batch_ingestion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import json
import os
import time
from typing import List, Dict, Any, Set, Tuple, Optional
from typing import Any, Dict, List, Optional, Set, Tuple

from common.utils import helper
from common.utils import storage_utils
from common.utils.docai_warehouse_helper import get_key_value_pairs
from common.utils.docai_warehouse_helper import get_metadata_properties
from common.utils.document_ai_utils import DocumentaiUtils
from common.utils.document_warehouse_utils import DocumentWarehouseUtils
from common.utils.helper import is_date
from common.utils.logging_handler import Logger
from config import API_LOCATION
from config import CALLER_USER
from config import DOCAI_PROJECT_NUMBER
Expand All @@ -13,14 +22,6 @@
from google.api_core.exceptions import NotFound
from google.cloud import contentwarehouse_v1
from google.cloud import storage
from common.utils import helper
from common.utils import storage_utils
from common.utils.docai_warehouse_helper import get_key_value_pairs
from common.utils.docai_warehouse_helper import get_metadata_properties
from common.utils.document_ai_utils import DocumentaiUtils
from common.utils.document_warehouse_utils import DocumentWarehouseUtils
from common.utils.helper import is_date
from common.utils.logging_handler import Logger

dw_utils = DocumentWarehouseUtils(
project_number=DOCAI_WH_PROJECT_NUMBER, api_location=API_LOCATION
Expand All @@ -45,8 +46,10 @@ def get_schema(args: argparse.Namespace):
f"CALLER_USER={CALLER_USER}"
)

assert processor_id, "processor_id is not set as PROCESSOR_ID env variable and " \
"is not provided as an input parameter (-p)"
assert processor_id, (
"processor_id is not set as PROCESSOR_ID env variable and "
"is not provided as an input parameter (-p)"
)
assert GCS_OUTPUT_BUCKET, "GCS_OUTPUT_BUCKET not set"
assert DOCAI_PROJECT_NUMBER, "DOCAI_PROJECT_NUMBER not set"

Expand Down Expand Up @@ -112,18 +115,27 @@ def batch_ingest(args: argparse.Namespace) -> None:
f"CALLER_USER={CALLER_USER}"
)

assert processor_id, "processor_id is not set as PROCESSOR_ID env variable and " \
"is not provided as an input parameter (-p)"
assert processor_id, (
"processor_id is not set as PROCESSOR_ID env variable and "
"is not provided as an input parameter (-p)"
)
assert GCS_OUTPUT_BUCKET, "GCS_OUTPUT_BUCKET not set"
assert DOCAI_PROJECT_NUMBER, "DOCAI_PROJECT_NUMBER not set"
assert DOCAI_WH_PROJECT_NUMBER, "DOCAI_WH_PROJECT_NUMBER not set"

initial_start_time = time.time()

created_folders, files_to_parse, processed_files, processed_dirs, error_files = \
prepare_file_structure(dir_uri, folder_name, overwrite, flatten)
(
created_folders,
files_to_parse,
processed_files,
processed_dirs,
error_files,
) = prepare_file_structure(dir_uri, folder_name, overwrite, flatten)

created_schemas, document_id_list = proces_documents(files_to_parse, schema_id, schema_name, processor_id, options)
created_schemas, document_id_list = proces_documents(
files_to_parse, schema_id, schema_name, processor_id, options
)

process_time = time.time() - initial_start_time
time_elapsed = round(process_time)
Expand All @@ -147,11 +159,12 @@ def batch_ingest(args: argparse.Namespace) -> None:
)


FUNCTION_MAP = {'batch_ingest': batch_ingest,
'get_schema': get_schema,
'upload_schema': upload_schema,
'delete_schema': delete_schema,
}
FUNCTION_MAP = {
"batch_ingest": batch_ingest,
"get_schema": get_schema,
"upload_schema": upload_schema,
"delete_schema": delete_schema,
}


def main():
Expand Down Expand Up @@ -186,19 +199,17 @@ def get_args():
""",
)

args_parser.add_argument('command', choices=FUNCTION_MAP.keys())
args_parser.add_argument("command", choices=FUNCTION_MAP.keys())
args_parser.add_argument(
"-d",
dest="dir_uri",
help="Path to gs directory uri, containing data with PDF documents to be loaded. "
"All original structure of sub-folders will be preserved.",
"All original structure of sub-folders will be preserved.",
)
args_parser.add_argument(
"-s", dest="schema_id", help="Optional existing schema_id."
)
args_parser.add_argument(
"-p", dest="processor_id", help="Processor_ID."
)
args_parser.add_argument("-p", dest="processor_id", help="Processor_ID.")
args_parser.add_argument(
"-sn",
dest="schema_name",
Expand Down Expand Up @@ -235,7 +246,7 @@ def get_args():
"-n",
dest="root_name",
help="Name of the root folder inside DW for batch ingestion."
" When skipped, will use the same name of the folder being loaded from.",
" When skipped, will use the same name of the folder being loaded from.",
)
args_parser.add_argument(
"-sns",
Expand All @@ -255,11 +266,11 @@ def get_args():


def proces_documents(
files_to_parse: Dict[str, Any],
schema_id: str,
schema_name: str,
processor_id: str,
options: bool
files_to_parse: Dict[str, Any],
schema_id: str,
schema_name: str,
processor_id: str,
options: bool,
) -> Tuple[Set[str], List[str]]:
created_schemas: Set[str] = set()
document_id_list: List[str] = []
Expand Down Expand Up @@ -334,7 +345,6 @@ def prepare_file_structure(
overwrite: bool,
flatten: bool,
):

created_folders = []
files_to_parse = {}
processed_files = []
Expand Down Expand Up @@ -541,7 +551,9 @@ def create_folder_schema(schema_path: str) -> str:
return folder_schema_id


def create_folder(folder_schema_id: str, display_name: str, reference_id: str) -> Optional[str]:
def create_folder(
folder_schema_id: str, display_name: str, reference_id: str
) -> Optional[str]:
reference_path = f"referenceId/{reference_id}"
try:
document = dw_utils.get_document(reference_path, CALLER_USER)
Expand Down
Loading

0 comments on commit b1dbb4e

Please sign in to comment.