forked from fire717/movenet.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·34 lines (29 loc) · 1.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
@Fire
https://github.com/fire717
"""
from lib import init, Data, MoveNet, Task
from config import cfg
from lib.utils.utils import arg_parser, update_tuner_cfg
import os, json
def main(cfg):
init(cfg)
label_config = f'{os.path.join(cfg["save_dir"], cfg["label"])}/cfg.txt'
if os.path.exists(label_config):
with open(label_config, 'r') as file:
cfg_temp = json.load(file)
update_tuner_cfg(cfg,cfg_temp)
cfg["ckpt"] = f'{cfg["save_dir"]}/{cfg["label"]}/newest.json'
model = MoveNet(num_classes=cfg["num_classes"],
width_mult=cfg["width_mult"],
mode='train')
data = Data(cfg)
train_loader, val_loader = data.getTrainValDataloader()
run_task = Task(cfg, model)
run_task.modelLoad(model_path=cfg["ckpt"])
with open(f'{os.path.join(cfg["save_dir"], cfg["label"])}/cfg.txt', 'w') as convert_file:
convert_file.write(json.dumps(cfg))
run_task.train(train_loader, val_loader)
if __name__ == '__main__':
cfg = arg_parser(cfg)
main(cfg)