Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.0/cherrypick] cherry-pick Sharding PR:29518 #29593

Merged
merged 4 commits into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; }

message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
}

message AMPConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,21 @@ def remove_cast_op(block, params, segment, offset):
return inserted_op_num

@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(
input_name))
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
output_name))
if output_name in reduced_grads_to_param:
continue
if shard.has_param(param_name):
Expand Down Expand Up @@ -137,10 +141,12 @@ def prune_fp16(block, shard, reduced_grads_to_param, nrings):
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0,
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])

comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])

block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@


class GradientClipHelper(object):
def __init__(self):
pass
def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id

def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")

def prune_gradient_clip(self, block, shard):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
deperated_vars = set()
deperate_op_idx = set()
for idx, op in enumerate(block.ops):
Expand Down Expand Up @@ -75,8 +80,10 @@ def prune_gradient_clip(self, block, shard):
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
attrs={
'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_var_deps(self, var_name):
return None

def _build_deps(self, ):

for var_name in self._start_vars:
self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def is_opti_var(self, var_name):
return True
return False

def filter_grads(self, grads):
grads_in_shard = []
for grad in grads:
param = grad.split("@")[0]
if self.has_param(param):
grads_in_shard.append(grad)
return grads_in_shard


class ProgramSegment(object):
def __init__(self, block):
Expand Down
193 changes: 145 additions & 48 deletions python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,52 +78,137 @@ def check_broadcast(block):
return


def check_allreduce_sum(block):
def check_allreduce_sum(block, shard, dp_ring_id=-1):
"""
if a Var is allreduced, the op order should be:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum op
- 3: sync_comm
- 4: op that use Var
the op order should be:
grad:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum_sharding
- 3: sync_comm
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
"""
var_status = {}
for op in block.ops:
vars_status = {}
dp_grads_status = {}
idx_last_grad_allreduce = -1
idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum":
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
var_status[var_name] = -1
param = var_name.split("@")[0]

assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1

if ring_id != 0:
assert shard.has_param(param)
assert ring_id == dp_ring_id

if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx

if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx

for op in block.ops:
if op.type == "c_sync_calc_stream":
for var_name in var_status:
if var_name in var_status and var_status[var_name] == 0:
var_status[var_name] = 1
for var_name in vars_status:
if var_name in vars_status and vars_status[var_name] == 0:
vars_status[var_name] = 1
for var_name in dp_grads_status:
if var_name in dp_grads_status and dp_grads_status[
var_name] == 0:
dp_grads_status[var_name] = 1

elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
if var_status[var_name] == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if var_status[var_name] == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (var_status[var_name] == 1)
var_status[var_name] = 2
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4

elif op.type == "c_sync_comm_stream":
for var_name in op.desc.input_arg_names():
if var_name in var_status and var_status[var_name] == 2:
var_status[var_name] = 3
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
for var_name in op.desc.input_arg_names():
if var_name in vars_status:
assert vars_status[var_name] == 2
vars_status[var_name] = 3
elif var_name in dp_grads_status:
assert dp_grads_status[var_name] == 2
dp_grads_status[var_name] = 3
else:
for var_name in op.desc.input_arg_names():
param = var_name.split("@")[0]
assert ring_id == dp_ring_id
assert shard.has_param(param)
assert dp_grads_status[var_name] == 4
dp_grads_status[var_name] = 5
else:
for input_name in op.desc.input_arg_names():
if input_name in var_status:
if var_status[input_name] != 3:
if input_name in vars_status:
if vars_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format(
var_name))
input_name))
if input_name in dp_grads_status:
if dp_ring_id == -1:
if dp_grads_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".
format(input_name))
else:
if dp_grads_status[input_name] != 5:
raise ValueError(
"The grad in shard should be allreduce and sync"
"twice before usage {}".format(input_name))

for output_name in op.desc.output_arg_names():
if output_name in var_status and \
var_status[output_name] == -1:
var_status[output_name] = 0
if output_name in vars_status and \
vars_status[output_name] == -1:
vars_status[output_name] = 0
if output_name in dp_grads_status and \
dp_grads_status[output_name] == -1:
dp_grads_status[output_name] = 0

# check sharding with amp
if idx_amp_allreduce != -1:
assert idx_amp_allreduce > idx_last_grad_allreduce

# check sharding with gradient_clip_by_global_norm
if idx_gradient_clip_allreduce != -1:
assert idx_gradient_clip_allreduce > idx_last_grad_allreduce

return


Expand Down Expand Up @@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
return


def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars):
def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
"""
_insert_sync_comm_ops
insert sync_comm_op for single var
"""
op_role = get_valid_op_role(block, insert_idx)
for i in range(nrings):
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': i,
OP_ROLE_KEY: op_role})
return nrings
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: op_role})
return 1


def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': int(ring_id),
OP_ROLE_KEY: op_role})
return 1


def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
Expand Down Expand Up @@ -210,31 +309,28 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return


def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
"""
_add_allreduce_ops
"""
ring_id = -1
for var in allreduce_vars:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})

return


def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
"""
ring_id = -1
op_role = get_valid_op_role(block, insert_idx)
for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_broadcast',
Expand All @@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
'root': root_device,
OP_ROLE_KEY: op_role
})

return


Expand Down
Loading