diff --git a/examples/multimodal/mimo/caption_generation_example.py b/examples/multimodal/mimo/caption_generation_example.py index a4fd0c2902f8..fd3f5a96bac6 100644 --- a/examples/multimodal/mimo/caption_generation_example.py +++ b/examples/multimodal/mimo/caption_generation_example.py @@ -1,7 +1,8 @@ import argparse +import logging import os import sys -import logging + import torch from megatron.core.optimizer import OptimizerConfig from transformers import AutoProcessor @@ -12,11 +13,13 @@ from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig from nemo.collections.multimodal.mimo.data.captioning import MimoCaptioningTaskEncoder +from nemo.collections.multimodal.mimo.data.mock import MockDataModule from nemo.collections.multimodal.mimo.model.base import BaseMimoConfig, BaseMimoModel, CustomMimoConfig from nemo.lightning.pytorch.optim import CosineAnnealingScheduler from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule from nemo.utils.exp_manager import TimingCallback + def main(args): # Global and micro batch sizes gbs = 1 @@ -31,7 +34,7 @@ def main(args): max_steps=1000, accelerator="gpu", strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="16-mixed"), + plugins=nl.MegatronMixedPrecision(precision="32"), val_check_interval=1000, limit_val_batches=50, ) @@ -61,9 +64,13 @@ def main(args): task_encoder=task_encoder, ) + # data = MockDataModule( + # tokenizer=tokenizer, vocab_size=tokenizer.vocab_size, micro_batch_size=mbs, global_batch_size=gbs + # ) + train_loader = data.train_dataloader() one_batch = next(iter(train_loader)) - + fabric = trainer.to_fabric() custom_config = CustomMimoConfig( @@ -74,22 +81,20 @@ def main(args): # base_config = BaseMimoConfig(vocab_size = tokenizer.vocab_size) model = BaseMimoModel(config=custom_config, tokenizer=tokenizer) model = fabric.load_model(args.local_model_path, model) - + model = model.module.cuda() model.eval() - images = one_batch["images"].cuda() input_ids = one_batch["tokens"].cuda() position_ids = one_batch["position_ids"].cuda() input_text = one_batch['input_text'] - - + output_images = one_batch['output_images'].cuda() + all_hidden_states = [] - - - input_ids = input_ids[:,:-7] - position_ids = position_ids[:,:-7] + + input_ids = input_ids[:, :-7] + position_ids = position_ids[:, :-7] # if torch.distributed.get_rank() == 0: #or other ranks # breakpoint() # torch.distributed.barrier() @@ -98,11 +103,12 @@ def main(args): with torch.no_grad(): output_dict = model( - input_ids = input_ids, + input_ids=input_ids, images=images, input_text=input_text, position_ids=position_ids, attention_mask=None, + output_images=output_images, # labels = labels, # loss_mask = loss_mask # num_media_tiles=num_media_tiles, @@ -123,40 +129,43 @@ def main(args): .expand_as(input_ids) ) # if torch.distributed.get_rank() == 0: #or other ranks - # breakpoint() + breakpoint() # torch.distributed.barrier() - all_hidden_states.append(hiden_states[-1,:,:]) - + all_hidden_states.append(hiden_states[-1, :, :]) + # If the generated token is the end of sequence token, stop generating # if next_token_ids.item() == tokenizer.eos_token_id: # break # if torch.distributed.get_rank() == 0: #or other ranks # breakpoint() # torch.distributed.barrier() - hidden_states_concat = torch.cat(all_hidden_states, dim = 0).unsqueeze(0) - vis_proj_out = model.module.module.module.vision_output_projection_module(hidden_states_concat) - actual_image_caption_embeddings = model.module.module.module.get_image_caption_embeddings(one_batch['input_text']) - mse_loss = torch.nn.functional.mse_loss(actual_image_caption_embeddings.to(vis_proj_out.device, dtype = vis_proj_out.dtype), vis_proj_out) - - + + hidden_states_concat = torch.cat(all_hidden_states, dim=0).unsqueeze(0) + # breakpoint() + vis_proj_out = model.module.module.vision_output_projection_module(hidden_states_concat) + actual_image_caption_embeddings = model.module.module.get_image_caption_embeddings(one_batch['input_text']) + mse_loss = torch.nn.functional.mse_loss( + actual_image_caption_embeddings.to(vis_proj_out.device, dtype=vis_proj_out.dtype), vis_proj_out + ) + device = vis_proj_out.device - image_decode_device = model.module.module.module.image_decoder.to(device) - gen_image =image_decode_device(prompt_embeds=actual_image_caption_embeddings.to(device)).images[0] + + image_decode_device = model.module.module.image_decoder.to(device) + gen_image = image_decode_device(prompt_embeds=actual_image_caption_embeddings.to(device)).images[0] gen_image.save('debug_image_gt.png') - + gen_image = image_decode_device(prompt_embeds=vis_proj_out).images[0] gen_image.save('debug_image_generated.png') - + logging.info(f"MSE loss for embeddings {mse_loss}") + breakpoint() generated_ids[generated_ids == -200] = 0 generated_texts = tokenizer.tokenizer.batch_decode(generated_ids, skip_special_tokens=False) logging.info("======== GENERATED TEXT OUTPUT ========") logging.info(f"{generated_texts}") logging.info("=======================================") - # Optimizer and scheduler setup - if __name__ == "__main__": @@ -170,4 +179,4 @@ def main(args): ) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/examples/multimodal/mimo/train_captioning_data.py b/examples/multimodal/mimo/train_captioning_data.py index 5c15cd3e488e..987158cf93ed 100644 --- a/examples/multimodal/mimo/train_captioning_data.py +++ b/examples/multimodal/mimo/train_captioning_data.py @@ -1,4 +1,6 @@ import argparse +import faulthandler +import logging import os import sys @@ -17,11 +19,16 @@ from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule from nemo.utils.exp_manager import TimingCallback +logging.basicConfig(level=logging.DEBUG) +faulthandler.enable() +# Optionally, redirect output to a specific file +# faulthandler.enable(file=open("faulthandler.log", "w")) + def main(args): # Global and micro batch sizes - gbs = 16 - mbs = 8 + gbs = 128 + mbs = 32 seq_length = 256 data_path = '/lustre/fsw/coreai_dlalgo_genai/ykarnati/datasets/cc3m-wds' tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf") @@ -70,9 +77,9 @@ def main(args): max_steps=3500, accelerator="gpu", strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="16-mixed"), + plugins=nl.MegatronMixedPrecision(precision="32"), callbacks=[checkpoint_callback, TimingCallback()], - val_check_interval=50, + val_check_interval=30, limit_val_batches=gbs, log_every_n_steps=1, num_sanity_val_steps=0, @@ -137,7 +144,7 @@ def main(args): "--log_dir", type=str, required=False, default="./", help="Directory for logging and checkpoints" ) parser.add_argument("--devices", type=int, required=False, default=8) - parser.add_argument("--tp_size", type=int, required=False, default=4) + parser.add_argument("--tp_size", type=int, required=False, default=2) parser.add_argument("--pp_size", type=int, required=False, default=1) parser.add_argument("--name", type=str, required=False, default="mimo_first_light") parser.add_argument("--wandb_project", type=str, required=False, default=None) diff --git a/examples/multimodal/mimo/train_example.py b/examples/multimodal/mimo/train_example.py index a13a533ac2ae..e2f89c54ff09 100644 --- a/examples/multimodal/mimo/train_example.py +++ b/examples/multimodal/mimo/train_example.py @@ -3,6 +3,7 @@ import sys import torch +import wandb from megatron.core.optimizer import OptimizerConfig from transformers import AutoProcessor @@ -13,13 +14,16 @@ from nemo.collections.multimodal.mimo.model.base import BaseMimoConfig, BaseMimoModel, CustomMimoConfig from nemo.lightning.pytorch.optim import CosineAnnealingScheduler from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.lightning.pytorch.strategies.utils import RestoreConfig from nemo.utils.exp_manager import TimingCallback def main(args): + + # wandb.init(project=args.wandb_project, name=args.name) # Global and micro batch sizes - gbs = 8 - mbs = 2 + gbs = 4 + mbs = 1 seq_length = 256 tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf") @@ -44,7 +48,7 @@ def main(args): save_last=True, monitor="reduced_train_loss", save_top_k=2, - every_n_train_steps=500, + every_n_train_steps=200, dirpath=args.log_dir, ) @@ -53,7 +57,7 @@ def main(args): max_steps=10000, accelerator="gpu", strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="16-mixed"), + plugins=nl.MegatronMixedPrecision(precision="32"), callbacks=[checkpoint_callback, TimingCallback()], val_check_interval=50, limit_val_batches=gbs, @@ -65,6 +69,7 @@ def main(args): vocab_size=tokenizer.vocab_size, image_special_token_indices=image_special_token_indices, image_special_tokens=image_special_tokens, + freeze_language_model=False, ) # base_config = BaseMimoConfig(vocab_size = tokenizer.vocab_size) model = BaseMimoModel(config=custom_config, tokenizer=tokenizer) @@ -87,14 +92,15 @@ def main(args): resume_if_exists=True, resume_ignore_no_checkpoint=True, resume_from_directory=args.log_dir, - restore_config=None, + # restore_config=None, + restore_config=RestoreConfig(path=args.restore_path) if args.restore_path else None, ) resume.setup(trainer, model) # Optimizer and scheduler setup opt_config = OptimizerConfig( optimizer='adam', - lr=0.001, + lr=2.0e-4, adam_beta1=0.9, adam_beta2=0.95, use_distributed_optimizer=False, @@ -104,7 +110,7 @@ def main(args): max_steps=trainer.max_steps, warmup_steps=70, constant_steps=0, - min_lr=2.0e-05, + min_lr=2.0e-4, ) opt = MegatronOptimizerModule(opt_config, sched) opt.connect(model) @@ -122,8 +128,9 @@ def main(args): parser.add_argument("--devices", type=int, required=False, default=8) parser.add_argument("--tp_size", type=int, required=False, default=2) parser.add_argument("--pp_size", type=int, required=False, default=1) - parser.add_argument("--name", type=str, required=False, default="mimo_first_light") + parser.add_argument("--name", type=str, required=False, default="mimo_decoder_align") parser.add_argument("--wandb_project", type=str, required=False, default=None) + parser.add_argument("--restore_path", type=str, required=False, default=None) args = parser.parse_args() main(args) diff --git a/nemo/collections/multimodal/mimo/data/captioning.py b/nemo/collections/multimodal/mimo/data/captioning.py index 952740494232..36d412980d81 100644 --- a/nemo/collections/multimodal/mimo/data/captioning.py +++ b/nemo/collections/multimodal/mimo/data/captioning.py @@ -103,7 +103,7 @@ def encode(self, input_sample: CaptioningSample, output_sample: MimoCaptioningSa output_sample.__key__ = input_sample.__key__ output_sample.__restore_key__ = input_sample.__restore_key__ - output_sample.input_image = torch.zeros((3, 336, 336), dtype=torch.float16) + output_sample.input_image = torch.zeros((3, 336, 336)) output_sample.tokens = tokens output_sample.labels = labels output_sample.loss_mask = loss_mask diff --git a/nemo/collections/multimodal/mimo/data/mock.py b/nemo/collections/multimodal/mimo/data/mock.py index d0f1785e8276..e55dee3c3771 100644 --- a/nemo/collections/multimodal/mimo/data/mock.py +++ b/nemo/collections/multimodal/mimo/data/mock.py @@ -13,14 +13,17 @@ # limitations under the License. from typing import Dict, List, Optional, Tuple -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + import numpy as np import pytorch_lightning as pl import torch +from PIL import Image from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.lightning.pytorch.plugins import MegatronDataSampler @@ -43,7 +46,7 @@ def __init__( persistent_workers: bool = False, ): super().__init__() - self.tokenizer=tokenizer + self.tokenizer = tokenizer self.seq_length = seq_length self.decoder_seq_length = decoder_seq_length self.num_train_samples = num_train_samples @@ -71,7 +74,7 @@ def setup(self, stage: str = "") -> None: self.vocab_size, self.tokenizer, self.crop_size, "valid", self.num_val_samples, self.decoder_seq_length ) self._test_ds = _MockMimoDataset( - self.vocab_size,self.tokenizer, self.crop_size, "test", self.num_test_samples, self.decoder_seq_length + self.vocab_size, self.tokenizer, self.crop_size, "test", self.num_test_samples, self.decoder_seq_length ) def train_dataloader(self) -> TRAIN_DATALOADERS: @@ -92,7 +95,7 @@ def test_dataloader(self) -> EVAL_DATALOADERS: def _create_dataloader(self, dataset, **kwargs) -> DataLoader: return DataLoader( dataset, - batch_size = self.micro_batch_size, + batch_size=self.micro_batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, @@ -100,6 +103,7 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader: **kwargs, ) + class _MockMimoDataset(Dataset): def __init__( self, @@ -126,17 +130,32 @@ def __init__( self.input_text = "Generate image of dog." self.label_text = "Here is the image of dog" self.special_tokens = [f"IMG_{i}" for i in range(8)] + + resolution = 768 + transform = transforms.Compose( + [ + transforms.Resize((resolution, resolution)), # Resize to target resolution + transforms.ToTensor(), # Convert to tensor, scales to [0, 1] + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # Normalize to [-1, 1] + ] + ) + + import os + + current_file_dir = os.path.dirname(os.path.abspath(__file__)) + image_path = os.path.join(current_file_dir, "dog.png") + image = Image.open(image_path).convert("RGB") + self.image_tensor = transform(image) + def __len__(self) -> int: return self.length def tokenize_text(self, text: str) -> torch.Tensor: """Tokenize the input text using the provided tokenizer.""" - tokens = self.tokenizer.tokenizer( - text,return_tensors="pt" - ) + tokens = self.tokenizer.tokenizer(text, return_tensors="pt") return tokens["input_ids"].squeeze(0) # Return as 1D tensor - def find_pattern_indices(self,template, pattern, search_start_index=0, allow_first_token_mismatch=False): + def find_pattern_indices(self, template, pattern, search_start_index=0, allow_first_token_mismatch=False): template_len = len(template) pattern_len = len(pattern) for i in range(search_start_index, template_len - pattern_len + 1): @@ -148,7 +167,8 @@ def find_pattern_indices(self,template, pattern, search_start_index=0, allow_fir def __getitem__(self, idx) -> Dict[str, torch.Tensor]: # Generate images with normal distribution np_gen = np.random.default_rng(seed=(self.seed + idx)) - images = torch.zeros((3, self.image_height, self.image_width), dtype=torch.float16) + # images = torch.zeros((3, self.image_height, self.image_width), dtype=torch.float16) + images = torch.zeros((3, self.image_height, self.image_width)) # Tokenize input text and label text input_tokens = self.tokenize_text(self.input_text) @@ -156,7 +176,7 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]: special_token_ids = [self.tokenizer.tokenizer.convert_tokens_to_ids(token) for token in self.special_tokens] label_tokens = torch.cat([label_tokens, torch.tensor(special_token_ids, dtype=torch.long)]) - + combined_tokens = torch.cat([input_tokens, label_tokens]) labels = torch.ones_like(combined_tokens) * self.ignore_placeholder answer_start = len(input_tokens) @@ -172,10 +192,10 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]: "position_ids": torch.arange(len(tokens), dtype=torch.int64), "labels": labels, "loss_mask": loss_mask, - "input_text": self.input_text + "input_text": self.input_text, + "output_images": self.image_tensor, } - def _collate_fn(self, batch): """Default collation function for the dataloader.""" collated_batch = {} @@ -186,12 +206,13 @@ def _collate_fn(self, batch): def collate_fn(self, batch): """Method to use as the `collate_fn` in DataLoader.""" return self._collate_fn(batch) - + + if __name__ == "__main__": tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf") special_tokens = [f"IMG_{i}" for i in range(8)] tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) - data_module = MockDataModule(tokenizer = tokenizer,vocab_size = tokenizer.vocab_size, micro_batch_size=1 ) + data_module = MockDataModule(tokenizer=tokenizer, vocab_size=tokenizer.vocab_size, micro_batch_size=1) data_module.setup() dataloader = data_module.test_dataloader() batch = next(iter(dataloader)) @@ -202,4 +223,4 @@ def collate_fn(self, batch): print("Position IDs:", batch["position_ids"]) print("Labels:", batch["labels"]) print("Loss Mask:", batch["loss_mask"]) - print("Input text:", batch["input_text"]) \ No newline at end of file + print("Input text:", batch["input_text"]) diff --git a/nemo/collections/multimodal/mimo/model/base.py b/nemo/collections/multimodal/mimo/model/base.py index 53ec61855882..ef2a7b54d3a3 100644 --- a/nemo/collections/multimodal/mimo/model/base.py +++ b/nemo/collections/multimodal/mimo/model/base.py @@ -11,49 +11,72 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from nemo.lightning import io -from nemo.utils import logging -from torch import Tensor -from megatron.core.transformer.transformer_config import TransformerConfig -from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union, List, Tuple -import torch from dataclasses import dataclass, field -from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule -from megatron.core.optimizer import OptimizerConfig -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from torch import nn -from megatron.core.transformer.spec_utils import ModuleSpec -from nemo.lightning import get_vocab_size, io -from dataclasses import dataclass -from nemo.collections.llm import fn +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Union + +import pytorch_lightning as L +import torch +import torch.nn.functional as F +import wandb from megatron.core.inference_params import InferenceParams -from nemo.lightning.megatron_parallel import MaskedTokenLossReduction from megatron.core.models.multimodal.llava_model import LLaVAModel as MCoreLLaVAModel -from nemo.collections.llm import Llama2Config7B, Llama2Config13B, LlamaConfig -import torch.nn.functional as F -import pytorch_lightning as L +from megatron.core.models.vision.multimodal_projector import MultimodalProjector as MCoreMultimodalProjector +from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.custom_layers.transformer_engine import ( TEColumnParallelLinear, TENorm, TERowParallelLinear, ) -from nemo.lightning import get_vocab_size, io -from nemo.lightning.megatron_parallel import MaskedTokenLossReductionWithLossMask -from nemo.collections.llm.gpt.model import local_layer_spec, transformer_engine_layer_spec -from megatron.core.models.vision.multimodal_projector import MultimodalProjector as MCoreMultimodalProjector from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import Tensor, nn + +from nemo.collections.llm import Llama2Config7B, Llama2Config13B, LlamaConfig, fn +from nemo.collections.llm.gpt.model import local_layer_spec, transformer_engine_layer_spec +from nemo.lightning import get_vocab_size, io +from nemo.lightning.megatron_parallel import MaskedTokenLossReduction, MaskedTokenLossReductionWithLossMask +from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule +from nemo.utils import logging if TYPE_CHECKING: from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank, get_packed_seq_params # from nemo.collections.multimodal.mimo.model.gpt import MimoGPTModel +def compute_snr(timesteps, noise_scheduler): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + class MimoLossReduction(MaskedTokenLossReduction): - def __init__(self, validation_step: bool = False, val_drop_last: bool = True, l2_weight: float = 10.0) -> None: + def __init__(self, validation_step: bool = False, val_drop_last: bool = True, l2_weight: float = 1.0) -> None: super().__init__(validation_step, val_drop_last) self.l2_weight = l2_weight @@ -78,18 +101,42 @@ def forward( l2_loss = self._calculate_l2_loss(output_projection_embeddings, image_caption_embeddings) l2_loss = self.l2_weight * l2_loss - logging.info(f"Yash loss debug single rank token loss {token_loss} l2 loss {l2_loss}") from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group + reduced_l2_loss = average_losses_across_data_parallel_group([l2_loss]) total_loss = token_loss + l2_loss logging.info(f"Yash loss debug total_loss {total_loss}") - token_loss_info['avg'] = token_loss_info['avg'] + reduced_l2_loss + token_loss_info['avg'] = token_loss_info['avg'] + reduced_l2_loss token_loss_info.update({"l2_loss": reduced_l2_loss}) - - logging.info(f"Yash loss debug full loss {token_loss_info['avg'] } reduced l2 loss {reduced_l2_loss}") - + # denoise l2 loss + + # mse_loss_weights = output_dict['denoise_mse_loss_weights'] + model_pred = output_dict['denoise_model_pred'] + target = output_dict['denoise_target'] + + # gen_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + # gen_loss = gen_loss.mean(dim=[]) * mse_loss_weights + # gen_loss = gen_loss.mean() + gen_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + reduced_gen_l2_loss = average_losses_across_data_parallel_group([gen_loss]) + + total_loss = total_loss + gen_loss + token_loss_info['avg'] = token_loss_info['avg'] + reduced_gen_l2_loss + + logging.info( + f"Yash loss debug full loss {token_loss_info['avg'] } token_loss {token_loss} embedding l2 loss {reduced_l2_loss} denoise l2 loss {reduced_gen_l2_loss}" + ) + # if torch.distributed.get_rank() == 0: + # wandb.log( + # { + # "full_loss": token_loss_info['avg'], + # "token_loss": token_loss, + # "embedding_l2_loss": reduced_l2_loss, + # "denoise_l2_loss": reduced_gen_l2_loss, + # } + # ) return total_loss, token_loss_info def _calculate_l2_loss(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor) -> torch.Tensor: @@ -100,6 +147,7 @@ def _calculate_l2_loss(self, embeddings1: torch.Tensor, embeddings2: torch.Tenso def mimo_forward_step(model, batch) -> torch.Tensor: forward_args = { "images": batch["images"], + "output_images": batch["output_images"], "input_ids": batch["tokens"], "position_ids": batch["position_ids"], "attention_mask": batch.get("attention_mask", None), @@ -127,7 +175,7 @@ def mimo_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: required_keys = set() required_keys.add("attention_mask") if parallel_state.is_pipeline_first_stage(): - required_keys.update(("images", "tokens", "position_ids", "input_text")) + required_keys.update(("images", "tokens", "position_ids", "input_text", "output_images")) if parallel_state.is_pipeline_last_stage(): required_keys.update(("labels", "loss_mask", "input_text")) @@ -220,7 +268,9 @@ def forward( if labels is None: return logits.transpose(0, 1).contiguous(), hidden_states - + # if torch.distributed.get_rank() == 0: # or other ranks + # breakpoint() + # torch.distributed.barrier() loss = self.compute_language_model_loss(labels, logits) return loss, hidden_states @@ -378,6 +428,9 @@ class CustomMimoConfig(TransformerConfig, io.IOMixin): image_special_tokens: Optional[List[str]] = None image_special_token_indices: Optional[List[int]] = None make_vocab_size_divisible_by: int = 128 + freeze_language_model: bool = True + freeze_vision_model: bool = True + freeze_vision_projection: bool = True def configure_model(self, tokenizer) -> "CustomMimoModel": @@ -408,6 +461,11 @@ def configure_model(self, tokenizer) -> "CustomMimoModel": img_w=self.vision_transformer_config.img_w, patch_dim=self.vision_transformer_config.patch_dim, ) + model.freeze( + freeze_language_model=self.freeze_language_model, + freeze_vision_model=self.freeze_vision_model, + freeze_vision_projection=self.freeze_vision_model, + ) return model @@ -496,6 +554,7 @@ def __init__( self.image_decoder_name = "stabilityai/stable-diffusion-2" self.scheduler = EulerDiscreteScheduler.from_pretrained(self.image_decoder_name, subfolder="scheduler") self.image_decoder = StableDiffusionPipeline.from_pretrained(self.image_decoder_name, scheduler=self.scheduler) + self.image_decoder.vae.requires_grad_(False) self.image_decoder.unet.requires_grad_(False) self.image_decoder.text_encoder.requires_grad_(False) @@ -517,6 +576,10 @@ def get_image_caption_embeddings(self, text_input): text_inputs = self.image_decoder.tokenizer( text_input, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=True ) + # if torch.distributed.get_rank() == 0: # or other ranks + # breakpoint() + # torch.distributed.barrier() + text_inputs = text_inputs.to(self.image_decoder.device) image_caption_embeddings = self.image_decoder.text_encoder(**text_inputs)[0] # b,77,1024 return image_caption_embeddings @@ -524,6 +587,7 @@ def get_image_caption_embeddings(self, text_input): def forward( self, images: torch.Tensor, + output_images: torch.Tensor, input_ids: torch.Tensor, input_text: str, position_ids: torch.Tensor, @@ -609,6 +673,10 @@ def forward( if num_image_tiles is None: num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device) + # if torch.distributed.get_rank() == 0: # or other ranks + # breakpoint() + # torch.distributed.barrier() + # Preprocess input, labels and loss mask. combined_embeddings, new_labels, new_loss_mask, attention_mask = self._preprocess_data( image_embeddings, @@ -622,8 +690,7 @@ def forward( attention_mask, ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] # TODO: Yash return this hidden state for computing loss - - + output, hidden_states = self.language_model( input_ids=None, position_ids=None, @@ -636,38 +703,95 @@ def forward( # if labels is None output is logits (b,s,vocab_size) or its loss (b,s) # send hidden_state for special tokens to output_projection module. + device = output_images.device + image_decoder = self.image_decoder.to(device) image_caption_embeddings = self.get_image_caption_embeddings(input_text) # (bs, 77, 1024) - if(new_labels is not None): + if new_labels is not None: special_token_mask = torch.zeros_like(new_labels, dtype=torch.bool) for idx in self.model_config.image_special_token_indices: special_token_mask |= new_labels == idx + + nonzero_indices = torch.nonzero(special_token_mask, as_tuple=False) + special_token_positions = nonzero_indices[:, 1] + special_token_indices = new_labels[special_token_mask] + + special_token_positions = special_token_positions.view( + new_labels.size(0), -1 + ) # batch_size, no_special_tokens + special_token_indices = special_token_indices.view(new_labels.size(0), -1) + + # if torch.distributed.get_rank() == 0: # or other ranks + # breakpoint() + # torch.distributed.barrier() special_token_mask = special_token_mask.transpose(0, 1).unsqueeze(-1) special_token_mask = special_token_mask.expand_as(hidden_states) selected_hidden_states = hidden_states[special_token_mask].view( hidden_states.size(1), -1, hidden_states.size(-1) ) + special_token_embeddings = self.language_model.embedding( + input_ids=special_token_indices, position_ids=special_token_positions + ) + special_token_embeddings = special_token_embeddings.transpose(0, 1) # change to b,s,h + output_projection_embeddings = self.vision_output_projection_module( - selected_hidden_states + selected_hidden_states + special_token_embeddings ) # (bs, no_special_tokens, 1024) # Image caption embeddings image_caption_embeddings = image_caption_embeddings.to( output_projection_embeddings.device, dtype=output_projection_embeddings.dtype ) + if labels is None or loss_mask is None: # return output return { - 'output': output, - # 'output_projection_embeddings': output_projection_embeddings, - # 'image_caption_embeddings': image_caption_embeddings, - 'hidden_states': hidden_states - } + 'output': output, + # 'output_projection_embeddings': output_projection_embeddings, + # 'image_caption_embeddings': image_caption_embeddings, + 'hidden_states': hidden_states, + } + + # for calcualating denoising loss + # with torch.no_grad(): + + # if torch.distributed.get_rank() == 0: # or other ranks + # breakpoint() + # torch.distributed.barrier() + + latents = image_decoder.to(device).vae.encode(output_images).latent_dist.sample() + latents = latents * image_decoder.vae.config.scaling_factor + + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each sample in the batch + timesteps = torch.randint(0, image_decoder.scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # add noise to latents using timesteps + noisy_latents = image_decoder.scheduler.add_noise(latents, noise, timesteps) + # make added noise target + target = noise + # predict the added noise + model_pred = image_decoder.unet(noisy_latents, timesteps, output_projection_embeddings).sample + # model_pred = image_decoder.unet( + # noisy_latents, timesteps, image_caption_embeddings.to(dtype=noisy_latents.dtype) + # # ).sample + # snr = compute_snr(timesteps, image_decoder.scheduler) + # mse_loss_weights = torch.stack([snr, 5 * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + + # if torch.distributed.get_rank() == 0: # or other ranks + # breakpoint() + # torch.distributed.barrier() + return { 'output': output, 'new_loss_mask': new_loss_mask, 'output_projection_embeddings': output_projection_embeddings, 'image_caption_embeddings': image_caption_embeddings, - 'hidden_states': hidden_states + 'hidden_states': hidden_states, + # 'denoise_mse_loss_weights': mse_loss_weights, + 'denoise_model_pred': model_pred, + 'denoise_target': target, } # return (output,output_projection_embeddings, image_caption_embeddings), new_loss_mask @@ -702,12 +826,14 @@ def forward( loss_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, + output_images: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, inference_params: InferenceParams = None, ) -> torch.Tensor: output_tensor = self.module( images=images, + output_images=output_images, input_ids=input_ids, position_ids=position_ids, loss_mask=loss_mask, @@ -750,11 +876,13 @@ def validation_loss_reduction(self) -> MimoLossReduction: return self._validation_loss_reduction -from nemo.lightning import OptimizerModule, io, teardown -from nemo.collections.multimodal.mimo.model.base import BaseMimoConfig, BaseMimoModel -from transformers import LlavaForConditionalGeneration from pathlib import Path + +from transformers import LlavaForConditionalGeneration + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.multimodal.mimo.model.base import BaseMimoConfig, BaseMimoModel +from nemo.lightning import OptimizerModule, io, teardown @io.model_importer(BaseMimoModel, "hf")