Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance GPU kernel of sequence erase op #7603

Merged
merged 5 commits into from
Jan 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/operators/sequence_erase_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
ops::SequenceEraseOpMaker);
REGISTER_OP_CPU_KERNEL(
sequence_erase,
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>);
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int64_t>);
79 changes: 33 additions & 46 deletions paddle/operators/sequence_erase_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,34 @@ using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor;

template <typename T>
__global__ void LabelErasedIdx(const T* in_dat, const int in_len,
const T* tokens, const int tokens_len,
int* num_erased) {
__global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why in_len use int64_t while tokens_len is size_t?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They have different data type.

const int* tokens, const size_t tokens_len,
size_t* num_erased) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_len) {
int erased = 0;
for (int i = 0; i < tokens_len; ++i) {
for (size_t i = 0; i < tokens_len; ++i) {
if (in_dat[index] == tokens[i]) {
erased = 1;
num_erased[index + 1] = 1;
break;
}
}
num_erased[index + 1] = erased;
if (index == 0) {
num_erased[0] = 0;
}
}
}

template <typename T>
__global__ void GetOutLod(const T* num_erased, const int* in_lod,
const int lod_len, int* out_lod0) {
__global__ void GetOutLod(const size_t* num_erased, const size_t* in_lod,
const size_t lod_len, size_t* out_lod0) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < lod_len) {
out_lod0[index] = in_lod[index] - num_erased[in_lod[index]];
}
}

template <typename T>
__global__ void SetOutput(const T* in_dat, const int in_len,
const int* num_erased, T* out_dat) {
__global__ void SetOutput(const T* in_dat, const int64_t in_len,
const size_t* num_erased, T* out_dat) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_len) {
if (in_dat[index] != in_dat[index + 1]) {
if (num_erased[index] == num_erased[index + 1]) {
out_dat[index - num_erased[index]] = in_dat[index];
}
}
Expand All @@ -72,53 +67,44 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<T>>("tokens");
auto tokens_len = tokens.size();
auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel();
auto in_dat = in->data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, We should registry an int64_t kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto lod0 = lod[0];

thrust::host_vector<T> host_tokens(tokens_len);
for (size_t i = 0; i < tokens.size(); ++i) {
host_tokens[i] = tokens[i];
}
thrust::device_vector<T> dev_tokens = host_tokens;
thrust::device_vector<int> num_erased(in_len + 1);

T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
// Copy tokens to GPU
thrust::device_vector<int> dev_tokens(tokens.begin(), tokens.end());
int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());

// Count number of elements to be erased
thrust::device_vector<size_t> num_erased(in_len + 1, 0);
size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
auto stream = ctx.cuda_device_context().stream();
LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr);
in_dat, in_len, dev_tokens_ptr, tokens.size(), num_erased_ptr);
thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(),
num_erased.begin() + 1);

// Calc LoD
// Copy LoD to GPU
auto lod0 = lod[0];
auto lod_len = lod0.size();
thrust::host_vector<int> host_lod(lod_len);
for (size_t i = 0; i < lod_len; ++i) {
host_lod[i] = lod0[i];
}
thrust::device_vector<int> dev_in_lod = host_lod;
thrust::device_vector<int> dev_out_lod(lod_len);
int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
thrust::device_vector<size_t> dev_in_lod = lod0;
size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());

// Calc output LoD
thrust::device_vector<size_t> dev_out_lod(lod_len);
size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
thrust::host_vector<int> host_out_lod = dev_out_lod;
std::vector<int> out_lod0(lod_len, 0);
for (size_t i = 0; i < lod_len; i++) {
out_lod0[i] = host_out_lod[i];
}

// Set LoD for output
thrust::host_vector<size_t> out_lod0 = dev_out_lod;
framework::LoD out_lod;
out_lod.push_back(out_lod0);
out->set_lod(out_lod);

// Set output
out->Resize({out_lod0.back(), 1});
out->Resize({static_cast<int64_t>(out_lod0.back()), 1});
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
Expand All @@ -130,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(sequence_erase,
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>);
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>,
paddle::operators::SequenceEraseOpCUDAKernel<int64_t>);
32 changes: 31 additions & 1 deletion python/paddle/v2/fluid/tests/test_sequence_erase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def sequence_erase(in_seq, lod0, tokens):
return np.array(out_seq).astype("int32"), new_lod0


class TestSequenceEraseOp(OpTest):
class TestSequenceEraseOpInt32(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
Expand All @@ -44,5 +44,35 @@ def test_check_output(self):
self.check_output()


class TestSequenceEraseOpInt64(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int64")
lod = [[0, 9, 13, 24, 30]]
tokens = [2, 3, 5]
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
self.attrs = {'tokens': tokens}
self.inputs = {'X': (in_seq, lod)}
self.outputs = {'Out': (out_seq, [new_lod0])}

def test_check_output(self):
self.check_output()


class TestSequenceEraseOpEmpty(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
lod = [[0, 9, 13, 24, 30]]
tokens = []
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
self.attrs = {'tokens': tokens}
self.inputs = {'X': (in_seq, lod)}
self.outputs = {'Out': (out_seq, [new_lod0])}

def test_check_output(self):
self.check_output()


if __name__ == '__main__':
unittest.main()