Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NaN] Fix nan print issue when running Megatron-Deepspeed with DeepSpeed #434

Merged
merged 1 commit into from
Aug 24, 2024

Conversation

ys950902
Copy link

@ys950902 ys950902 commented Aug 5, 2024

When we running megatron-deepspeed with deepspeed met nan issue, the only way we can judge this issue can see below is no lm loss print and the number of nan iterations is still 0 which is not correct:
iteration 9/ 10 | consumed samples: 108 | consumed tokens: 442368 | elapsed time per iteration (ms): 1979.2 | learning rate: 4.219E-07 | global batch size: 12 | loss scale: 1.0 | actual seqlen: 4096 | number of skipped iterations: 0 | number of nan iterations: 0 | samples per second: 6.063 | tokens per gpu per second (tgs): 2069.506 | TFLOPs: 127.00 |

This pr is to fix this issue, whether is skipped iter we should do the nan check.

Copy link

@abhilash1910 abhilash1910 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@tjruwase
Copy link

tjruwase commented Aug 7, 2024

@ys950902, can you please share a bit more details about why skipped_iter is False in this case?

@ys950902
Copy link
Author

ys950902 commented Aug 7, 2024

@ys950902, can you please share a bit more details about why skipped_iter is False in this case?

Hi @tjruwase, thanks for your reply, when you running Megatron-DeepSpeed with DeepSpeed for 3D parallelism:
https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/training.py#L674
or running for zero2/3
https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/training.py#L762
the skipped_iter is set to 0 by default, and DeepSpeed won't update this flag, so is false here.

@tjruwase
Copy link

tjruwase commented Aug 7, 2024

@ys950902, thanks for the explanation. I think the correct solution is to use the was_step_applied() API of DeepSpeed. And I noticed that for the non-3D parallelism case, it is already used to set update_successful.

update_successful = model[0].was_step_applied()

The problem is that update_successful is not used to appropriately set skipped_iter unlike the non-deepspeed code path.

if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0

Can you try setting update_successful and skipped_iter for both deepspeed cases in a consistent fashion to the megatron case? Thanks

@ys950902
Copy link
Author

ys950902 commented Aug 7, 2024

@ys950902, thanks for the explanation. I think the correct solution is to use the was_step_applied() API of DeepSpeed. And I noticed that for the non-3D parallelism case, it is already used to set update_successful.

update_successful = model[0].was_step_applied()

The problem is that update_successful is not used to appropriately set skipped_iter unlike the non-deepspeed code path.

if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0

Can you try setting update_successful and skipped_iter for both deepspeed cases in a consistent fashion to the megatron case? Thanks

Got it, I will fix it as you suggested!

@ys950902
Copy link
Author

ys950902 commented Aug 8, 2024

Hi @tjruwase, could you please take a look on this pr and with the modify in deepspeed to support bfloat16 microsoft/DeepSpeed#5879.

@ys950902
Copy link
Author

Hi @tjruwase, will you merge this pr?

@tjruwase tjruwase merged commit 4f9f1f6 into microsoft:main Aug 24, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants