Skip to content

Commit

Permalink
[Auto Parallel-Performance] Sharding Comm Optimization (#48604)
Browse files Browse the repository at this point in the history
* remove deps and prior comm

* grad comm fuse

* add deps for amp&global norm

* stage2 broadcast prior deps

* stage2 grad overlap

* stream_analyzer bugfix

* overlap enable

* dep op namescope

* depend support multiple inputs

* check finite deps

* stage2 param comm overlap

* Set kD2HStream

* grad comm hierarchical

* grad comm hierarchical

* new unitest

Co-authored-by: chenruibiao <[email protected]>
  • Loading branch information
JZ-LIANG and From00 authored Jan 4, 2023
1 parent 852c8db commit 5592f8a
Show file tree
Hide file tree
Showing 15 changed files with 1,350 additions and 198 deletions.
8 changes: 6 additions & 2 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "overlap_grad_comm", False)
set_field_default_config(SHARDING, "bucket_size_numel", -1)
set_field_default_config(SHARDING, "enable_overlap", False)
set_field_default_config(SHARDING, "param_comm_stream_num", 1)
set_field_default_config(SHARDING, "grad_comm_stream_num", 1)
set_field_default_config(SHARDING, "param_bucket_size_numel", 1)
set_field_default_config(SHARDING, "grad_bucket_size_numel", 1)
set_field_default_config(SHARDING, "enable_hierarchical_comm", False)
set_field_default_config(SHARDING, "partition_algor", "greedy_even")
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
Expand Down
27 changes: 26 additions & 1 deletion python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ class ParallelMode:
MoEParallel = "auto_parallel/moe_parallel"


class SyncMode:
"""
the synchorization mode for communication or auxiliary operator
"""

AmpFlagSync = "auto_parallel/amp_flag_synchorization"
GlobalNormSync = "auto_parallel/global_norm_synchorization"


def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
Expand Down Expand Up @@ -441,7 +450,7 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
assert (
dims_mapping is not None
), "Unexception: dims_mapping of output [{}] of op [{}] is None".format(
), "Unexpected: dims_mapping of output [{}] of op [{}] is None".format(
grad_var.name, op_dist_attr.op_type
)
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
Expand Down Expand Up @@ -502,6 +511,22 @@ def is_data_parallel_reduce_op(op):
)


def is_amp_flag_sync_op(op):
return (
op.type == "c_allreduce_max"
and op.desc.has_attr("op_namescope")
and SyncMode.AmpFlagSync in op.desc.attr("op_namescope")
)


def is_global_norm_sync_op(op):
return (
op.type == "c_allreduce_sum"
and op.desc.has_attr("op_namescope")
and SyncMode.GlobalNormSync in op.desc.attr("op_namescope")
)


