From 6aabe998e1f0848876b6891211a4970c765b8cf4 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Mon, 11 Nov 2024 18:29:38 -0800 Subject: [PATCH] load ckpt with dist_checkpointing --- .../collections/multimodal/mimo/model/base.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/nemo/collections/multimodal/mimo/model/base.py b/nemo/collections/multimodal/mimo/model/base.py index ef2a7b54d3a3..ea5926e60d64 100644 --- a/nemo/collections/multimodal/mimo/model/base.py +++ b/nemo/collections/multimodal/mimo/model/base.py @@ -18,6 +18,7 @@ import torch import torch.nn.functional as F import wandb +from megatron.core import dist_checkpointing from megatron.core.inference_params import InferenceParams from megatron.core.models.multimodal.llava_model import LLaVAModel as MCoreLLaVAModel from megatron.core.models.vision.multimodal_projector import MultimodalProjector as MCoreMultimodalProjector @@ -461,6 +462,24 @@ def configure_model(self, tokenizer) -> "CustomMimoModel": img_w=self.vision_transformer_config.img_w, patch_dim=self.vision_transformer_config.patch_dim, ) + from megatron.core.dist_checkpointing.validation import StrictHandling + + sharded_state_dict = dict(state_dict=model.language_model.sharded_state_dict(prefix="module.")) + if torch.distributed.get_rank() == 0: # or other ranks + breakpoint() + torch.distributed.barrier() + strict = StrictHandling.LOG_UNEXPECTED + loaded_state_dict = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, + checkpoint_dir='/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5/weights', + strict=strict, + ) + loaded_state_dict = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()} + if torch.distributed.get_rank() == 0: # or other ranks + breakpoint() + torch.distributed.barrier() + model.language_model.load_state_dict(loaded_state_dict) + model.freeze( freeze_language_model=self.freeze_language_model, freeze_vision_model=self.freeze_vision_model,