-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
119 lines (97 loc) · 3.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import pandas as pd
import numpy as np
import os
from functools import partial
from datetime import datetime
import torch
from utils import generate_csv
import config
from trainer import feature_train
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
import random
import torch.backends.cudnn as cudnn
os.environ['WANDB_API_KEY'] = config.wandb_api_key
def seed_everything(seed=4242):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
def train_worker(train_df, test_df):
date = datetime.now().isoformat().replace(':', '-')
model_dir = os.path.join(config.model_dir, date)
os.makedirs(model_dir)
# train one model for 18 classes
if config.merge_feature:
feature_train(train_df, test_df, config.merge_feature_name, config.model_name, model_dir)
# train 3 features with 3 models.
else:
for feature in config.features:
feature_train(train_df, test_df, feature, config.model_name, model_dir)
def main():
seed_everything()
train_df = pd.read_csv(config.with_system_path_csv)
test_df = pd.read_csv(config.test_csv)
if config.ray_tune:
ray_config = {
"batch_size": tune.choice([2, 4, 8, 16]),
"loss": tune.choice(config.loss),
}
# set scheduler
scheduler = ASHAScheduler(
metric="f1_score", # statics for selecting model
mode="max", # selecting method, what is good metric
max_t=config.NUM_EPOCH, # ray tune can't try more than {max_t} times
)
# set reporter
reporter = CLIReporter(
# parameter_columns=["l1", "l2", "lr", "batch_size"],
metric_columns=[
"training_iteration",
"loss",
"accuracy",
"f1_scroe",
]
)
# run ray
result = tune.run(
partial(train_worker, train_df=train_df, test_df=test_df),
resources_per_trial={"cpu": 4, "gpu": 1},
config=ray_config,
num_samples=10,
scheduler=scheduler,
progress_reporter=reporter,
)
best_trial = result.get_best_trial("loss", "min", "last")
else:
train_worker(train_df, test_df)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-g-path",
dest="generate_path",
action="store_true",
default=False,
required=False,
help="Generate csv file with system path",
)
parser.add_argument(
"-split-train",
dest="split_train",
action="store_true",
default=True,
required=False,
help="Train with split features",
)
args = parser.parse_args()
# Generate csv file for training
if args.generate_path:
generate_csv(config.train_csv, config.train_dir, config.with_system_path_csv)
elif args.split_train:
main()
print('End Train!')