-
Notifications
You must be signed in to change notification settings - Fork 25
/
inference.py
77 lines (63 loc) · 3.13 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import numpy as np
import random
import os
import pandas as pd
from argparse import ArgumentParser
from model.resnet import wide_resnet50_2
from model.de_resnet import de_wide_resnet50_2
from utils.utils_test import evaluation_multi_proj
from utils.utils_train import MultiProjectionLayer
from dataset.dataset import MVTecDataset_test, get_data_transforms
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_args():
parser = ArgumentParser()
parser.add_argument('--checkpoint_folder', default = './your_checkpoint_folder', type=str)
parser.add_argument('--image_size', default = 256, type=int)
parser.add_argument('--classes', nargs="+", default=["carpet", "leather"])
pars = parser.parse_args()
return pars
def inference(_class_, pars):
if not os.path.exists(pars.checkpoint_folder):
os.makedirs(pars.checkpoint_folder)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_transform, gt_transform = get_data_transforms(pars.image_size, pars.image_size)
test_path = '/content/' + _class_
checkpoint_class = pars.checkpoint_folder + '/' + _class_ + '/' + 'wres50_'+_class_+'.pth'
test_data = MVTecDataset_test(root=test_path, transform=data_transform, gt_transform=gt_transform)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
# Use pretrained wide_resnet50 for encoder
encoder, bn = wide_resnet50_2(pretrained=True)
encoder = encoder.to(device)
bn = bn.to(device)
decoder = de_wide_resnet50_2(pretrained=False)
decoder = decoder.to(device)
proj_layer = MultiProjectionLayer(base=64).to(device)
# Load trained weights for projection layer, bn (OCBE), decoder (student)
checkpoint_class = pars.checkpoint_folder + '/' + _class_ + '/' + 'wres50_'+_class_+'.pth'
ckp = torch.load(checkpoint_class, map_location='cpu')
proj_layer.load_state_dict(ckp['proj'])
bn.load_state_dict(ckp['bn'])
decoder.load_state_dict(ckp['decoder'])
auroc_px, auroc_sp, aupro_px = evaluation_multi_proj(encoder, proj_layer, bn, decoder, test_dataloader, device)
print('{}: Sample Auroc: {:.4f}, Pixel Auroc:{:.4f}, Pixel Aupro: {:.4f}'.format(_class_, auroc_sp, auroc_px, aupro_px))
return auroc_sp, auroc_px, aupro_px
if __name__ == '__main__':
pars = get_args()
item_list = [ 'carpet','grid','leather','tile','wood','bottle','cable','capsule','hazelnut','metal_nut','pill','screw','toothbrush','transistor','zipper']
setup_seed(111)
metrics = {'class': [], 'AUROC_sample':[], 'AUROC_pixel': [], 'AUPRO_pixel': []}
for c in pars.classes:
auroc_sp, auroc_px, aupro_px = inference(c, pars)
metrics['class'].append(c)
metrics['AUROC_sample'].append(auroc_sp)
metrics['AUROC_pixel'].append(auroc_px)
metrics['AUPRO_pixel'].append(aupro_px)
metrics_df = pd.DataFrame(metrics)
metrics_df.to_csv(f'{pars.checkpoint_folder}/metrics_checkpoints.csv', index=False)