-
Notifications
You must be signed in to change notification settings - Fork 0
/
Torch_Custom_CNNs2.2.1.py
executable file
·164 lines (137 loc) · 7.64 KB
/
Torch_Custom_CNNs2.2.1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import argparse
import sys
import yaml
import os
import wandb
import pprint
import pickle
parser = argparse.ArgumentParser('encoder decoder examiner')
parser.add_argument('--model_name', type=str, default='test',
help='save name for model')
parser.add_argument('--project_name', type=str, default='test',
help='Name for wandb project')
parser.add_argument('--run_name', type=str, default=None,
help='Name for wandb run')
parser.add_argument('--sweep_id', type=str, default=None,
help='sweep if for weights and biases')
parser.add_argument('--WANDB_MODE', type=str, default='online',
help='WANDB_MODE for running offline')
parser.add_argument('--entity', type=str, default=None,
help='WANDB entity')
parser.add_argument('--sweep_config', type=str, default=None,
help='.yml sweep configuration file')
parser.add_argument('--model_config', type=str, default=None,
help='.yml model configuration file')
parser.add_argument('--sweep_count', type=int, default=100,
help='Number of models to train in sweep')
parser.add_argument('--root', type=str, default='/local/scratch/jrs596/dat/',
help='location of all data')
parser.add_argument('--data_dir', type=str, default='test',
help='location of all data')
parser.add_argument('--save', type=str, default=None,
help='save "model", "weights" or "both" ?')
parser.add_argument('--weights', type=str, default=None,
help='location of pre-trained weights')
parser.add_argument('--quantise', action='store_true', default=False,
help='Train with Quantization Aware Training?')
parser.add_argument('--ema_beta', type=float, default=None,
help='beta value for exponential moving average of weights')
parser.add_argument('--batch_size', type=int, default=21,
help='Initial batch size')
parser.add_argument('--max_epochs', type=int, default=200,
help='n epochs before early stopping')
parser.add_argument('--min_epochs', type=int, default=10,
help='n epochs before loss is assesed for early stopping')
parser.add_argument('--patience', type=int, default=1,
help='n epochs to run without improvment in loss')
parser.add_argument('--beta', type=float, default=1.00,
help='minimum required per cent improvment in validation loss')
parser.add_argument('--learning_rate', type=float, default=1e-3,
help='Learning rate, Default:1e-5')
parser.add_argument('--l1_lambda', type=float, default=1e-5,
help='l1_lambda for regularization, Default:1e-5')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='Learning rate, Default:1e-5')
parser.add_argument('--eps', type=float, default=1e-6,
help='eps, Default:1e-8')
parser.add_argument('--batchnorm_momentum', type=float, default=1e-1,
help='Batch norm momentum hyperparameter for resnets, Default:1e-1')
parser.add_argument('--input_size', type=int, default=224,
help='image input size')
parser.add_argument('--delta', type=float, default=1.4,
help='delta for dynamic focal loss')
parser.add_argument('--arch', type=str, default='resnet18',
help='Model architecture. resnet18, resnet50, resnext50, resnext101 or convnext_tiny')
parser.add_argument('--cont_train', action='store_true', default=False,
help='Continue training from previous checkpoint?')
parser.add_argument('--remove_batch_norm', action='store_true', default=False,
help='Deactivate all batchnorm layers?')
parser.add_argument('--split_image', action='store_true', default=False,
help='Split image into smaller chunks?')
parser.add_argument('--n_tokens', type=int, default=4,
help='Sqrt of number of tokens to split image into')
parser.add_argument('--criterion', type=str, default='crossentropy',
help='Loss function to use. DFLOSS or crossentropy')
parser.add_argument('--GPU', type=str, default='0',
help='Which GPU device to use')
args = parser.parse_args()
print(args)
sys.path.append(os.path.join(os.getcwd(), 'CocoaReader/utils'))
import toolbox
from training_loop import train_model
from collections import OrderedDict
def train():
toolbox.SetSeeds(42)
wandb.init(project=args.project_name)
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
wandb.save(os.path.join(script_dir, '*'))
data_dir, num_classes, initial_bias, _ = toolbox.setup(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = toolbox.build_model(num_classes=num_classes, arch=args.arch, config=None).to(device)
if args.weights is not None:
print("Loading pretrained weights from: ", args.weights, "\n")
# Load the pretrained weights using pickle
with open(args.weights, 'rb') as f:
pretrained_weights = pickle.load(f)
# If CUDA is not available, map the tensors to CPU
if not torch.cuda.is_available():
pretrained_weights['model'] = {k: v.cpu() if torch.is_tensor(v) else v for k, v in pretrained_weights['model'].items()}
new_state_dict = OrderedDict()
# Loop through the original state dictionary and remove 'module.'
for k, v in pretrained_weights['model'].items():
name = k[7:] # Remove 'module.' prefix
new_state_dict[name] = v
# Update the model weights
model.load_state_dict(new_state_dict)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate,
weight_decay=args.weight_decay, eps=args.eps)
image_datasets = toolbox.build_datasets(data_dir=data_dir, input_size=args.input_size) #If images are pre compressed, use input_size=None, else use input_size=args.input_size
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True, num_workers=6, worker_init_fn=toolbox.worker_init_fn, drop_last=False) for x in ['train', 'val']}
criterion = nn.CrossEntropyLoss()
trained_model, best_f1, best_f1_loss, best_train_f1, run_name, _, _ = train_model(args=args, model=model, optimizer=optimizer, device=device, dataloaders_dict=dataloaders_dict, criterion=criterion, patience=args.patience, initial_bias=initial_bias, input_size=None, batch_size=args.batch_size)
return trained_model, best_f1, best_f1_loss, best_train_f1
if args.sweep_config != None:
with open(args.sweep_config) as file:
config = yaml.load(file, Loader=yaml.FullLoader)
sweep_config = config['sweep_config']
sweep_config['metric'] = config['metric']
sweep_config['parameters'] = config['parameters']
print('Sweep config:')
pprint.pprint(sweep_config)
if args.sweep_id is None:
sweep_id = wandb.sweep(sweep=sweep_config, project=args.project_name, entity=args.entity)
else:
sweep_id = args.sweep_id
print("Sweep ID: ", sweep_id)
print()
wandb.agent(sweep_id,
project=args.project_name,
function=train,
count=args.sweep_count)
else:
train()