-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Automatic Mixed Precision with GradScaler to Address NaN Loss Issues #13
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge AMP via Gradscalar into ECT.
@@ -78,6 +78,7 @@ def convert(self, value, param, ctx): | |||
@click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) | |||
@click.option('--tf32', help='Enable tf32 for A100/H100 training speed', metavar='BOOL', type=bool, default=False, show_default=True) | |||
@click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) | |||
@click.option('--enable_gradscaler', help='Enable torch.cuda.amp.GradScaler, NOTE overwritting loss_scale set by --ls', metavar='BOOL', type=bool, default=False, show_default=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Zixiang @aiihn ,
Thanks for your neat PR!
Would it be better to use a short abbreviation like amp
as the option name? AMP already stands for Automatic Mixed Precision.
if enable_gradscaler: | ||
if 'gradscaler_state' in data: | ||
dist.print0(f'Loading GradScaler state from "{resume_state_dump}"...') | ||
# Although not loading the state_dict of the GradScaler works well, loading it can improve reproducibility. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. Thanks for the comments!
scaler.step(optimizer) | ||
scaler.update() | ||
else: | ||
# Update weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO is also unclear to me either. It seems still useful and compatible per Claude.
It's fine to remove my commented code for lr rampup.
Hi @aiihn , Thank you again for your PR! I had another AMP implementation that could also be helpful for ECT. I’ll check it out later and test Links for reference: Cheers, |
Description
This pull request addresses the issue of NaN losses occurring during mixed-precision training with
--fp16
enabled (#12).Key Changes
torch.cuda.amp.GradScaler
to dynamically adjust loss scaling.GradScaler
will override theloss_scale
set manually by--ls
.Usage
Use
--fp16=True
along with--enable_gradscaler=True
. For example, below is the mixed-training command modified from run_ecm_1hour.sh.The FID records obtained using the above command are shown in the following images: