forked from GoogleCloudPlatform/document-ai-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdocai_pipeline.py
112 lines (93 loc) · 3.97 KB
/
docai_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Document AI End to End Pipeline"""
from os.path import basename as path_basename
from typing import List, Tuple
from consts import DOCAI_PROCESSOR_LOCATION
from consts import DOCAI_PROJECT_ID
from consts import FIRESTORE_PROJECT_ID
from docai_utils import classify_document_bytes
from docai_utils import extract_document_entities
from docai_utils import process_document_bytes
from docai_utils import select_processor_from_classification
from firestore_utils import save_to_firestore
from google.api_core.exceptions import GoogleAPICallError
def run_docai_pipeline(
local_files: List[Tuple[str, str]], firestore_collection: str
) -> List[str]:
"""
Classify Document Types,
Select Appropriate Parser Processor,
Extract Entities,
Save Entities to Firestore
"""
status_messages: List[str] = []
def progress_update(message: str):
"""
Print progress update to stdout and add to message queue
"""
print(message)
status_messages.append(message)
for file_path, mime_type in local_files:
file_name = path_basename(file_path)
# Read File into Memory
with open(file_path, "rb") as file:
file_content = file.read()
progress_update(f"Processing {file_name}")
document_classification = classify_document_bytes(file_content, mime_type)
progress_update(f"\tClassification: {document_classification}")
# Optional: If you want to ignore unclassified documents
if document_classification == "other":
progress_update(f"\tSkipping file: {file_name}")
continue
# Get Specialized Processor
(
processor_type,
processor_id,
) = select_processor_from_classification(document_classification)
progress_update(f"\tUsing Processor {processor_type}: {processor_id}")
# Run Parser
try:
document_proto = process_document_bytes(
DOCAI_PROJECT_ID,
DOCAI_PROCESSOR_LOCATION,
processor_id,
file_content,
mime_type,
)
except GoogleAPICallError:
print("Skipping file:", file_path)
continue
# Extract Entities from Document
document_entities = extract_document_entities(document_proto)
# Specific Classification
# e.g. w2_2020, 1099int_2020, 1099div_2020
document_entities["classification"] = document_classification
# Processor Type corresponds to a Broad Category
# e.g. Multiple W2 Years correspond to the same processor type
document_entities["broad_classification"] = processor_type.removesuffix(
"_PROCESSOR"
)
document_entities["source_file"] = file_name
document_id = document_entities["broad_classification"]
# Save Document Entities to Firestore
progress_update(f"\tWriting to Firestore Collection {firestore_collection}")
progress_update(f"\tDocument ID: {document_id}")
save_to_firestore(
project_id=FIRESTORE_PROJECT_ID,
collection=firestore_collection,
document_id=document_id,
data=document_entities,
)
return status_messages