Skip to content

Commit

Permalink
pytorch-lightning/hydra-zen integration (#105)
Browse files Browse the repository at this point in the history
* Updating trainer in line with fabric convention

* Added new folders to gitignore

* Added LightningDataModule to dataloader and updated base LoadingData to be more readable

* created new sampling callback, ema callback is still wip

* fixed line-length in precommit

* split networks.py into the building modules and the main unet

* create hydra-zen configs.py file

* unet lightningmodule created

* separated sampling and validation metrics

* fixed an error where timestep wasn't on correct device during inference

* removing fabric trainer script, may revisit in the future

* resolving a gimmemotif error that occurs during sampling, old method commented out

* added wandb outputs folder to .gitignore

* added main lightning trainer script that uses hydra-zen config

* changing step_ema epoch from 2000 to 100

* adding new encode_data.pkl file for reference

* fix pre-commit-config rebase bug

* removing old training script/updating to new lightning setup
  • Loading branch information
ssenan authored Mar 21, 2023
1 parent 86004a4 commit 3b6c6eb
Show file tree
Hide file tree
Showing 16 changed files with 1,058 additions and 46 deletions.
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,12 @@ dmypy.json

# Adding logs folder
logs/*
!logs/.gitkeep

# Hydra outputs folder
outputs/*

# Checkpoints folder
dnadiffusion/checkpoints/*

# Wandb folder
wandb/*
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "line-length=88"]
args: ["--profile", "black", "line_length=88"]
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
Expand Down
57 changes: 57 additions & 0 deletions dnadiffusion/callbacks/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any

import lightning as L
import torch
from lightning.pytorch.utilities import rank_zero_only

# Check this link
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py


class EMA(L.Callback):
def init(
self,
beta: float = 0.995,
) -> None:
self.beta = beta
self.step = 0
self.step_start_ema = 100

def on_train_batch_end(
self,
trainer: L.Trainer,
pl_module: L.LightningModule,
) -> None:
self.step_ema(pl_module.ema_model, pl_module.model)

def update_model_average(
self,
ma_model: L.LightningModule,
current_model: L.LightningModule,
) -> None:
for current_params, ma_params in zip(
current_model.parameters(), ma_model.parameters()
):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)

def update_average(self, old: torch.Tensor, new: torch.Tensor) -> torch.Tensor:
if old is None:
return new
old = old
return old * self.beta + (1 - self.beta) * new

def step_ema(
self,
ema_model: L.LightningModule,
model: L.LightningModule,
) -> None:
if self.step < self.step_start_ema:
self.reset_parameters(ema_model, model)
self.step += 1
return
self.update_model_average(ema_model, model)
self.step += 1

def reset_parameters(self, ema_model, model):
ema_model.load_state_dict(model.state_dict())
66 changes: 66 additions & 0 deletions dnadiffusion/callbacks/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import lightning as L
import pandas as pd
from lightning.pytorch.utilities import rank_zero_only

from dnadiffusion.metrics.sampling_metrics import (
compare_motif_list,
generate_similarity_using_train,
sampling_to_metric,
)


class Sample(L.Callback):
def __init__(
self,
data_module: L.LightningDataModule,
image_size: int,
num_sampling_to_compare_cells: int,
) -> None:
self.data_module = data_module
self.image_size = image_size
self.number_sampling_to_compare_cells = num_sampling_to_compare_cells

def on_train_start(self, *args, **kwargs) -> None:
self.X_train = self.data_module.X_train
self.train_motifs = self.data_module.train_motifs
self.test_motifs = self.data_module.test_motifs
self.shuffle_motifs = self.data_module.shuffle_motifs
self.cell_types = self.data_module.cell_types
self.numeric_to_tag = self.data_module.numeric_to_tag

@rank_zero_only
def on_train_epoch_end(self, trainer: L.Trainer, L_module: L.LightningModule):
if (trainer.current_epoch + 1) % 15 == 0:
L_module.eval()
additional_variables = {
"model": L_module.model,
"timesteps": L_module.timesteps,
"device": L_module.device,
"betas": L_module.betas,
"sqrt_one_minus_alphas_cumprod": L_module.sqrt_one_minus_alphas_cumprod,
"sqrt_recip_alphas": L_module.sqrt_recip_alphas,
"posterior_variance": L_module.posterior_variance,
"image_size": self.image_size,
}

synt_df = sampling_to_metric(
self.cell_types,
self.numeric_to_tag,
additional_variables,
int(self.number_sampling_to_compare_cells / 10),
)
seq_similarity = generate_similarity_using_train(self.X_train)
train_kl = compare_motif_list(synt_df, self.train_motifs)
test_kl = compare_motif_list(synt_df, self.test_motifs)
shuffle_kl = compare_motif_list(synt_df, self.shuffle_motifs)
L_module.train()

trainer.logger.log_metrics(
{
"train_kl": train_kl,
"test_kl": test_kl,
"shuffle_kl": shuffle_kl,
"seq_similarity": seq_similarity,
},
step=trainer.global_step,
)
101 changes: 101 additions & 0 deletions dnadiffusion/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import lightning as L
import torch
import torchvision.transforms as T
from hydra.core.config_store import ConfigStore
from hydra_zen import MISSING, builds, make_custom_builds_fn
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

from dnadiffusion.callbacks.sampling import Sample
from dnadiffusion.data.dataloader import LoadingDataModule
from dnadiffusion.models.training_modules import UnetDiffusion
from dnadiffusion.models.unet import Unet as UnetBase

# Custom Builds Function
sbuilds = make_custom_builds_fn(populate_full_signature=True)
pbuilds = make_custom_builds_fn(zen_partial=True, populate_full_signature=True)

# Transforms config if we need to add more
# transforms = builds(T.Compose, [builds(T.ToTensor)])
transforms = builds(T.ToTensor)

# Loading data
LoadingData = builds(
LoadingDataModule,
input_csv="./dnadiffusion/data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt",
subset_components=[
"GM12878_ENCLB441ZZZ",
"hESCT0_ENCLB449ZZZ",
"K562_ENCLB843GMH",
"HepG2_ENCLB029COU",
],
load_saved_data=False,
change_component_index=True,
number_of_sequences_to_motif_creation=1000,
transform=transforms,
populate_full_signature=True,
)

# Diffusion Model
Unet = builds(
UnetBase,
dim=200,
init_dim=None,
dim_mults=(1, 2, 4),
channels=1,
resnet_block_groups=4,
learned_sinusoidal_dim=10,
num_classes=10,
output_attention=False,
)

# Optimizers
Adam = pbuilds(torch.optim.Adam, lr=1e-3)

# Lightning Module
UnetConfig = builds(
UnetDiffusion,
model=Unet,
lr=1e-3,
timesteps=50,
beta=0.995,
optimizer=Adam,
)

# Callbacks
sample = builds(
Sample,
data_module=MISSING,
image_size=200,
num_sampling_to_compare_cells=1000,
)

wandb = builds(
WandbLogger,
project="dnadiffusion",
notes="lightning",
)

checkpoint = builds(
ModelCheckpoint,
dirpath="dnadiffusion/checkpoints/",
every_n_epochs=500,
)

# Lightning Trainer
LightningTrainer = builds(
L.Trainer,
accelerator="cuda",
strategy="ddp_find_unused_parameters_true",
num_nodes=1,
devices=8,
max_epochs=10000,
logger=wandb,
callbacks=[checkpoint],
)

# Registering the builds
cs = ConfigStore.instance()

cs.store(group="data", name="LoadingData", node=LoadingData)
cs.store(group="model", name="Unet", node=UnetConfig)
Loading

0 comments on commit 3b6c6eb

Please sign in to comment.