Skip to content

Commit

Permalink
refactor sharding stage1, split param [develop] (PaddlePaddle#57715)
Browse files Browse the repository at this point in the history
* sharding stage1 refactor

* add config log

* remove FLAGS_shard_split_param

* add assert

* polish

* polish

* follow comment

* follow comment
  • Loading branch information
liuzhenhai93 authored Nov 8, 2023
1 parent d659b58 commit ed28804
Show file tree
Hide file tree
Showing 8 changed files with 702 additions and 132 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ message DygraphShardingConfig {
optional bool tensor_fusion = 1 [ default = false ];
optional int32 accumulate_steps = 2 [ default = 1 ];
optional bool comm_overlap = 3 [ default = false ];
optional bool split_param = 4 [ default = false ];
}

message HybridConfig {
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,15 @@ def check_models(models):
def _is_valid_optimizer(optimizer):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
)

return isinstance(
optimizer,
(
paddle.optimizer.Optimizer,
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
),
)

Expand Down Expand Up @@ -483,11 +485,14 @@ def __call__(self, state_dict):
def _set_multi_precision(optimizer, multi_precision):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
)

optimizer = (
optimizer._inner_opt
if isinstance(optimizer, DygraphShardingOptimizer)
if isinstance(
optimizer, (DygraphShardingOptimizer, DygraphShardingOptimizerV2)
)
else optimizer
)
if hasattr(optimizer, "_multi_precision"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,24 @@

import paddle
from paddle import framework
from paddle.base.framework import EagerParamBase
from paddle.distributed import fleet

from ...utils.log_util import logger
from ...utils.tensor_fusion_helper import fused_parameters
from ...utils.tensor_fusion_helper import (
HOOK_ACTION,
FusedCommBuffer,
assign_group_by_size,
fused_parameters,
)

g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0))

if g_shard_norm_align_dp:
assert (
not g_shard_use_reduce
), "g_shard_norm_align_dp is not support if g_shard_use_reduce is true"
), "g_shard_norm_align_dp is not supported if g_shard_use_reduce is true"


def _is_trainable(param):
Expand All @@ -54,6 +60,7 @@ class DygraphShardingOptimizer:
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm

def __init__(self, optimizer, hcg):
logger.info("init DygraphShardingOptimizer")
# TODO(pangengzheng): support param_groups
if isinstance(optimizer._parameter_list[0], dict):
raise TypeError(
Expand All @@ -76,6 +83,7 @@ def __init__(self, optimizer, hcg):
self.tensor_fusion = strategy.hybrid_configs[
'sharding_configs'
].tensor_fusion

self.accumulate_steps = strategy.hybrid_configs[
'sharding_configs'
].accumulate_steps
Expand Down Expand Up @@ -416,3 +424,281 @@ def _set_inner_opt_attr(self, attr_name, value):

def __getattr__(self, item):
return getattr(self._inner_opt, item)


class DygraphShardingOptimizerV2:
"""
A wrapper for Sharding Optimizer in Dygraph, which split params
.. warning: DygraphShardingOptimizer is experimental and subject to change.
.. ZeRO: https://arxiv.org/abs/1910.02054
"""

# TODO (JZ-LIANG)
# TO support following featrues in future:
# 1. fused update parameter sync
# 2. parameters_groups
# 3. dynamic trainable params, which is the case bewteen pretraining and finetuning
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
# 5. do not shard small params

def __init__(self, optimizer, hcg):
logger.info("init DygraphShardingOptimizerV2")
assert (
g_shard_use_reduce
), "g_shard_use_reduce must be true if DygraphShardingOptimizerV2 is used"

# TODO(pangengzheng): support param_groups
if isinstance(optimizer._parameter_list[0], dict):
raise TypeError(
"Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter"
)
if not hasattr(optimizer, '_apply_optimize') or not callable(
optimizer._apply_optimize
):
raise ValueError(
"the optimzier object should have _apply_optimize function"
)

self._inner_opt = optimizer
self._hcg = hcg
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
self._sharding_rank = self._hcg.get_sharding_parallel_rank()

self._parameter_list = optimizer._parameter_list

# param name -> slice_param
self._slice_params = {}
# comm_buffer_list = []
self._comm_buffer_list = []

# slice parameter list
self._local_parameter_list = [
self._create_slice_param(p) for p in optimizer._parameter_list
]

strategy = fleet.fleet._user_defined_strategy
self.tensor_fusion = strategy.hybrid_configs[
'sharding_configs'
].tensor_fusion

assert not self.tensor_fusion, "not supported yet"

self.accumulate_steps = strategy.hybrid_configs[
'sharding_configs'
].accumulate_steps
self.comm_overlap = strategy.hybrid_configs[
'sharding_configs'
].comm_overlap

self.pp_overlap = strategy.hybrid_configs[
'pp_configs'
].sharding_comm_overlap

# TODO(liuzhenhai):support it latter
assert not self.comm_overlap, "not supported yet"

self._build_comm_buffers()
self._set_inner_opt_attr('_parameter_list', self._local_parameter_list)
self._set_inner_opt_attr('_param_groups', self._local_parameter_list)

def _build_comm_buffers(self, group_size=256 * 1024 * 1024):
if self.pp_overlap:
return

comm_group = self._hcg.get_sharding_parallel_group()
var_groups = assign_group_by_size(self._parameter_list, group_size)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx,
parameters,
comm_group,
act=HOOK_ACTION.REDUCE_SCATTER,
)
self._comm_buffer_list.append(buffer)

