Skip to content

Commit

Permalink
[develop] Support release_grads in Pipeline Parallel and Sharding sta…
Browse files Browse the repository at this point in the history
…ge1 v1/v2
  • Loading branch information
haohongxiang committed Dec 6, 2023
1 parent b5ebcae commit dd6e66c
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 37 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ message PpConfig {
optional bool enable_timer = 3 [ default = false ];
optional bool sharding_comm_overlap = 4 [ default = false ];
optional bool profiling = 5 [ default = false ];
optional bool release_gradients = 6 [ default = false ];
}

message DygraphShardingConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,9 @@ def __init__(self, optimizer, hcg):
self.pp_overlap = strategy.hybrid_configs[
'pp_configs'
].sharding_comm_overlap
self.pp_release_grads = strategy.hybrid_configs[
'pp_configs'
].release_gradients

# TODO(liuzhenhai):support it latter
assert not self.comm_overlap, "not supported yet"
Expand All @@ -553,14 +556,16 @@ def _build_comm_buffers(self, group_size=256 * 1024 * 1024):
parameters,
comm_group,
act=HOOK_ACTION.REDUCE_SCATTER,
release_grads=self.pp_release_grads,
)
self._comm_buffer_list.append(buffer)

def clear_grad(self, set_to_zero=True):
"""
should clear grad for all parameters in model
"""
assert set_to_zero, "should not erase grad buffer"
if not self.pp_release_grads:
assert set_to_zero, "should not erase grad buffer"

def clear_grad_func(p):
if hasattr(p, "main_grad") and p.main_grad is not None:
Expand All @@ -583,6 +588,10 @@ def clear_grad_func(p):
for p in self._parameter_list:
clear_grad_func(p)

if self.pp_release_grads and not self.pp_overlap:
for comm_buffer in self._comm_buffer_list:
comm_buffer._clear_grad_storage()

def filter_parameters(self, parameter_list, hcg):
parameter_list = [
self._slice_params[param.name] for param in parameter_list
Expand All @@ -597,6 +606,9 @@ def reduce_gradients(self, parameter_list, hcg):
logger.debug("sharding start gradients sync")
with framework.no_grad():
for comm_buffer in self._comm_buffer_list:
if self.pp_release_grads and comm_buffer.grad_storage is None:
for param in comm_buffer.params:
comm_buffer._copy_grad_to_buffer(param)
comm_buffer._comm_grads()
comm_buffer.scale_grads()

Expand Down Expand Up @@ -660,6 +672,9 @@ def _assign_slice_grad(self):
for param in comm_buffer.params:
assert param.name in self._slice_params
slice_param = self._slice_params[param.name]
if self.pp_release_grads and hasattr(slice_param, "main_grad"):
assert not slice_param.main_grad._is_initialized()
del slice_param.main_grad
comm_buffer.assign_slice_grad(param, slice_param)

assert param_num == len(self._parameter_list)
Expand Down
28 changes: 22 additions & 6 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def __init__(self, layers, hcg, strategy):
self._enable_timer = self._strategy.hybrid_configs[
"pp_configs"
].enable_timer
self._release_gradients = self._strategy.hybrid_configs[
"pp_configs"
].release_gradients

self._sharding_split_param = self._strategy.hybrid_configs[
"sharding_configs"
Expand Down Expand Up @@ -246,15 +249,15 @@ def __init__(self, layers, hcg, strategy):
if self._sharding_comm_overlap:
assert self.use_sharding_parallel and self.num_stages > 1

assert not (
self._dp_comm_overlap and self._sharding_comm_overlap
), "Cannot use dp pp overlap and sharding pp overlap at the same time."

self._chunk_2_comm_buffers = defaultdict(list)
self._comm_overlap = (
self._dp_comm_overlap or self._sharding_comm_overlap
)

assert (
self._comm_overlap or not self._release_gradients
), "Cannot use release_gradients without comm_overlap."

if self._enable_timer:
if not timer.is_timer_initialized():
timer.set_timers()
Expand Down Expand Up @@ -379,7 +382,13 @@ def register_allreduce_overlap_hook(

for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst
group_idx,
parameters,
comm_group,
acc_steps,
act,
dst,
release_grads=self._release_gradients,
)
self._chunk_2_comm_buffers[chunk_idx].append(buffer)
for param in parameters:
Expand Down Expand Up @@ -852,7 +861,14 @@ def _optimizer_step(self):
else:
self.optimizer.step()

self.optimizer.clear_grad()
if self._release_gradients:
self.optimizer.clear_grad(set_to_zero=False)
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer._clear_grad_storage()
else:
self.optimizer.clear_grad()

if self.lr_scheduler:
self.lr_scheduler.step()

Expand Down
Loading

0 comments on commit dd6e66c

Please sign in to comment.