Skip to content

Commit

Permalink
Set default grad_accum_dtype to None
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoxigua999 committed Dec 17, 2024
1 parent f42f536 commit 47f7cd8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion openrlhf/utils/deepspeed/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.max_norm = max_norm
self.adam_offload = getattr(args, "adam_offload", False)
self.zpg = getattr(args, "zpg", 1)
self.grad_accum_dtype = getattr(args, "grad_accum_dtype", "fp32")
self.grad_accum_dtype = getattr(args, "grad_accum_dtype", None)
# disable_trace_cache
self.disable_trace_cache = getattr(args, "disable_trace_cache", False)

Expand Down
2 changes: 1 addition & 1 deletion openrlhf/utils/deepspeed/deepspeed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_train_ds_config(
"gradient_clipping": max_norm,
"prescale_gradients": False,
"wall_clock_breakdown": False,
"data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"},
"data_types": {"grad_accum_dtype": grad_accum_dtype},
}


Expand Down

0 comments on commit 47f7cd8

Please sign in to comment.