Skip to content

Commit

Permalink
Custom FSDP path added to megatron parallel.
Browse files Browse the repository at this point in the history
Signed-off-by: mingyuanm <[email protected]>
  • Loading branch information
Victor49152 committed Nov 12, 2024
1 parent 3b62be0 commit b8044db
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 23 deletions.
18 changes: 14 additions & 4 deletions nemo/collections/diffusion/flux_controlnet_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from nemo.collections.diffusion.models.flux_controlnet.model import MegatronFluxControlNetModel, FluxControlNetConfig
from nemo.collections.diffusion.utils.flux_pipeline_utils import configs
from nemo.collections.diffusion.utils.mcore_parallel_utils import Utils
from megatron.core.distributed import DistributedDataParallelConfig



def main(args):
Expand All @@ -50,31 +52,39 @@ def main(args):
lr=1.0e-04,
adam_beta1=0.9,
adam_beta2=0.999,
use_distributed_optimizer=False,
use_distributed_optimizer=True,
bf16=True,
)

model_params = configs['dev']
model_params.t5_params['version'] = '/ckpts/text_encoder_2'
model_params.clip_params['version'] = '/ckpts/text_encoder'
model_params.vae_params.ckpt = '/ckpts/ae.safetensors'
model_params.flux_params.num_joint_layers=args.num_joint_layers
model_params.flux_params.num_single_layers=args.num_single_layers
# model_params.flux_params.num_joint_layers=args.num_joint_layers
# model_params.flux_params.num_single_layers=args.num_single_layers

if args.image_precached:
model_params.vae_params = None
if args.text_precached:
model_params.t5_params = None
model_params.clip_params = None

flux_controlnet_config = FluxControlNetConfig(guidance_embed=True,num_joint_layers=1,num_single_layers=1)
flux_controlnet_config = FluxControlNetConfig(guidance_embed=True,num_joint_layers=args.num_joint_layers,num_single_layers=args.num_single_layers)

model = MegatronFluxControlNetModel(model_params, flux_controlnet_config)

ddp = DistributedDataParallelConfig(
use_custom_fsdp=True,
data_parallel_sharding_strategy='MODEL_AND_OPTIMIZER_STATES',
overlap_param_gather=True,
overlap_grad_reduce=True,
)

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
pipeline_dtype=torch.bfloat16,
ddp=ddp
)

# def find_frozen_submodules(model):
Expand Down
13 changes: 12 additions & 1 deletion nemo/collections/diffusion/flux_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from nemo.collections.diffusion.utils.flux_pipeline_utils import configs
from nemo.collections.diffusion.utils.mcore_parallel_utils import Utils

from megatron.core.distributed import DistributedDataParallelConfig


def main(args):

Expand All @@ -50,7 +52,7 @@ def main(args):
lr=1.0e-04,
adam_beta1=0.9,
adam_beta2=0.999,
use_distributed_optimizer=False,
use_distributed_optimizer=True,
bf16=True,
)

Expand All @@ -68,12 +70,21 @@ def main(args):

model = MegatronFluxModel(model_params)

ddp = DistributedDataParallelConfig(
use_custom_fsdp=True,
data_parallel_sharding_strategy='MODEL_AND_OPTIMIZER_STATES',
overlap_param_gather=True,
overlap_grad_reduce=True,
)

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
pipeline_dtype=torch.bfloat16,
ddp=ddp,
)


# def find_frozen_submodules(model):
# frozen_submodules = []
# frozen_submodule_names = []
Expand Down
7 changes: 4 additions & 3 deletions nemo/collections/diffusion/models/flux/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nemo.collections.diffusion.utils.flux_ckpt_converter import flux_transformer_converter
from nemo.collections.diffusion.utils.flux_pipeline_utils import FluxModelParams
from nemo.collections.diffusion.vae.autoencoder import AutoEncoder
from nemo.utils import logging


