Skip to content

Commit

Permalink
Add option to change batch size if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
BoxiangW committed Nov 13, 2024
1 parent 085e957 commit e5e2dc2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand All @@ -201,6 +202,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def setup_microbatch_calculator(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand All @@ -121,6 +122,7 @@ def setup_microbatch_calculator(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand All @@ -201,6 +202,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand Down

0 comments on commit e5e2dc2

Please sign in to comment.