Skip to content

Commit

Permalink
follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Apr 11, 2018
1 parent 8eaec5d commit 6db96ec
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 27 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
nv_library(broad_cast_op_handle SRCS broad_cast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
endif()

cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
Expand All @@ -15,8 +15,8 @@ cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)

if(WITH_GPU)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_test(broad_cast_op_test SRCS broad_cast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broad_cast_op_handle)
nv_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broadcast_op_handle)
else()
set(multi_devices_graph_builder_deps)
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/broad_cast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"

namespace paddle {
namespace framework {
Expand All @@ -28,16 +28,16 @@ Tensor *GetTensorFromVar(Variable *in_var) {
}
return nullptr;
}
BCastOpHandle::BCastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs)
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), ctxs_(ctxs) {
for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs_.DevCtx(p);
}
}

void BCastOpHandle::RunImpl() {
void BroadcastOpHandle::RunImpl() {
PADDLE_ENFORCE_EQ(this->inputs_.size(), 1);
PADDLE_ENFORCE_EQ(this->outputs_.size(), places_.size());

Expand Down Expand Up @@ -97,7 +97,7 @@ void BCastOpHandle::RunImpl() {
}
}

std::string BCastOpHandle::Name() const { return "broadcast"; }
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ namespace framework {
namespace details {

/*
* BroadCast the input to all scope.
* Broadcast the input to all scope.
*
*/
struct BCastOpHandle : public OpHandleBase {
struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const platform::ContextMap &ctxs_;

BCastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs);
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs);

std::string Name() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/broad_cast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "gtest/gtest.h"

#include "paddle/fluid/platform/device_context.h"
Expand All @@ -23,12 +23,12 @@ namespace p = paddle::platform;
// test data amount
const f::DDim kDims = {20, 20};

class BroadCastTester : public ::testing::Test {
class BroadcastTester : public ::testing::Test {
public:
void SetUp() override {
int count = p::GetCUDADeviceCount();
if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu BroadCast, because the CUDA "
LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
"device count is "
<< count;
exit(0);
Expand All @@ -40,7 +40,7 @@ class BroadCastTester : public ::testing::Test {
}

template <class T>
void BroadCastInitOp(int gpu_id = 0) {
void BroadcastInitOp(int gpu_id = 0) {
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scope_.push_back(&g_scope_.NewScope());
auto* out_var = local_scope_[j]->Var("out");
Expand All @@ -50,7 +50,7 @@ class BroadCastTester : public ::testing::Test {
in_var->GetMutable<T>();

bc_op_handle_ =
new f::details::BCastOpHandle(local_scope_, gpu_list_, *ctxs_);
new f::details::BroadcastOpHandle(local_scope_, gpu_list_, *ctxs_);

f::details::VarHandle* in_var_handle = new f::details::VarHandle();
in_var_handle->place_ = gpu_list_[gpu_id];
Expand All @@ -68,7 +68,7 @@ class BroadCastTester : public ::testing::Test {
bc_op_handle_->AddOutput(out_var_handle);
}
}
void BroadCastDestroy() {
void BroadcastDestroy() {
delete ctxs_;
for (auto in : bc_op_handle_->inputs_) {
delete in;
Expand All @@ -84,12 +84,12 @@ class BroadCastTester : public ::testing::Test {
p::ContextMap* ctxs_;
std::vector<f::Scope*> local_scope_;
std::vector<p::Place> gpu_list_;
f::details::BCastOpHandle* bc_op_handle_;
f::details::BroadcastOpHandle* bc_op_handle_;
};

TEST_F(BroadCastTester, BroadCastTestLodTensor) {
TEST_F(BroadcastTester, BroadcastTestLodTensor) {
int gpu_id = 0;
BroadCastInitOp<f::LoDTensor>(gpu_id);
BroadcastInitOp<f::LoDTensor>(gpu_id);

auto in_var = local_scope_[gpu_id]->Var("input");
auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>();
Expand Down Expand Up @@ -122,12 +122,12 @@ TEST_F(BroadCastTester, BroadCastTestLodTensor) {
}
}

BroadCastDestroy();
BroadcastDestroy();
}

TEST_F(BroadCastTester, BroadCastTestSelectedRows) {
TEST_F(BroadcastTester, BroadcastTestSelectedRows) {
int gpu_id = 0;
BroadCastInitOp<f::SelectedRows>(gpu_id);
BroadcastInitOp<f::SelectedRows>(gpu_id);

auto in_var = local_scope_[gpu_id]->Var("input");
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
Expand Down Expand Up @@ -170,5 +170,5 @@ TEST_F(BroadCastTester, BroadCastTestSelectedRows) {
}
}

BroadCastDestroy();
BroadcastDestroy();
}

0 comments on commit 6db96ec

Please sign in to comment.