-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_feat_logit_ood.py
130 lines (105 loc) · 4.24 KB
/
test_feat_logit_ood.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Test the OOD detection on both feature and logit space.
Suitable for methods such as VIM
"""
import pdb
import numpy as np
import matplotlib.pyplot as plt
import os
import argparse
import torch
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from ood_scores.get_scorers import get_scorer
from utils.argparser import OODArgs
from dataset.dataloaders import get_loaders_for_ood
from models.get_models import get_model
from utils.metrics import get_measures
from utils.utils import print_measures, load_features, get_feat_dims, _to_np, _to_tensor
np.random.seed(10)
# ==================== Prepare
# args
argparser = OODArgs()
args = argparser.get_args()
print(args)
# scorer
scorer = get_scorer(args.score, args)
# dataloaders
id_train_loader, id_test_loader, ood_loaders = get_loaders_for_ood(args)
# feature dims
net = get_model(arch=args.arch, args=args, load_file=args.load_file, strict=True)
net.eval()
if args.ngpu > 1:
net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
elif args.ngpu > 0:
net.cuda()
# torch.cuda.manual_seed(1)
device = "cuda" if torch.cuda.is_available() else "cpu"
cudnn.benchmark = True # fire on all cylinders
embed_mode = "supcon" in args.arch or "clip" in args.arch
proj_dim, featdims = get_feat_dims(args, net, embed_mode=embed_mode)
last_start_idx = featdims[:-1].sum() #<- For getting the last feature only
# ==================== Load id_train features and prepare the scorer
# Load id-train feature and append
id_train_feat, id_train_logit, id_train_label = load_features(args, id=True, split="train", last_idx=last_start_idx, \
featdim=featdims[-1], feat_space=0)
scorer.append_features(id_train_feat, id_train_logit, id_train_label)
print(f"The ID data number: {scorer.N}; Feature dim: {scorer.D}; Class number: {scorer.num_class}")
if args.score in ["VIM", "Residual"]:
# fetch model classifier params
if "resnet" in args.arch or ("vit" in args.arch and "clip" not in args.arch):
params = [p for p in net.fc.parameters()]
w = _to_np(params[0])
b = _to_np(params[1])
else:
raise NotImplementedError
scorer.fit(w, b)
else:
scorer.fit()
print("Scorer fitting done.")
# ==================== Load id test features and cal ood score
print("Calculating the in distribution OOD scores...")
id_test_feat, id_test_logit, id_test_label = load_features(args, id=True, split="test", last_idx=last_start_idx, \
featdim=featdims[-1], feat_space=args.feat_space)
if args.run_together:
id_scores = scorer.cal_score(id_test_feat, id_test_logit)
else:
N = id_test_feat.shape[0]
id_scores = []
for i in tqdm(range(N)):
id_score_this = scorer.cal_score(id_test_feat[i, :], id_test_logit[i, :])
id_scores.append(id_score_this)
# if i > 100:
# print(f"Max and min score: {max(id_scores)}, {min(id_scores)}")
# exit()
id_scores = np.array(id_scores)
# ==================== Load ood test features, cal ood score, and evaluate ood_performance
aurocs, fprs = [], []
ood_names = [ood_loader.dataset.name for ood_loader in ood_loaders]
for ood_name in ood_names:
print(f"\n\n{ood_name} OOD Detection")
# load feature
ood_feat, ood_logit = load_features(args, id=False, ood_name=ood_name, last_idx=last_start_idx, \
featdim=featdims[-1], feat_space=args.feat_space)
# calculate ood_scores
if args.run_together:
ood_scores = scorer.cal_score(ood_feat, ood_logit)
else:
N = ood_feat.shape[0]
ood_scores = []
for i in tqdm(range(N)):
ood_score_this = scorer.cal_score(ood_feat[i, :], ood_logit[i, :])
ood_scores.append(ood_score_this.squeeze())
ood_scores = np.array(ood_scores)
# evaluate - Use all the OOD samples
auroc_this, _, fpr_this = get_measures(id_scores, ood_scores)
aurocs.append(auroc_this)
fprs.append(fpr_this)
# print
print_measures(auroc=auroc_this, fpr=fpr_this, method_name=args.score)
# the mean performance
print("\n\n")
print("Mean results")
auroc = np.array(aurocs).mean()
fpr = np.array(fprs).mean()
print_measures(auroc=auroc, fpr=fpr, method_name=args.score)