Skip to content

Commit

Permalink
add PADDLE_LOSS_SCALE env, add PADDLE_FUSE_ALLREDUCE env (PaddlePaddl…
Browse files Browse the repository at this point in the history
…e#167)

fix lightgcn loss bug
  • Loading branch information
qingshui authored and root committed Nov 28, 2022
1 parent 8d73deb commit 9b59b75
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions python/paddle/fluid/transpiler/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,10 @@ class SingleProcessMultiThread(GradAllReduce):
def __init__(self):
GradAllReduce.__init__(self, 1)
self.mode = "single_process_multi_thread"
self.fuse_allreduce = os.getenv("FLAGS_fuse_allreduce", False)
self.gpu_nums = os.getenv("FLAGS_selected_gpus",
"0,1,2,3,4,5,6,7,8").split(",")
self.fuse_allreduce = int(os.getenv("PADDLE_FUSE_ALLREDUCE", "1"))
self.loss_scale = int(os.getenv("PADDLE_LOSS_SCALE", "1"))
self.gpu_nums = len(os.getenv("FLAGS_selected_gpus",
"0,1,2,3,4,5,6,7").split(","))

def _transpile_startup_program(self):
nodes_num = 0
Expand All @@ -504,13 +505,18 @@ def _transpile_startup_program(self):
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})

def _transpile_main_program(self):
if self._get_update_param_count() == 0:
# not need loss scale and no dense param
param_cnt = self._get_update_param_count()
if self.loss_scale is 0 and param_cnt is 0:
return
# scale loss
self._insert_scale_loss_grad_ops()
# no param
if param_cnt is 0:
return
# fuse allreduce
if self.fuse_allreduce:
print("begin used fuse_allreduce")
if self.fuse_allreduce > 0:
print("begin used fuse_allreduce param count = %s" % (param_cnt))
# use fuse allreduce
self._insert_fuse_allreduce_ops()
else:
Expand Down Expand Up @@ -546,6 +552,7 @@ def _insert_scale_loss_grad_ops(self):
training workers, we scale the loss grad by the number of workers
'''
scale = 1.0 / self.nranks / self.gpu_nums
print("begin _insert_scale_loss_grad_ops scale = %s" % (scale))
block = self.main_program.global_block()
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_loss_grad_op(op):
Expand Down

0 comments on commit 9b59b75

Please sign in to comment.