-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
99 lines (90 loc) · 4.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import os
import argparse
from torchvision import transforms
from dataset import *
from engine import *
from models.model import build_model
import PIL
def get_args_parser():
parser = argparse.ArgumentParser('Set parameters for Knowledge Distillation training', add_help=False)
parser.add_argument('--model', default='efficientnet_v2_l', type = str,
help="name of model to use")
parser.add_argument('--lr', default=1e-3, type = float)
parser.add_argument('--device', default = 'cuda:0', type = str)
parser.add_argument('--batch-size', default = 16, type = int)
parser.add_argument('--num-epochs', default = 1000, type = int)
parser.add_argument('--data-root', default = './data', type = str)
parser.add_argument('--step-eval-epoch', default = 10, type = int)
parser.add_argument('--save-dir', default = './weights', type = str)
parser.add_argument('--log-dir', default = './logs', type = str)
parser.add_argument('--resume', default = None, type = str)
return parser
def main(args):
print(args)
CLASS_TO_INDEX = class_to_index(os.path.join(args.data_root, 'train'))
n_classes = len(CLASS_TO_INDEX)
device = torch.device(args.device)
if args.resume is not None:
model = build_model(model_name = args.model, n_classes = n_classes, pretrained = False)
state_dict = torch.load(args.resume)
model.load_state_dict(state_dict)
else:
model = build_model(model_name = args.model, n_classes = n_classes)
#transform = model.get_transform()
transform_train = transforms.Compose([
transforms.Resize(224, interpolation= PIL.Image.BICUBIC),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
transform_test = transforms.Compose([
transforms.Resize(380, interpolation= PIL.Image.BICUBIC),
transforms.CenterCrop(384),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
dataset_train = CustomDataset(os.path.join(args.data_root, 'train'), transform = transform_train, mapping = CLASS_TO_INDEX)
train_dataloader = get_dataloader(dataset_train, batch_size = args.batch_size)
dataset_test = CustomDataset(os.path.join(args.data_root, 'test'), transform = transform_test, mapping = CLASS_TO_INDEX)
test_dataloader = get_dataloader(dataset_test, batch_size = args.batch_size, shuffle = False)
optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
model.to(device)
min_loss = torch.inf
best = None
epochs = args.num_epochs
criterion = torch.nn.CrossEntropyLoss()
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
log_path = os.path.join(args.log_dir, 'log.txt')
with open(log_path, 'w') as f:
for epoch in range(epochs):
# Training
train_loss = train_one_epoch(model,
train_dataloader,
criterion,
optimizer,
device)
scheduler.step()
print('Epoch: {} - Train loss: {:.4f}'.format(epoch, train_loss))
f.write('Epoch: {} - Train loss: {:.4f}\n'.format(epoch, train_loss))
# Evaluation
if epoch > 0 and epoch % args.step_eval_epoch == 0:
eval_loss = eval(model,
test_dataloader,
criterion,
device)
print('Epoch: {} - Eval loss: {:.4f}'.format(epoch, eval_loss))
f.write('Epoch: {} - Eval loss: {:.4f}\n'.format(epoch, eval_loss))
if eval_loss < min_loss:
min_loss = eval_loss
best = model.state_dict()
save_path = os.path.join(args.save_dir, 'best.pth')
torch.save(best, save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser('Training with Knowledge Distillation script', parents=[get_args_parser()])
args = parser.parse_args()
main(args)