-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtrain.py
162 lines (137 loc) · 4.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import click
import yaml
from glob import glob
import torch
import torch.distributed as dist
from utils.util import init_random_seed, set_random_seed
from utils.dist_util import get_dist_info, init_dist
from utils.logging import get_root_logger
import configs.ViTPose_base_coco_256x192 as b_cfg
import configs.ViTPose_large_coco_256x192 as l_cfg
import configs.ViTPose_huge_coco_256x192 as h_cfg
from models.model import ViTPose
from datasets.COCO import COCODataset
from utils.train_valid_fn import train_model
CUR_PATH = osp.dirname(__file__)
@click.command()
@click.option('--config-path', type=click.Path(exists=True), default='config.yaml', required=True, help='train config file path')
@click.option('--model-name', type=str, default='b', required=True, help='[b: ViT-B, l: ViT-L, h: ViT-H]')
def main(config_path, model_name):
cfg = {'b':b_cfg,
'l':l_cfg,
'h':h_cfg}.get(model_name.lower())
# Load config.yaml
with open(config_path, 'r') as f:
cfg_yaml = yaml.load(f, Loader=yaml.SafeLoader)
for k, v in cfg_yaml.items():
if hasattr(cfg, k):
raise ValueError(f"Already exsist {k} in config")
else:
cfg.__setattr__(k, v)
# set cudnn_benchmark
if cfg.cudnn_benchmark:
torch.backends.cudnn.benchmark = True
# Set work directory (session-level)
if not hasattr(cfg, 'work_dir'):
cfg.__setattr__('work_dir', f"{CUR_PATH}/runs/train")
if not osp.exists(cfg.work_dir):
os.makedirs(cfg.work_dir)
session_list = sorted(glob(f"{cfg.work_dir}/*"))
if len(session_list) == 0:
session = 1
else:
session = int(os.path.basename(session_list[-1])) + 1
session_dir = osp.join(cfg.work_dir, str(session).zfill(3))
os.makedirs(session_dir)
cfg.__setattr__('work_dir', session_dir)
if cfg.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info.
if cfg.launcher == 'none':
distributed = False
if len(cfg.gpu_ids) > 1:
warnings.warn(
f"We treat {cfg['gpu_ids']} as gpu-ids, and reset to "
f"{cfg['gpu_ids'][0:1]} as gpu-ids to avoid potential error in "
"non-distribute training time.")
cfg.gpu_ids = cfg.gpu_ids[0:1]
else:
distributed = True
init_dist(cfg.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(session_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log some basic info
logger.info(f'Distributed training: {distributed}')
# set random seeds
seed = init_random_seed(cfg.seed)
logger.info(f"Set random seed to {seed}, "
f"deterministic: {cfg.deterministic}")
set_random_seed(seed, deterministic=cfg.deterministic)
meta['seed'] = seed
# Set model
model = ViTPose(cfg.model)
if cfg.resume_from:
model.load_state_dict(torch.load(cfg.resume_from)['state_dict'])
# Set dataset
datasets_train = COCODataset(
root_path=cfg.data_root,
data_version="train_custom",
is_train=True,
use_gt_bboxes=True,
image_width=192,
image_height=256,
scale=True,
scale_factor=0.35,
flip_prob=0.5,
rotate_prob=0.5,
rotation_factor=45.,
half_body_prob=0.3,
use_different_joints_weight=True,
heatmap_sigma=3,
soft_nms=False
)
datasets_valid = COCODataset(
root_path=cfg.data_root,
data_version="valid_custom",
is_train=False,
use_gt_bboxes=True,
image_width=192,
image_height=256,
scale=False,
scale_factor=0.35,
flip_prob=0.5,
rotate_prob=0.5,
rotation_factor=45.,
half_body_prob=0.3,
use_different_joints_weight=True,
heatmap_sigma=3,
soft_nms=False
)
train_model(
model=model,
datasets_train=datasets_train,
datasets_valid=datasets_valid,
cfg=cfg,
distributed=distributed,
validate=cfg.validate,
timestamp=timestamp,
meta=meta
)
if __name__ == '__main__':
main()