Skip to content

Commit

Permalink
[BugFix]Fix segment fault in order setting (#52293)
Browse files Browse the repository at this point in the history
* fix bug in proto

* add utest
  • Loading branch information
ForFishes authored Mar 30, 2023
1 parent 155018e commit d2cdc7e
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 5 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ message HybridConfig {
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
repeated string order = 5 ;
}

message AMPConfig {
Expand Down
10 changes: 9 additions & 1 deletion python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

import google.protobuf
import google.protobuf.text_format

Expand Down Expand Up @@ -149,6 +151,7 @@ def __init__(self):
if _global_flags().is_public(key):
self.strategy.sync_nccl_allreduce = bool(_global_flags()[key])

self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp']
self.__lock_attr = True
logger.info("distributed strategy initialized")

Expand Down Expand Up @@ -1691,8 +1694,13 @@ def hybrid_configs(self):

@hybrid_configs.setter
def hybrid_configs(self, configs):
hybrid_config = copy.deepcopy(configs)
if "order" in hybrid_config:
self.hybrid_parallel_order = hybrid_config["order"]
hybrid_config.pop('order')

check_configs_key(
self.strategy.hybrid_configs, configs, "hybrid_configs"
self.strategy.hybrid_configs, hybrid_config, "hybrid_configs"
)
assign_configs_value(self.strategy.hybrid_configs, configs)

Expand Down
4 changes: 1 addition & 3 deletions python/paddle/distributed/fleet/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,7 @@ def _init_hybrid_parallel_env(self):
"mp": ['model', self.mp_degree],
}

order = self.hybrid_configs["order"]
if not order:
order = ['dp', 'pp', 'sharding', 'mp']
order = self._user_defined_strategy.hybrid_parallel_order
if order[:].sort() != list(d_hybrid_degree.keys())[:].sort():
raise AssertionError(
'The order of hybrid_config setting is incorrect.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ def test_hybrid_parallel_configs(self):
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)

def test_hybrid_parallel_configs_order(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 4,
"order": ['sharding', 'mp', 'dp', 'pp'],
}
self.assertEqual(strategy.hybrid_configs["dp_degree"], 1)
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
self.assertEqual(
strategy.hybrid_parallel_order, ['sharding', 'mp', 'dp', 'pp']
)

def test_localsgd(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.localsgd = True
Expand Down

0 comments on commit d2cdc7e

Please sign in to comment.