-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
93 lines (76 loc) · 3.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
from copy import copy
import sys
import importlib
import os
import random
import numpy as np
import torch
import timm
torch.set_printoptions(sci_mode=False)
def set_seed(seed=1234):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-C", "--config", help="config filename", default="cfg_stage2_s2_sp1")
parser.add_argument("-G", "--gpu_id", default="", help="GPU ID")
parser_args, other_args = parser.parse_known_args(sys.argv)
# Use all GPUs unless specified
if parser_args.gpu_id != "":
os.environ['CUDA_VISIBLE_DEVICES'] = str(parser_args.gpu_id)
# Load CFG
cfg = copy(importlib.import_module('src.configs.{}'.format(parser_args.config)).cfg)
cfg.config_file = parser_args.config
print("config ->", cfg.config_file)
# Overwrite other arguments
if len(other_args) > 1:
other_args = {v.split("=")[0].lstrip("-"):v.split("=")[1] for v in other_args[1:]}
for key in other_args:
# Nested config
if "." in key:
keys = key.split(".")
assert len(keys) == 2
print(f'overwriting cfg.{keys[0]}.{keys[1]}: {cfg.__dict__[keys[0]].__dict__[keys[1]]} -> {other_args[key]}')
cfg_type = type(cfg.__dict__[keys[0]].__dict__[keys[1]])
if cfg_type == bool:
cfg.__dict__[keys[0]],__dict__[keys[1]] = other_args[key] == 'True'
elif cfg_type == type(None):
cfg.__dict__[keys[0]].__dict__[keys[1]] = other_args[key]
else:
cfg.__dict__[keys[0]].__dict__[keys[1]] = cfg_type(other_args[key])
print(cfg.__dict__[keys[0]].__dict__[keys[1]])
# Main config
elif key in cfg.__dict__:
print(f'overwriting cfg.{key}: {cfg.__dict__[key]} -> {other_args[key]}')
cfg_type = type(cfg.__dict__[key])
if cfg_type == bool:
cfg.__dict__[key] = other_args[key] == 'True'
elif cfg_type == type(None):
cfg.__dict__[key] = other_args[key]
else:
cfg.__dict__[key] = cfg_type(other_args[key])
print(cfg.__dict__[key])
# Set seed
if cfg.seed < 0:
cfg.seed = np.random.randint(1_000_000)
print("seed", cfg.seed)
set_seed(cfg.seed)
# Quick development run
if cfg.fast_dev_run:
cfg.epochs= 1
cfg.no_wandb= None
return cfg
if __name__ == "__main__":
cfg= parse_args()
if cfg.project == "rsna":
from src.modules.train_stage2 import train
train(cfg)
elif cfg.project == "rsna_localizer":
from src.modules.train_stage1 import train
train(cfg)