forked from ServiceNow/LCFCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
109 lines (87 loc) · 3.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import torch
import losses
import numpy as np
import utils as ut
from torchvision import transforms
from datasets import dataset_dict
from models import model_dict
def train(dataset_name, model_name, metric_name, path_history, path_model, path_opt, path_best_model, reset=False):
# SET SEED
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
# Train datasets
transformer = ut.ComposeJoint(
[ut.RandomHorizontalFlipJoint(),
[transforms.ToTensor(), None],
[transforms.Normalize(*ut.mean_std), None],
[None, ut.ToLong() ]
])
train_set = dataset_dict[dataset_name](split="train",
transform_function=transformer)
trainloader = torch.utils.data.DataLoader(train_set, batch_size=1,
num_workers=0,
drop_last=False,
sampler=ut.RandomSampler(train_set))
# Val datasets
transformer = ut.ComposeJoint(
[
[transforms.ToTensor(), None],
[transforms.Normalize(*ut.mean_std), None],
[None, ut.ToLong() ]
])
val_set = dataset_dict[dataset_name](split="val",
transform_function=transformer)
test_set = dataset_dict[dataset_name](split="test",
transform_function=transformer)
# Model
model = model_dict[model_name](train_set.n_classes).cuda()
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-5, weight_decay=0.0005)
# Train
if os.path.exists(path_history) and not reset:
history = ut.load_json(path_history)
model.load_state_dict(torch.load(path_model))
opt.load_state_dict(torch.load(path_opt))
s_epoch = history["train"][-1]["epoch"]
print("Resuming epoch...{}".format(s_epoch))
else:
history = {"train":[], "val":[], "test":[],
"model_name":model_name,
"dataset_name":dataset_name,
"path_model":path_model,
"path_opt":path_opt,
"path_best_model":path_best_model,
"best_val_epoch":-1, "best_val_mae":np.inf}
s_epoch = 0
print("Starting from scratch...")
for epoch in range(s_epoch + 1, 1000):
train_dict = ut.fit(model, trainloader, opt,
loss_function=losses.lc_loss,
epoch=epoch)
# Update history
history["trained_images"] = list(model.trained_images)
history["train"] += [train_dict]
# Save model, opt and history
torch.save(model.state_dict(), path_model)
torch.save(opt.state_dict(), path_opt)
ut.save_json(path_history, history)
# %%%%%%%%%%% 2. VALIDATION PHASE %%%%%%%%%%%%"
with torch.no_grad():
val_dict = ut.val(model=model, dataset=val_set, epoch=epoch,
metric_name=metric_name)
# Update history
history["val"] += [val_dict]
# Lower is better
if val_dict[metric_name] <= history["best_val_mae"]:
history["best_val_epoch"] = epoch
history["best_val_mae"] = val_dict[metric_name]
torch.save(model.state_dict(), path_best_model)
# Test Model
if not (dataset_name == "penguins" and epoch < 50):
testDict = ut.val(model=model,
dataset=test_set,
epoch=epoch, metric_name=metric_name)
history["test"] += [testDict]
ut.save_json(path_history, history)