diff --git a/applications/DeepSpeed-Chat/training/utils/utils.py b/applications/DeepSpeed-Chat/training/utils/utils.py index 4ad7d8709..39f659fb6 100644 --- a/applications/DeepSpeed-Chat/training/utils/utils.py +++ b/applications/DeepSpeed-Chat/training/utils/utils.py @@ -209,9 +209,12 @@ def get_optimizer_grouped_parameters( 0.0, }, ] - if not optimizer_grouped_parameters[1]["params"]: - optimizer_grouped_parameters.pop(1) - return optimizer_grouped_parameters + + non_empty_groups = [] + for group in optimizer_grouped_parameters: + if group["params"]: + non_empty_groups.append(group) + return non_empty_groups def _z3_params_to_fetch(param_list):