Skip to content

Commit

Permalink
[AutoParallel] add dist_attr in data_parallel optimization (#49744)
Browse files Browse the repository at this point in the history
* fix dist_attr in data_parallel in optimization

* fix grad_clip pass when pp2

* fix dist_attr
  • Loading branch information
zhaoyinglia authored Feb 27, 2023
1 parent 3c12104 commit a36cdd6
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 117 deletions.
18 changes: 16 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,14 @@ def __str__(self):
)

for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
try:
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
except IndexError:
raise IndexError(
"There is not input var '{}''s dist_attr in current op '{}'".format(
arg_name, self.serial_op.desc.type()
)
)
if self.dist_attr.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
Expand All @@ -238,7 +245,14 @@ def __str__(self):
)

for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
try:
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
except IndexError:
raise IndexError(
"There is not output var '{}''s dist_attr in current op '{}'".format(
arg_name, self.serial_op.desc.type()
)
)
if self.dist_attr.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
Expand Down
33 changes: 23 additions & 10 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,9 +1426,6 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
def naive_set_dist_op_attr_for_program_by_mesh(
new_op, process_mesh, ctx, is_recompute=False
):
# hack to skip coalesce var for dist attr
if not is_recompute:
return
assert process_mesh is not None

new_op_dist_attr = OperatorDistAttr()
Expand Down Expand Up @@ -2314,15 +2311,31 @@ def insert_dependencies_for_vars(
},
outputs={"Out": post_vars},
)

# depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, oprole)
# depend_op.desc.set_input("Dep", [first_var.name])
# self.desc.set_output(out_proto.name, out_arg_names)

naive_set_dist_op_attr_for_program_by_mesh(
depend_op, process_mesh, dist_context, is_recompute
)
# TODO: condition can be removed when add correct dist_attr for coalesce vars and ops in sharding_pass
if is_recompute or process_mesh != [-1]:
depend_op_dist_attr = OperatorDistAttr()
depend_op_dist_attr.impl_idx = 0
depend_op_dist_attr.impl_type = "default"
depend_op_dist_attr.process_mesh = process_mesh
depend_op_dist_attr.is_recompute = is_recompute
for input_varname in depend_op.desc.input_arg_names():
var = block.var(input_varname)
mapping = dist_context.get_tensor_dist_attr_for_program(
var
).dims_mapping
depend_op_dist_attr.set_input_dims_mapping(input_varname, mapping)
for output_varname in depend_op.desc.output_arg_names():
var = block.var(output_varname)
mapping = dist_context.get_tensor_dist_attr_for_program(
var
).dims_mapping
depend_op_dist_attr.set_output_dims_mapping(output_varname, mapping)
dist_context.set_op_dist_attr_for_program(
depend_op, depend_op_dist_attr
)

if op_namescope is not None:
depend_op._set_attr('op_namescope', "/{}".format(op_namescope))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
from collections import OrderedDict

import paddle
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_reduce_op,
is_data_parallel_scale_op,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
find_higher_order_backward_op,
get_var_numel,
Expand Down Expand Up @@ -463,6 +468,21 @@ def _update_program(self, grad_groups):
group.coalesce_var = group.gradients[0]
continue

ref_process_mesh = set()
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
grad_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(grad_)
)
ref_process_mesh.update(
set(grad_dist_attr.process_mesh.process_ids)
)

shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))

# create coalesce tensor
group.coalesce_var = block.create_var(
name=unique_name.generate(
Expand All @@ -473,6 +493,13 @@ def _update_program(self, grad_groups):
stop_gradient=True,
)

tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
tensor_dist_attr.dims_mapping = []
self.dist_context.set_tensor_dist_attr_for_program(
group.coalesce_var, tensor_dist_attr
)

# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
Expand All @@ -492,11 +519,27 @@ def _update_program(self, grad_groups):
), "should found c_allreduce_sum op but found {}".format(
str(allreduce_op)
)
allreduce_op._rename_input(
allreduce_op.input_arg_names[0], group.coalesce_var.name
allreduce_op_dist_attr = (
self.dist_context.get_op_dist_attr_for_program(allreduce_op)
)
old_in_name = allreduce_op.input_arg_names[0]
new_in_name = group.coalesce_var.name
allreduce_op._rename_input(old_in_name, new_in_name)
input_dist_attr = allreduce_op_dist_attr.get_input_dist_attr(
old_in_name
)
allreduce_op._rename_output(
allreduce_op.output_arg_names[0], group.coalesce_var.name
allreduce_op_dist_attr.set_input_dist_attr(
new_in_name, input_dist_attr
)

old_out_name = allreduce_op.output_arg_names[0]
new_out_name = group.coalesce_var.name
allreduce_op._rename_output(old_out_name, new_out_name)
out_dist_attr = allreduce_op_dist_attr.get_output_dist_attr(
old_out_name
)
allreduce_op_dist_attr.set_output_dist_attr(
new_out_name, out_dist_attr
)

# remvoe un-used op
Expand All @@ -512,15 +555,8 @@ def _update_program(self, grad_groups):
block._remove_op(idx, False)

# insert coalesce op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))

grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(
coalesce_op = block._insert_op_without_sync(
group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
Expand All @@ -538,8 +574,32 @@ def _update_program(self, grad_groups):
},
)

op_dist_attr = OperatorDistAttr()
op_dist_attr.impl_idx = 0
op_dist_attr.impl_type = "default"
op_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
for in_name in coalesce_op.input_arg_names:
in_var = block.var(in_name)
in_var_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(in_var)
)
op_dist_attr.set_input_dims_mapping(
in_name, in_var_dist_attr.dims_mapping
)
for out_name in coalesce_op.output_arg_names:
out_var = block.var(out_name)
out_var_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(out_var)
)
op_dist_attr.set_output_dims_mapping(
out_name, out_var_dist_attr.dims_mapping
)

