Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix]Fix segment fault in order setting #52293

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ message HybridConfig {
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
repeated string order = 5 ;
}

message AMPConfig {
Expand Down
10 changes: 9 additions & 1 deletion python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

import google.protobuf
import google.protobuf.text_format

Expand Down Expand Up @@ -149,6 +151,7 @@ def __init__(self):
if _global_flags().is_public(key):
self.strategy.sync_nccl_allreduce = bool(_global_flags()[key])

self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp']
self.__lock_attr = True
logger.info("distributed strategy initialized")

Expand Down Expand Up @@ -1691,8 +1694,13 @@ def hybrid_configs(self):

@hybrid_configs.setter
def hybrid_configs(self, configs):
hybrid_config = copy.deepcopy(configs)
if "order" in hybrid_config:
self.hybrid_parallel_order = hybrid_config["order"]
hybrid_config.pop('order')

check_configs_key(
self.strategy.hybrid_configs, configs, "hybrid_configs"
self.strategy.hybrid_configs, hybrid_config, "hybrid_configs"
)
assign_configs_value(self.strategy.hybrid_configs, configs)

Expand Down
4 changes: 1 addition & 3 deletions python/paddle/distributed/fleet/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,7 @@ def _init_hybrid_parallel_env(self):
"mp": ['model', self.mp_degree],
}

order = self.hybrid_configs["order"]
if not order:
order = ['dp', 'pp', 'sharding', 'mp']
order = self._user_defined_strategy.hybrid_parallel_order
if order[:].sort() != list(d_hybrid_degree.keys())[:].sort():
raise AssertionError(
'The order of hybrid_config setting is incorrect.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ def test_hybrid_parallel_configs(self):
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)

def test_hybrid_parallel_configs_order(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 4,
"order": ['sharding', 'mp', 'dp', 'pp'],
}
self.assertEqual(strategy.hybrid_configs["dp_degree"], 1)
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
self.assertEqual(
strategy.hybrid_parallel_order, ['sharding', 'mp', 'dp', 'pp']
)

def test_localsgd(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.localsgd = True
Expand Down