From d12588d261b55e5fe54bb215c33af74766826455 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Mon, 10 Apr 2023 13:25:33 +0800 Subject: [PATCH] Broadcast the master weight along with param for distributed training. (#52638) * Broadcast the master weight along with param for distributed training. * Fix codestyle. --- .../meta_optimizers/sharding_optimizer.py | 1298 +++++++++++------ 1 file changed, 814 insertions(+), 484 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index ccac803e72130..6e92810118227 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -28,6 +28,7 @@ from .sharding.offload_helper import OffloadHelper from .sharding.prune import ProgramDeps from .sharding import utils + # FIXME: import * from .sharding.utils import * import logging @@ -84,7 +85,7 @@ def _enable_strategy(self, dist_strategy, context): dist_strategy.sharding_configs = {"segment_broadcast_MB": 32} def _get_sharding_segment_strategy(self): - """ get + """get self._sharding_segment_strategy 1. if by_size: self._broadcast_MB 2. if by_anchors: self._sharding_segment_anchors @@ -97,21 +98,26 @@ def _get_sharding_segment_strategy(self): if segment_strategy == "segment_broadcast_MB": self._broadcast_MB = sharding_configs["segment_broadcast_MB"] - assert self._broadcast_MB > 0, "segment size should larger than zero !" + assert ( + self._broadcast_MB > 0 + ), "segment size should larger than zero !" elif segment_strategy == "segment_anchors": self._sharding_segment_anchors = sharding_configs["segment_anchors"] - assert len(self._sharding_segment_anchors - ) > 0, "you should set the sharding segment anchors !" + assert ( + len(self._sharding_segment_anchors) > 0 + ), "you should set the sharding segment anchors !" self._backward_remain_anchors = self._sharding_segment_anchors[:] self._forward_remain_anchors = [] else: raise NotImplementedError( "the sharding segment strategy [{}] is not implemented".format( - str(segment_strategy))) + str(segment_strategy) + ) + ) self._sharding_segment_strategy = segment_strategy def _get_hybrid_degree(self): - """ get + """get self.hybrid_dp self.sharding_degree self.mp_degree @@ -135,21 +141,32 @@ def _get_hybrid_degree(self): assert strategy.pipeline is True if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): - assert pp_degree == 2, ("For manually set pipeline, only " - "pp_degree = 2 is supported.") - assert global_world_size == mp_degree * sharding_degree * dp_degree, \ - "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format( - global_world_size, mp_degree, sharding_degree, dp_degree) + assert pp_degree == 2, ( + "For manually set pipeline, only " "pp_degree = 2 is supported." + ) + assert ( + global_world_size == mp_degree * sharding_degree * dp_degree + ), "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format( + global_world_size, mp_degree, sharding_degree, dp_degree + ) else: - assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \ - "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format( - global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree) + assert ( + global_world_size + == mp_degree * sharding_degree * pp_degree * dp_degree + ), "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format( + global_world_size, + mp_degree, + sharding_degree, + pp_degree, + dp_degree, + ) # FIXME (JZ-LIANG) deprecated hybrid_dp if sharding_configs["hybrid_dp"]: logger.warning( "[hybrid_dp] API setting is deprecated. Now when " - "dp_degree >= 2, its will be in hybrid dp mode automatically") + "dp_degree >= 2, its will be in hybrid dp mode automatically" + ) assert dp_degree >= 1 self.hybrid_dp = True if dp_degree > 1 else False @@ -159,7 +176,7 @@ def _get_hybrid_degree(self): self.dp_degree = dp_degree def _get_hybrid_dp_mode(self): - """ get + """get self.hybrid_dp_mode = 'pp_hybrid_dp' or 'sharding_hybrid_dp' self.gradient_merge_mode = 'pp_gm' or 'sharding_gm' self._gradient_merge_acc_step @@ -183,9 +200,10 @@ def _get_hybrid_dp_mode(self): if self.pp_degree > 1: dp_mode = "pp_hybrid_dp" else: - assert self.sharding_degree > 1, \ - "by now we only support five kind of hybrid dp: sharding_hybrid_dp, " \ + assert self.sharding_degree > 1, ( + "by now we only support five kind of hybrid dp: sharding_hybrid_dp, " "mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp." + ) dp_mode = "sharding_hybrid_dp" # gradient merge @@ -198,23 +216,33 @@ def _get_hybrid_dp_mode(self): gm_mode = "pp_gm" gm_acc_step = strategy.pipeline_configs['accumulate_steps'] gradient_scale_configs = strategy.gradient_scale_configs - assert gradient_scale_configs['scale_strategy'] == 'avg', \ - 'For pipeline mode, the ' 'gradient scale mode should ' \ - 'be "avg", but got {}'.format(gradient_scale_configs['scale_strategy']) + assert gradient_scale_configs['scale_strategy'] == 'avg', ( + 'For pipeline mode, the ' + 'gradient scale mode should ' + 'be "avg", but got {}'.format( + gradient_scale_configs['scale_strategy'] + ) + ) # Note (Yuang Liu): this avg_loss flag determines where to do the average op for grad merge. # If True, will do sum firstly for gradient merge, then do scale by gm_acc_step. # If False, will scale loss by gm_acc_step first, then do sum for gradient merge. self.scale_gradient = gradient_scale_configs['scale_gradient'] if gm_acc_step > 1: - logger.info("Gradient merge in [{}], acc step = [{}]".format( - gm_mode, gm_acc_step)) + logger.info( + "Gradient merge in [{}], acc step = [{}]".format( + gm_mode, gm_acc_step + ) + ) optimizer_sharding = False # TODO(wangxi): need support dp_as_opt_sharding with sharding # need support without pp in future - if self.sharding_degree == 1 and self.dp_degree > 1 \ - and sharding_configs['_dp_as_optimizer_sharding'] \ - and self.pp_degree > 1: + if ( + self.sharding_degree == 1 + and self.dp_degree > 1 + and sharding_configs['_dp_as_optimizer_sharding'] + and self.pp_degree > 1 + ): optimizer_sharding = True self.hybrid_dp_mode = dp_mode @@ -224,19 +252,23 @@ def _get_hybrid_dp_mode(self): # this feature is design for ascend, and should NOT be used in GPU training self.pp_allreduce_in_optimize = sharding_configs[ - "pp_allreduce_in_optimize"] + "pp_allreduce_in_optimize" + ] - def _inner_opt_minimize(self, loss, startup_program, parameter_list, - no_grad_set): + def _inner_opt_minimize( + self, loss, startup_program, parameter_list, no_grad_set + ): pipeline_configs = self.user_defined_strategy.pipeline_configs if self.inner_opt is None: raise ValueError( - "self.inner_opt of ShardingOptimizer should not be None.") + "self.inner_opt of ShardingOptimizer should not be None." + ) if self.pp_degree > 1: pp_optimizer = fluid.optimizer.PipelineOptimizer( - self.inner_opt, self._gradient_merge_acc_step) + self.inner_opt, self._gradient_merge_acc_step + ) self._pp_optimizer = pp_optimizer global_rank = self.role_maker._worker_index() @@ -253,17 +285,25 @@ def _inner_opt_minimize(self, loss, startup_program, parameter_list, 'global_ring_id': 3, 'mp_degree': self.mp_degree, 'mp_rank': global_rank % self.mp_degree, - 'scale_gradient': self.scale_gradient + 'scale_gradient': self.scale_gradient, } main_program = loss.block.program main_program._pipeline_opt = pipeline_opt - optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize( - loss, startup_program, parameter_list, no_grad_set) + ( + optimize_ops, + params_grads, + program_list, + self.pipeline_pair, + self.pp_ring_map, + ) = pp_optimizer.minimize( + loss, startup_program, parameter_list, no_grad_set + ) assert self.pp_degree == len(program_list) else: optimize_ops, params_grads = self.inner_opt.minimize( - loss, startup_program, parameter_list, no_grad_set) + loss, startup_program, parameter_list, no_grad_set + ) if startup_program is None: startup_program = default_startup_program() @@ -272,8 +312,9 @@ def _inner_opt_minimize(self, loss, startup_program, parameter_list, startup_program = startup_program._pipeline_opt['startup_program'] print("pp_rank:", self.pp_rank) if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): - main_program = program_list[int( - os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))] + main_program = program_list[ + int(os.getenv("PADDLE_MANUAL_PIPELINE_STAGE")) + ] else: main_program = program_list[self.pp_rank] with open("main_%d" % self.role_maker._worker_index(), 'w') as f: @@ -299,14 +340,16 @@ def _inner_opt_minimize(self, loss, startup_program, parameter_list, return optimize_ops, params_grads def _apply_sharding_pass(self, params_grads): - if self.sharding_degree == 1: return + if self.sharding_degree == 1: + return main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() # step1: build shard - self._build_shard(params_grads, self.sharding_rank, - self.sharding_degree) + self._build_shard( + params_grads, self.sharding_rank, self.sharding_degree + ) # step2: split_program self._split_program(main_block) @@ -318,13 +361,16 @@ def _apply_sharding_pass(self, params_grads): # step4: remove unneeded ops and vars from block self._prune_main_program( - main_block, self._shard, - [self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id]) + main_block, + self._shard, + [self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id], + ) self._prune_startup_program(startup_block, self._shard) def _apply_opt_sharding_pass(self, params_grads): - """ outer dp as optimizer sharding """ - if self._optimizer_sharding is False: return + """outer dp as optimizer sharding""" + if self._optimizer_sharding is False: + return main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() @@ -338,12 +384,15 @@ def _apply_opt_sharding_pass(self, params_grads): # step4: remove unneeded ops and vars from block self._prune_main_program( - main_block, self._shard, - [self.mp_ring_id, self.pp_ring_id, self.dp_ring_id]) + main_block, + self._shard, + [self.mp_ring_id, self.pp_ring_id, self.dp_ring_id], + ) self._prune_startup_program(startup_block, self._shard) def _insert_allreduce_for_pp(self, params_grads): - if self.pp_degree == 1: return + if self.pp_degree == 1: + return strategy = self.user_defined_strategy sharding_configs = strategy.sharding_configs @@ -363,10 +412,12 @@ def _insert_allreduce_for_pp(self, params_grads): main_block._remove_op(idx) for idx, op in reversed(list(enumerate(main_block.ops))): - if op.type != 'cast': continue + if op.type != 'cast': + continue in_name = op.input_arg_names[0] - if in_name not in self._params: continue - #if self._shard.has_param(param_name): continue + if in_name not in self._params: + continue + # if self._shard.has_param(param_name): continue if in_name not in main_block.vars: main_block._remove_op(idx) @@ -376,7 +427,8 @@ def _insert_allreduce_for_pp(self, params_grads): shard = self._shard if self._optimizer_sharding else None accumulated_grad_names = self._pp_optimizer._accumulate_gradients( - main_block, strategy=strategy, shard=shard) + main_block, strategy=strategy, shard=shard + ) len_of_ops = len(main_block.ops) if self.scale_gradient: @@ -384,8 +436,9 @@ def _insert_allreduce_for_pp(self, params_grads): first_optimize_op_index = get_first_optimize_op_idx(main_block) if self.pp_allreduce_in_optimize: - logger.info("Pipeline Persistable grad is {}".format( - accumulated_grad_names)) + logger.info( + "Pipeline Persistable grad is {}".format(accumulated_grad_names) + ) # FIXME(wangxi): accumulated_grad get from pipeline is not # include sharding's param@BroadCast grad when # pp_allreduce_in_optimize @@ -397,10 +450,11 @@ def _insert_allreduce_for_pp(self, params_grads): self._shard, core.op_proto_and_checker_maker.OpRole.Optimize, use_calc_stream=True, - rank=self.sharding_rank) + rank=self.sharding_rank, + ) logger.info("PP-Sharding grad is {}".format(accumulated_grad_names)) - first_optimize_op_index += (len(main_block.ops) - len_of_ops) + first_optimize_op_index += len(main_block.ops) - len_of_ops len_of_ops = len(main_block.ops) if self._optimizer_sharding: @@ -413,10 +467,12 @@ def _insert_allreduce_for_pp(self, params_grads): OpRole.Optimize, use_calc_stream=True, rank=self.dp_rank, - strategy=strategy) + strategy=strategy, + ) logger.info( - "Optimizer grad in this rank {}".format(accumulated_grad_names)) - first_optimize_op_index += (len(main_block.ops) - len_of_ops) + "Optimizer grad in this rank {}".format(accumulated_grad_names) + ) + first_optimize_op_index += len(main_block.ops) - len_of_ops len_of_ops = len(main_block.ops) # NOTE(wangxi): we fused after optimize_cast @@ -424,14 +480,17 @@ def _insert_allreduce_for_pp(self, params_grads): optimizer_param = utils.insert_broadcast_param_ops( main_block, len_of_ops, - self.dp_ring_id, [x[0].name for x in params_grads], + self.dp_ring_id, + [x[0].name for x in params_grads], self._shard, OpRole.Optimize, use_calc_stream=True, rank=self.dp_rank, - strategy=None if optimize_cast else strategy) + strategy=None if optimize_cast else strategy, + ) logger.info( - "Optimizer param in this rank {}".format(optimizer_param)) + "Optimizer param in this rank {}".format(optimizer_param) + ) if not strategy.fuse_grad_merge and not optimize_cast: assert len(accumulated_grad_names) == len(optimizer_param) elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": @@ -442,15 +501,20 @@ def _insert_allreduce_for_pp(self, params_grads): accumulated_grad_names, core.op_proto_and_checker_maker.OpRole.Optimize, use_calc_stream=True, - user_defined_strategy=strategy) - first_optimize_op_index += (len(main_block.ops) - len_of_ops) + user_defined_strategy=strategy, + ) + first_optimize_op_index += len(main_block.ops) - len_of_ops len_of_ops = len(main_block.ops) # FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there? def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names): - if self.user_defined_strategy.amp and \ - self.user_defined_strategy.amp_configs['use_dynamic_loss_scaling']: + if ( + self.user_defined_strategy.amp + and self.user_defined_strategy.amp_configs[ + 'use_dynamic_loss_scaling' + ] + ): # For AMP, if using dynamic loss scaling the avg # operation can be simple done by modify the LossScaling op. for idx, op in enumerate(main_block.ops): @@ -461,7 +525,8 @@ def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names): loss_scale_tmp_var = main_block.create_var( name=loss_scale_tmp_var_name, shape=loss_scaling_var.shape, - dtype=loss_scaling_var.dtype) + dtype=loss_scaling_var.dtype, + ) main_block._insert_op_without_sync( idx, type='scale', @@ -471,8 +536,9 @@ def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names): 'scale': self._gradient_merge_acc_step, 'bias': 0.0, 'bias_after_scale': False, - OP_ROLE_KEY: OpRole.Optimize - }) + OP_ROLE_KEY: OpRole.Optimize, + }, + ) op._rename_input(loss_scale_name, loss_scale_tmp_var_name) break else: @@ -483,7 +549,9 @@ def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names): if is_optimizer_op(op) and op.type != 'c_sync_comm_stream': tmp_first_opt_idx = idx break - assert tmp_first_opt_idx is not None, 'Occurs some errors, no optimize ops' + assert ( + tmp_first_opt_idx is not None + ), 'Occurs some errors, no optimize ops' for grad in accumulated_grad_names: main_block._insert_op_without_sync( tmp_first_opt_idx, @@ -494,14 +562,17 @@ def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names): 'scale': 1.0 / self._gradient_merge_acc_step, 'bias': 0.0, 'bias_after_scale': False, - OP_ROLE_KEY: OpRole.Optimize - }) + OP_ROLE_KEY: OpRole.Optimize, + }, + ) def _adapt_amp_clip_without_sharding(self): # if not use sharding, adapt amp/clip, for remain parallelism. # cast --> amp --> clip --> opt - if self.sharding_degree > 1: return - if self._optimizer_sharding: return + if self.sharding_degree > 1: + return + if self._optimizer_sharding: + return main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() @@ -515,9 +586,9 @@ def _adapt_amp_clip_without_sharding(self): FP16Utils.sync_amp_check_nan_inf(main_block, rings) gradientclip_helper = GradientClipHelper(None) - gradientclip_helper.sync_global_norm(main_block, - [self.mp_ring_id, self.pp_ring_id], - self.mp_rank) + gradientclip_helper.sync_global_norm( + main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank + ) def _insert_loss_grad_scale_op(self): main_block = self._main_program.global_block() @@ -538,8 +609,9 @@ def _apply_optimize_offload_pass(self, params_grads): mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None - offload_helper = OffloadHelper(mp_ring_id=mp_ring_id, - dp_ring_id=dp_ring_id) + offload_helper = OffloadHelper( + mp_ring_id=mp_ring_id, dp_ring_id=dp_ring_id + ) # optimize offload should be enable while gradient merge is enable and # acc_step is quite large (e.g. >> 100). Since its memcpy could not be @@ -555,32 +627,32 @@ def _apply_optimize_offload_pass(self, params_grads): # will take more memory, but will be faster. Trade space for time. if self._optimizer_sharding: offload_helper.opt_sharding_cast_fp32param( - main_block, startup_block, - [x[0].name for x in params_grads]) + main_block, startup_block, [x[0].name for x in params_grads] + ) # NOTE(wangxi): fused after optimize_cast - utils.fuse_opt_broadcast_param_ops(main_block, - dp_ring_id, - self._shard, - strategy=strategy) + utils.fuse_opt_broadcast_param_ops( + main_block, dp_ring_id, self._shard, strategy=strategy + ) else: offload_helper.cast_fp32param_in_optimize( - main_block, startup_block) + main_block, startup_block + ) def _dump_program_for_debug(self): main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() - with open("start_sharding_%d" % self.role_maker._worker_index(), - 'w') as f: + with open( + "start_sharding_%d" % self.role_maker._worker_index(), 'w' + ) as f: f.writelines(str(startup_block.program)) - with open("main_sharding_%d" % self.role_maker._worker_index(), - 'w') as f: + with open( + "main_sharding_%d" % self.role_maker._worker_index(), 'w' + ) as f: f.writelines(str(main_block.program)) - def minimize_impl(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None): + def minimize_impl( + self, loss, startup_program=None, parameter_list=None, no_grad_set=None + ): # TODO: (JZ-LIANG) support multiple comm in future # self._nrings = self.user_defined_strategy.nccl_comm_num self._nrings_sharding = 1 @@ -595,7 +667,8 @@ def minimize_impl(self, # inner optimize minimize optimize_ops, params_grads = self._inner_opt_minimize( - loss, startup_program, parameter_list, no_grad_set) + loss, startup_program, parameter_list, no_grad_set + ) self._init_comm() @@ -644,13 +717,15 @@ def _init_pair_comm(self, pair, ring_id): ] pp_rank = 0 if self.pp_rank == pair[0] else 1 if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None: - self._collective_helper._init_communicator(self._startup_program, - self.current_endpoint, - pp_group_endpoints, - pp_rank, - ring_id, - False, - sync=False) + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + pp_group_endpoints, + pp_rank, + ring_id, + False, + sync=False, + ) def _init_npu_pipeline_comm(self, startup_block): # NOTE(wangxi): some bug with hccl, must set pp_degree be even number @@ -668,15 +743,22 @@ def _init_npu_pipeline_comm(self, startup_block): my_pair.append(pair) # for example: self.pp_rank=2, self.pp_degree=4 - send_to_next_pair = (self.pp_rank, (self.pp_rank + 1) % self.pp_degree - ) # 2->3 + send_to_next_pair = ( + self.pp_rank, + (self.pp_rank + 1) % self.pp_degree, + ) # 2->3 recv_from_next_pair = ( - (self.pp_rank + 1) % self.pp_degree, self.pp_rank) # 3->2 + (self.pp_rank + 1) % self.pp_degree, + self.pp_rank, + ) # 3->2 recv_from_prev_pair = ( - (self.pp_rank - 1 + self.pp_degree) % self.pp_degree, self.pp_rank + (self.pp_rank - 1 + self.pp_degree) % self.pp_degree, + self.pp_rank, ) # 1->2 - send_to_prev_pair = (self.pp_rank, (self.pp_rank - 1 + self.pp_degree) % - self.pp_degree) # 2->1 + send_to_prev_pair = ( + self.pp_rank, + (self.pp_rank - 1 + self.pp_degree) % self.pp_degree, + ) # 2->1 even = (self.pp_rank % 2) == 0 @@ -685,54 +767,66 @@ def _init_npu_pipeline_comm(self, startup_block): ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] self._init_pair_comm(pair, ring_id) my_pair.remove(pair) - logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format( - pair, ring_id)) + logger.info( + "pair0(even->odd): pp pair:{}, ring_id: {}".format(pair, ring_id) + ) # 2. even recv from next, odd send to prev, 1->0, 3->2 pair = recv_from_next_pair if even else send_to_prev_pair ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] self._init_pair_comm(pair, ring_id) my_pair.remove(pair) - logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format( - pair, ring_id)) + logger.info( + "pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair, ring_id) + ) # if pp_degree is 2, only need pair(0->1, 1->0) if self.pp_degree > 2: # 3. odd send to next, even recv from prev, 1->2, 3->0 pair = send_to_next_pair if not even else recv_from_prev_pair - ring_id = self.pp_ring_map.get(pair[0] * 1000 + pair[1], - max_ring_id + - 1) # 3->0 not in pp_ring_map + ring_id = self.pp_ring_map.get( + pair[0] * 1000 + pair[1], max_ring_id + 1 + ) # 3->0 not in pp_ring_map self._init_pair_comm(pair, ring_id) if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: my_pair.remove(pair) - logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format( - pair, ring_id)) + logger.info( + "pair2(odd->even): pp pair:{}, ring_id: {}".format( + pair, ring_id + ) + ) # 4. odd recv from next, even send to prev, 2->1, 0->3 pair = recv_from_next_pair if not even else send_to_prev_pair - ring_id = self.pp_ring_map.get(pair[0] * 1000 + pair[1], - max_ring_id + - 2) # 0->3 not in pp_ring_map + ring_id = self.pp_ring_map.get( + pair[0] * 1000 + pair[1], max_ring_id + 2 + ) # 0->3 not in pp_ring_map self._init_pair_comm(pair, ring_id) if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: my_pair.remove(pair) - logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format( - pair, ring_id)) + logger.info( + "pair3(odd<-even): pp pair:{}, ring_id: {}".format( + pair, ring_id + ) + ) - assert len(my_pair) == 0, "Current pipeline does not support cross stage communication, " \ - "please check unexpected pair {}".format(my_pair) + assert len(my_pair) == 0, ( + "Current pipeline does not support cross stage communication, " + "please check unexpected pair {}".format(my_pair) + ) def _init_pipeline_comm(self, startup_block): # TODO (JZ-LIANG) to unify pp_rank_ and pp_rank if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None: - self._collective_helper._init_communicator(self._startup_program, - self.current_endpoint, - self.pp_group_endpoints, - self.pp_rank, - self.pp_ring_id, - False, - sync=False) + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + self.pp_group_endpoints, + self.pp_rank, + self.pp_ring_id, + False, + sync=False, + ) if core.is_compiled_with_npu(): self._init_npu_pipeline_comm(startup_block) @@ -752,13 +846,15 @@ def _init_comm(self): # mp ring if self.mp_degree > 1: - self._collective_helper._init_communicator(self._startup_program, - self.current_endpoint, - self.mp_group_endpoints, - self.mp_rank, - self.mp_ring_id, - False, - sync=False) + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + self.mp_group_endpoints, + self.mp_rank, + self.mp_ring_id, + False, + sync=False, + ) # sharding ring if self.sharding_degree > 1: @@ -769,7 +865,8 @@ def _init_comm(self): self.sharding_rank, self.sharding_ring_id, False, - sync=False) + sync=False, + ) # pp ring if self.pp_degree > 1: @@ -777,13 +874,15 @@ def _init_comm(self): # pure dp ring if self.dp_degree > 1: - self._collective_helper._init_communicator(self._startup_program, - self.current_endpoint, - self.dp_group_endpoints, - self.dp_rank, - self.dp_ring_id, - False, - sync=False) + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + self.dp_group_endpoints, + self.dp_rank, + self.dp_ring_id, + False, + sync=False, + ) startup_block._sync_with_cpp() @@ -794,9 +893,12 @@ def _build_shard(self, params_grads, shard_rank, shard_size): # step 3: get broadcast vars self._broadcast_vars = self._shard.find_broadcast_params( - self._main_program.global_block()) + self._main_program.global_block() + ) - def _wait(self, ): + def _wait( + self, + ): endpoints = self.global_endpoints[:] current_endpoint = endpoints[self.global_rank] if self.global_rank == 0: @@ -821,7 +923,7 @@ def _split_program(self, block): segment._end_idx = last_backward_op_idx for op_idx in reversed(range(last_backward_op_idx)): op = block.ops[op_idx] - assert (int(op.attr('op_role')) != int(OpRole.Optimize)) + assert int(op.attr('op_role')) != int(OpRole.Optimize) if self._sharding_segment_strategy == "segment_broadcast_MB": if segment._param_mem >= self._broadcast_MB: segment = self.collect_segment(segment, op_idx, block) @@ -835,21 +937,27 @@ def _split_program(self, block): if ".cast_fp16@GRAD" not in input_name: continue else: - input_name = input_name[:input_name. - find(".cast_fp16@GRAD")] + input_name = input_name[ + : input_name.find(".cast_fp16@GRAD") + ] if input_name in self._backward_remain_anchors: segment = self.collect_segment( - segment, op_idx, block) - assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format( - input_name) + segment, op_idx, block + ) + assert ( + input_name not in self._forward_remain_anchors + ), "segment anchor [{}] met twice !".format( + input_name + ) self._backward_remain_anchors.remove(input_name) self._forward_remain_anchors.append(input_name) elif int(op.attr('op_role')) == int(OpRole.Forward): for output_name in op.desc.output_arg_names(): if output_name in self._forward_remain_anchors: segment = self.collect_segment( - segment, op_idx, block) + segment, op_idx, block + ) self._forward_remain_anchors.remove(output_name) # find broadcast vars @@ -865,47 +973,49 @@ def _split_program(self, block): if self._shard.has_param(input_name): broadcast_var_name = input_name else: - broadcast_var_name = unique_name.generate(input_name + - "@BroadCast") + broadcast_var_name = unique_name.generate( + input_name + "@BroadCast" + ) segment._fill_constant_vars.append(broadcast_var_name) # (JZ-LIANG) should use Param base name ? broadcast_var_base_name = input_name if "subprog" in broadcast_var_base_name: # remove suffix - broadcast_var_base_name = broadcast_var_base_name[: - broadcast_var_base_name - .find( - ".subprog" - )] + broadcast_var_base_name = broadcast_var_base_name[ + : broadcast_var_base_name.find(".subprog") + ] - var2broadcast_time[ - broadcast_var_base_name] = var2broadcast_time.get( - broadcast_var_base_name, 0) + 1 + var2broadcast_time[broadcast_var_base_name] = ( + var2broadcast_time.get(broadcast_var_base_name, 0) + 1 + ) segment._param2broadcast[input_name] = broadcast_var_name segment._broadcast_vars.append( - (broadcast_var_name, self._shard.device(input_name))) + (broadcast_var_name, self._shard.device(input_name)) + ) segment._param_mem += get_var_size( - self._main_program.global_block().var(input_name)) + self._main_program.global_block().var(input_name) + ) # find reduce vars if self.pp_degree > 1 and self.pp_allreduce_in_optimize: # place pipeline gradient allreduce in optimize pass else: - if is_backward_op(op) and \ - OP_ROLE_VAR_KEY in op.attr_names: + if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names: op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] if len(op_role_var) != 0: assert len(op_role_var) % 2 == 0 for i in range(0, len(op_role_var), 2): - param, reduced_grad = op_role_var[i], op_role_var[i - + - 1] + param, reduced_grad = ( + op_role_var[i], + op_role_var[i + 1], + ) segment._allreduce_vars.append(reduced_grad) - assert (reduced_grad - not in self._reduced_grads_to_param) + assert ( + reduced_grad not in self._reduced_grads_to_param + ) self._reduced_grads_to_param[reduced_grad] = param # find cast op @@ -920,29 +1030,40 @@ def _split_program(self, block): self._segments.insert(0, segment) if self._sharding_segment_strategy == "segment_anchors": - assert len( - self._forward_remain_anchors) == 0, "remain anchors {}".format( - self._forward_remain_anchors) - assert len( - self._backward_remain_anchors) == 0, "remain anchors {}".format( - self._backward_remain_anchors) + assert ( + len(self._forward_remain_anchors) == 0 + ), "remain anchors {}".format(self._forward_remain_anchors) + assert ( + len(self._backward_remain_anchors) == 0 + ), "remain anchors {}".format(self._backward_remain_anchors) if self._verbose: - for varname in sorted(var2broadcast_time, - key=var2broadcast_time.get, - reverse=True): - logger.info("Sharding broadcast: [{}] times [{}]".format( - var2broadcast_time[varname], varname)) + for varname in sorted( + var2broadcast_time, key=var2broadcast_time.get, reverse=True + ): + logger.info( + "Sharding broadcast: [{}] times [{}]".format( + var2broadcast_time[varname], varname + ) + ) for idx_ in range(len(self._segments)): logger.info("segment [{}] :".format(idx_)) - logger.info("start op: [{}] [{}]".format( - block.ops[self._segments[idx_]._start_idx].desc.type(), - block.ops[self._segments[idx_]. - _start_idx].desc.input_arg_names())) - logger.info("end op: [{}] [{}]".format( - block.ops[self._segments[idx_]._end_idx].desc.type(), - block.ops[ - self._segments[idx_]._end_idx].desc.input_arg_names())) + logger.info( + "start op: [{}] [{}]".format( + block.ops[self._segments[idx_]._start_idx].desc.type(), + block.ops[ + self._segments[idx_]._start_idx + ].desc.input_arg_names(), + ) + ) + logger.info( + "end op: [{}] [{}]".format( + block.ops[self._segments[idx_]._end_idx].desc.type(), + block.ops[ + self._segments[idx_]._end_idx + ].desc.input_arg_names(), + ) + ) return def _prune_main_program(self, block, shard, rings): @@ -954,7 +1075,7 @@ def _prune_main_program(self, block, shard, rings): 2. prune cast_fp32_to_fp16; update amp_infine_checking 3. prune gradient_clip related; update global_norm_sum 4. prune optimizer op + param + gradient - + """ weightdecay_helper = WeightDecayHelper() weightdecay_helper.prune_weight_decay(block, shard) @@ -975,17 +1096,18 @@ def _prune_main_program(self, block, shard, rings): input_names = op.desc.input_arg_names() output_names = op.desc.output_arg_names() # FIXME(wangxi): need use grads, pipeline grad is @GRAD@MERGE - if op.type == "c_allreduce_sum" and \ - op.attr('use_model_parallel') is False: - assert (len(output_names) == 1) + if ( + op.type == "c_allreduce_sum" + and op.attr('use_model_parallel') is False + ): + assert len(output_names) == 1 output_name = output_names[0] reduced_grads.append(output_name) # prune optimizer state and param pruned_opti_vars = [] for var_name in list(block.vars.keys()): - if shard.is_opti_var(var_name) and \ - not shard.has_opt_var(var_name): + if shard.is_opti_var(var_name) and not shard.has_opt_var(var_name): pruned_opti_vars.append(var_name) program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars) @@ -996,17 +1118,17 @@ def _prune_main_program(self, block, shard, rings): # Prune for idx, op in reversed(list(enumerate(block.ops))): if op.type in [ - "c_allreduce_sum", - "c_sync_comm_stream", - "c_calc_comm_stream", - "c_gen_nccl_id", - "c_comm_init", - 'send_v2', - 'recv_v2', + "c_allreduce_sum", + "c_sync_comm_stream", + "c_calc_comm_stream", + "c_gen_nccl_id", + "c_comm_init", + 'send_v2', + 'recv_v2', ]: pass elif op.type == "conditional_block": - assert (op.desc.has_attr("sub_block")) + assert op.desc.has_attr("sub_block") subblock_idx = op.desc.attr("sub_block").id subblock_deps = program_deps.get_sub_block_deps(subblock_idx) # only prune amp subblock @@ -1022,7 +1144,8 @@ def _prune_main_program(self, block, shard, rings): reversed_output_vars.append(output_name) # prune for sub_op_idx, _ in reversed( - list(enumerate(subblock_deps._block.ops))): + list(enumerate(subblock_deps._block.ops)) + ): if subblock_deps.should_remove_op(sub_op_idx): subblock_deps.remove_op(sub_op_idx) reversed_input_vars = [] @@ -1038,7 +1161,9 @@ def _prune_main_program(self, block, shard, rings): # _should_removed_var: opt state not cur shard if program_deps.should_remove_op(idx): # NOTE(wangxi): need reserve all param in optimizer_sharding - reserved_vars = self._params if self._optimizer_sharding else None + reserved_vars = ( + self._params if self._optimizer_sharding else None + ) program_deps.remove_op(idx, reserved_vars) # NOTE (JZ-LIANG) revise and unify logic here @@ -1049,7 +1174,8 @@ def _prune_main_program(self, block, shard, rings): # remove inputs that not on this card reserved_x = [] for var_name in op.desc.input("X"): - if block.has_var(var_name): reserved_x.append(var_name) + if block.has_var(var_name): + reserved_x.append(var_name) op.desc.set_input('X', reserved_x) block._sync_with_cpp() return @@ -1059,7 +1185,7 @@ def _add_broadcast_allreduce(self, block): add broadcast allreduce op if enable gradient_merge, insert related ops - if combined with pipeline(grad accumulate), + if combined with pipeline(grad accumulate), the grad allreduce should be done in optimize role """ if len(self._segments) < 1: @@ -1072,175 +1198,280 @@ def _add_broadcast_allreduce(self, block): # NOTE (JZ-LIANG) revise and unify logic here # fix the _end_idx for segments[-1] if pp is used. new_end_idx = self._segments[-1]._end_idx - for idx in range(self._segments[-1]._end_idx - 1, - self._segments[-1]._start_idx - 1, -1): + for idx in range( + self._segments[-1]._end_idx - 1, + self._segments[-1]._start_idx - 1, + -1, + ): op = block.ops[idx] if op.type == "fill_constant" or op.type == "sum": - if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1 + if "MERGED" in op.output_arg_names[0]: + new_end_idx = idx + 1 elif op.type == "cast": - if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1 + if "@TMP" in op.output_arg_names[0]: + new_end_idx = idx + 1 self._segments[-1]._end_idx = new_end_idx if self._segments[-1]._allreduce_vars: shard_allredue_vars = self._shard.filter_grads( - self._segments[-1]._allreduce_vars) - if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: - if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( - shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self.dp_ring_id, shard_allredue_vars) + self._segments[-1]._allreduce_vars + ) + if ( + self.gradient_merge_mode != "sharding_gm" + or self._gradient_merge_acc_step <= 1 + ): + if ( + self.hybrid_dp + and self.hybrid_dp_mode == "sharding_hybrid_dp" + and len(shard_allredue_vars) >= 1 + ): + insert_sync_comm_ops( + block, + self._segments[-1]._end_idx, + self.dp_ring_id, + shard_allredue_vars, + ) insert_allreduce_ops( block, self._segments[-1]._end_idx, self.dp_ring_id, shard_allredue_vars, - user_defined_strategy=self.user_defined_strategy) + user_defined_strategy=self.user_defined_strategy, + ) # gradient merge - elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: + elif ( + self.gradient_merge_mode == "sharding_gm" + and self._gradient_merge_acc_step > 1 + ): self.create_persistable_gradients_and_insert_merge_ops( - block, self._startup_program.global_block(), - self._segments[-1]._end_idx, shard_allredue_vars, - self._shard) - - insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self.sharding_ring_id, - self._segments[-1]._allreduce_vars) + block, + self._startup_program.global_block(), + self._segments[-1]._end_idx, + shard_allredue_vars, + self._shard, + ) + + insert_sync_comm_ops( + block, + self._segments[-1]._end_idx, + self.sharding_ring_id, + self._segments[-1]._allreduce_vars, + ) # allreduce --> reduce - insert_reduce_ops(block, - self._segments[-1]._end_idx, - self.sharding_ring_id, - self._segments[-1]._allreduce_vars, - self._shard, - op_role=OpRole.Backward, - use_calc_stream=False) + insert_reduce_ops( + block, + self._segments[-1]._end_idx, + self.sharding_ring_id, + self._segments[-1]._allreduce_vars, + self._shard, + op_role=OpRole.Backward, + use_calc_stream=False, + ) for idx, segment in reversed(list(enumerate(self._segments))): - allreduce_vars = self._segments[ - idx - 1]._allreduce_vars if idx > 0 else [] - broadcast_vars = self._segments[ - idx + - 1]._broadcast_vars if idx < len(self._segments) - 1 else [] - fill_constant_vars = self._segments[ - idx + - 2]._fill_constant_vars if idx < len(self._segments) - 2 else [] - cast_ops = self._segments[ - idx + 2]._cast_ops if idx < len(self._segments) - 2 else {} + allreduce_vars = ( + self._segments[idx - 1]._allreduce_vars if idx > 0 else [] + ) + broadcast_vars = ( + self._segments[idx + 1]._broadcast_vars + if idx < len(self._segments) - 1 + else [] + ) + fill_constant_vars = ( + self._segments[idx + 2]._fill_constant_vars + if idx < len(self._segments) - 2 + else [] + ) + cast_ops = ( + self._segments[idx + 2]._cast_ops + if idx < len(self._segments) - 2 + else {} + ) for op_idx in reversed(range(segment._start_idx, segment._end_idx)): op = block.ops[op_idx] for input_name in op.desc.input_arg_names(): - if input_name in segment._param2broadcast and \ - input_name != segment._param2broadcast[input_name]: - op._rename_input(input_name, - segment._param2broadcast[input_name]) + if ( + input_name in segment._param2broadcast + and input_name != segment._param2broadcast[input_name] + ): + op._rename_input( + input_name, segment._param2broadcast[input_name] + ) for param_name, broadcast_name in segment._param2broadcast.items(): if param_name != broadcast_name: block.create_var( name=broadcast_name, - shape=self._main_program.global_block().var( - param_name).shape, - dtype=self._main_program.global_block().var( - param_name).dtype, - persistable=False) + shape=self._main_program.global_block() + .var(param_name) + .shape, + dtype=self._main_program.global_block() + .var(param_name) + .dtype, + persistable=False, + ) # step1: remove cast ops block._sync_with_cpp() segment._end_idx += FP16Utils.remove_cast_op( - block, self._params, segment, 0) + block, self._params, segment, 0 + ) # step2: add Sync ops shard_allredue_vars = self._shard.filter_grads(allreduce_vars) - if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: - if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( - shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, segment._end_idx, - self.dp_ring_id, shard_allredue_vars) + if ( + self.gradient_merge_mode != "sharding_gm" + or self._gradient_merge_acc_step <= 1 + ): + if ( + self.hybrid_dp + and self.hybrid_dp_mode == "sharding_hybrid_dp" + and len(shard_allredue_vars) >= 1 + ): + insert_sync_comm_ops( + block, + segment._end_idx, + self.dp_ring_id, + shard_allredue_vars, + ) broad_cast_vars = [x[0] for x in broadcast_vars] if len(broad_cast_vars) > 0: - insert_sync_comm_ops(block, segment._end_idx, - self.sharding_ring_id, - broad_cast_vars) + insert_sync_comm_ops( + block, + segment._end_idx, + self.sharding_ring_id, + broad_cast_vars, + ) else: comm_dep_vars = allreduce_vars + [ x[0] for x in broadcast_vars ] if len(comm_dep_vars) > 0: - insert_sync_comm_ops(block, segment._end_idx, - self.sharding_ring_id, - comm_dep_vars) + insert_sync_comm_ops( + block, + segment._end_idx, + self.sharding_ring_id, + comm_dep_vars, + ) # gradient merge - elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: + elif ( + self.gradient_merge_mode == "sharding_gm" + and self._gradient_merge_acc_step > 1 + ): broad_cast_vars = [x[0] for x in broadcast_vars] if len(broad_cast_vars) > 0: - insert_sync_comm_ops(block, segment._end_idx, - self.sharding_ring_id, broad_cast_vars) + insert_sync_comm_ops( + block, + segment._end_idx, + self.sharding_ring_id, + broad_cast_vars, + ) - calc_dep_vars = fill_constant_vars + [ - k for k, v in cast_ops.items() - ] + self._segments[idx]._allreduce_vars + calc_dep_vars = ( + fill_constant_vars + + [k for k, v in cast_ops.items()] + + self._segments[idx]._allreduce_vars + ) if len(calc_dep_vars) > 0: - insert_sync_calc_op(block, segment._end_idx, - [calc_dep_vars[-1]]) + insert_sync_calc_op( + block, segment._end_idx, [calc_dep_vars[-1]] + ) # step3: insert `fill_constant` ops - insert_fill_constant_ops(block, segment._end_idx, - fill_constant_vars) + insert_fill_constant_ops( + block, segment._end_idx, fill_constant_vars + ) # step4: add `cast` ops insert_cast_ops(block, segment._end_idx, cast_ops) # step5: add broadcast ops # gradient merge - if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: + if ( + self.gradient_merge_mode == "sharding_gm" + and self._gradient_merge_acc_step > 1 + ): self.create_persistable_gradients_and_insert_merge_ops( - block, self._startup_program.global_block(), - segment._start_idx, shard_allredue_vars, self._shard) + block, + self._startup_program.global_block(), + segment._start_idx, + shard_allredue_vars, + self._shard, + ) - insert_broadcast_ops(block, segment._start_idx, - self.sharding_ring_id, broadcast_vars) + insert_broadcast_ops( + block, segment._start_idx, self.sharding_ring_id, broadcast_vars + ) # step6: add all_reduce ops # dp - if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: - if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( - shard_allredue_vars) >= 1: + if ( + self.gradient_merge_mode != "sharding_gm" + or self._gradient_merge_acc_step <= 1 + ): + if ( + self.hybrid_dp + and self.hybrid_dp_mode == "sharding_hybrid_dp" + and len(shard_allredue_vars) >= 1 + ): insert_allreduce_ops( block, segment._start_idx, self.dp_ring_id, shard_allredue_vars, - user_defined_strategy=self.user_defined_strategy) - insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + user_defined_strategy=self.user_defined_strategy, + ) + insert_sync_comm_ops( + block, + segment._start_idx, + self.sharding_ring_id, + allreduce_vars, + ) # gradient merge - elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: - insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + elif ( + self.gradient_merge_mode == "sharding_gm" + and self._gradient_merge_acc_step > 1 + ): + insert_sync_comm_ops( + block, + segment._start_idx, + self.sharding_ring_id, + allreduce_vars, + ) # sharding # allreduce --> reduce # TODO temp change if len(allreduce_vars) > 0: - insert_reduce_ops(block, - segment._start_idx, - self.sharding_ring_id, - allreduce_vars, - self._shard, - op_role=OpRole.Backward, - use_calc_stream=False) + insert_reduce_ops( + block, + segment._start_idx, + self.sharding_ring_id, + allreduce_vars, + self._shard, + op_role=OpRole.Backward, + use_calc_stream=False, + ) block._sync_with_cpp() if self._segments[0]._broadcast_vars: broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars] - insert_sync_comm_ops(block, self._segments[0]._start_idx, - self.sharding_ring_id, broadcast_vars) - insert_broadcast_ops(block, self._segments[0]._start_idx, - self.sharding_ring_id, - self._segments[0]._broadcast_vars) + insert_sync_comm_ops( + block, + self._segments[0]._start_idx, + self.sharding_ring_id, + broadcast_vars, + ) + insert_broadcast_ops( + block, + self._segments[0]._start_idx, + self.sharding_ring_id, + self._segments[0]._broadcast_vars, + ) fill_constant_vars = [] for x in self._segments[:2]: @@ -1254,12 +1485,14 @@ def _add_broadcast_allreduce(self, block): calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()] if fill_constant_vars or cast_ops: - insert_sync_calc_op(block, self._segments[0]._start_idx, - [calc_deps_vars[-1]]) + insert_sync_calc_op( + block, self._segments[0]._start_idx, [calc_deps_vars[-1]] + ) if fill_constant_vars: - insert_fill_constant_ops(block, self._segments[0]._start_idx, - fill_constant_vars) + insert_fill_constant_ops( + block, self._segments[0]._start_idx, fill_constant_vars + ) if cast_ops: insert_cast_ops(block, self._segments[0]._start_idx, cast_ops) @@ -1273,7 +1506,7 @@ def _prune_startup_program(self, block, shard): continue if self._optimizer_sharding and shard.is_param(output_name): continue - #TODO why do we remove op, when only one var is removed + # TODO why do we remove op, when only one var is removed block._remove_op(idx, sync=False) break @@ -1295,23 +1528,36 @@ def _build_groups(self): pp: 4 pp-pair: >= 20 if one parallelism is not enable: -1 - and only support parallelism hierarchy: mp --> sharding --> pp --> dp + and only support parallelism hierarchy: mp --> sharding --> pp --> dp """ # step 1: initialize nccl self.global_word_size = self.role_maker._worker_num() self.global_rank = self.role_maker._worker_index() self.global_endpoints = self.role_maker._get_trainer_endpoints() self.current_endpoint = self.global_endpoints[self.global_rank] - self._collective_helper = CollectiveHelper(self.role_maker, - nrings=self._nrings_sharding) - assert self.global_word_size % self.mp_degree == 0, \ - "global_word_size: {} should be divisible to the mp_degree: {}".format(self.global_word_size, self.mp_degree) - assert self.global_word_size % self.sharding_degree == 0, \ - "global_word_size: {} should be divisible to the sharding_degree: {}".format(self.global_word_size, self.sharding_degree) - assert self.global_word_size % self.pp_degree == 0, \ - "global_word_size: {} should be divisible to the pp_degree: {}".format(self.global_word_size, self.pp_degree) - assert self.global_word_size % self.dp_degree == 0, \ - "global_word_size: {} should be divisible to the dp_degree: {}".format(self.global_word_size, self.dp_degree) + self._collective_helper = CollectiveHelper( + self.role_maker, nrings=self._nrings_sharding + ) + assert ( + self.global_word_size % self.mp_degree == 0 + ), "global_word_size: {} should be divisible to the mp_degree: {}".format( + self.global_word_size, self.mp_degree + ) + assert ( + self.global_word_size % self.sharding_degree == 0 + ), "global_word_size: {} should be divisible to the sharding_degree: {}".format( + self.global_word_size, self.sharding_degree + ) + assert ( + self.global_word_size % self.pp_degree == 0 + ), "global_word_size: {} should be divisible to the pp_degree: {}".format( + self.global_word_size, self.pp_degree + ) + assert ( + self.global_word_size % self.dp_degree == 0 + ), "global_word_size: {} should be divisible to the dp_degree: {}".format( + self.global_word_size, self.dp_degree + ) # mp group if self.mp_degree > 1: @@ -1319,14 +1565,16 @@ def _build_groups(self): self.mp_rank = self.global_rank % self.mp_degree self.mp_group_id = self.global_rank // self.mp_degree self.mp_group_endpoints = [ - ep for idx, ep in enumerate(self.global_endpoints) + ep + for idx, ep in enumerate(self.global_endpoints) if idx // self.mp_degree == self.mp_group_id ] assert self.current_endpoint in self.mp_group_endpoints - assert len( - self.mp_group_endpoints - ) == self.mp_degree, "num of mp worker in group is [{}], but mp group size is [{}]".format( - len(self.mp_group_endpoints), self.mp_degree) + assert ( + len(self.mp_group_endpoints) == self.mp_degree + ), "num of mp worker in group is [{}], but mp group size is [{}]".format( + len(self.mp_group_endpoints), self.mp_degree + ) else: self.mp_degree = 1 self.mp_ring_id = -1 @@ -1337,23 +1585,28 @@ def _build_groups(self): # sharding if self.sharding_degree > 1: self.sharding_ring_id = 1 - self.sharding_rank = (self.global_rank // - self.mp_degree) % self.sharding_degree - self.sharding_group_id = self.global_rank // (self.mp_degree * - self.sharding_degree) + self.sharding_rank = ( + self.global_rank // self.mp_degree + ) % self.sharding_degree + self.sharding_group_id = self.global_rank // ( + self.mp_degree * self.sharding_degree + ) # mp + sharding + ... if self.mp_degree > 1: self.sharding_group_endpoints = [ - ep for idx, ep in enumerate(self.global_endpoints) - if (idx // (self.mp_degree * self.sharding_degree)) == self. - sharding_group_id and idx % self.mp_degree == self.mp_rank + ep + for idx, ep in enumerate(self.global_endpoints) + if (idx // (self.mp_degree * self.sharding_degree)) + == self.sharding_group_id + and idx % self.mp_degree == self.mp_rank ] # sharding + ... else: self.sharding_group_endpoints = [ - ep for idx, ep in enumerate(self.global_endpoints) - if (idx // (self.mp_degree * self.sharding_degree) - ) == self.sharding_group_id + ep + for idx, ep in enumerate(self.global_endpoints) + if (idx // (self.mp_degree * self.sharding_degree)) + == self.sharding_group_id ] assert self.current_endpoint in self.sharding_group_endpoints else: @@ -1368,20 +1621,28 @@ def _build_groups(self): self.pp_pair_ring_id = 20 # pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3 self.pp_ring_id = 4 - self.pp_rank = self.global_rank // (self.sharding_degree * - self.mp_degree) % self.pp_degree + self.pp_rank = ( + self.global_rank + // (self.sharding_degree * self.mp_degree) + % self.pp_degree + ) # (NOTE): Already adjust for (outter-pure) dp self.pp_group_id = self.global_rank // ( - self.mp_degree * self.sharding_degree * self.pp_degree) + self.mp_degree * self.sharding_degree * self.pp_degree + ) pp_first_stage_idx = self.global_rank % ( - self.sharding_degree * self.mp_degree) + self.pp_group_id * ( - self.mp_degree * self.sharding_degree * self.pp_degree) + self.sharding_degree * self.mp_degree + ) + self.pp_group_id * ( + self.mp_degree * self.sharding_degree * self.pp_degree + ) pp_stage_offset = self.sharding_degree * self.mp_degree self.pp_group_endpoints = [] for i in range(self.pp_degree): self.pp_group_endpoints.append( - self.global_endpoints[pp_first_stage_idx + - pp_stage_offset * i]) + self.global_endpoints[ + pp_first_stage_idx + pp_stage_offset * i + ] + ) assert self.current_endpoint in self.pp_group_endpoints else: self.pp_ring_id = -1 @@ -1397,29 +1658,48 @@ def _build_groups(self): # sharding-hybrid-dp as one senario of outter-pure-dp local_pp_degree = self.pp_degree if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): - assert self.pp_degree == 2, ("For manually set pipeline, only " - "pp_degree = 2 is supported.") - assert self.global_word_size == self.mp_degree * self.sharding_degree * self.dp_degree, \ - "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format( - self.global_word_size, self.mp_degree, self.sharding_degree, self.dp_degree) + assert self.pp_degree == 2, ( + "For manually set pipeline, only " "pp_degree = 2 is supported." + ) + assert ( + self.global_word_size + == self.mp_degree * self.sharding_degree * self.dp_degree + ), "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format( + self.global_word_size, + self.mp_degree, + self.sharding_degree, + self.dp_degree, + ) local_pp_degree = 1 else: - assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format( - self.mp_degree, self.sharding_degree, self.pp_degree, - self.dp_degree, self.global_word_size) + assert ( + self.global_word_size + == self.mp_degree + * self.sharding_degree + * self.pp_degree + * self.dp_degree + ), "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format( + self.mp_degree, + self.sharding_degree, + self.pp_degree, + self.dp_degree, + self.global_word_size, + ) if self.dp_degree > 1: self.dp_ring_id = 2 self.dp_rank = self.global_rank // ( - self.sharding_degree * self.mp_degree * local_pp_degree) + self.sharding_degree * self.mp_degree * local_pp_degree + ) dp_first_rank_idx = self.global_rank % ( - self.sharding_degree * self.mp_degree * local_pp_degree) - dp_offset = (self.sharding_degree * self.mp_degree * - local_pp_degree) + self.sharding_degree * self.mp_degree * local_pp_degree + ) + dp_offset = self.sharding_degree * self.mp_degree * local_pp_degree self.dp_group_endpoints = [] for i in range(self.dp_degree): self.dp_group_endpoints.append( - self.global_endpoints[dp_first_rank_idx + dp_offset * i]) + self.global_endpoints[dp_first_rank_idx + dp_offset * i] + ) assert self.current_endpoint in self.dp_group_endpoints logger.info("Hybrid DP mode turn on !") else: @@ -1448,8 +1728,9 @@ def _build_groups(self): logger.info("sharding group size: {}".format(self.sharding_degree)) logger.info("sharding rank: {}".format(self.sharding_rank)) logger.info("sharding group id: {}".format(self.sharding_group_id)) - logger.info("sharding group endpoints: {}".format( - self.sharding_group_endpoints)) + logger.info( + "sharding group endpoints: {}".format(self.sharding_group_endpoints) + ) logger.info("sharding ring id: {}".format(self.sharding_ring_id)) logger.info("#####" * 6) @@ -1462,15 +1743,15 @@ def _build_groups(self): logger.info("pure dp group size: {}".format(self.dp_degree)) logger.info("pure dp rank: {}".format(self.dp_rank)) - logger.info("pure dp group endpoints: {}".format( - self.dp_group_endpoints)) + logger.info( + "pure dp group endpoints: {}".format(self.dp_group_endpoints) + ) logger.info("pure dp ring id: {}".format(self.dp_ring_id)) logger.info("#####" * 6) return def _recreate_not_persist_param_as_var(self): - def recreate_not_persist_param_as_var(program): block = program.global_block() params = block.all_parameters() @@ -1494,14 +1775,16 @@ def recreate_not_persist_param_as_var(program): is_distributed = param.is_distributed block._remove_var(name, sync=False) - var = block.create_var(name=name, - shape=shape, - dtype=dtype, - type=type, - lod_level=lod_level, - stop_gradient=stop_gradient, - trainable=trainable, - persistable=False) + var = block.create_var( + name=name, + shape=shape, + dtype=dtype, + type=type, + lod_level=lod_level, + stop_gradient=stop_gradient, + trainable=trainable, + persistable=False, + ) if have_dist_attr: var.is_distributed = is_distributed @@ -1516,6 +1799,13 @@ def _initialization_broadcast(self): identical when hybrid-dp is used, and the initialization of not distributed param between mp group to be identical. """ + + def _find_master_param(all_vars_name, param_name): + for var_name in all_vars_name: + if param_name in var_name and "fp32_master" in var_name: + return var_name + return None + if self.dp_degree <= 1 and self.mp_degree <= 1: return @@ -1536,8 +1826,10 @@ def _initialization_broadcast(self): if op.type == 'c_broadcast': broadcast_params.add(op.desc.output_arg_names()[0]) + all_vars_name = startup_block.vars for param in params_name: - if param in broadcast_params: continue + if param in broadcast_params: + continue rings = [] # need sync not distributed param in mp group @@ -1547,30 +1839,51 @@ def _initialization_broadcast(self): rings.append(self.dp_ring_id) for ring in rings: - startup_block.append_op(type='c_broadcast', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': ring, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward - }) + startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': ring, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward, + }, + ) + # Broadcast the master weight at the same time for AMP-O2 training. + master_param = _find_master_param(all_vars_name, param) + if master_param is not None: + startup_block.append_op( + type='c_broadcast', + inputs={'X': master_param}, + outputs={'Out': master_param}, + attrs={ + 'ring_id': ring, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward, + }, + ) startup_block._sync_with_cpp() # sharding gradient merge def create_persistable_gradients_and_insert_merge_ops( - self, main_block, startup_block, insert_idx, grad_names, shard): + self, main_block, startup_block, insert_idx, grad_names, shard + ): for grad_name in grad_names: - assert get_grad_device( - grad_name, shard - ) == shard.worker_idx, "try to merge gradient not belong to current shard: [{}]".format( - grad_name) + assert ( + get_grad_device(grad_name, shard) == shard.worker_idx + ), "try to merge gradient not belong to current shard: [{}]".format( + grad_name + ) persistable_grad_name = grad_name + '@GradiantMerge' - assert grad_name not in self._grad2merged_grad, "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format( - grad_name) + assert ( + grad_name not in self._grad2merged_grad + ), "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format( + grad_name + ) self._grad2merged_grad[grad_name] = persistable_grad_name grad_var = main_block.var(grad_name) # create var @@ -1578,36 +1891,38 @@ def create_persistable_gradients_and_insert_merge_ops( name=persistable_grad_name, shape=grad_var.shape, dtype=grad_var.dtype, - persistable=True) + persistable=True, + ) startup_gradient_merge_var = startup_block.create_var( name=persistable_grad_name, shape=grad_var.shape, dtype=grad_var.dtype, - persistable=True) + persistable=True, + ) # merge gradient main_block._insert_op_without_sync( insert_idx, type="elementwise_add", - inputs={ - 'X': grad_name, - 'Y': gradient_merge_var - }, + inputs={'X': grad_name, 'Y': gradient_merge_var}, outputs={'Out': gradient_merge_var}, attrs={ 'axis': -1, 'use_mkldnn': False, - OP_ROLE_KEY: OpRole.Backward - }) + OP_ROLE_KEY: OpRole.Backward, + }, + ) # startup initialization - startup_block.append_op(type="fill_constant", - outputs={"Out": startup_gradient_merge_var}, - attrs={ - "shape": grad_var.shape, - "dtype": grad_var.dtype, - "value": float(0), - }) + startup_block.append_op( + type="fill_constant", + outputs={"Out": startup_gradient_merge_var}, + attrs={ + "shape": grad_var.shape, + "dtype": grad_var.dtype, + "value": float(0), + }, + ) main_block._sync_with_cpp() startup_block._sync_with_cpp() @@ -1620,14 +1935,17 @@ def _create_gm_cond(self, main_block): value=int(self._gradient_merge_acc_step), dtype='int32', persistable=True, - force_cpu=True) + force_cpu=True, + ) - zero_var = layers.create_global_var(name="gradient_merge_zero", - shape=[1], - value=int(0), - dtype='int32', - persistable=True, - force_cpu=True) + zero_var = layers.create_global_var( + name="gradient_merge_zero", + shape=[1], + value=int(0), + dtype='int32', + persistable=True, + force_cpu=True, + ) # Add step var & cond var current_step_var = layers.create_global_var( @@ -1636,42 +1954,40 @@ def _create_gm_cond(self, main_block): value=int(0), dtype='int32', persistable=True, - force_cpu=True) + force_cpu=True, + ) - cond_var = main_block.create_var(name="gradient_merge_cond", - shape=[1], - dtype='bool') + cond_var = main_block.create_var( + name="gradient_merge_cond", shape=[1], dtype='bool' + ) with device_guard("cpu"): # step_var = (step_var + 1) % k_step - main_block.append_op(type='increment', - inputs={'X': [current_step_var]}, - outputs={'Out': [current_step_var]}, - attrs={ - 'step': float(1), - OP_ROLE_KEY: OpRole.Optimize - }) - - main_block.append_op(type='elementwise_mod', - inputs={ - 'X': current_step_var, - 'Y': acc_step_var - }, - outputs={'Out': current_step_var}, - attrs={ - 'axis': -1, - OP_ROLE_KEY: OpRole.Optimize, - 'use_mkldnn': False - }) + main_block.append_op( + type='increment', + inputs={'X': [current_step_var]}, + outputs={'Out': [current_step_var]}, + attrs={'step': float(1), OP_ROLE_KEY: OpRole.Optimize}, + ) + + main_block.append_op( + type='elementwise_mod', + inputs={'X': current_step_var, 'Y': acc_step_var}, + outputs={'Out': current_step_var}, + attrs={ + 'axis': -1, + OP_ROLE_KEY: OpRole.Optimize, + 'use_mkldnn': False, + }, + ) # cond_var = (step_var == 0) - main_block.append_op(type='equal', - inputs={ - 'X': current_step_var, - 'Y': zero_var - }, - outputs={'Out': cond_var}, - attrs={OP_ROLE_KEY: OpRole.Optimize}) + main_block.append_op( + type='equal', + inputs={'X': current_step_var, 'Y': zero_var}, + outputs={'Out': cond_var}, + attrs={OP_ROLE_KEY: OpRole.Optimize}, + ) # paddle.static.Print(current_step_var, message="in FWBW last conditional") return cond_var @@ -1681,7 +1997,7 @@ def _true_apply_gradient(self): grad@gradientmerge / acc_step re-create all optimize ops of origin main block and rename them cast(backward) - amp + amp clip opt # fill constant grad@gradientmerge @@ -1698,35 +2014,37 @@ def _true_apply_gradient(self): # allreduce grad@gradientmerge if self.hybrid_dp: - assert self.dp_ring_id >= 0, "dp_ring_id should larger than 0 when in sharding&DP mode" + assert ( + self.dp_ring_id >= 0 + ), "dp_ring_id should larger than 0 when in sharding&DP mode" for grad, merged_grad in self._grad2merged_grad.items(): merged_grad_var = main_block.var(merged_grad) - cur_block.append_op(type='c_allreduce_sum', - inputs={'X': merged_grad_var}, - outputs={'Out': merged_grad_var}, - attrs={ - 'ring_id': self.dp_ring_id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize - }) + cur_block.append_op( + type='c_allreduce_sum', + inputs={'X': merged_grad_var}, + outputs={'Out': merged_grad_var}, + attrs={ + 'ring_id': self.dp_ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }, + ) # grad@gradientmerge / acc_step for grad, merged_grad in self._grad2merged_grad.items(): # grad /= k_steps merged_grad_var = main_block.var(merged_grad) - cur_block.append_op(type='scale', - inputs={'X': merged_grad_var}, - outputs={'Out': merged_grad_var}, - attrs={ - 'scale': - 1.0 / float(self._gradient_merge_acc_step), - 'bias': - 0.0, - 'bias_after_scale': - False, - OP_ROLE_KEY: - OpRole.Optimize - }) + cur_block.append_op( + type='scale', + inputs={'X': merged_grad_var}, + outputs={'Out': merged_grad_var}, + attrs={ + 'scale': 1.0 / float(self._gradient_merge_acc_step), + 'bias': 0.0, + 'bias_after_scale': False, + OP_ROLE_KEY: OpRole.Optimize, + }, + ) # re-create optimize ops already_moved_var_names = [] @@ -1737,15 +2055,19 @@ def _true_apply_gradient(self): for input_name in new_op_desc.input_arg_names(): if input_name in self._grad2merged_grad: new_op_desc._rename_input( - input_name, self._grad2merged_grad[input_name]) + input_name, self._grad2merged_grad[input_name] + ) for output_name in new_op_desc.output_arg_names(): if output_name in self._grad2merged_grad: new_op_desc._rename_output( - output_name, self._grad2merged_grad[output_name]) + output_name, self._grad2merged_grad[output_name] + ) # move non temp optimize vars from block0 to cond block - if output_name not in already_moved_var_names and output_name not in self._grad2merged_grad.keys( + if ( + output_name not in already_moved_var_names + and output_name not in self._grad2merged_grad.keys() ): var_ = self._main_program.global_block().var(output_name) if not var_.persistable: @@ -1754,11 +2076,14 @@ def _true_apply_gradient(self): shape_ = var_.shape type_ = var_.dtype self._main_program.global_block()._remove_var( - var_.name, sync=False) - self.cond_block.create_var(name=name_, - shape=shape_, - dtype=type_, - persistable=False) + var_.name, sync=False + ) + self.cond_block.create_var( + name=name_, + shape=shape_, + dtype=type_, + persistable=False, + ) already_moved_var_names.append(name_) self._main_program.global_block()._sync_with_cpp() @@ -1767,14 +2092,16 @@ def _true_apply_gradient(self): # fill zero to grad@gradientmerge for grad, merged_grad in self._grad2merged_grad.items(): merged_grad_var = main_block.var(merged_grad) - cur_block.append_op(type='fill_constant', - outputs={'Out': merged_grad_var}, - attrs={ - "shape": merged_grad_var.shape, - "dtype": merged_grad_var.dtype, - "value": float(0), - OP_ROLE_KEY: OpRole.Optimize - }) + cur_block.append_op( + type='fill_constant', + outputs={'Out': merged_grad_var}, + attrs={ + "shape": merged_grad_var.shape, + "dtype": merged_grad_var.dtype, + "value": float(0), + OP_ROLE_KEY: OpRole.Optimize, + }, + ) # lr_var = main_block.var("gradient_merge_current_step") # paddle.static.Print(lr_var, message="in OPTIMIZE last conditional") @@ -1786,7 +2113,10 @@ def _sharding_gradient_merge(self): create cond block """ - if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: + if ( + self.gradient_merge_mode != "sharding_gm" + or self._gradient_merge_acc_step <= 1 + ): return main_block = self._main_program.global_block() @@ -1805,7 +2135,8 @@ def _sharding_gradient_merge(self): main_block._remove_op(op_idx, sync=False) tmp_copy_block._sync_with_cpp() self.original_optimize_ops_desc = list( - reversed(self.original_optimize_ops_desc)) + reversed(self.original_optimize_ops_desc) + ) # back to block 0 self._main_program._rollback() @@ -1822,18 +2153,17 @@ def _sharding_gradient_merge(self): # cond op step_scope = self._main_program.global_block().create_var( - type=core.VarDesc.VarType.STEP_SCOPES) + type=core.VarDesc.VarType.STEP_SCOPES + ) conditional_block_op = self._main_program.global_block().append_op( type='conditional_block', inputs={ 'Cond': cond, 'Input': [], }, - outputs={ - 'Out': [], - 'Scope': [step_scope] - }, + outputs={'Out': [], 'Scope': [step_scope]}, attrs={ 'sub_block': cond_block, 'is_scalar_condition': True, - }) + }, + )