Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced DiT Training #11226

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions nemo/collections/diffusion/data/diffusion_energon_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import logging
from typing import Any, Dict, Literal

from megatron.energon import DefaultTaskEncoder, get_train_dataset
from megatron.core import parallel_state
from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset
from pytorch_lightning.utilities.types import EVAL_DATALOADERS

from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule
Expand Down Expand Up @@ -56,6 +57,9 @@ def __init__(
pin_memory: bool = True,
task_encoder: DefaultTaskEncoder = None,
use_train_split_for_val: bool = False,
virtual_epoch_length: int = 1_000_000_000, # a hack to avoid energon end of epoch warning
packing_buffer_size: int | None = None,
max_samples_per_sequence: int | None = None,
) -> None:
"""
Initialize the SimpleMultiModalDataModule.
Expand All @@ -82,6 +86,10 @@ def __init__(
task_encoder=task_encoder,
)
self.use_train_split_for_val = use_train_split_for_val
self.virtual_epoch_length = virtual_epoch_length
self.num_workers_val = 1
self.packing_buffer_size = packing_buffer_size
self.max_samples_per_sequence = max_samples_per_sequence

def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'):
"""
Expand All @@ -106,29 +114,55 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val
batch_size=self.micro_batch_size,
task_encoder=self.task_encoder,
worker_config=worker_config,
max_samples_per_sequence=None,
shuffle_buffer_size=100,
max_samples_per_sequence=self.max_samples_per_sequence,
shuffle_buffer_size=None,
split_part=split,
batch_drop_last=True,
virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning
virtual_epoch_length=self.virtual_epoch_length,
packing_buffer_size=self.packing_buffer_size,
)
return _dataset

def val_dataloader(self) -> EVAL_DATALOADERS:
"""
Configure the validation DataLoader.
Initialize and return the validation DataLoader.
This method configures the DataLoader for validation data.
Parameters:
worker_config: Configuration for the data loader workers.
This method initializes the DataLoader for the validation dataset. It ensures that the parallel state
is initialized correctly for distributed training and returns a configured DataLoader object.
Returns:
DataLoader: The DataLoader for validation data.
EVAL_DATALOADERS: The DataLoader for the validation dataset.
"""
if self.use_train_split_for_val:
return self.train_dataloader()
return super().val_dataloader()
if self.val_dataloader_object:
return self.val_dataloader_object

if not parallel_state.is_initialized():
message = (
"Muiltimodal val data loader parallel state is not initialized "
f"using default worker config with no_workers {self.num_workers}"
)
logging.info(message)

worker_config = WorkerConfig.default_worker_config(self.num_workers_val)
else:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
data_parallel_group = parallel_state.get_data_parallel_group()

logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}")
worker_config = WorkerConfig(
rank=rank,
world_size=world_size,
num_workers=self.num_workers_val,
data_parallel_group=data_parallel_group,
worker_debug_path=None,
worker_log_level=0,
)
val_dataset = self.datasets_provider(worker_config, split='val')
energon_loader = get_savable_loader(val_dataset, worker_config=worker_config)
self.val_dataloader_object = energon_loader
return self.val_dataloader_object

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Expand Down
218 changes: 218 additions & 0 deletions nemo/collections/diffusion/data/diffusion_fake_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader

from nemo.collections.diffusion.models.model import DiTConfig
from nemo.lightning.pytorch.plugins import MegatronDataSampler

from .diffusion_taskencoder import pos_id_3d


class PosEmb3D:
"""Generates and provides 3D positional embeddings for video data."""

def __init__(self, *, max_t=96, max_h=960, max_w=960):
self.max_t = max_t
self.max_h = max_h
self.max_w = max_w
self.generate_pos_id()

def generate_pos_id(self):
"""Generates the positional ID grid based on max_t, max_h, and max_w."""
self.grid = torch.stack(
torch.meshgrid(
torch.arange(self.max_t, device='cpu'),
torch.arange(self.max_h, device='cpu'),
torch.arange(self.max_w, device='cpu'),
),
dim=-1,
)

