Skip to content

Commit

Permalink
Fix bug in MegatronParallel
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Jun 10, 2024
1 parent 8c522c3 commit cd8d333
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ def __init__(
_model.configure_model()
_pipeline.append(_model)

if self.ddp_config is not None:
if isinstance(ddp_config, DistributedDataParallelConfig):
from megatron.core.distributed import DistributedDataParallel as McoreDDP

_pipeline = [
McoreDDP(
model_chunk.config,
self.ddp_config,
ddp_config,
model_chunk,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
Expand Down

0 comments on commit cd8d333

Please sign in to comment.