self.dist_context.set_op_dist_attr_for_program(
coalesce_op, op_dist_attr
)

block._sync_with_cpp()
# TODO update dist attr

def _add_dependencies(self, grad_groups):
# NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based
Expand All @@ -551,22 +611,12 @@ def _add_dependencies(self, grad_groups):
block = default_main_program().global_block()

# Build maps
vars_to_coalesce_map = {}
coalesce_to_vars_map = {}

for group in grad_groups:
grad_names = []
coalesce_name = group.coalesce_var.name
for grad in group.gradients:
vars_to_coalesce_map[grad.name] = coalesce_name
grad_names.append(grad.name)
coalesce_to_vars_map[coalesce_name] = grad_names
coalesce_to_vars_map[group.coalesce_var.name] = group

# analyze dependencies
# Record ONLY the last grad that generated before allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
prior_allreduce_deps = {}
dep_map = {}
for idx, op in reversed(list(enumerate(block.ops))):
if is_forward_op(op):
break
Expand All @@ -575,86 +625,41 @@ def _add_dependencies(self, grad_groups):

if is_data_parallel_reduce_op(op):
coalesce_var_name = op.output_arg_names[0]

# NOTE only add extra deps for fused tensor, other tensor rely on
# data flow analysis of executor.
if self.coalesce_prefix in coalesce_var_name:
prior_allreduce_deps[coalesce_var_name] = [
idx,
None,
coalesce_var_name,
]
not_sync_coalesces.append(coalesce_var_name)
continue

for out_name in op.output_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
prior_allreduce_deps[var_name][1] = out_name
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexpected: {} has NOT been add prior Dep before allreduce.".format(
not_sync_coalesces
)

# Record ONLY the first grad that used after allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
post_allreduce_deps = {}
for idx, op in enumerate(block.ops):
if is_forward_op(op):
continue

if is_data_parallel_reduce_op(op):
coalesce_var_name = op.input_arg_names[0]
if self.coalesce_prefix in coalesce_var_name:
post_allreduce_deps[coalesce_var_name] = [
None,
coalesce_var_name,
None,
group = coalesce_to_vars_map[coalesce_var_name]
dep_map[idx] = [
(
idx,
group.gradients[-1],
group.coalesce_var,
op.attr(OP_ROLE_KEY),
)
]
not_sync_coalesces.append(coalesce_var_name)
continue

for out_name in op.input_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
post_allreduce_deps[var_name][0] = idx
post_allreduce_deps[var_name][2] = out_name
not_sync_coalesces.remove(var_name)

assert (
len(not_sync_coalesces) == 0
), "Unexpected: {} has NOT been add post Dep after allreduce.".format(
not_sync_coalesces
)
dep_map[idx].append(
(
idx + 1,
group.coalesce_var,
group.gradients,
op.attr(OP_ROLE_KEY),
)
)

# Update program IR insert dependencise op
dep_var_pairs = []
for deps in [prior_allreduce_deps, post_allreduce_deps]:
for pair in deps.values():
dep_var_pairs.append(pair)

dep_var_pairs.sort(key=lambda x: x[0], reverse=True)
for idx, prior_name, post_name in dep_var_pairs:
prior_var = block.var(prior_name)
post_var = block.var(post_name)
depend_op = insert_dependencies_for_vars(
block,
idx,
prior_var,
post_var,
self.dist_context,
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesce var
is_recompute=False,
sync=False,
op_namescope="data_parallel_overlap_dep",
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
# insert dependency op
indice = sorted(list(dep_map.keys()), reverse=True)
for i in indice:
for idx, prior_vars, post_vars, op_role in dep_map[i][::-1]:
depend_op = insert_dependencies_for_vars(
block,
idx,
prior_vars,
post_vars,
self.dist_context,
op_role,
is_recompute=False,
sync=False,
op_namescope="data_parallel_overlap_dep",
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
block._sync_with_cpp()

# remove naive synchronization & assign allreduce stream
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/passes/auto_parallel_grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def _is_pure_data_parallel(self):
"c_allreduce_sum",
] and not is_data_parallel_reduce_op(op):
return False
if op.type in ["send_v2", "recv_v2"]:
return False

return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ def _apply_single_impl(self, main_program, startup_program, context):
post_var,
self._dist_context,
OpRole.Optimize,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
op_namescope=op_namescope,
Expand Down

0 comments on commit a36cdd6

Please sign in to comment.