Skip to content

Commit

Permalink
Flux training with DDP tested on 1 GPU
Browse files Browse the repository at this point in the history
Signed-off-by: mingyuanm <[email protected]>
  • Loading branch information
Victor49152 committed Nov 1, 2024
1 parent 83456df commit e0de704
Show file tree
Hide file tree
Showing 8 changed files with 666 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
self.W = image_W
self.image_key = image_key
self.txt_key = txt_key
self.hint_key = hint_key
self.precached_mode = precached_mode
if precached_mode:
#TODO implement this
Expand Down
176 changes: 176 additions & 0 deletions nemo/collections/diffusion/flux_controlnet_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import torch
import torch.nn as nn
from megatron.core.optimizer import OptimizerConfig
from pytorch_lightning.loggers import WandbLogger
from transformers import AutoProcessor

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.optim import WarmupHoldPolicyScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback

from nemo.collections.diffusion.models.flux_controlnet.model import MegatronFluxControlNetModel, FluxControlNetConfig
from nemo.collections.diffusion.utils.flux_pipeline_utils import configs
from nemo.collections.diffusion.utils.mcore_parallel_utils import Utils


def main(args):

from nemo.collections.diffusion.data.diffusion_mock_datamodule import MockDataModule

data = MockDataModule(
image_h=1024,
image_w=1024,
micro_batch_size=args.mbs,
global_batch_size=args.gbs
)

# Optimizer and scheduler setup
opt_config = OptimizerConfig(
optimizer='adam',
lr=1.0e-04,
adam_beta1=0.9,
adam_beta2=0.999,
use_distributed_optimizer=False,
bf16=True,
)

model_params = configs['dev']
model_params.t5_params['version'] = '/ckpts/text_encoder_2'
model_params.clip_params['version'] = '/ckpts/text_encoder'
model_params.vae_params.ckpt = '/ckpts/ae.safetensors'
model_params.flux_params.num_joint_layers=1
model_params.flux_params.num_single_layers=1

flux_controlnet_config = FluxControlNetConfig(guidance_embed=True,num_joint_layers=1,num_single_layers=1)

model = MegatronFluxControlNetModel(model_params, flux_controlnet_config)

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
pipeline_dtype=torch.bfloat16,
)

# def find_frozen_submodules(model):
# frozen_submodules = []
# frozen_submodule_names = []
# for name, module in model.named_modules():
# if (
# isinstance(module, nn.Module)
# and list(module.parameters())
# and all(not param.requires_grad for param in module.parameters())
# ):
# frozen_submodule_names.append(name)
# frozen_submodules.append(module)
# return frozen_submodule_names, frozen_submodules
#
# frozen_submodule_names, frozen_submodules = find_frozen_submodules(model)
#
# # Training strategy setup
#
# strategy = nl.FSDPStrategy(
# ignored_states = frozen_submodules
# )

# Checkpoint callback setup
checkpoint_callback = nl.ModelCheckpoint(
save_last=True,
monitor="reduced_train_loss",
save_top_k=2,
every_n_train_steps=1000,
dirpath=args.log_dir,
)

# Trainer setup
trainer = nl.Trainer(
num_nodes=args.num_nodes,
devices=args.devices,
max_steps=args.max_steps,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
callbacks=[checkpoint_callback, TimingCallback()],
val_check_interval=1000,
limit_val_batches=0,
log_every_n_steps=1,
num_sanity_val_steps=0,
)

# Logger setup
nemo_logger = nl.NeMoLogger(
explicit_log_dir=args.log_dir,
name=args.name,
wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None,
)

# Auto resume setup
resume = nl.AutoResume(
resume_if_exists=False,
resume_ignore_no_checkpoint=True,
resume_from_directory=args.log_dir,
restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None,
)


sched = WarmupHoldPolicyScheduler(
max_steps=trainer.max_steps,
warmup_steps=1000,
hold_steps=1000000000000,
)
opt = MegatronOptimizerModule(opt_config, sched)



llm.train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
resume=resume,
optim=opt
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")

parser.add_argument(
"--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint"
)
parser.add_argument(
"--log_dir",
type=str,
required=False,
default="./nemo_experiments",
help="Directory for logging and checkpoints",
)
parser.add_argument("--devices", type=int, required=False, default=1)
parser.add_argument("--num_nodes", type=int, required=False, default=1)
parser.add_argument("--max_steps", type=int, required=False, default=5190)
parser.add_argument("--tp_size", type=int, required=False, default=1)
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--name", type=str, required=False, default="neva_pretrain")
parser.add_argument("--wandb_project", type=str, required=False, default=None)
parser.add_argument("--mbs", type=int, required=False, default=1)
parser.add_argument("--gbs", type=int, required=False, default=1)

args = parser.parse_args()
main(args)
3 changes: 3 additions & 0 deletions nemo/collections/diffusion/models/dit/dit_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def forward(
query, key, value = self.get_query_key_value_tensors(hidden_states)
added_query, added_key, added_value = self.get_added_query_key_value_tensors(additional_hidden_states)


query = torch.cat([added_query, query], dim=0)
key = torch.cat([added_key, key], dim=0)
value = torch.cat([added_value, value], dim=0)
Expand Down Expand Up @@ -281,6 +282,8 @@ def forward(
# ==================================
# core attention computation
# ==================================
if query.dtype != key.dtype or value.dtype != query.dtype:
import pdb;pdb.set_trace()
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
Expand Down
7 changes: 4 additions & 3 deletions nemo/collections/diffusion/models/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from nemo.lightning import megatron_parallel as mp
from megatron.core.transformer.enums import ModelType
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction
import numpy as np



Expand Down Expand Up @@ -200,7 +201,7 @@ def forward(
if controlnet_double_block_samples is not None:
interval_control = len(self.double_blocks) / len(controlnet_double_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_single_block_samples[id_block // interval_control]
hidden_states = hidden_states + controlnet_double_block_samples[id_block // interval_control]

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)

Expand All @@ -214,8 +215,8 @@ def forward(
if controlnet_single_block_samples is not None:
interval_control = len(self.double_blocks) / len(controlnet_double_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[encoder_hidden_states.shape[0] :, ...] = (
hidden_states[encoder_hidden_states.shape[0] :, ...]
hidden_states[encoder_hidden_states.shape[0]: , ...] = (
hidden_states[encoder_hidden_states.shape[0]: , ...]
+ controlnet_single_block_samples[id_block // interval_control]
)

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/diffusion/models/flux/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __call__(
device: torch.device = 'cuda',
dtype: torch.dtype = torch.float32,
save_to_disk: bool = True,
offload: bool = True,
offload: bool = False,
):
assert device == 'cuda', 'Transformer blocks in Mcore must run on cuda devices'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.nn as nn
from typing import Tuple



Expand Down
Loading

0 comments on commit e0de704

Please sign in to comment.