-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
64 lines (55 loc) · 2.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import hydra
import hydra.utils as hu
import pytorch_lightning as pl
@hydra.main(config_path="configs/", config_name="beats")
def main(cfg):
if "seed" in cfg:
pl.seed_everything(cfg.seed)
feature_extractor = hu.instantiate(cfg.features)
fe_model = hu.instantiate(cfg.fe_model)
net = hu.instantiate(
cfg.net,
fe_model=fe_model
)
optimizer = hu.instantiate(cfg.optim, params=net.parameters())
lr_scheduler = hu.instantiate(
cfg.lr_scheduler, optimizer) if "lr_scheduler" in cfg else None
criterion = hu.instantiate(cfg.criterion)
datamodule = hu.instantiate(cfg.datamodule)
model = hu.instantiate(
cfg.model,
net=net,
feature_extractor=feature_extractor,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
criterion=criterion,
datamodule=datamodule
)
model_ckpt = hu.instantiate(cfg.model_checkpoint)
logger, callbacks = [], []
profiler = None
if "profiler" in cfg.trainer and cfg.trainer.profiler:
profiler = pl.profiler.AdvancedProfiler(dirpath=cfg.logger.save_dir,
filename=cfg.experiment)
if "logger" in cfg:
logger = hu.instantiate(cfg.logger)
if "callbacks" in cfg:
for _, cb_cfg in cfg.callbacks.items():
callbacks.append(hu.instantiate(cb_cfg))
if "resume" in cfg:
trainer = hu.instantiate(cfg.trainer,
checkpoint_callback=model_ckpt,
callbacks=callbacks,
logger=logger,
resume_from_checkpoint=cfg.resume.ckpt_path,
profiler=profiler)
print("Resuming model checkpoint..")
else:
trainer = hu.instantiate(cfg.trainer,
checkpoint_callback=model_ckpt,
callbacks=callbacks,
logger=logger,
profiler=profiler)
trainer.fit(model=model, datamodule=datamodule)
if __name__ == "__main__":
main()