Skip to content

Commit

Permalink
mimo overfit to single example
Browse files Browse the repository at this point in the history
  • Loading branch information
yashaswikarnati committed Nov 8, 2024
1 parent 6793f63 commit 62e41b0
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 100 deletions.
65 changes: 37 additions & 28 deletions examples/multimodal/mimo/caption_generation_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import argparse
import logging
import os
import sys
import logging

import torch
from megatron.core.optimizer import OptimizerConfig
from transformers import AutoProcessor
Expand All @@ -12,11 +13,13 @@
from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig
from nemo.collections.multimodal.mimo.data.captioning import MimoCaptioningTaskEncoder
from nemo.collections.multimodal.mimo.data.mock import MockDataModule
from nemo.collections.multimodal.mimo.model.base import BaseMimoConfig, BaseMimoModel, CustomMimoConfig
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback


def main(args):
# Global and micro batch sizes
gbs = 1
Expand All @@ -31,7 +34,7 @@ def main(args):
max_steps=1000,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="16-mixed"),
plugins=nl.MegatronMixedPrecision(precision="32"),
val_check_interval=1000,
limit_val_batches=50,
)
Expand Down Expand Up @@ -61,9 +64,13 @@ def main(args):
task_encoder=task_encoder,
)

# data = MockDataModule(
# tokenizer=tokenizer, vocab_size=tokenizer.vocab_size, micro_batch_size=mbs, global_batch_size=gbs
# )

train_loader = data.train_dataloader()
one_batch = next(iter(train_loader))

fabric = trainer.to_fabric()

custom_config = CustomMimoConfig(
Expand All @@ -74,22 +81,20 @@ def main(args):
# base_config = BaseMimoConfig(vocab_size = tokenizer.vocab_size)
model = BaseMimoModel(config=custom_config, tokenizer=tokenizer)
model = fabric.load_model(args.local_model_path, model)

model = model.module.cuda()
model.eval()


images = one_batch["images"].cuda()
input_ids = one_batch["tokens"].cuda()
position_ids = one_batch["position_ids"].cuda()
input_text = one_batch['input_text']

output_images = one_batch['output_images'].cuda()

all_hidden_states = []


input_ids = input_ids[:,:-7]
position_ids = position_ids[:,:-7]

input_ids = input_ids[:, :-7]
position_ids = position_ids[:, :-7]
# if torch.distributed.get_rank() == 0: #or other ranks
# breakpoint()
# torch.distributed.barrier()
Expand All @@ -98,11 +103,12 @@ def main(args):
with torch.no_grad():

output_dict = model(
input_ids = input_ids,
input_ids=input_ids,
images=images,
input_text=input_text,
position_ids=position_ids,
attention_mask=None,
output_images=output_images,
# labels = labels,
# loss_mask = loss_mask
# num_media_tiles=num_media_tiles,
Expand All @@ -123,40 +129,43 @@ def main(args):
.expand_as(input_ids)
)
# if torch.distributed.get_rank() == 0: #or other ranks
# breakpoint()
breakpoint()
# torch.distributed.barrier()
all_hidden_states.append(hiden_states[-1,:,:])
all_hidden_states.append(hiden_states[-1, :, :])

# If the generated token is the end of sequence token, stop generating
# if next_token_ids.item() == tokenizer.eos_token_id:
# break
# if torch.distributed.get_rank() == 0: #or other ranks
# breakpoint()
# torch.distributed.barrier()
hidden_states_concat = torch.cat(all_hidden_states, dim = 0).unsqueeze(0)
vis_proj_out = model.module.module.module.vision_output_projection_module(hidden_states_concat)
actual_image_caption_embeddings = model.module.module.module.get_image_caption_embeddings(one_batch['input_text'])
mse_loss = torch.nn.functional.mse_loss(actual_image_caption_embeddings.to(vis_proj_out.device, dtype = vis_proj_out.dtype), vis_proj_out)



hidden_states_concat = torch.cat(all_hidden_states, dim=0).unsqueeze(0)
# breakpoint()
vis_proj_out = model.module.module.vision_output_projection_module(hidden_states_concat)
actual_image_caption_embeddings = model.module.module.get_image_caption_embeddings(one_batch['input_text'])
mse_loss = torch.nn.functional.mse_loss(
actual_image_caption_embeddings.to(vis_proj_out.device, dtype=vis_proj_out.dtype), vis_proj_out
)

device = vis_proj_out.device
image_decode_device = model.module.module.module.image_decoder.to(device)
gen_image =image_decode_device(prompt_embeds=actual_image_caption_embeddings.to(device)).images[0]

image_decode_device = model.module.module.image_decoder.to(device)
gen_image = image_decode_device(prompt_embeds=actual_image_caption_embeddings.to(device)).images[0]
gen_image.save('debug_image_gt.png')

gen_image = image_decode_device(prompt_embeds=vis_proj_out).images[0]
gen_image.save('debug_image_generated.png')

logging.info(f"MSE loss for embeddings {mse_loss}")
breakpoint()
generated_ids[generated_ids == -200] = 0
generated_texts = tokenizer.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
logging.info("======== GENERATED TEXT OUTPUT ========")
logging.info(f"{generated_texts}")
logging.info("=======================================")


# Optimizer and scheduler setup



if __name__ == "__main__":
Expand All @@ -170,4 +179,4 @@ def main(args):
)

