-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_inference.py
146 lines (126 loc) · 5.31 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from inference.classification import classifier_multi_models, classifier
from inference.detection import run_detection
# from inference.ocr import run_ocr
from inference.utils import crop_bbox_images
from typing import Dict
import yaml
import torch
import json
import pandas as pd
# This is the function which maps from text to vaipe's label.
def find_vaipe_label(label_drugname: Dict, text):
return label_drugname[text]
# We should pass the ocr_output_dict to this function to have the mapping.
def text_to_vaipe_label(label_drugname: Dict, ocr):
new_ocr = {}
for key, value in ocr.items():
labels = []
image_name = key.split('/')[-1]
for text in value:
vaipe_label = find_vaipe_label(label_drugname, text)
for t in vaipe_label:
labels.append(t)
new_ocr[image_name] = labels
return new_ocr
# This is the function which changes form to be easy to process.
def change_form(pill_pres_map):
pres_pill = {}
for dic in pill_pres_map:
pres_name = dic['pres'].split('.')[0] + '.png'
pill_names = []
for pill in dic['pill']:
pill_name = pill.split('.')[0] + '.jpg'
pill_names.append(pill_name)
pres_pill[pres_name] = pill_names
return pres_pill
def map_to_final_result(od_results, ocr_result, pill_pres_map):
# od_results is result of step 1, ocr_result is result of step 2
fin_res = {}
pres_pill = change_form(pill_pres_map)
for key, value in ocr_result.items():
for pill in pres_pill[key]:
labels = []
if pill not in od_results:
print(f'Not found annotations on {pill}')
continue
for label in od_results[pill]:
if label['class_id'] not in value:
class_id = 107
else:
class_id = label['class_id']
labels.append({'class_id': class_id, 'x_min': label['x_min'], 'y_min': label['y_min'], 'x_max': label['x_max'], 'y_max': label['y_max'], 'confidence_score': label['confidence_score']})
fin_res[pill] = labels
return fin_res
# ==================== MAIN =======================
if __name__ == '__main__':
# Load config file
cfg = yaml.safe_load(open('configs/config_inference.yaml'))
device = torch.device(cfg['device'])
# ocr -> run detection -> crop bbox images -> run classification
# Uncommend below line to run OCR (if neccesary)
# run_ocr(cfg['pres_image_dir'], output_dir=cfg['ocr']['output'])
detection_results = run_detection(cfg['pill_image_dir'], cfg['detection'])
crop_bbox_images(detection_results, cfg['crop'])
if cfg['multi_models']:
classifier_multi_models(cfg['classifier_multi_models'], device)
else:
classifier(cfg['classifier'], device)
# generate submission
print('Generating submit file ...', end = ' ')
# load all neccessary files
with open(cfg['ocr']['output'], 'r') as f:
ocr_output_dict = json.load(f)
with open(cfg['label_drugnames_path'], 'r') as f:
label_drugname = json.load(f)
with open(cfg['pill_pres_map_path'], 'r') as f:
pill_pres_map = json.load(f)
with open(cfg['crop']['crop_detection_map'], 'r') as f:
crop_detection_map = json.load(f)
classifier_df = pd.read_csv(cfg['classifier']['output'])
ocr_result = text_to_vaipe_label(label_drugname, ocr_output_dict)
od_results = {}
for i in range(len(classifier_df)):
image_id = classifier_df['image_id'][i]
prediction = classifier_df['prediction'][i]
confidence = classifier_df['confidence'][i]
if image_id not in crop_detection_map:
print('Not found any annotations in', image_id)
continue
annotation = crop_detection_map[image_id]
assert prediction < 107
item = {
'x_min': annotation['x_min'],
'y_min': annotation['y_min'],
'x_max': annotation['x_max'],
'y_max': annotation['y_max'],
'class_id': prediction if confidence > cfg['classifier']['threshold'] else 107,
'confidence_score': confidence
}
if annotation['image_id'] not in od_results:
od_results[annotation['image_id']] = []
# print(annotation['image_id'])
od_results[annotation['image_id']].append(item)
fin_res = map_to_final_result(od_results, ocr_result, pill_pres_map)
class_id = []
x_min, x_max, y_min, y_max = [], [], [], []
confidence_score = []
image_name = []
for key, value in fin_res.items():
for dic in value:
image_name.append(key)
class_id.append(int(dic['class_id']))
confidence_score.append(dic['confidence_score'])
x_min.append(dic['x_min'])
y_min.append(dic['y_min'])
x_max.append(dic['x_max'])
y_max.append(dic['y_max'])
results_1 = {
'image_name': image_name,
'class_id': class_id,
'confidence_score': confidence_score,
'x_min': x_min, 'y_min': y_min, 'x_max': x_max, 'y_max': y_max
}
df = pd.DataFrame(data = results_1)
df.to_csv(cfg['submit_file'], index = False)
print('Successfully!')
print('Submission has been saved at ', cfg['submit_file'])