Skip to content

Commit

Permalink
follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Apr 13, 2018
1 parent 02842cf commit 384d6ee
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 40 deletions.
37 changes: 6 additions & 31 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,24 @@ void BroadcastOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) {
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[in_place]);
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
if (platform::is_same_place(in_place, out_p)) continue;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}

//
auto in_scope_idx = in_var_handle[0]->scope_idx_;
PADDLE_ENFORCE_LT(in_scope_idx, local_scopes_.size(),
"The input(%s) is not in the local_scopes.",
in_var_handle[0]->name_);
auto in_var = local_scopes_[in_scope_idx]->FindVar(in_var_handle[0]->name_);
auto in_var =
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
Tensor *in_tensor = GetTensorFromVar(in_var);

for (auto *out : out_var_handles) {
auto &out_p = out->place_;
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);

auto out_scope_idx = out->scope_idx_;
PADDLE_ENFORCE_LT(out_scope_idx, local_scopes_.size(),
"%s is not in the local_scopes ", out->name_);

auto *s = local_scopes_[out_scope_idx];
auto out_var = s->FindVar(out->name_);
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
"The place of input and output should be the same.");
"Places must be all on CPU or all on CUDA.");

if (in_var->IsType<framework::SelectedRows>()) {
auto &in_sr = in_var->Get<framework::SelectedRows>();
Expand All @@ -109,24 +100,8 @@ void BroadcastOpHandle::RunImpl() {
}

Tensor *out_tensor = GetTensorFromVar(out_var);
if (platform::is_cpu_place(in_place)) {
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
out_tensor);
} else if (platform::is_gpu_place(in_place)) {
#ifdef PADDLE_WITH_CUDA
auto src_gpu_place = boost::get<platform::CUDAPlace>(in_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(out_p);
void *dst_ptr = out_tensor->mutable_data(out_p);
void *src_ptr = in_tensor->data<void>();
int64_t size = in_tensor->numel() * SizeOfType(in_tensor->type());
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<platform::CUDADeviceContext *>(dev_ctxes_[out_p])
->stream());
#else
PADDLE_THROW("CUDAPlace is not supported in CPU device.");
#endif
}
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
out_tensor);
}
}

Expand Down
9 changes: 3 additions & 6 deletions paddle/fluid/framework/details/gather_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,10 @@ void GatherOpHandle::RunImpl() {
auto in_handle = static_cast<VarHandle *>(in);
auto in_p = in_handle->place_;
in_places.push_back(in_p);
PADDLE_ENFORCE_LT(in_handle->scope_idx_, local_scopes_.size(),
"%s is not the the local_scopes ", in_handle->name_);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"The place of input should be the same.");
auto *s = local_scopes_[in_handle->scope_idx_];
auto in_var = s->FindVar(in_handle->name_);

"Places must be all on CPU or all on CUDA.");
auto in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>();

PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/tensor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/tensor_util.h"
#include <algorithm>
#include <limits>
#include <vector>

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -65,8 +67,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
Expand Down

0 comments on commit 384d6ee

Please sign in to comment.