args = parser.parse_args()
main(args)
main(args)
17 changes: 12 additions & 5 deletions examples/multimodal/mimo/train_captioning_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import faulthandler
import logging
import os
import sys

Expand All @@ -17,11 +19,16 @@
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback

logging.basicConfig(level=logging.DEBUG)
faulthandler.enable()
# Optionally, redirect output to a specific file
# faulthandler.enable(file=open("faulthandler.log", "w"))


def main(args):
# Global and micro batch sizes
gbs = 16
mbs = 8
gbs = 128
mbs = 32
seq_length = 256
data_path = '/lustre/fsw/coreai_dlalgo_genai/ykarnati/datasets/cc3m-wds'
tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf")
Expand Down Expand Up @@ -70,9 +77,9 @@ def main(args):
max_steps=3500,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="16-mixed"),
plugins=nl.MegatronMixedPrecision(precision="32"),
callbacks=[checkpoint_callback, TimingCallback()],
val_check_interval=50,
val_check_interval=30,
limit_val_batches=gbs,
log_every_n_steps=1,
num_sanity_val_steps=0,
Expand Down Expand Up @@ -137,7 +144,7 @@ def main(args):
"--log_dir", type=str, required=False, default="./", help="Directory for logging and checkpoints"
)
parser.add_argument("--devices", type=int, required=False, default=8)
parser.add_argument("--tp_size", type=int, required=False, default=4)
parser.add_argument("--tp_size", type=int, required=False, default=2)
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--name", type=str, required=False, default="mimo_first_light")
parser.add_argument("--wandb_project", type=str, required=False, default=None)
Expand Down
23 changes: 15 additions & 8 deletions examples/multimodal/mimo/train_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys

import torch
import wandb
from megatron.core.optimizer import OptimizerConfig
from transformers import AutoProcessor

Expand All @@ -13,13 +14,16 @@
from nemo.collections.multimodal.mimo.model.base import BaseMimoConfig, BaseMimoModel, CustomMimoConfig
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.utils.exp_manager import TimingCallback


def main(args):

# wandb.init(project=args.wandb_project, name=args.name)
# Global and micro batch sizes
gbs = 8
mbs = 2
gbs = 4
mbs = 1
seq_length = 256

tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf")
Expand All @@ -44,7 +48,7 @@ def main(args):
save_last=True,
monitor="reduced_train_loss",
save_top_k=2,
every_n_train_steps=500,
every_n_train_steps=200,
dirpath=args.log_dir,
)

Expand All @@ -53,7 +57,7 @@ def main(args):
max_steps=10000,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="16-mixed"),
plugins=nl.MegatronMixedPrecision(precision="32"),
callbacks=[checkpoint_callback, TimingCallback()],
val_check_interval=50,
limit_val_batches=gbs,
Expand All @@ -65,6 +69,7 @@ def main(args):
vocab_size=tokenizer.vocab_size,
image_special_token_indices=image_special_token_indices,
image_special_tokens=image_special_tokens,
freeze_language_model=False,
)
# base_config = BaseMimoConfig(vocab_size = tokenizer.vocab_size)
model = BaseMimoModel(config=custom_config, tokenizer=tokenizer)
Expand All @@ -87,14 +92,15 @@ def main(args):
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
resume_from_directory=args.log_dir,
restore_config=None,
# restore_config=None,
restore_config=RestoreConfig(path=args.restore_path) if args.restore_path else None,
)
resume.setup(trainer, model)

# Optimizer and scheduler setup
opt_config = OptimizerConfig(
optimizer='adam',
lr=0.001,
lr=2.0e-4,
adam_beta1=0.9,
adam_beta2=0.95,
use_distributed_optimizer=False,
Expand All @@ -104,7 +110,7 @@ def main(args):
max_steps=trainer.max_steps,
warmup_steps=70,
constant_steps=0,
min_lr=2.0e-05,
min_lr=2.0e-4,
)
opt = MegatronOptimizerModule(opt_config, sched)
opt.connect(model)
Expand All @@ -122,8 +128,9 @@ def main(args):
parser.add_argument("--devices", type=int, required=False, default=8)
parser.add_argument("--tp_size", type=int, required=False, default=2)
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--name", type=str, required=False, default="mimo_first_light")
parser.add_argument("--name", type=str, required=False, default="mimo_decoder_align")
parser.add_argument("--wandb_project", type=str, required=False, default=None)
parser.add_argument("--restore_path", type=str, required=False, default=None)

args = parser.parse_args()
main(args)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/multimodal/mimo/data/captioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def encode(self, input_sample: CaptioningSample, output_sample: MimoCaptioningSa

output_sample.__key__ = input_sample.__key__
output_sample.__restore_key__ = input_sample.__restore_key__
output_sample.input_image = torch.zeros((3, 336, 336), dtype=torch.float16)
output_sample.input_image = torch.zeros((3, 336, 336))
output_sample.tokens = tokens
output_sample.labels = labels
output_sample.loss_mask = loss_mask
Expand Down
Loading

0 comments on commit 62e41b0

Please sign in to comment.