def clear_grad(self, set_to_zero=True):
"""
should clear grad for all parameters in model
"""
assert set_to_zero, "should not erase grad buffer"

def clear_grad_func(p):
if hasattr(p, "main_grad") and p.main_grad is not None:
assert p._grad_ivar() is None
if set_to_zero:
p.main_grad.zero_()
else:
p.main_grad._clear()
p.main_grad = None
elif not hasattr(p, "main_grad"):
if self.tensor_fusion:
if set_to_zero:
p.grad.zero_()
else:
p.grad._clear()
p.grad = None
else:
p.clear_gradient(set_to_zero)

for p in self._parameter_list:
clear_grad_func(p)

def filter_parameters(self, parameter_list, hcg):
parameter_list = [
self._slice_params[param.name] for param in parameter_list
]
parameter_list = [
param for param in parameter_list if param._is_initialized()
]
return parameter_list

def reduce_gradients(self, parameter_list, hcg):
# TODO merge grad / nrank with dp
logger.debug("sharding start gradients sync")
with framework.no_grad():
for comm_buffer in self._comm_buffer_list:
comm_buffer._comm_grads()
comm_buffer.scale_grads()

def _sharding_sync_parameters(self):
"""
sync parameter across sharding group
"""

logger.debug("sharding start sync parameters")
with framework.no_grad():
for comm_buffer in self._comm_buffer_list:
comm_buffer.sync_params()

def _update_trainable(self):
"""
allow user to update trainable parameters list during training
"""
raise NotImplementedError

def minimize(
self, loss, startup_program=None, parameters=None, no_grad_set=None
):
# NOTE in dygraph mode, the only different between step and minimize is that minimize
# allow user to customize the parameters for updating on each step
raise AssertionError("not supported yet")

def _create_slice_param(self, param):
# not initialized yet
slice_param = EagerParamBase(shape=[1], dtype=param.dtype)
slice_param.name = param.name

def copy_attr(attr_name):
if hasattr(param, attr_name):
setattr(slice_param, attr_name, getattr(param, attr_name))

copy_attr("is_distributed")
copy_attr("optimize_attr")
copy_attr("do_model_average")
copy_attr("need_clip")

self._slice_params[param.name] = slice_param
return slice_param

def _collect_comm_buffers(self):
if self._comm_buffer_list:
return
for param in self._parameter_list:
if not hasattr(param, "comm_buffer_ref"):
continue
comm_buffer_ref = param.comm_buffer_ref
del param.comm_buffer_ref
comm_buffer = comm_buffer_ref()
self._comm_buffer_list.append(comm_buffer)

assert self._comm_buffer_list

def _assign_slice_grad(self):
param_num = 0
for comm_buffer in self._comm_buffer_list:
param_num = param_num + len(comm_buffer.params)
for param in comm_buffer.params:
assert param.name in self._slice_params
slice_param = self._slice_params[param.name]
comm_buffer.assign_slice_grad(param, slice_param)

assert param_num == len(self._parameter_list)

def step(self):
# TODO Check whether the model trainable param changed and update state accordingly
# hack for pp comm overlap
self._collect_comm_buffers()
self._assign_slice_grad()

if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
if (
hasattr(param, "regularizer")
and param.regularizer is not None
):
raise ValueError(
f"param {param.name} should not has the regularizer attribute"
)
if param.stop_gradient:
continue
# update on slice
assert param.name in self._slice_params
param = self._slice_params[param.name]
grad_var = param._grad_ivar()
if hasattr(param, "main_grad") and param.main_grad is not None:
grad_var = param.main_grad
if grad_var is not None:
params_grads.append((param, grad_var))

self._apply_optimize(
loss=None,
startup_program=None,
params_grads=params_grads,
)

# sync parameters across sharding ranks
self._sharding_sync_parameters()

@framework.dygraph_only
def set_state_dict(self, state_dict):
inner_state = {}
parameters = self._parameter_list

if "LR_Scheduler" in state_dict:
inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler")

if "master_weights" in state_dict:
master = state_dict.pop("master_weights")
inner_state["master_weights"] = {}
for p in parameters:
for k, v in master.items():
if p.name == k:
v.name = self._inner_opt._gen_master_weight_var_name(p)
inner_state["master_weights"][k] = v

for p in parameters:
for k, v in state_dict.items():
if p.name in k:
inner_state[k] = v

self._inner_opt.set_state_dict(inner_state)

def _set_inner_opt_attr(self, attr_name, value):
inner_opt = self._inner_opt
inner_opt_name = '_inner_opt'
if not isinstance(attr_name, str):
raise TypeError(
f"attr_name should be str type, but is {type(attr_name)}"
)
while hasattr(inner_opt, attr_name):
setattr(inner_opt, attr_name, value)
inner_opt = getattr(inner_opt, inner_opt_name, None)
if inner_opt is None:
break

def __getattr__(self, item):
return getattr(self._inner_opt, item)
Loading

0 comments on commit ed28804

Please sign in to comment.