-
Notifications
You must be signed in to change notification settings - Fork 0
/
efold_training.py
107 lines (93 loc) · 2.92 KB
/
efold_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import sys
sys.path.append(os.path.abspath("."))
from lightning.pytorch.strategies import DDPStrategy
import wandb
from lightning.pytorch.loggers import WandbLogger
import pandas as pd
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor
from efold.core.callbacks import ModelCheckpoint # , WandbTestLogger
from efold.config import device
from efold import DataModule, create_model
import sys
import os
import torch
# use float32 only
torch.set_default_dtype(torch.float32)
sys.path.append(os.path.abspath("."))
# Train loop
if __name__ == "__main__":
USE_WANDB = 1
STRATEGY = "random"
n_gpu = 1
print("Running on device: {}".format(device))
if USE_WANDB:
wandb_logger = WandbLogger(project='family_split', name='eFold_GroupI_test')
# fit loop
batch_size = 1
dm = DataModule(
name=[
# 'RNAStralign_Group_I_intron',
'RNAStralign_5S',
'RNAStralign_telomerase',
'RNAStralign_SRP',
'RNAStralign_tmRNA',
'RNAStralign_RNaseP',
'RNAStralign',
'RNAStralign_16S',
'RNAStralign_tRNA'],
strategy=STRATEGY,
shuffle_train=False if STRATEGY == "ddp" else True,
data_type=["structure"], #
force_download=False,
batch_size=batch_size,
max_len=1000,
min_len=1,
structure_padding_value=0,
train_split=None,
external_valid=["RNAStralign_Group_I_intron",
"RNAStralign_validation"],
)
model = create_model(
model="efold",
ntoken=5,
d_model=64,
c_z=32,
d_cnn=64,
num_blocks=4,
no_recycles=0,
dropout=0,
lr=1e-3,
weight_decay=0,
gamma=0.995,
wandb=USE_WANDB,
)
# import torch
# model.load_state_dict(torch.load('/root/eFold/models_eFold_UFold/eFold_V2_PT10-15+FT_epoch5.pt',
# map_location=torch.device(device)))
if USE_WANDB:
wandb_logger.watch(model, log="all")
trainer = Trainer(
accelerator=device,
devices=n_gpu if STRATEGY == "ddp" else 1,
strategy=DDPStrategy(find_unused_parameters=False) if STRATEGY == "ddp" else 'auto',
# precision="16-mixed",
max_epochs=1000,
log_every_n_steps=1,
accumulate_grad_batches=32,
use_distributed_sampler=STRATEGY != "ddp",
logger=wandb_logger if USE_WANDB else None,
callbacks=[
LearningRateMonitor(logging_interval="epoch"),
ModelCheckpoint(every_n_epoch=1),
]
if USE_WANDB
else [],
enable_checkpointing=False,
)
trainer.fit(model, datamodule=dm)
# trainer.test(model, datamodule=dm)
if USE_WANDB:
wandb.finish()