Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero2: avoid graph breaks in torch.compile by using param_idx #6803

Merged
merged 12 commits into from
Dec 20, 2024
8 changes: 5 additions & 3 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(self,
for param in param_group['params']:
if param.requires_grad:
param.grad_accum = None
param.param_idx_in_group = len(trainable_parameters)
trainable_parameters.append(param)
self.bit16_groups.append(trainable_parameters)

Expand Down Expand Up @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

self.grads_in_ipg_bucket.append(grad_reduc)
self.params_in_ipg_bucket.append((i, param, param_id))
self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id))

#make sure the average tensor function knows how to average the gradients
if is_moe_param(param):
Expand Down Expand Up @@ -1067,7 +1068,7 @@ def average_tensor(self, tensor):

process_group = self.dp_process_group
# count = 0
for i, param, param_id in self.params_in_ipg_bucket:
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
nelyahu marked this conversation as resolved.
Show resolved Hide resolved

process_group = self.dp_process_group

Expand Down Expand Up @@ -1383,7 +1384,8 @@ def reduce_ipg_grads(self):
stream = get_accelerator().current_stream()

with get_accelerator().stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket:
param = self.bit16_groups[group_idx][param_idx_in_group]

assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def strict_average_tensor(tensor):
process_group = optimizer.dp_process_group
curr_size = 0
pg_offsets = []
for i, param, param_id in optimizer.params_in_ipg_bucket:
for i, param_idx, param_id in optimizer.params_in_ipg_bucket:
param = optimizer.bit16_groups[i][param_idx]
process_group = optimizer.dp_process_group
if optimizer.ipg_bucket_has_moe_params:
process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param(
Expand Down