Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
Signed-off-by: root <[email protected]>
  • Loading branch information
root authored and jiemingz committed Nov 9, 2024
1 parent 43ba11a commit bdf246a
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
7 changes: 7 additions & 0 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ class GPTConfig(TransformerConfig, io.IOMixin):
data_step_fn: Callable = gpt_data_step

def configure_model(self, tokenizer) -> "MCoreGPTModel":
if self.enable_cuda_graph:
assert HAVE_TE, "Transformer Engine is required for cudagraphs."
assert getattr(self, 'use_te_rng_tracker', False), (
"Transformer engine's RNG tracker is required for cudagraphs, it can be "
"enabled with use_te_rng_tracker=True'."
)

vp_size = self.virtual_pipeline_model_parallel_size
if vp_size:
p_size = self.pipeline_model_parallel_size
Expand Down
1 change: 1 addition & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def init_parallel_ranks(
pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None),
use_fp8=fp8,
init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False),
use_te_rng_tracker=getattr(parallel_config, "use_te_rng_tracker", False),
# apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
)

Expand Down
4 changes: 4 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class ParallelismConfig:
pipeline_dtype: torch.dtype
encoder_tensor_model_parallel_size: int = 0
encoder_pipeline_model_parallel_size: int = 0
use_te_rng_tracker: bool


class MegatronStrategy(DDPStrategy, io.IOMixin):
Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(
ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron",
lazy_init: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
use_te_rng_tracker: bool = False,
save_ckpt_format: str = "torch_dist",
ckpt_async_save: bool = True,
ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere?
Expand Down Expand Up @@ -239,6 +241,7 @@ def __init__(
self.ckpt_load_optimizer = ckpt_load_optimizer
self.ckpt_save_optimizer = ckpt_save_optimizer
self.pipeline_dtype = pipeline_dtype
self.use_te_rng_tracker = use_te_rng_tracker
self._setup_optimizers = setup_optimizers
self._init_model_parallel = init_model_parallel
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
Expand Down Expand Up @@ -863,6 +866,7 @@ def parallelism(self) -> ParallelismConfig:
encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size,
pipeline_dtype=self.pipeline_dtype,
use_te_rng_tracker=self.use_te_rng_tracker,
)

@contextmanager
Expand Down
11 changes: 7 additions & 4 deletions nemo/utils/get_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
def is_global_rank_zero():
""" Helper function to determine if the current process is global_rank 0 (the main process)
"""


# Try to get the MPI global rank env var
mpi_rank = get_envint("OMPI_COMM_WORLD_RANK", None)
if mpi_rank is not None:
return mpi_rank == 0

# Try to get the pytorch RANK env var
# RANK is set by torch.distributed.launch
rank = get_envint("RANK", None)
Expand All @@ -32,10 +39,6 @@ def is_global_rank_zero():
if slurm_rank is not None:
return slurm_rank == 0

# Try to get the MPI global rank env var
mpi_rank = get_envint("OMPI_COMM_WORLD_RANK", None)
if mpi_rank is not None:
return mpi_rank == 0

# if neither pytorch, SLURM nor MPI env vars are set
# check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars
Expand Down
2 changes: 2 additions & 0 deletions tests/lightning/test_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_init_parallel_ranks() -> None:
mock_parallel_config.encoder_pipeline_model_parallel_size = 0
mock_parallel_config.tp_comm_overlap = False
mock_parallel_config.pipeline_model_parallel_split_rank = None
mock_parallel_config.use_te_rng_tracker = False

_strategy_lib.init_parallel_ranks(
world_size=24,
Expand All @@ -105,6 +106,7 @@ def test_init_parallel_ranks() -> None:
"encoder_tensor_model_parallel_size": 0,
"use_fp8": False,
"init_mpi_proc_group": False,
"use_te_rng_tracker" : False,
}
for k, v in expected_app_state.items():
assert hasattr(app_state, k), f"Expected to find {k} in AppState"
Expand Down

0 comments on commit bdf246a

Please sign in to comment.