Skip to content

Commit

Permalink
Merge pull request #9072 from chengduoZH/feature/refine_parallel_do
Browse files Browse the repository at this point in the history
Refine parallel_do_grad
  • Loading branch information
chengduo authored Mar 15, 2018
2 parents 41894da + ef28e7d commit 11c43e5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/assign_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class AssignFunctor {
private:
void copy_tensor(const framework::LoDTensor &lod_tensor,
framework::LoDTensor *out) const {
if (lod_tensor.numel() == 0) return;
auto &out_tensor = *out;
TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor);
out_tensor.set_lod(lod_tensor.lod());
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/nccl_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
T* recvbuffer = nullptr;
if (root == gpu_id) {
recvbuffer = out->mutable_data<T>(ctx.GetPlace());
} else {
out->Resize(framework::make_ddim({0}));
}
VLOG(3) << "gpu : " << gpu_id << " invoke reduce. send " << x->numel()
<< " recv " << out->numel();
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,15 @@ def __call__(self, block, context):
if o_argu in self.param_grad_names:
allreduce_out_name = o_argu + "__nccl_all_reduce__"
op_desc = _create_op_desc_(
"ncclAllReduce", {
"ncclReduce",
{
"X": [o_argu],
"Communicator":
['nccl_com__do_not_change_']
}, {"Out": [allreduce_out_name]},
{"reduction": "ncclSum"})
},
{"Out": [allreduce_out_name]},
{"reduction": "ncclSum",
"root": 0}, )
block.desc.append_op().copy_from(op_desc)

op_desc = _create_op_desc_(
Expand Down

0 comments on commit 11c43e5

Please sign in to comment.