Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Sharding Optimization:Partition Algorithm & Stage2 Parameter Bucket communication #47180

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b309a72
partition param by order
JZ-LIANG Oct 19, 2022
c8c9fb7
bugfix
JZ-LIANG Oct 21, 2022
a88af49
bugfix
JZ-LIANG Oct 21, 2022
4121b55
bugfix
JZ-LIANG Oct 21, 2022
0335bce
bugfix
JZ-LIANG Oct 21, 2022
5539596
add logging
JZ-LIANG Oct 25, 2022
e93def0
reorder opt
JZ-LIANG Oct 26, 2022
faaea3c
config
JZ-LIANG Oct 26, 2022
c4c4b54
bugfix
JZ-LIANG Oct 26, 2022
482a947
bugfix
JZ-LIANG Oct 26, 2022
184bbc9
bugfix
JZ-LIANG Oct 26, 2022
2483f38
bugfix
JZ-LIANG Oct 26, 2022
5b5a61a
bugfix
JZ-LIANG Oct 26, 2022
5646f8f
bugfix
JZ-LIANG Oct 26, 2022
c462aba
bugfix
JZ-LIANG Oct 26, 2022
139aa0f
bugfix
JZ-LIANG Oct 26, 2022
36d97be
bugfix
JZ-LIANG Oct 26, 2022
309b5d5
stage2 bucket
JZ-LIANG Nov 1, 2022
d8bea00
bugfix
JZ-LIANG Nov 1, 2022
6d8c7cf
bugfix
JZ-LIANG Nov 1, 2022
7ccaff1
logging
JZ-LIANG Nov 1, 2022
34415f3
bugfix
JZ-LIANG Nov 1, 2022
038204e
bugfix
JZ-LIANG Nov 1, 2022
6a2f678
bugfix
JZ-LIANG Nov 1, 2022
c5943bb
bugfix
JZ-LIANG Nov 1, 2022
36fdfe2
bugfix
JZ-LIANG Nov 1, 2022
7eb8f63
bugfix
JZ-LIANG Nov 1, 2022
ce5531a
debug
JZ-LIANG Nov 1, 2022
8ed2703
debug
JZ-LIANG Nov 1, 2022
aeac121
debug
JZ-LIANG Nov 1, 2022
8c88d57
debug
JZ-LIANG Nov 1, 2022
17d4f1b
debug
JZ-LIANG Nov 1, 2022
2fdd2ac
debug
JZ-LIANG Nov 1, 2022
fa7f049
old engine
JZ-LIANG Nov 3, 2022
ee8abd8
Merge remote-tracking branch 'upstream/develop' into AutoParallel/sha…
JZ-LIANG Nov 7, 2022
9f2534f
rm engine
JZ-LIANG Nov 7, 2022
104e674
update unitest
JZ-LIANG Nov 7, 2022
314ebee
Merge remote-tracking branch 'upstream/develop' into AutoParallel/sha…
JZ-LIANG Nov 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "overlap_grad_comm", False)
set_field_default_config(SHARDING, "bucket_size_numel", -1)
set_field_default_config(SHARDING, "partition_algor", "greedy_even")
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])

Expand Down
13 changes: 13 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from functools import reduce

import paddle.fluid.core as core
from paddle.fluid.framework import Variable
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.process_group import (
get_all_process_groups,
Expand Down Expand Up @@ -1790,6 +1791,18 @@ def find_higher_order_backward_op(program):
return False


def get_var_numel(var):
"""
input:
- var: variable
return:
number of elemnet in var
"""
assert isinstance(var, Variable)
assert -1 not in var.shape
return reduce(lambda x, y: x * y, var.shape)


def get_lr(optimizer):
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer.get_lr()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

from collections import OrderedDict
import numpy as np

import paddle
from paddle.fluid import unique_name
from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_scale_op,
is_data_parallel_reduce_op,
Expand All @@ -28,8 +28,8 @@
is_loss_grad_op,
is_optimize_op,
ring_id_to_process_group,
get_var_numel,
)
from .pass_base import PassBase, PassType, register_pass

# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
Expand All @@ -44,10 +44,6 @@
__max_stream_num_allow__ = 16


def numel(var):
return np.prod(list(var.shape))


@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
"""
Expand Down Expand Up @@ -430,7 +426,7 @@ def op_depend_on_group(op, group):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
grad_var = block.var(grad_name)
grad_numel = numel(grad_var)
grad_numel = get_var_numel(grad_var)

if cur_group.acceptable(grad_var, ring_id):
assert grad_name not in grouped_grad_names
Expand Down Expand Up @@ -594,7 +590,7 @@ def acceptable(self, grad_var, ring_id):
return True
if ring_id != self.ring_id:
return False
if numel(grad_var) + self.numel > self.max_group_size:
if get_var_numel(grad_var) + self.numel > self.max_group_size:
return False
if grad_var.dtype != self.dtype:
return False
Expand All @@ -605,7 +601,7 @@ def add(self, grad_var, ring_id, i):
self.gradients.append(grad_var)
self.ring_id = ring_id
self.dtype = grad_var.dtype
self.numel += numel(grad_var)
self.numel += get_var_numel(grad_var)

# remove auxiliary ops in non-fuse dp allreduce
self.remove_allreduce_op_indices.append(i)
Expand Down
Loading