Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/Add Broadcast and Gather op handle #9825

10 changes: 9 additions & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cc_library(var_handle SRCS var_handle.cc DEPS place)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
Expand All @@ -20,3 +20,11 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)

cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)

cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context gather_op_handle)
111 changes: 111 additions & 0 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/broadcast_op_handle.h"

namespace paddle {
namespace framework {
namespace details {

Tensor *GetTensorFromVar(Variable *in_var) {
if (in_var->IsType<LoDTensor>()) {
return in_var->GetMutable<LoDTensor>();
} else if (in_var->IsType<SelectedRows>()) {
return in_var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
return nullptr;
}

BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}

void BroadcastOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
"The number of input should be one.");

// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
}

PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");

// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}

//
auto in_scope_idx = in_var_handle[0]->scope_idx_;
auto in_var =
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
Tensor *in_tensor = GetTensorFromVar(in_var);

for (auto *out : out_var_handles) {
auto &out_p = out->place_;
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);

PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
"Places must be all on CPU or all on CUDA.");

if (in_var->IsType<framework::SelectedRows>()) {
auto &in_sr = in_var->Get<framework::SelectedRows>();
auto out_sr = out_var->GetMutable<framework::SelectedRows>();
if (&in_sr == out_sr) continue;
out_sr->set_height(in_sr.height());
out_sr->set_rows(in_sr.rows());
out_sr->mutable_value()->Resize(in_sr.value().dims());
out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto in_lod = in_var->Get<framework::LoDTensor>();
auto out_lod = out_var->GetMutable<framework::LoDTensor>();
if (&in_lod == out_lod) continue;
out_lod->set_lod(in_lod.lod());
out_lod->Resize(in_lod.dims());
out_lod->mutable_data(out_p, in_lod.type());
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
}

Tensor *out_tensor = GetTensorFromVar(out_var);
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
out_tensor);
}
}

std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details
} // namespace framework
} // namespace paddle
48 changes: 48 additions & 0 deletions paddle/fluid/framework/details/broadcast_op_handle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {
namespace details {

struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;

BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);

std::string Name() const override;

bool IsMultiDeviceTransfer() override { return false; };

protected:
void RunImpl() override;
};

} // namespace details
} // namespace framework
} // namespace paddle
Loading