Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] add dist_attr in data_parallel optimization #49744

Merged
merged 7 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,14 @@ def __str__(self):
)

for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
try:
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
except IndexError:
raise IndexError(
"There is not input var '{}''s dist_attr in current op '{}'".format(
arg_name, self.serial_op.desc.type()
)
)
if self.dist_attr.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
Expand All @@ -238,7 +245,14 @@ def __str__(self):
)

for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
try:
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
except IndexError:
raise IndexError(
"There is not output var '{}''s dist_attr in current op '{}'".format(
arg_name, self.serial_op.desc.type()
)
)
if self.dist_attr.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
Expand Down
33 changes: 23 additions & 10 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,9 +1424,6 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
def naive_set_dist_op_attr_for_program_by_mesh(
new_op, process_mesh, ctx, is_recompute=False
):
# hack to skip coalesce var for dist attr
if not is_recompute:
return
assert process_mesh is not None

new_op_dist_attr = OperatorDistAttr()
Expand Down Expand Up @@ -2312,15 +2309,31 @@ def insert_dependencies_for_vars(
},
outputs={"Out": post_vars},
)

# depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, oprole)
# depend_op.desc.set_input("Dep", [first_var.name])
# self.desc.set_output(out_proto.name, out_arg_names)

naive_set_dist_op_attr_for_program_by_mesh(
depend_op, process_mesh, dist_context, is_recompute
)
# TODO: condition can be removed when add correct dist_attr for coalesce vars and ops in sharding_pass
if is_recompute or process_mesh != [-1]:
depend_op_dist_attr = OperatorDistAttr()
depend_op_dist_attr.impl_idx = 0
depend_op_dist_attr.impl_type = "default"
depend_op_dist_attr.process_mesh = process_mesh
depend_op_dist_attr.is_recompute = is_recompute
for input_varname in depend_op.desc.input_arg_names():
var = block.var(input_varname)
mapping = dist_context.get_tensor_dist_attr_for_program(
var
).dims_mapping
depend_op_dist_attr.set_input_dims_mapping(input_varname, mapping)
for output_varname in depend_op.desc.output_arg_names():
var = block.var(output_varname)
mapping = dist_context.get_tensor_dist_attr_for_program(
var
).dims_mapping
depend_op_dist_attr.set_output_dims_mapping(output_varname, mapping)
dist_context.set_op_dist_attr_for_program(
depend_op, depend_op_dist_attr
)

if op_namescope is not None:
depend_op._set_attr('op_namescope', "/{}".format(op_namescope))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
from collections import OrderedDict

import paddle
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_reduce_op,
is_data_parallel_scale_op,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
find_higher_order_backward_op,
get_var_numel,
Expand Down Expand Up @@ -463,6 +468,21 @@ def _update_program(self, grad_groups):
group.coalesce_var = group.gradients[0]
continue

ref_process_mesh = set()
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
grad_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(grad_)
)
ref_process_mesh.update(
set(grad_dist_attr.process_mesh.process_ids)
)

shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))

# create coalesce tensor
group.coalesce_var = block.create_var(
name=unique_name.generate(
Expand All @@ -473,6 +493,13 @@ def _update_program(self, grad_groups):
stop_gradient=True,
)

tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
tensor_dist_attr.dims_mapping = []
self.dist_context.set_tensor_dist_attr_for_program(
group.coalesce_var, tensor_dist_attr
)

# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
Expand All @@ -492,11 +519,27 @@ def _update_program(self, grad_groups):
), "should found c_allreduce_sum op but found {}".format(
str(allreduce_op)
)
allreduce_op._rename_input(
allreduce_op.input_arg_names[0], group.coalesce_var.name
allreduce_op_dist_attr = (
self.dist_context.get_op_dist_attr_for_program(allreduce_op)
)
old_in_name = allreduce_op.input_arg_names[0]
new_in_name = group.coalesce_var.name
allreduce_op._rename_input(old_in_name, new_in_name)
input_dist_attr = allreduce_op_dist_attr.get_input_dist_attr(
old_in_name
)
allreduce_op._rename_output(
allreduce_op.output_arg_names[0], group.coalesce_var.name
allreduce_op_dist_attr.set_input_dist_attr(
new_in_name, input_dist_attr
)

old_out_name = allreduce_op.output_arg_names[0]
new_out_name = group.coalesce_var.name
allreduce_op._rename_output(old_out_name, new_out_name)
out_dist_attr = allreduce_op_dist_attr.get_output_dist_attr(
old_out_name
)
allreduce_op_dist_attr.set_output_dist_attr(
new_out_name, out_dist_attr
)

# remvoe un-used op
Expand All @@ -512,15 +555,8 @@ def _update_program(self, grad_groups):
block._remove_op(idx, False)

# insert coalesce op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))

grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(
coalesce_op = block._insert_op_without_sync(
group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
Expand All @@ -538,8 +574,32 @@ def _update_program(self, grad_groups):
},
)

op_dist_attr = OperatorDistAttr()
op_dist_attr.impl_idx = 0
op_dist_attr.impl_type = "default"
op_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
for in_name in coalesce_op.input_arg_names:
in_var = block.var(in_name)
in_var_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(in_var)
)
op_dist_attr.set_input_dims_mapping(
in_name, in_var_dist_attr.dims_mapping
)
for out_name in coalesce_op.output_arg_names:
out_var = block.var(out_name)
out_var_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(out_var)
)
op_dist_attr.set_output_dims_mapping(
out_name, out_var_dist_attr.dims_mapping
)