def get_pos_id_3d(self, *, t, h, w):
"""Retrieves a subset of the positional IDs for the specified dimensions.
Parameters:
t (int): Number of time frames.
h (int): Height dimension.
w (int): Width dimension.
Returns:
torch.Tensor: The positional IDs tensor with shape (t, h, w, 3).
"""
if t > self.max_t or h > self.max_h or w > self.max_w:
self.max_t = max(self.max_t, t)
self.max_h = max(self.max_h, h)
self.max_w = max(self.max_w, w)
self.generate_pos_id()
return self.grid[:t, :h, :w]


class DiTVideoLatentFakeDataset(torch.utils.data.Dataset):
"""A fake dataset for generating synthetic video latent data."""

def __init__(
self,
n_frames,
max_h,
max_w,
patch_size,
in_channels,
crossattn_emb_size,
max_text_seqlen=512,
seq_length=8192,
):
self.max_t = n_frames
self.max_height = max_h
self.max_width = max_w
self.patch_size = patch_size
self.in_channels = in_channels
self.text_dim = crossattn_emb_size
self.text_seqlen = max_text_seqlen
self.seq_length = seq_length

def __len__(self):
"""Returns the total number of samples."""
return 100000000

def __getitem__(self, idx):
"""Generates a single sample of data.
Parameters:
idx (int): Index of the data sample.
Returns:
dict: A dictionary containing video latent data and related information.
"""
t = self.max_t
h = self.max_height
w = self.max_width
p = self.patch_size
c = self.in_channels

video_latent = torch.ones(self.seq_length, c * p**2, dtype=torch.bfloat16) * 0.5
text_embedding = torch.randn(self.text_seqlen, self.text_dim, dtype=torch.bfloat16)
pos_emb = pos_id_3d.get_pos_id_3d(t=t, h=h // p, w=w // p).reshape(-1, 3)

return {
'video': video_latent,
't5_text_embeddings': text_embedding,
'seq_len_q': torch.tensor([video_latent.shape[0]], dtype=torch.int32).squeeze(),
'seq_len_kv': torch.tensor([self.text_seqlen], dtype=torch.int32).squeeze(),
'pos_ids': torch.zeros((self.seq_length, 3), dtype=torch.int32),
'loss_mask': torch.ones(video_latent.shape[0], dtype=torch.bfloat16),
}

def _collate_fn(self, batch):
"""A default implementation of a collation function.
Users should override this method to define custom data loaders.
"""
return torch.utils.data.dataloader.default_collate(batch)

def collate_fn(self, batch):
"""Method that user passes as a functor to DataLoader.
The method optionally performs neural type checking and adds types to the outputs.
Please note, subclasses of Dataset should not implement `input_types`.
Usage:
dataloader = torch.utils.data.DataLoader(
....,
collate_fn=dataset.collate_fn,
....
)
Returns:
Collated batch, with or without types.
"""
return self._collate_fn(batch)


class VideoLatentFakeDataModule(pl.LightningDataModule):
"""A LightningDataModule for generating fake video latent data for training."""

def __init__(
self,
model_config: DiTConfig,
seq_length: int = 2048,
micro_batch_size: int = 1,
global_batch_size: int = 8,
num_workers: int = 1,
pin_memory: bool = True,
task_encoder=None,
use_train_split_for_val: bool = False,
) -> None:
super().__init__()
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.num_workers = num_workers
self.model_config = model_config

self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
)

def setup(self, stage: str = "") -> None:
"""Sets up the dataset for training and validation.
Parameters:
stage (str): Optional stage argument (unused).
"""
self._train_ds = DiTVideoLatentFakeDataset(
n_frames=self.model_config.max_frames,
max_h=self.model_config.max_img_h,
max_w=self.model_config.max_img_w,
patch_size=self.model_config.patch_spatial,
in_channels=self.model_config.in_channels,
crossattn_emb_size=self.model_config.crossattn_emb_size,
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
"""Returns the training DataLoader."""
if not hasattr(self, "_train_ds"):
self.setup()
return self._create_dataloader(self._train_ds)

def val_dataloader(self) -> EVAL_DATALOADERS:
"""Returns the validation DataLoader."""
if not hasattr(self, "_train_ds"):
self.setup()
return self._create_dataloader(self._train_ds)

def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
"""Creates a DataLoader for the given dataset.
Parameters:
dataset (Dataset): The dataset to load.
**kwargs: Additional arguments for DataLoader.
Returns:
DataLoader: The DataLoader instance.
"""
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
collate_fn=dataset.collate_fn,
**kwargs,
)
Loading
Loading