diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index d9182c488f23f..b9055d38d38c5 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -55,7 +55,6 @@ message HybridConfig { optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; optional int32 sharding_degree = 4 [ default = 1 ]; - repeated string order = 5 ; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 455f7bca37528..950fddaf9dba7 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + import google.protobuf import google.protobuf.text_format @@ -149,6 +151,7 @@ def __init__(self): if _global_flags().is_public(key): self.strategy.sync_nccl_allreduce = bool(_global_flags()[key]) + self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp'] self.__lock_attr = True logger.info("distributed strategy initialized") @@ -1691,8 +1694,13 @@ def hybrid_configs(self): @hybrid_configs.setter def hybrid_configs(self, configs): + hybrid_config = copy.deepcopy(configs) + if "order" in hybrid_config: + self.hybrid_parallel_order = hybrid_config["order"] + hybrid_config.pop('order') + check_configs_key( - self.strategy.hybrid_configs, configs, "hybrid_configs" + self.strategy.hybrid_configs, hybrid_config, "hybrid_configs" ) assign_configs_value(self.strategy.hybrid_configs, configs) diff --git a/python/paddle/distributed/fleet/fleet.py b/python/paddle/distributed/fleet/fleet.py index eda074100eec3..9debd488d2ea7 100755 --- a/python/paddle/distributed/fleet/fleet.py +++ b/python/paddle/distributed/fleet/fleet.py @@ -412,9 +412,7 @@ def _init_hybrid_parallel_env(self): "mp": ['model', self.mp_degree], } - order = self.hybrid_configs["order"] - if not order: - order = ['dp', 'pp', 'sharding', 'mp'] + order = self._user_defined_strategy.hybrid_parallel_order if order[:].sort() != list(d_hybrid_degree.keys())[:].sort(): raise AssertionError( 'The order of hybrid_config setting is incorrect.' diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py index e773014629da5..99f235b5887fd 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py @@ -84,6 +84,21 @@ def test_hybrid_parallel_configs(self): self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) + def test_hybrid_parallel_configs_order(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 2, + "pp_degree": 4, + "order": ['sharding', 'mp', 'dp', 'pp'], + } + self.assertEqual(strategy.hybrid_configs["dp_degree"], 1) + self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) + self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) + self.assertEqual( + strategy.hybrid_parallel_order, ['sharding', 'mp', 'dp', 'pp'] + ) + def test_localsgd(self): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.localsgd = True