class FluxInferencePipeline(nn.Module):
Expand All @@ -52,18 +53,18 @@ def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converte
if save_converted_model:
save_path = os.path.join(ckpt_path, 'nemo_flux_transformer.safetensors')
save_safetensors(ckpt, save_path)
print(f'saving converted transformer checkpoint to {save_path}')
logging.info(f'saving converted transformer checkpoint to {save_path}')
else:
ckpt = load_safetensors(ckpt_path)
missing, unexpected = self.transformer.load_state_dict(ckpt, strict=False)
missing = [
k for k in missing if not k.endswith('_extra_state')
] # These keys are mcore specific and should not affect the model performance
if len(missing) > 0:
print(
logging.info(
f"The folloing keys are missing during checkpoint loading, please check the ckpt provided or the image quality may be compromised.\n {missing}"
)
print(f"Found unexepected keys: \n {unexpected}")
logging.info(f"Found unexepected keys: \n {unexpected}")

def encoder_prompt(
self,
Expand Down
20 changes: 20 additions & 0 deletions nemo/collections/diffusion/models/flux_controlnet/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn as nn
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from safetensors.torch import load_file as load_safetensors
from safetensors.torch import save_file as save_safetensors
from PIL.Image import Image

from nemo.collections.diffusion.models.flux.model import Flux, FluxConfig
Expand Down Expand Up @@ -30,6 +32,24 @@ def __init__(

self.vae_scale_factor = 2 ** (len(self.vae.params.ch_mult))

def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converted_model=None):
if do_convert_from_hf:
ckpt = flux_transformer_converter(ckpt_path, self.transformer.config)
if save_converted_model:
save_path = os.path.join(ckpt_path, 'nemo_flux_controlnet_transformer.safetensors')
save_safetensors(ckpt, save_path)
logging.info(f'saving converted transformer checkpoint to {save_path}')
else:
ckpt = load_safetensors(ckpt_path)
missing, unexpected = self.transformer.load_state_dict(ckpt, strict=False)
missing = [
k for k in missing if not k.endswith('_extra_state')
] # These keys are mcore specific and should not affect the model performance
if len(missing) > 0:
logging.info(
f"The folloing keys are missing during checkpoint loading, please check the ckpt provided or the image quality may be compromised.\n {missing}"
)
logging.info(f"Found unexepected keys: \n {unexpected}")

def configure_modules(self):
if isinstance(self.flux_transformer, FluxConfig):
Expand Down
13 changes: 9 additions & 4 deletions nemo/collections/diffusion/utils/flux_ckpt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _import_qkv(transformer_config, q, k, v):
return qkv_weights


key_mapping = {
flux_key_mapping = {
'double_blocks': {
'norm1.linear.weight': 'adaln.adaLN_modulation.1.weight',
'norm1.linear.bias': 'adaln.adaLN_modulation.1.bias',
Expand Down Expand Up @@ -159,14 +159,18 @@ def flux_transformer_converter(ckpt_path=None, transformer_config=None):
temp = key.split('.')
idx, k = temp[1], '.'.join(temp[2:])
num_double_blocks = max(int(idx), num_double_blocks)
new_key = '.'.join(['double_blocks', idx, key_mapping['double_blocks'][k]])
new_key = '.'.join(['double_blocks', idx, flux_key_mapping['double_blocks'][k]])
elif key.startswith('single_transformer_blocks'):
temp = key.split('.')
idx, k = temp[1], '.'.join(temp[2:])
num_single_blocks = max(int(idx), num_single_blocks)
new_key = '.'.join(['single_blocks', idx, key_mapping['single_blocks'][k]])
new_key = '.'.join(['single_blocks', idx, flux_key_mapping['single_blocks'][k]])
elif key.startswith('controlnet_blocks'):
new_key = 'controlnet_double_blocks' + key.split('.')[1:]
elif key.startswith('x_embedder'):
new_key = 'controlnet_x_embedder' + key.split('.')[1:]
else:
new_key = key_mapping[key]
new_key = flux_key_mapping[key]
new_state_dict[new_key] = value

for i in range(num_double_blocks + 1):
Expand Down Expand Up @@ -204,3 +208,4 @@ def flux_transformer_converter(ckpt_path=None, transformer_config=None):
)

return new_state_dict

29 changes: 18 additions & 11 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import torch.distributed
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -557,10 +558,8 @@ def init_ddp(self):
return

from megatron.core import parallel_state

for model_chunk_idx, model_chunk in enumerate(self):
module = model_chunk.module

# Mcore DistributedDataParallel has to be called with grad. Normally this call is redundant, but for
# PEFT with num_sanity_val_steps > 0 this is necessary.
init_ddp_context = nullcontext if all(x.requires_grad for x in module.parameters()) else torch.enable_grad
Expand All @@ -574,15 +573,23 @@ def init_ddp(self):
disable_bucketing = (model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step

with init_ddp_context():
ddp = DDP(
module.config,
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
disable_bucketing=disable_bucketing,
)

if self.ddp_config.use_custom_fsdp:
DDP = FullyShardedDataParallel
ddp = DDP(
module.config,
self.ddp_config,
module,
disable_bucketing=disable_bucketing,
)
else:
ddp = DDP(
module.config,
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
disable_bucketing=disable_bucketing,
)
model_chunk.module = ddp
model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore
Expand Down

0 comments on commit b8044db

Please sign in to comment.