forked from quangminhdinh/TrafficVLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
108 lines (89 loc) · 2.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import os
import torch
import gc
from torch.utils.data import (
RandomSampler,
SequentialSampler,
DataLoader
)
from args import get_args_parser
from config import (
get_cfg_defaults,
convert_to_dict,
get_sig
)
from utils import fix_seed
from models import get_tokenizer, TrafficVLM
from dataset import (
WTSTrainDataset,
WTSValDataset,
wts_base_collate_fn
)
from solver import (
get_optimizer,
get_solver,
setup_logging
)
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def main(args, cfg):
torch.cuda.empty_cache()
gc.collect()
experiment_dir = os.path.join(cfg.GLOB.EXP_PARENT_DIR, args.experiment)
setup_logging(experiment_dir)
fix_seed(cfg.GLOB.SEED)
device = torch.device(cfg.GLOB.DEVICE)
tokenizer = get_tokenizer(cfg.MODEL.T5_PATH, cfg.DATA.NUM_BINS)
train_set = WTSTrainDataset(
cfg.DATA,
tokenizer,
cfg.MODEL.FEATURE_BRANCHES,
cfg.SOLVER.TRAIN.DENOISING,
cfg.SOLVER.TRAIN.PHASE_NOISE_DENSITY,
cfg.MODEL.USE_LOCAL,
cfg.MODEL.MAX_PHASES
)
val_set = WTSValDataset(cfg.DATA, tokenizer, cfg.MODEL.USE_LOCAL, cfg.MODEL.MAX_PHASES)
train_sampler = RandomSampler(train_set)
val_sampler = SequentialSampler(val_set)
train_loader = DataLoader(train_set,
batch_size=cfg.SOLVER.TRAIN.BATCH_SIZE,
sampler=train_sampler,
collate_fn=wts_base_collate_fn,
num_workers=os.cpu_count()) # type: ignore
val_loader = DataLoader(val_set,
batch_size=cfg.SOLVER.VAL.BATCH_SIZE or cfg.SOLVER.TRAIN.BATCH_SIZE,
sampler=val_sampler,
collate_fn=wts_base_collate_fn,
num_workers=os.cpu_count()) # type: ignore
model = TrafficVLM(
cfg.MODEL, tokenizer, cfg.DATA.NUM_BINS, cfg.DATA.MAX_FEATS, cfg.DATA.SUB_FEATURE is not None
)
model.to(device)
# Set up optimizer
params_for_optimization = list(p for p in model.parameters() if p.requires_grad)
optimizer = get_optimizer(cfg.SOLVER.TRAIN.OPTIMIZER, params_for_optimization)
hparams = convert_to_dict(cfg)
signature = get_sig(hparams)
solver = get_solver(
cfg.SOLVER,
args.experiment,
signature=signature,
local_dir=experiment_dir,
model=model,
train_loader=train_loader,
val_loader=val_loader,
optim=optimizer,
hparams=hparams,
device=device
)
solver.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(parents=[get_args_parser()])
args = parser.parse_args()
# config experiment
cfg = get_cfg_defaults()
cfg.merge_from_file(f"experiments/{args.experiment}.yml")
cfg.freeze()
main(args, cfg)