-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
116 lines (101 loc) · 3.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import os
import torch
from omegaconf import OmegaConf
from src.dataset.eyepacs import EyePACS
from src.dataset.utils import compute_label_dims
from src.generative_model.stylegan import StyleGAN2Model
from src.generative_model.trainer import create_trainer
from src.utils.utils import get_labels, load_yaml_config, make_exp_folder
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--train_config",
type=str,
help="name of yaml config file",
default="configs/configs_train/test.yaml",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
config = load_yaml_config(config_filename=args.train_config)
experiment_folder = make_exp_folder(config)
config = OmegaConf.create(config)
labels = get_labels(config)
train_set = EyePACS(
image_root_dir=config.data.image_root_dir,
meta_factorized_path=config.data.meta_factorized_path,
columns_mapping_path=config.data.columns_mapping_path,
splits_dir=config.data.splits_dir,
split="train",
image_size=config.data.image_size,
input_preprocessing=config.data.input_preprocessing,
labels=labels,
onehot_enc=False,
subset=config.data.train_subset,
filter_meta=config.data.filter_meta,
ram=config.data.ram,
)
val_set = EyePACS(
image_root_dir=config.data.image_root_dir,
meta_factorized_path=config.data.meta_factorized_path,
columns_mapping_path=config.data.columns_mapping_path,
splits_dir=config.data.splits_dir,
split="val",
image_size=config.data.image_size,
input_preprocessing=config.data.input_preprocessing,
labels=labels,
onehot_enc=False,
subset=config.data.val_subset,
filter_meta=config.data.filter_meta,
ram=config.data.ram,
)
train_dataloader = torch.utils.data.DataLoader(
train_set,
config.data.batch_size,
shuffle=True,
pin_memory=True,
num_workers=config.data.num_workers,
prefetch_factor=config.data.prefetch_factor,
drop_last=True,
)
val_dataloader = torch.utils.data.DataLoader(
val_set,
config.data.batch_size,
shuffle=False,
num_workers=config.data.num_workers,
prefetch_factor=config.data.prefetch_factor,
drop_last=True,
)
lambda_gp = (
0.0002
* (config.data.image_size**2)
/ (config.data.batch_size * len(config.gpus))
) # heuristic formula from original implementation
cond_dims = compute_label_dims(train_set, config.data.conditional_labels)
class_dims = compute_label_dims(train_set, config.data.classifier_labels)
if len(cond_dims) > 0:
_, counts = torch.unique(train_set._meta[:, 0], return_counts=True)
cond_distribution = torch.distributions.categorical.Categorical(
probs=counts / counts.sum()
)
else:
cond_distribution = None
trainer, checkpoint_callback = create_trainer(config, experiment_folder)
model = StyleGAN2Model(
config,
experiment_folder,
lambda_gp,
cond_dims,
class_dims,
cond_distribution,
)
trainer.fit(
model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
ckpt_path=config.resume,
)
with open(os.path.join(experiment_folder, "best_ckpt.txt"), "w") as text_file:
text_file.write(checkpoint_callback.best_model_path)