Skip to content

Commit

Permalink
[Dygraph] Fix bugs of mp in eager mode (#46303)
Browse files Browse the repository at this point in the history
* fix bugs of mp

* fix bugs of mp

* update

* update

* fix bug
  • Loading branch information
haohongxiang authored Sep 22, 2022
1 parent 8bed319 commit 1100243
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(rid);
distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
opts.reduce_op = distributed::ReduceOp::MAX;

// allocate memory on device.
softmax->mutable_data<T>(place);
Expand Down Expand Up @@ -348,6 +348,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {

in_out.clear();
in_out.push_back(predicted_logits);
opts.reduce_op = distributed::ReduceOp::SUM;
pg->AllReduce(in_out, in_out, opts)->Synchronize();

// step 4, obtain exp(logit)
Expand All @@ -364,6 +365,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {

in_out.clear();
in_out.push_back(sum_exp_logits);
opts.reduce_op = distributed::ReduceOp::SUM;
pg->AllReduce(in_out, in_out, opts)->Synchronize();

auto eigen_loss = math::EigenMatrix<T>::From(loss_2d);
Expand Down
21 changes: 19 additions & 2 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,38 @@ def _broadcast_data_help(data, shape, dtype, hcg):
group=model_parallel_group,
sync_op=True)

if mp_rank != 0:
if in_dygraph_mode():
data._clear_data()
input_data._share_buffer_to(data)
else:
data.value().get_tensor()._clear()
data.value().get_tensor()._share_data_with(
input_data.value().get_tensor())


def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device()
for v in inputs:
if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad():
v = v.cuda() if "gpu" in cur_device else v
if "gpu" in cur_device and in_dygraph_mode() \
and not v.place.is_gpu_place():
v_gpu = v.cuda(int(cur_device.split(":")[1]))
v._clear_data()
v_gpu._share_buffer_to(v)
_broadcast_data_help(v, v.shape, v.dtype, hcg)
else:
logger.error("it doesn't support data type {}".format(type(v)))

for k, v in kwargs.items():
if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad():
v = v.cuda() if "gpu" in cur_device else v
if "gpu" in cur_device and in_dygraph_mode() \
and not v.place.is_gpu_place():
v_gpu = v.cuda(int(cur_device.split(":")[1]))
v._clear_data()
v_gpu._share_buffer_to(v)
_broadcast_data_help(v, v.shape, v.dtype, hcg)
kwargs[k] = v
else:
Expand Down

0 comments on commit 1100243

Please sign in to comment.