diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 10b939d4aecb..57328fbaa189 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -181,6 +181,7 @@ def initialize_model_parallel_for_nemo( micro_batch_size=micro_batch_size, data_parallel_size=app_state.data_parallel_size, rampup_batch_size=rampup_batch_size, + decrease_batch_size_if_needed=False, ) else: if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): @@ -201,6 +202,7 @@ def initialize_model_parallel_for_nemo( micro_batch_size=micro_batch_size, data_parallel_size=app_state.data_parallel_size, rampup_batch_size=rampup_batch_size, + decrease_batch_size_if_needed=False ) else: if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index 9cf686464417..c0db0286e26d 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -103,6 +103,7 @@ def setup_microbatch_calculator( micro_batch_size=micro_batch_size, data_parallel_size=app_state.data_parallel_size, rampup_batch_size=rampup_batch_size, + decrease_batch_size_if_needed=False ) else: if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): @@ -121,6 +122,7 @@ def setup_microbatch_calculator( micro_batch_size=micro_batch_size, data_parallel_size=app_state.data_parallel_size, rampup_batch_size=rampup_batch_size, + decrease_batch_size_if_needed=False, ) else: if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): diff --git a/nemo/lightning/megatron_init.py b/nemo/lightning/megatron_init.py index 10b939d4aecb..74298de34dd5 100644 --- a/nemo/lightning/megatron_init.py +++ b/nemo/lightning/megatron_init.py @@ -181,6 +181,7 @@ def initialize_model_parallel_for_nemo( micro_batch_size=micro_batch_size, data_parallel_size=app_state.data_parallel_size, rampup_batch_size=rampup_batch_size, + decrease_batch_size_if_needed=False, ) else: if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): @@ -201,6 +202,7 @@ def initialize_model_parallel_for_nemo( micro_batch_size=micro_batch_size, data_parallel_size=app_state.data_parallel_size, rampup_batch_size=rampup_batch_size, + decrease_batch_size_if_needed=False, ) else: if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):