-
Notifications
You must be signed in to change notification settings - Fork 64
/
engine_for_pretraining.py
178 lines (157 loc) · 7.39 KB
/
engine_for_pretraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
import sys
from typing import Iterable
import torch
from einops import rearrange
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import utils
def train_one_epoch(model: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
loss_scaler,
max_norm: float = 0,
patch_size: int = 16,
normlize_target: bool = True,
log_writer=None,
lr_scheduler=None,
start_steps=None,
lr_schedule_values=None,
wd_schedule_values=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter(
'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter(
'min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
for step, batch in enumerate(
metric_logger.log_every(data_loader, print_freq, header)):
# assign learning rate & weight decay for each step
it = start_steps + step # global training iteration
if lr_schedule_values is not None or wd_schedule_values is not None:
for i, param_group in enumerate(optimizer.param_groups):
if lr_schedule_values is not None:
param_group["lr"] = lr_schedule_values[it] * param_group[
"lr_scale"]
if wd_schedule_values is not None and param_group[
"weight_decay"] > 0:
param_group["weight_decay"] = wd_schedule_values[it]
# NOTE: When the decoder mask ratio is 0,
# in other words, when decoder masking is not used,
# decode_masked_pos = ~bool_masked_pos
images, bool_masked_pos, decode_masked_pos = batch
images = images.to(device, non_blocking=True)
bool_masked_pos = bool_masked_pos.to(
device, non_blocking=True).flatten(1).to(torch.bool)
decode_masked_pos = decode_masked_pos.to(
device, non_blocking=True).flatten(1).to(torch.bool)
with torch.no_grad():
# calculate the predict label
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :,
None,
None,
None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :,
None, None,
None]
unnorm_images = images * std + mean # in [0, 1]
if normlize_target:
images_squeeze = rearrange(
unnorm_images,
'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c',
p0=2,
p1=patch_size,
p2=patch_size)
images_norm = (images_squeeze - images_squeeze.mean(
dim=-2, keepdim=True)) / (
images_squeeze.var(
dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
images_patch = rearrange(images_norm, 'b n p c -> b n (p c)')
else:
images_patch = rearrange(
unnorm_images,
'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)',
p0=2,
p1=patch_size,
p2=patch_size)
B, N, C = images_patch.shape
labels = images_patch[~decode_masked_pos].reshape(B, -1, C)
if loss_scaler is None:
outputs = model(images, bool_masked_pos, decode_masked_pos)
loss = (outputs - labels)**2
loss = loss.mean(dim=-1)
cal_loss_mask = bool_masked_pos[~decode_masked_pos].reshape(B, -1)
loss = (loss * cal_loss_mask).sum() / cal_loss_mask.sum()
else:
with torch.cuda.amp.autocast():
outputs = model(images, bool_masked_pos, decode_masked_pos)
loss = (outputs - labels)**2
loss = loss.mean(dim=-1)
cal_loss_mask = bool_masked_pos[~decode_masked_pos].reshape(
B, -1)
loss = (loss * cal_loss_mask).sum() / cal_loss_mask.sum()
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(2)
optimizer.zero_grad()
if loss_scaler is None:
loss.backward()
if max_norm is None:
grad_norm = utils.get_grad_norm_(model.parameters())
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm)
optimizer.step()
loss_scale_value = 0
else:
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(
optimizer, 'is_second_order') and optimizer.is_second_order
grad_norm = loss_scaler(
loss,
optimizer,
clip_grad=max_norm,
parameters=model.parameters(),
create_graph=is_second_order)
loss_scale_value = loss_scaler.state_dict()["scale"]
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
metric_logger.update(loss_scale=loss_scale_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group["weight_decay"] > 0:
weight_decay_value = group["weight_decay"]
metric_logger.update(weight_decay=weight_decay_value)
metric_logger.update(grad_norm=grad_norm)
if log_writer is not None:
log_writer.update(loss=loss_value, head="loss")
log_writer.update(loss_scale=loss_scale_value, head="opt")
log_writer.update(lr=max_lr, head="opt")
log_writer.update(min_lr=min_lr, head="opt")
log_writer.update(weight_decay=weight_decay_value, head="opt")
log_writer.update(grad_norm=grad_norm, head="opt")
log_writer.set_step()
if lr_scheduler is not None:
lr_scheduler.step_update(start_steps + step)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}