Skip to content

Commit

Permalink
[2.0/cherrypick] cherry-pick Sharding PR:29518 (#29593)
Browse files Browse the repository at this point in the history
* Sharding add hybrid-dp feature

* update sharding in distributed_strategy

* update sharding unitest

* revise code format for sharding
  • Loading branch information
JZ-LIANG authored Dec 16, 2020
1 parent d82b030 commit ab04bf0
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 93 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; }

message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
}

message AMPConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,21 @@ def remove_cast_op(block, params, segment, offset):
return inserted_op_num

@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(
input_name))
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
output_name))
if output_name in reduced_grads_to_param:
continue
if shard.has_param(param_name):
Expand Down Expand Up @@ -137,10 +141,12 @@ def prune_fp16(block, shard, reduced_grads_to_param, nrings):
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0,
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])

comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])

block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@


class GradientClipHelper(object):
def __init__(self):
pass
def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id

def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")

def prune_gradient_clip(self, block, shard):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
deperated_vars = set()
deperate_op_idx = set()
for idx, op in enumerate(block.ops):
Expand Down Expand Up @@ -75,8 +80,10 @@ def prune_gradient_clip(self, block, shard):
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
attrs={
'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_var_deps(self, var_name):
return None

def _build_deps(self, ):

for var_name in self._start_vars:
self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def is_opti_var(self, var_name):
return True
return False

def filter_grads(self, grads):
grads_in_shard = []
for grad in grads:
param = grad.split("@")[0]
if self.has_param(param):
grads_in_shard.append(grad)
return grads_in_shard


class ProgramSegment(object):
def __init__(self, block):
Expand Down
193 changes: 145 additions & 48 deletions python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,52 +78,137 @@ def check_broadcast(block):
return


def check_allreduce_sum(block):
def check_allreduce_sum(block, shard, dp_ring_id=-1):
"""
if a Var is allreduced, the op order should be:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum op
- 3: sync_comm
- 4: op that use Var
the op order should be:
grad:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum_sharding
- 3: sync_comm
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
"""
var_status = {}
for op in block.ops:
vars_status = {}
dp_grads_status = {}
idx_last_grad_allreduce = -1
idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum":
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
var_status[var_name] = -1
param = var_name.split("@")[0]

assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1

if ring_id != 0:
assert shard.has_param(param)
assert ring_id == dp_ring_id

if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx

if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx

for op in block.ops:
if op.type == "c_sync_calc_stream":
for var_name in var_status:
if var_name in var_status and var_status[var_name] == 0:
var_status[var_name] = 1
for var_name in vars_status:
if var_name in vars_status and vars_status[var_name] == 0:
vars_status[var_name] = 1
for var_name in dp_grads_status:
if var_name in dp_grads_status and dp_grads_status[
var_name] == 0:
dp_grads_status[var_name] = 1

elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
if var_status[var_name] == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if var_status[var_name] == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (var_status[var_name] == 1)
var_status[var_name] = 2
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4

elif op.type == "c_sync_comm_stream":
for var_name in op.desc.input_arg_names():
if var_name in var_status and var_status[var_name] == 2:
var_status[var_name] = 3
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
for var_name in op.desc.input_arg_names():
if var_name in vars_status:
assert vars_status[var_name] == 2
vars_status[var_name] = 3
elif var_name in dp_grads_status:
assert dp_grads_status[var_name] == 2
dp_grads_status[var_name] = 3
else:
for var_name in op.desc.input_arg_names():
param = var_name.split("@")[0]
assert ring_id == dp_ring_id
assert shard.has_param(param)
assert dp_grads_status[var_name] == 4
dp_grads_status[var_name] = 5
else:
for input_name in op.desc.input_arg_names():
if input_name in var_status:
if var_status[input_name] != 3:
if input_name in vars_status:
if vars_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format(
var_name))
input_name))
if input_name in dp_grads_status:
if dp_ring_id == -1:
if dp_grads_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".
format(input_name))
else:
if dp_grads_status[input_name] != 5:
raise ValueError(
"The grad in shard should be allreduce and sync"
"twice before usage {}".format(input_name))

for output_name in op.desc.output_arg_names():
if output_name in var_status and \
var_status[output_name] == -1:
var_status[output_name] = 0
if output_name in vars_status and \
vars_status[output_name] == -1:
vars_status[output_name] = 0
if output_name in dp_grads_status and \
dp_grads_status[output_name] == -1:
dp_grads_status[output_name] = 0

# check sharding with amp
if idx_amp_allreduce != -1:
assert idx_amp_allreduce > idx_last_grad_allreduce

# check sharding with gradient_clip_by_global_norm
if idx_gradient_clip_allreduce != -1:
assert idx_gradient_clip_allreduce > idx_last_grad_allreduce

return


Expand Down Expand Up @@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
return


def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars):
def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
"""
_insert_sync_comm_ops
insert sync_comm_op for single var
"""
op_role = get_valid_op_role(block, insert_idx)
for i in range(nrings):
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': i,
OP_ROLE_KEY: op_role})
return nrings
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: op_role})
return 1


def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': int(ring_id),
OP_ROLE_KEY: op_role})
return 1


def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
Expand Down Expand Up @@ -210,31 +309,28 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return


def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
"""
_add_allreduce_ops
"""
ring_id = -1
for var in allreduce_vars:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})

return


def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
"""
ring_id = -1
op_role = get_valid_op_role(block, insert_idx)
for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_broadcast',
Expand All @@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
'root': root_device,
OP_ROLE_KEY: op_role
})

return


Expand Down
Loading

0 comments on commit ab04bf0

Please sign in to comment.