Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[incubate/new_frl] Support release_grads when non_overlap #58621

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def _build_comm_buffers(self, group_size=256 * 1024 * 1024):
parameters,
comm_group,
act=HOOK_ACTION.REDUCE_SCATTER,
release_grads=self.pp_release_grads,
)
self._comm_buffer_list.append(buffer)

Expand All @@ -501,6 +502,10 @@ def clear_grad_func(p):
for p in self._parameter_list:
clear_grad_func(p)

if self.pp_release_grads and not self.pp_overlap:
for comm_buffer in self._comm_buffer_list:
comm_buffer._clear_grad_storage()

def filter_parameters(self, parameter_list, hcg):

parameter_list = [
Expand All @@ -515,6 +520,10 @@ def reduce_gradients(self, parameter_list, hcg):
# TODO merge grad / nrank with dp
with framework.no_grad():
for comm_buffer in self._comm_buffer_list:
if self.pp_release_grads and comm_buffer.grad_storage is None:
for param in comm_buffer.params:
comm_buffer._copy_grad_to_buffer(param)

comm_buffer._comm_grads()
comm_buffer.scale_and_split_grads()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,6 @@ def __init__(self, layers, hcg, strategy):
self._dp_comm_overlap or self._sharding_comm_overlap
)

assert (
self._comm_overlap or not self._release_gradients
), "Cannot use release_gradients without comm_overlap."

if self._enable_timer:
if not timer.is_timer_initialized():
timer.set_timers()
Expand Down
6 changes: 2 additions & 4 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _copy_grad_to_buffer(self, param):
if self.grad_storage is None:
assert self._params_step_dict[param.name] == 0

self.grad_storage = paddle.empty(
self.grad_storage = paddle.zeros(
[self.buffer_size], dtype=self._dtype
)

Expand All @@ -374,9 +374,7 @@ def _copy_grad_to_buffer(self, param):
grad_var.stop_gradient = True
grad_var.flatten_()

if self._params_step_dict[param.name] == 0:
paddle.assign(grad_var, tmp_var)

tmp_var.add_(grad_var)
tmp_var.get_tensor()._set_dims(param.shape)

if self.use_main_grad:
Expand Down