diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 6b158a33b226..47035b070c05 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -180,6 +180,13 @@ class GPTConfig(TransformerConfig, io.IOMixin): data_step_fn: Callable = gpt_data_step def configure_model(self, tokenizer) -> "MCoreGPTModel": + if self.enable_cuda_graph: + assert HAVE_TE, "Transformer Engine is required for cudagraphs." + assert getattr(self, 'use_te_rng_tracker', False), ( + "Transformer engine's RNG tracker is required for cudagraphs, it can be " + "enabled with use_te_rng_tracker=True'." + ) + vp_size = self.virtual_pipeline_model_parallel_size if vp_size: p_size = self.pipeline_model_parallel_size diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 1bee71e26e17..8ef8b8a1a3ad 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -90,6 +90,7 @@ def init_parallel_ranks( pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None), use_fp8=fp8, init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False), + use_te_rng_tracker=getattr(parallel_config, "use_te_rng_tracker", False), # apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), ) diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index c62a90313b45..d61acd84b9de 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -99,6 +99,7 @@ class ParallelismConfig: pipeline_dtype: torch.dtype encoder_tensor_model_parallel_size: int = 0 encoder_pipeline_model_parallel_size: int = 0 + use_te_rng_tracker: bool class MegatronStrategy(DDPStrategy, io.IOMixin): @@ -195,6 +196,7 @@ def __init__( ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron", lazy_init: bool = False, pipeline_dtype: Optional[torch.dtype] = None, + use_te_rng_tracker: bool = False, save_ckpt_format: str = "torch_dist", ckpt_async_save: bool = True, ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere? @@ -239,6 +241,7 @@ def __init__( self.ckpt_load_optimizer = ckpt_load_optimizer self.ckpt_save_optimizer = ckpt_save_optimizer self.pipeline_dtype = pipeline_dtype + self.use_te_rng_tracker = use_te_rng_tracker self._setup_optimizers = setup_optimizers self._init_model_parallel = init_model_parallel self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) @@ -863,6 +866,7 @@ def parallelism(self) -> ParallelismConfig: encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size, encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size, pipeline_dtype=self.pipeline_dtype, + use_te_rng_tracker=self.use_te_rng_tracker, ) @contextmanager diff --git a/nemo/utils/get_rank.py b/nemo/utils/get_rank.py index 37d3906760e7..401b6996a438 100644 --- a/nemo/utils/get_rank.py +++ b/nemo/utils/get_rank.py @@ -20,6 +20,13 @@ def is_global_rank_zero(): """ Helper function to determine if the current process is global_rank 0 (the main process) """ + + + # Try to get the MPI global rank env var + mpi_rank = get_envint("OMPI_COMM_WORLD_RANK", None) + if mpi_rank is not None: + return mpi_rank == 0 + # Try to get the pytorch RANK env var # RANK is set by torch.distributed.launch rank = get_envint("RANK", None) @@ -32,10 +39,6 @@ def is_global_rank_zero(): if slurm_rank is not None: return slurm_rank == 0 - # Try to get the MPI global rank env var - mpi_rank = get_envint("OMPI_COMM_WORLD_RANK", None) - if mpi_rank is not None: - return mpi_rank == 0 # if neither pytorch, SLURM nor MPI env vars are set # check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars diff --git a/tests/lightning/test_strategy_lib.py b/tests/lightning/test_strategy_lib.py index 241debd16316..497ce72558c2 100644 --- a/tests/lightning/test_strategy_lib.py +++ b/tests/lightning/test_strategy_lib.py @@ -82,6 +82,7 @@ def test_init_parallel_ranks() -> None: mock_parallel_config.encoder_pipeline_model_parallel_size = 0 mock_parallel_config.tp_comm_overlap = False mock_parallel_config.pipeline_model_parallel_split_rank = None + mock_parallel_config.use_te_rng_tracker = False _strategy_lib.init_parallel_ranks( world_size=24, @@ -105,6 +106,7 @@ def test_init_parallel_ranks() -> None: "encoder_tensor_model_parallel_size": 0, "use_fp8": False, "init_mpi_proc_group": False, + "use_te_rng_tracker" : False, } for k, v in expected_app_state.items(): assert hasattr(app_state, k), f"Expected to find {k} in AppState"