-
Notifications
You must be signed in to change notification settings - Fork 1
/
text_rec.py
129 lines (103 loc) · 3.81 KB
/
text_rec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import os.path as osp
import math
import cv2
import onnxruntime
import numpy as np
class TextRecognizer:
def __init__(
self,
model_path: str = "checkpoints/ch_PP-OCRv4_rec_infer.onnx",
use_gpu: bool = False,
word_dict_path: str = "checkpoints/rec_word_dict.txt",
):
"""
初始化
params:
model_path (str): ONNX模型路径
use_gpu (bool): 是否使用GPU推理
word_dict_path (str): 字典文件路径,用于字符映射
"""
so = onnxruntime.SessionOptions()
so.log_severity_level = 3
providers = ['CUDAExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
self.session = onnxruntime.InferenceSession(model_path, so, providers=providers)
self.word_dict_path = word_dict_path
self.rec_image_shape = [3, 48, 320]
self.alphabet = self.load_alphabet()
def load_alphabet(self) -> list:
"""
加载字符字典
return:
list: 字符映射列表
"""
with open(self.word_dict_path, "r", encoding="utf8") as f:
lines = f.readlines()
alphabet = []
for line in lines:
decoded_line = line.strip("\n").strip("\r\n")
alphabet.append(decoded_line)
alphabet.append(' ')
return alphabet
def decode(self, t: np.ndarray, length: int, raw: bool = False) -> str:
"""
解码模型预测的字符索引为实际字符
params:
t (np.ndarray): 模型预测的字符索引
length (int): 有效字符的长度
raw (bool): 是否直接解码所有字符,忽略重复字符
return:
str: 解码后的文本
"""
t = t[:length]
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
def resize_norm_img(self, img: np.ndarray) -> np.ndarray:
"""
调整图像大小并归一化
params:
img (np.ndarray): 输入图像
return:
np.ndarray: 调整大小并归一化后的图像
"""
imgC, imgH, imgW = self.rec_image_shape
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def __call__(self, images: list) -> list:
"""
处理输入图像列表,返回每张图像的识别文本
params:
images (list of np.ndarray): 需要处理的图像列表
return:
list: 每张图像对应的识别文本列表
"""
batch_images = np.array([self.resize_norm_img(im) for im in images])
ort_inputs = {i.name: batch_images for i in self.session.get_inputs()}
ort_outputs = self.session.run(None, ort_inputs)
batch_preds = ort_outputs[0]
texts = []
for pred in batch_preds:
length = pred.shape[0]
pred = pred.reshape(length, -1)
pred = np.argmax(pred, axis=1)
sim_pred = self.decode(pred, length, raw=False)
texts.append(sim_pred)
return texts