-
Notifications
You must be signed in to change notification settings - Fork 6
/
unsup_cls.py
165 lines (149 loc) · 7.08 KB
/
unsup_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import copy
import torch
import utils
import models
import argparse
import numpy as np
from sklearn import metrics
from munkres import Munkres
import torch.backends.cudnn as cudnn
from torchvision import transforms as pth_transforms
from loader import ImageFolder
from models.head import DINOHead
def eval_pred(label, pred, calc_acc=False):
nmi = metrics.normalized_mutual_info_score(label, pred)
ari = metrics.adjusted_rand_score(label, pred)
f = metrics.fowlkes_mallows_score(label, pred)
if not calc_acc:
return nmi, ari, f, -1
pred_adjusted = get_y_preds(label, pred, len(set(label)))
acc = metrics.accuracy_score(pred_adjusted, label)
return nmi, ari, f, acc
def calculate_cost_matrix(C, n_clusters):
cost_matrix = np.zeros((n_clusters, n_clusters))
# cost_matrix[i,j] will be the cost of assigning cluster i to label j
for j in range(n_clusters):
s = np.sum(C[:, j]) # number of examples in cluster i
for i in range(n_clusters):
t = C[i, j]
cost_matrix[j, i] = s - t
return cost_matrix
def get_cluster_labels_from_indices(indices):
n_clusters = len(indices)
cluster_labels = np.zeros(n_clusters)
for i in range(n_clusters):
cluster_labels[i] = indices[i][1]
return cluster_labels
def get_y_preds(y_true, cluster_assignments, n_clusters):
"""
Computes the predicted labels, where label assignments now
correspond to the actual labels in y_true (as estimated by Munkres)
cluster_assignments: array of labels, outputted by kmeans
y_true: true labels
n_clusters: number of clusters in the dataset
returns: a tuple containing the accuracy and confusion matrix,
in that order
"""
confusion_matrix = metrics.confusion_matrix(y_true, cluster_assignments, labels=None)
# compute accuracy based on optimal 1:1 assignment of clusters to labels
cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters)
indices = Munkres().compute(cost_matrix)
kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices)
if np.min(cluster_assignments) != 0:
cluster_assignments = cluster_assignments - np.min(cluster_assignments)
y_pred = kmeans_to_true_cluster_labels[cluster_assignments]
return y_pred
@torch.no_grad()
def main_eval(args):
print("git:\n {}\n".format(utils.get_sha()))
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
cudnn.benchmark = True
# ============ preparing data ... ============
transform = pth_transforms.Compose([
pth_transforms.Resize(256, interpolation=3),
pth_transforms.CenterCrop(224),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
valdir = os.path.join(args.data_path, "val")
dataset_val = ImageFolder(valdir, transform=transform)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
)
print(f"Data loaded with {len(dataset_val)} val imgs.")
# ============ building network ... ============
if 'swin' in args.arch:
args.patch_size = 4
model = models.__dict__[args.arch](
window_size=args.window_size,
patch_size=args.patch_size,
num_classes=0)
embed_dim = model.num_features
else:
model = models.__dict__[args.arch](
patch_size=args.patch_size,
num_classes=0)
embed_dim = model.embed_dim
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
model = utils.MultiCropWrapper(model, DINOHead(
embed_dim,
args.out_dim,
act='gelu'))
model.cuda(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
utils.restart_from_checkpoint(args.pretrained_weights, **{args.checkpoint_key: model})
model.eval()
# ============ evaluate unsup cls ... ============
print("Evaluating unsupervised classification for val set...")
eval_unsup(model, data_loader_val)
@torch.no_grad()
def eval_unsup(model, data_loader):
metric_logger = utils.MetricLogger(delimiter=" ")
real_labels, pred_labels = [], []
for samples, labels in metric_logger.log_every(data_loader, 10):
samples = samples.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
output = model(samples)
pred = utils.concat_all_gather(output.max(dim=1)[1])
pred_labels.append(pred)
real_labels.append(utils.concat_all_gather(labels))
pred_labels = torch.cat(pred_labels).cpu().detach().numpy()
real_labels = torch.cat(real_labels).cpu().detach().numpy()
nmi, ari, fscore, adjacc = eval_pred(real_labels, pred_labels, calc_acc=True)
print("NMI: {}, ARI: {}, F: {}, ACC: {}".format(nmi, ari, fscore, adjacc))
if __name__ == '__main__':
parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet')
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
parser.add_argument('--pretrained_weights', default='', type=str, help="""Path to pretrained
weights to evaluate. Set to `download` to automatically load the pretrained DINO from url.
Otherwise the model is randomly initialized""")
parser.add_argument('--arch', default='vit_small', type=str, choices=['vit_tiny', 'vit_small', 'vit_base',
'vit_large', 'swin_tiny','swin_small', 'swin_base', 'swin_large', 'resnet50', 'resnet101'], help='Architecture.')
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
parser.add_argument('--window_size', default=7, type=int, help='Window size of the model.')
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str,
help='Please specify path to the ImageNet data.')
parser.add_argument("--out_dim", type=int, default=1000, help="out_dim")
args = parser.parse_args()
utils.init_distributed_mode(args)
for checkpoint_key in args.checkpoint_key.split(','):
print("Starting evaluating {}.".format(checkpoint_key))
args_copy = copy.deepcopy(args)
args_copy.checkpoint_key = checkpoint_key
main_eval(args_copy)