self.dist_context.set_op_dist_attr_for_program(
coalesce_op, op_dist_attr
)

block._sync_with_cpp()
# TODO update dist attr

def _add_dependencies(self, grad_groups):
# NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based
Expand All @@ -551,22 +611,12 @@ def _add_dependencies(self, grad_groups):
block = default_main_program().global_block()

# Build maps
vars_to_coalesce_map = {}
coalesce_to_vars_map = {}

for group in grad_groups:
grad_names = []
coalesce_name = group.coalesce_var.name
for grad in group.gradients:
vars_to_coalesce_map[grad.name] = coalesce_name
grad_names.append(grad.name)
coalesce_to_vars_map[coalesce_name] = grad_names
coalesce_to_vars_map[group.coalesce_var.name] = group

# analyze dependencies
# Record ONLY the last grad that generated before allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
prior_allreduce_deps = {}
dep_map = {}
for idx, op in reversed(list(enumerate(block.ops))):
if is_forward_op(op):
break
Expand All @@ -575,86 +625,41 @@ def _add_dependencies(self, grad_groups):

if is_data_parallel_reduce_op(op):
coalesce_var_name = op.output_arg_names[0]

# NOTE only add extra deps for fused tensor, other tensor rely on
# data flow analysis of executor.
if self.coalesce_prefix in coalesce_var_name:
prior_allreduce_deps[coalesce_var_name] = [
idx,
None,
coalesce_var_name,
]
not_sync_coalesces.append(coalesce_var_name)
continue

for out_name in op.output_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
prior_allreduce_deps[var_name][1] = out_name
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexpected: {} has NOT been add prior Dep before allreduce.".format(
not_sync_coalesces
)

# Record ONLY the first grad that used after allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
post_allreduce_deps = {}
for idx, op in enumerate(block.ops):
if is_forward_op(op):
continue

if is_data_parallel_reduce_op(op):
coalesce_var_name = op.input_arg_names[0]
if self.coalesce_prefix in coalesce_var_name:
post_allreduce_deps[coalesce_var_name] = [
None,
coalesce_var_name,
None,
group = coalesce_to_vars_map[coalesce_var_name]
dep_map[idx] = [
(
idx,
group.gradients[-1],
group.coalesce_var,
op.attr(OP_ROLE_KEY),
)
]
not_sync_coalesces.append(coalesce_var_name)
continue

for out_name in op.input_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
post_allreduce_deps[var_name][0] = idx
post_allreduce_deps[var_name][2] = out_name
not_sync_coalesces.remove(var_name)

assert (
len(not_sync_coalesces) == 0
), "Unexpected: {} has NOT been add post Dep after allreduce.".format(
not_sync_coalesces
)
dep_map[idx].append(
(
idx + 1,
group.coalesce_var,
group.gradients,
op.attr(OP_ROLE_KEY),
)
)

# Update program IR insert dependencise op
dep_var_pairs = []
for deps in [prior_allreduce_deps, post_allreduce_deps]:
for pair in deps.values():
dep_var_pairs.append(pair)

dep_var_pairs.sort(key=lambda x: x[0], reverse=True)
for idx, prior_name, post_name in dep_var_pairs:
prior_var = block.var(prior_name)
post_var = block.var(post_name)
depend_op = insert_dependencies_for_vars(
block,
idx,
prior_var,
post_var,
self.dist_context,
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesce var
is_recompute=False,
sync=False,
op_namescope="data_parallel_overlap_dep",
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
# insert dependency op
indice = sorted(list(dep_map.keys()), reverse=True)
for i in indice:
for idx, prior_vars, post_vars, op_role in dep_map[i][::-1]:
depend_op = insert_dependencies_for_vars(
block,
idx,
prior_vars,
post_vars,
self.dist_context,
op_role,
is_recompute=False,
sync=False,
op_namescope="data_parallel_overlap_dep",
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
block._sync_with_cpp()

# remove naive synchronization & assign allreduce stream
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/passes/auto_parallel_grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def _is_pure_data_parallel(self):
"c_allreduce_sum",
] and not is_data_parallel_reduce_op(op):
return False
if op.type in ["send_v2", "recv_v2"]:
return False

return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ def _apply_single_impl(self, main_program, startup_program, context):
post_var,
self._dist_context,
OpRole.Optimize,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
op_namescope=op_namescope,
Expand Down