-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
pytorch-lightning/hydra-zen integration (#105)
* Updating trainer in line with fabric convention * Added new folders to gitignore * Added LightningDataModule to dataloader and updated base LoadingData to be more readable * created new sampling callback, ema callback is still wip * fixed line-length in precommit * split networks.py into the building modules and the main unet * create hydra-zen configs.py file * unet lightningmodule created * separated sampling and validation metrics * fixed an error where timestep wasn't on correct device during inference * removing fabric trainer script, may revisit in the future * resolving a gimmemotif error that occurs during sampling, old method commented out * added wandb outputs folder to .gitignore * added main lightning trainer script that uses hydra-zen config * changing step_ema epoch from 2000 to 100 * adding new encode_data.pkl file for reference * fix pre-commit-config rebase bug * removing old training script/updating to new lightning setup
- Loading branch information
Showing
16 changed files
with
1,058 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import Any | ||
|
||
import lightning as L | ||
import torch | ||
from lightning.pytorch.utilities import rank_zero_only | ||
|
||
# Check this link | ||
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py | ||
|
||
|
||
class EMA(L.Callback): | ||
def init( | ||
self, | ||
beta: float = 0.995, | ||
) -> None: | ||
self.beta = beta | ||
self.step = 0 | ||
self.step_start_ema = 100 | ||
|
||
def on_train_batch_end( | ||
self, | ||
trainer: L.Trainer, | ||
pl_module: L.LightningModule, | ||
) -> None: | ||
self.step_ema(pl_module.ema_model, pl_module.model) | ||
|
||
def update_model_average( | ||
self, | ||
ma_model: L.LightningModule, | ||
current_model: L.LightningModule, | ||
) -> None: | ||
for current_params, ma_params in zip( | ||
current_model.parameters(), ma_model.parameters() | ||
): | ||
old_weight, up_weight = ma_params.data, current_params.data | ||
ma_params.data = self.update_average(old_weight, up_weight) | ||
|
||
def update_average(self, old: torch.Tensor, new: torch.Tensor) -> torch.Tensor: | ||
if old is None: | ||
return new | ||
old = old | ||
return old * self.beta + (1 - self.beta) * new | ||
|
||
def step_ema( | ||
self, | ||
ema_model: L.LightningModule, | ||
model: L.LightningModule, | ||
) -> None: | ||
if self.step < self.step_start_ema: | ||
self.reset_parameters(ema_model, model) | ||
self.step += 1 | ||
return | ||
self.update_model_average(ema_model, model) | ||
self.step += 1 | ||
|
||
def reset_parameters(self, ema_model, model): | ||
ema_model.load_state_dict(model.state_dict()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import lightning as L | ||
import pandas as pd | ||
from lightning.pytorch.utilities import rank_zero_only | ||
|
||
from dnadiffusion.metrics.sampling_metrics import ( | ||
compare_motif_list, | ||
generate_similarity_using_train, | ||
sampling_to_metric, | ||
) | ||
|
||
|
||
class Sample(L.Callback): | ||
def __init__( | ||
self, | ||
data_module: L.LightningDataModule, | ||
image_size: int, | ||
num_sampling_to_compare_cells: int, | ||
) -> None: | ||
self.data_module = data_module | ||
self.image_size = image_size | ||
self.number_sampling_to_compare_cells = num_sampling_to_compare_cells | ||
|
||
def on_train_start(self, *args, **kwargs) -> None: | ||
self.X_train = self.data_module.X_train | ||
self.train_motifs = self.data_module.train_motifs | ||
self.test_motifs = self.data_module.test_motifs | ||
self.shuffle_motifs = self.data_module.shuffle_motifs | ||
self.cell_types = self.data_module.cell_types | ||
self.numeric_to_tag = self.data_module.numeric_to_tag | ||
|
||
@rank_zero_only | ||
def on_train_epoch_end(self, trainer: L.Trainer, L_module: L.LightningModule): | ||
if (trainer.current_epoch + 1) % 15 == 0: | ||
L_module.eval() | ||
additional_variables = { | ||
"model": L_module.model, | ||
"timesteps": L_module.timesteps, | ||
"device": L_module.device, | ||
"betas": L_module.betas, | ||
"sqrt_one_minus_alphas_cumprod": L_module.sqrt_one_minus_alphas_cumprod, | ||
"sqrt_recip_alphas": L_module.sqrt_recip_alphas, | ||
"posterior_variance": L_module.posterior_variance, | ||
"image_size": self.image_size, | ||
} | ||
|
||
synt_df = sampling_to_metric( | ||
self.cell_types, | ||
self.numeric_to_tag, | ||
additional_variables, | ||
int(self.number_sampling_to_compare_cells / 10), | ||
) | ||
seq_similarity = generate_similarity_using_train(self.X_train) | ||
train_kl = compare_motif_list(synt_df, self.train_motifs) | ||
test_kl = compare_motif_list(synt_df, self.test_motifs) | ||
shuffle_kl = compare_motif_list(synt_df, self.shuffle_motifs) | ||
L_module.train() | ||
|
||
trainer.logger.log_metrics( | ||
{ | ||
"train_kl": train_kl, | ||
"test_kl": test_kl, | ||
"shuffle_kl": shuffle_kl, | ||
"seq_similarity": seq_similarity, | ||
}, | ||
step=trainer.global_step, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import lightning as L | ||
import torch | ||
import torchvision.transforms as T | ||
from hydra.core.config_store import ConfigStore | ||
from hydra_zen import MISSING, builds, make_custom_builds_fn | ||
from lightning.pytorch.callbacks import ModelCheckpoint | ||
from lightning.pytorch.loggers import WandbLogger | ||
|
||
from dnadiffusion.callbacks.sampling import Sample | ||
from dnadiffusion.data.dataloader import LoadingDataModule | ||
from dnadiffusion.models.training_modules import UnetDiffusion | ||
from dnadiffusion.models.unet import Unet as UnetBase | ||
|
||
# Custom Builds Function | ||
sbuilds = make_custom_builds_fn(populate_full_signature=True) | ||
pbuilds = make_custom_builds_fn(zen_partial=True, populate_full_signature=True) | ||
|
||
# Transforms config if we need to add more | ||
# transforms = builds(T.Compose, [builds(T.ToTensor)]) | ||
transforms = builds(T.ToTensor) | ||
|
||
# Loading data | ||
LoadingData = builds( | ||
LoadingDataModule, | ||
input_csv="./dnadiffusion/data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt", | ||
subset_components=[ | ||
"GM12878_ENCLB441ZZZ", | ||
"hESCT0_ENCLB449ZZZ", | ||
"K562_ENCLB843GMH", | ||
"HepG2_ENCLB029COU", | ||
], | ||
load_saved_data=False, | ||
change_component_index=True, | ||
number_of_sequences_to_motif_creation=1000, | ||
transform=transforms, | ||
populate_full_signature=True, | ||
) | ||
|
||
# Diffusion Model | ||
Unet = builds( | ||
UnetBase, | ||
dim=200, | ||
init_dim=None, | ||
dim_mults=(1, 2, 4), | ||
channels=1, | ||
resnet_block_groups=4, | ||
learned_sinusoidal_dim=10, | ||
num_classes=10, | ||
output_attention=False, | ||
) | ||
|
||
# Optimizers | ||
Adam = pbuilds(torch.optim.Adam, lr=1e-3) | ||
|
||
# Lightning Module | ||
UnetConfig = builds( | ||
UnetDiffusion, | ||
model=Unet, | ||
lr=1e-3, | ||
timesteps=50, | ||
beta=0.995, | ||
optimizer=Adam, | ||
) | ||
|
||
# Callbacks | ||
sample = builds( | ||
Sample, | ||
data_module=MISSING, | ||
image_size=200, | ||
num_sampling_to_compare_cells=1000, | ||
) | ||
|
||
wandb = builds( | ||
WandbLogger, | ||
project="dnadiffusion", | ||
notes="lightning", | ||
) | ||
|
||
checkpoint = builds( | ||
ModelCheckpoint, | ||
dirpath="dnadiffusion/checkpoints/", | ||
every_n_epochs=500, | ||
) | ||
|
||
# Lightning Trainer | ||
LightningTrainer = builds( | ||
L.Trainer, | ||
accelerator="cuda", | ||
strategy="ddp_find_unused_parameters_true", | ||
num_nodes=1, | ||
devices=8, | ||
max_epochs=10000, | ||
logger=wandb, | ||
callbacks=[checkpoint], | ||
) | ||
|
||
# Registering the builds | ||
cs = ConfigStore.instance() | ||
|
||
cs.store(group="data", name="LoadingData", node=LoadingData) | ||
cs.store(group="model", name="Unet", node=UnetConfig) |
Oops, something went wrong.