From 9b59b755cb686fd1d3409a1ec241806004053abf Mon Sep 17 00:00:00 2001 From: qingshui Date: Thu, 24 Nov 2022 16:23:24 +0800 Subject: [PATCH] add PADDLE_LOSS_SCALE env, add PADDLE_FUSE_ALLREDUCE env (#167) fix lightgcn loss bug --- python/paddle/fluid/transpiler/collective.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index 433b29d2dddcd..d4fa63930acf1 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -477,9 +477,10 @@ class SingleProcessMultiThread(GradAllReduce): def __init__(self): GradAllReduce.__init__(self, 1) self.mode = "single_process_multi_thread" - self.fuse_allreduce = os.getenv("FLAGS_fuse_allreduce", False) - self.gpu_nums = os.getenv("FLAGS_selected_gpus", - "0,1,2,3,4,5,6,7,8").split(",") + self.fuse_allreduce = int(os.getenv("PADDLE_FUSE_ALLREDUCE", "1")) + self.loss_scale = int(os.getenv("PADDLE_LOSS_SCALE", "1")) + self.gpu_nums = len(os.getenv("FLAGS_selected_gpus", + "0,1,2,3,4,5,6,7").split(",")) def _transpile_startup_program(self): nodes_num = 0 @@ -504,13 +505,18 @@ def _transpile_startup_program(self): block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) def _transpile_main_program(self): - if self._get_update_param_count() == 0: + # not need loss scale and no dense param + param_cnt = self._get_update_param_count() + if self.loss_scale is 0 and param_cnt is 0: return # scale loss self._insert_scale_loss_grad_ops() + # no param + if param_cnt is 0: + return # fuse allreduce - if self.fuse_allreduce: - print("begin used fuse_allreduce") + if self.fuse_allreduce > 0: + print("begin used fuse_allreduce param count = %s" % (param_cnt)) # use fuse allreduce self._insert_fuse_allreduce_ops() else: @@ -546,6 +552,7 @@ def _insert_scale_loss_grad_ops(self): training workers, we scale the loss grad by the number of workers ''' scale = 1.0 / self.nranks / self.gpu_nums + print("begin _insert_scale_loss_grad_ops scale = %s" % (scale)) block = self.main_program.global_block() for idx, op in reversed(list(enumerate(block.ops))): if not self._is_loss_grad_op(op):