Skip to content

Commit

Permalink
load ckpt with dist_checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
yashaswikarnati committed Nov 12, 2024
1 parent 62e41b0 commit 6aabe99
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions nemo/collections/multimodal/mimo/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.nn.functional as F
import wandb
from megatron.core import dist_checkpointing
from megatron.core.inference_params import InferenceParams
from megatron.core.models.multimodal.llava_model import LLaVAModel as MCoreLLaVAModel
from megatron.core.models.vision.multimodal_projector import MultimodalProjector as MCoreMultimodalProjector
Expand Down Expand Up @@ -461,6 +462,24 @@ def configure_model(self, tokenizer) -> "CustomMimoModel":
img_w=self.vision_transformer_config.img_w,
patch_dim=self.vision_transformer_config.patch_dim,
)
from megatron.core.dist_checkpointing.validation import StrictHandling

sharded_state_dict = dict(state_dict=model.language_model.sharded_state_dict(prefix="module."))
if torch.distributed.get_rank() == 0: # or other ranks
breakpoint()
torch.distributed.barrier()
strict = StrictHandling.LOG_UNEXPECTED
loaded_state_dict = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict,
checkpoint_dir='/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5/weights',
strict=strict,
)
loaded_state_dict = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()}
if torch.distributed.get_rank() == 0: # or other ranks
breakpoint()
torch.distributed.barrier()
model.language_model.load_state_dict(loaded_state_dict)

model.freeze(
freeze_language_model=self.freeze_language_model,
freeze_vision_model=self.freeze_vision_model,
Expand Down

0 comments on commit 6aabe99

Please sign in to comment.