def is_in_backward_phase(dist_ctx):
# NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators
# in Forward phase and operators in Backward phase (both with op_role=1), which will mislead
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
SyncMode,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
Expand Down Expand Up @@ -166,6 +167,7 @@ def backward(ctx, *args, **kwargs):
OP_ROLE_KEY: OpRole.Optimize,
},
)
allreduce_op._set_attr('op_namescope', str('/') + SyncMode.AmpFlagSync)
cast_op2 = main_block.append_op(
type='cast',
inputs={'X': inf_var_int32},
Expand Down
10 changes: 10 additions & 0 deletions python/paddle/distributed/auto_parallel/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ def _apply_post_optimization(
[main_program], [startup_program], self._pass_context
)

# deps for newexe
config = {}
config["dist_context"] = self._dist_context
APSED_pass = new_pass(
"auto_parallel_supplement_explicit_dependencies", config
)
APSED_pass.apply(
[main_program], [startup_program], self._pass_context
)

# gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
Expand Down
17 changes: 9 additions & 8 deletions python/paddle/distributed/auto_parallel/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,16 @@ def clear_all_process_groups():
_g_process_group_map[0] = ProcessGroup(0, [])


def new_process_group(ranks, group_id=None):
def new_process_group(ranks, group_id=None, force_new_group=False):

global _g_process_group_map
# A key constructed from ranks is used for avoiding duplication
new_key = ''.join(map(str, sorted(ranks)))
for pg_id, pg in _g_process_group_map.items():
cur_key = ''.join(map(str, sorted(pg.ranks)))
if pg_id != 0 and new_key == cur_key:
return pg
if not force_new_group:
# A key constructed from ranks is used for avoiding duplication
new_key = ''.join(map(str, sorted(ranks)))
for pg_id, pg in _g_process_group_map.items():
cur_key = ''.join(map(str, sorted(pg.ranks)))
if pg_id != 0 and new_key == cur_key:
return pg
# If not matching the existing one, construt a new process group
num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation
Expand Down Expand Up @@ -137,7 +139,6 @@ def instantiate(self):
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1

if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(
Expand Down
78 changes: 59 additions & 19 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,8 @@ def _get_split_indices(


def set_grad_var_shape(program, dist_context):
from paddle.distributed.fleet.meta_optimizers.common import OpRole

from .operators.common import infer_shape

block = program.global_block()
Expand Down Expand Up @@ -1955,6 +1957,9 @@ def set_recompute_segments(model, losses, strategy, program):
and hasattr(model.gpt, "checkpoints")
):
ckpts = model.gpt.checkpoints
# last recompute segment is not need to recompute
if len(ckpts) > 2:
ckpts.pop()
else:
ckpts = recompute.checkpoints
else:
Expand Down Expand Up @@ -2189,6 +2194,7 @@ def insert_dependencies_for_two_ops(
dist_context,
is_recompute=False,
sync=False,
op_namescope=None,
):
"""
dependency: prior_op should be run before posterior_op
Expand Down Expand Up @@ -2233,49 +2239,74 @@ def _select_best_depend_var(vars):
[block.var(name) for name in posterior_op.input_arg_names]
)

return insert_dependencies_for_two_vars(
return insert_dependencies_for_vars(
block,
idx,
first_var,
second_var,
dist_context,
OpRole.Backward,
prior_op_mesh,
is_recompute,
sync,
process_mesh=prior_op_mesh,
is_recompute=is_recompute,
sync=sync,
op_namescope=op_namescope,
use_nop=False,
)


def insert_dependencies_for_two_vars(
def insert_dependencies_for_vars(
block,
idx,
prior_var,
post_var,
prior_vars,
post_vars,
dist_context,
oprole,
process_mesh=None,
is_recompute=False,
sync=False,
op_namescope=None,
use_nop=False,
):
"""
dependency: op that generates prior_var should be run before op that generates post_var
dependency: op that generates prior_vars should be run before op that generates post_vars
"""
assert block.has_var(prior_var.name)
assert block.has_var(post_var.name)

if isinstance(prior_vars, Variable):
prior_vars = [prior_vars]
if isinstance(post_vars, Variable):
post_vars = [post_vars]
for prior_var in prior_vars:
assert block.has_var(prior_var.name)
for post_var in post_vars:
assert block.has_var(post_var.name)

if process_mesh is None:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
post_var
post_vars[0]
).process_mesh
assert process_mesh is not None

depend_op = block._insert_op_without_sync(
idx,
type='nop',
inputs={
"X": prior_var,
},
outputs={"Out": post_var},
)
use_nop = True
if use_nop:
depend_op = block._insert_op_without_sync(
idx,
type='nop',
inputs={
"X": prior_vars,
},
outputs={"Out": post_vars},
)
else:
depend_op = block._insert_op_without_sync(
idx,
type='depend',
inputs={
"X": post_vars,
"Dep": prior_vars,
},
outputs={"Out": post_vars},
)

# depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, oprole)
# depend_op.desc.set_input("Dep", [first_var.name])
Expand All @@ -2284,13 +2315,22 @@ def insert_dependencies_for_two_vars(
naive_set_dist_op_attr_for_program_by_mesh(
depend_op, process_mesh, dist_context, is_recompute
)
if op_namescope is not None:
depend_op._set_attr('op_namescope', "/{}".format(op_namescope))

if sync:
block._sync_with_cpp()

return depend_op


def is_dep_skip_op(op):
if "c_" in op.type:
return True

return False


def use_standalone_executor():
return os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [
1,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .cpp_pass import * # noqa: F403
from .ps_trainer_pass import * # noqa: F403
from .ps_server_pass import * # noqa: F403
Expand Down
Loading

0 comments on commit 5592f8a

Please sign in to comment.