-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathyolox.py
249 lines (207 loc) · 8.22 KB
/
yolox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# Copyright (c) Megvii, Inc. and its affiliates.
# Unstructured modified the original source code found at:
# https://github.com/Megvii-BaseDetection/YOLOX/blob/237e943ac64aa32eb32f875faa93ebb18512d41d/yolox/data/data_augment.py
# https://github.com/Megvii-BaseDetection/YOLOX/blob/ac379df3c97d1835ebd319afad0c031c36d03f36/yolox/utils/demo_utils.py
import cv2
import numpy as np
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from PIL import Image as PILImage
from unstructured_inference.constants import ElementType, Source
from unstructured_inference.inference.layoutelement import LayoutElements
from unstructured_inference.models.unstructuredmodel import (
UnstructuredObjectDetectionModel,
)
from unstructured_inference.utils import (
LazyDict,
LazyEvaluateInfo,
download_if_needed_and_get_local_path,
)
YOLOX_LABEL_MAP = {
0: ElementType.CAPTION,
1: ElementType.FOOTNOTE,
2: ElementType.FORMULA,
3: ElementType.LIST_ITEM,
4: ElementType.PAGE_FOOTER,
5: ElementType.PAGE_HEADER,
6: ElementType.PICTURE,
7: ElementType.SECTION_HEADER,
8: ElementType.TABLE,
9: ElementType.TEXT,
10: ElementType.TITLE,
}
MODEL_TYPES = {
"yolox": LazyDict(
model_path=LazyEvaluateInfo(
download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_l0.05.onnx",
),
label_map=YOLOX_LABEL_MAP,
),
"yolox_tiny": LazyDict(
model_path=LazyEvaluateInfo(
download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_tiny.onnx",
),
label_map=YOLOX_LABEL_MAP,
),
"yolox_quantized": LazyDict(
model_path=LazyEvaluateInfo(
download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_l0.05_quantized.onnx",
),
label_map=YOLOX_LABEL_MAP,
),
}
class UnstructuredYoloXModel(UnstructuredObjectDetectionModel):
def predict(self, x: PILImage.Image):
"""Predict using YoloX model."""
super().predict(x)
return self.image_processing(x)
def initialize(self, model_path: str, label_map: dict):
"""Start inference session for YoloX model."""
self.model_path = model_path
available_providers = C.get_available_providers()
ordered_providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
providers = [provider for provider in ordered_providers if provider in available_providers]
self.model = onnxruntime.InferenceSession(
model_path,
providers=providers,
)
self.layout_classes = label_map
def image_processing(
self,
image: PILImage.Image,
) -> LayoutElements:
"""Method runing YoloX for layout detection, returns a PageLayout
parameters
----------
page
Path for image file with the image to process
origin_img
If specified, an Image object for process with YoloX model
page_number
Number asigned to the PageLayout returned
output_directory
Boolean indicating if result will be stored
"""
# The model was trained and exported with this shape
# TODO (benjamin): check other shapes for inference
input_shape = (1024, 768)
origin_img = np.array(image)
img, ratio = preprocess(origin_img, input_shape)
session = self.model
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
output = session.run(None, ort_inputs)
# TODO(benjamin): check for p6
predictions = demo_postprocess(output[0], input_shape, p6=False)[0]
boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]
boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
boxes_xyxy /= ratio
# Note (Benjamin): Distinct models (quantized and original) requires distincts
# levels of thresholds
if "quantized" in self.model_path:
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.0, score_thr=0.07)
else:
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.1, score_thr=0.25)
order = np.argsort(dets[:, 1])
sorted_dets = dets[order]
return LayoutElements(
element_coords=sorted_dets[:, :4].astype(float),
element_probs=sorted_dets[:, 4].astype(float),
element_class_ids=sorted_dets[:, 5].astype(int),
element_class_id_map=self.layout_classes,
source=Source.YOLOX,
)
# Note: preprocess function was named preproc on original source
def preprocess(img, input_size, swap=(2, 0, 1)):
"""Preprocess image data before YoloX inference."""
if len(img.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
else:
padded_img = np.ones(input_size, dtype=np.uint8) * 114
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r
def demo_postprocess(outputs, img_size, p6=False):
"""Postprocessing for YoloX model."""
grids = []
expanded_strides = []
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
return outputs
def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
"""Multiclass NMS implemented in Numpy"""
# TODO(benjamin): check for non-class agnostic
# if class_agnostic:
nms_method = multiclass_nms_class_agnostic
# else:
# nms_method = multiclass_nms_class_aware
return nms_method(boxes, scores, nms_thr, score_thr)
def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy. Class-agnostic version."""
cls_inds = scores.argmax(1)
cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
valid_score_mask = cls_scores > score_thr
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
valid_cls_inds = cls_inds[valid_score_mask]
keep = nms(valid_boxes, valid_scores, nms_thr)
dets = np.concatenate(
[valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]],
1,
)
return dets
def nms(boxes, scores, nms_thr):
"""Single class NMS implemented in Numpy."""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= nms_thr)[0]
order = order[inds + 1]
return keep