Skip to content

Commit

Permalink
cpu gpu transform function (#7191)
Browse files Browse the repository at this point in the history
* add rename guard

* add device_data_transform

* add device_data_transform_test

* modify GetExpectedKernelType

* update operator.run

* support test test_label_semantic_roles

* optimize code

* optimize code

* rename GetActualKernelType to GetExpectedKernelType

* fix chunk_eval_op and device_data_transform_test

* add is_same_place to place

* optimize code, refine rename_guard

* refine rename guard, add GetKernelTypeForVar

* optimize code

* add some log

* rename guard

* use sub scope to create var

* fix compile

* add IsInitialized for Tensor

* add VarIsTensor

* fix op_registry_test

* test

* tmp disable priority

* restore switch_kernel.md

* code clean
  • Loading branch information
jacquesqiao authored Jan 8, 2018
1 parent 8814bec commit 0f353ab
Show file tree
Hide file tree
Showing 49 changed files with 429 additions and 200 deletions.
7 changes: 6 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope)

cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto)
cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor)

cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)

cc_library(attribute SRCS attribute.cc DEPS framework_proto)
Expand Down Expand Up @@ -77,3 +79,6 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operat
cc_test(init_test SRCS init_test.cc DEPS init)

cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)

nv_test(device_data_transform_test SRCS device_data_transform_test.cu
DEPS operator op_registry init math_function)
33 changes: 33 additions & 0 deletions paddle/framework/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ limitations under the License. */
#include <functional>

#include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"

namespace paddle {
Expand All @@ -25,6 +27,37 @@ DataTransformFnMap& DataTransformFnMap::Instance() {
return data_transform_map;
}

Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor) {
Tensor* out = nullptr;
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) {
out = DeviceTransform(input_tensor, expected_kernel_type.place_);
}
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
return out;
}

void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var) {
if (in_var.IsType<LoDTensor>()) {
auto& in_lod_tensor = in_var.Get<LoDTensor>();
auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) {
auto& in_selected_rows = in_var.Get<SelectedRows>();
auto* trans_selected_rows = out_var.GetMutable<SelectedRows>();
trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
} else {
PADDLE_THROW("unknown var type");
}
}

auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
DataLayout::kNHWC, LibraryType::kPlain);

Expand Down
8 changes: 8 additions & 0 deletions paddle/framework/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>

#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/math/math_function.h"
Expand Down Expand Up @@ -49,6 +50,13 @@ struct KernelTypePairHash {
}
};

Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor);

void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var);

template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
Expand Down
46 changes: 46 additions & 0 deletions paddle/framework/device_data_transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/device_data_transform.h"

namespace paddle {
namespace framework {

static const platform::DeviceContext* GetDeviceContext(
const platform::Place& src_place, const platform::Place& dst_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
return pool.Get(src_place);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
return pool.Get(dst_place);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}

Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place) {
VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place;
Tensor* out = new Tensor();
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
dev_ctx->Wait();
CopyFrom(in, dst_place, *dev_ctx, out);
dev_ctx->Wait();
return out;
}

} // namespace framework
} // namespace paddle
27 changes: 27 additions & 0 deletions paddle/framework/device_data_transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"

namespace paddle {
namespace framework {

Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place);

} // namespace framework
} // namespace paddle
168 changes: 168 additions & 0 deletions paddle/framework/device_data_transform_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "gtest/gtest.h"

#include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/device_context.h"

namespace paddle {
namespace framework {

template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};

class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input1 of test op");
AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
AddComment("This is test op");
}
};

class TestOpWithKernel : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
ctx.Input<Tensor>("input")->place());
}
}
};

template <typename DeviceContext, typename T>
class TestKernel : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl;

const Tensor* input = ctx.Input<Tensor>("input");

std::cout << "input place:" << input->place() << std::endl;
auto* output = ctx.Output<framework::LoDTensor>("output");
output->Resize(input->dims());
output->mutable_data<T>(ctx.GetPlace());

operators::TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
input, input, output, ctx.template device_context<DeviceContext>(),
AddFunctor<T>());
functor.Run();
}
};

} // namespace framework
} // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(
test_op, paddle::framework::TestOpWithKernel,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CUDADeviceContext, float>);

static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}

TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;

ASSERT_EQ(InitDevices({"CPU", "GPU:0"}), true);

paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;

// create an op to run on CPU
paddle::framework::proto::OpDesc cpu_op_desc;
cpu_op_desc.set_type("test_op");
BuildVar("input", {"IN1"}, cpu_op_desc.add_inputs());
BuildVar("output", {"OUT1"}, cpu_op_desc.add_outputs());

auto cpu_op = paddle::framework::OpRegistry::CreateOp(cpu_op_desc);
// prepare input
auto* in_t = scope.Var("IN1")->GetMutable<LoDTensor>();
auto* src_ptr = in_t->mutable_data<float>({2, 3}, CPUPlace());
for (int i = 0; i < 2 * 3; ++i) {
src_ptr[i] = static_cast<float>(i);
}

// get output
auto* output = scope.Var("OUT1");
cpu_op->Run(scope, cpu_place);

auto* output_ptr = output->Get<LoDTensor>().data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output_ptr[i], static_cast<float>(i) * 2);
}

// create an op to run on GPU
paddle::framework::proto::OpDesc gpu_op_desc;
gpu_op_desc.set_type("test_op");
BuildVar("input", {"OUT1"}, gpu_op_desc.add_inputs());
BuildVar("output", {"OUT2"}, gpu_op_desc.add_outputs());

auto attr = gpu_op_desc.mutable_attrs()->Add();
attr->set_name("use_gpu");
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(true);

auto gpu_op = paddle::framework::OpRegistry::CreateOp(gpu_op_desc);

paddle::platform::CUDAPlace cuda_place(0);
// get output
auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place);

// auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
DeviceContextPool& pool = DeviceContextPool::Instance();
auto dev_ctx = pool.Get(cuda_place);

paddle::framework::Tensor output_tensor;
CopyFrom(output2->Get<LoDTensor>(), paddle::platform::CPUPlace(), *dev_ctx,
&output_tensor);

dev_ctx->Wait();
float* output2_ptr = output_tensor.data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output2_ptr[i], static_cast<float>(i) * 4);
}
}
22 changes: 9 additions & 13 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}

framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}
Expand Down Expand Up @@ -282,16 +282,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}

framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}

framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};

Expand Down Expand Up @@ -371,6 +366,7 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
op_desc.set_type("op_with_multi_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);

// TODO(qiao) add priority back
// use all available kernels
paddle::framework::UseALL();
op->Run(scope, cuda_place);
Expand All @@ -380,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::UseCPU();
op->Run(scope, cpu_place);

EXPECT_EQ(op_test_value, -9);
EXPECT_EQ(op_test_value, -20);

// add cuda kernels
paddle::framework::UseCUDA();
op->Run(scope, cuda_place);

EXPECT_EQ(op_test_value, -10);
EXPECT_EQ(op_test_value, -30);

// use cudnn kernel
paddle::framework::UseCUDNN();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -20);
EXPECT_EQ(op_test_value, -40);
}
Loading

0 comments on commit 0f353ab

Please sign in to comment.