This repository has been archived by the owner on Apr 28, 2024. It is now read-only.
forked from microsoft/esvit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_utils.py
188 lines (167 loc) · 5.81 KB
/
inference_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import argparse
import os
import time
import torch
# from torchvision import models as torchvision_models
import utils
from models.vision_transformer import DINOHead
from models import build_model
from config import config
from config import update_config
from datasets import build_dataloader
import cv2
import json
def get_args_parser():
parser = argparse.ArgumentParser("EsViT", add_help=False)
parser.add_argument(
"--cfg",
default="experiments/imagenet/swin/swin_small_patch4_window14_224.yaml",
help="experiment configure file name",
type=str,
)
# Model parameters
parser.add_argument(
"--arch",
default="swin_small",
type=str,
choices=["swin_tiny", "swin_small", "swin_base", "swin_large", "swin"],
help="""Name of architecture to train. For quick experiments with ViTs,
we recommend using deit_tiny or deit_small.""",
)
parser.add_argument(
"--out_dim",
default=65536,
type=int,
help="""Dimensionality of
the DINO head output. For complex and large datasets large values (like 65k) work well.""",
)
parser.add_argument(
"--norm_last_layer",
default=True,
type=utils.bool_flag,
help="""Whether or not to weight normalize the last layer of the DINO head.
Not normalizing leads to better performance but can make the training unstable.
In our experiments, we typically set this paramater to False with deit_small and True with vit_base.""",
)
parser.add_argument(
"--use_bn_in_head",
default=False,
type=utils.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)",
)
parser.add_argument(
"--use_dense_prediction",
default=True,
type=utils.bool_flag,
help="Whether to use dense prediction in projection head (Default: False)",
)
# Training/Optimization parameters
parser.add_argument(
"--batch_size_per_gpu",
default=1,
type=int,
help="Per-GPU batch-size : number of distinct images loaded on one GPU.",
)
# Dataset
parser.add_argument(
"--dataset", default="imagenet1k", type=str, help="Pre-training dataset."
)
parser.add_argument(
"--zip_mode",
type=utils.bool_flag,
default=False,
help="""Whether or not to use zip file.""",
)
parser.add_argument(
"--tsv_mode",
type=utils.bool_flag,
default=False,
help="""Whether or not to use tsv file.""",
)
# Misc
parser.add_argument(
"--data_path",
default="./tmp",
type=str,
help="Please specify path to the test data.",
)
parser.add_argument(
"--pretrained_weights_ckpt",
default="params/checkpoint_best.pth",
type=str,
help="Path to pretrained weights to evaluate.",
)
parser.add_argument("--seed", default=0, type=int, help="Random seed.")
parser.add_argument(
"--num_workers",
default=0,
type=int,
help="Number of data loading workers per GPU.",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
def eval_esvit(args):
print("Evaluating esvit")
utils.fix_random_seeds(args.seed)
print("git:\n {}\n".format(utils.get_sha()))
print(
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
)
# ============ preparing data ... ============
data_loader = build_dataloader(args)
# ============ building student and teacher networks ... ============
if "swin" in args.arch:
update_config(config, args)
student = build_model(config, use_dense_prediction=args.use_dense_prediction)
teacher = build_model(
config, is_teacher=True, use_dense_prediction=args.use_dense_prediction
)
print(args.norm_last_layer)
student.head = DINOHead(
student.num_features,
args.out_dim,
use_bn=args.use_bn_in_head,
norm_last_layer=args.norm_last_layer,
)
teacher.head = DINOHead(teacher.num_features, args.out_dim, args.use_bn_in_head)
if args.use_dense_prediction:
student.head_dense = DINOHead(
student.num_features,
args.out_dim,
use_bn=args.use_bn_in_head,
norm_last_layer=args.norm_last_layer,
)
teacher.head_dense = DINOHead(
teacher.num_features, args.out_dim, args.use_bn_in_head
)
# there is no backpropagation through the teacher, so no need for gradients
for p in teacher.parameters():
p.requires_grad = False
print(f"Student and Teacher are built: they are both {args.arch} network.")
# ============ optionally resume training ... ============
if args.pretrained_weights_ckpt:
utils.restart_from_checkpoint(
os.path.join(args.pretrained_weights_ckpt),
student=student,
teacher=teacher,
)
print(f"Resumed from {args.pretrained_weights_ckpt}")
# imgs = []
outs = []
# labels = []
t0 = time.time_ns()
print(len(data_loader))
for i, (img, label) in enumerate(data_loader):
out = teacher(img)
for i, val in enumerate(out):
print(f"Type @ idx {i}: {type(val)}")
outs.append(out[-1])
tf = time.time_ns()
print(f"Time spend ns: {tf - t0}")
outs = torch.cat(outs, dim=0)
return out, outs