From aa5cd1a88d682364423043e358a31feafa07c855 Mon Sep 17 00:00:00 2001 From: qshuihu Date: Mon, 25 Jul 2022 17:58:35 +0800 Subject: [PATCH 1/4] add trade weight support remove template object get gpu pull push size --- paddle/fluid/framework/fleet/box_wrapper.cc | 146 +----- paddle/fluid/framework/fleet/box_wrapper.cu | 122 ++--- paddle/fluid/framework/fleet/box_wrapper.h | 24 +- .../fluid/framework/fleet/box_wrapper_impl.h | 70 +-- .../fused/fused_seqpool_cvm_tradew_op.cc | 222 ++++++++ .../fused/fused_seqpool_cvm_tradew_op.cu | 484 ++++++++++++++++++ .../fused/fused_seqpool_cvm_tradew_op.h | 50 ++ python/paddle/fluid/contrib/layers/nn.py | 59 +++ 8 files changed, 925 insertions(+), 252 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cc create mode 100644 paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu create mode 100644 paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.h diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index c8b60c018d0d6..7db27c5f43e87 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -18,7 +18,6 @@ #include #include #include - #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/platform/collective_helper.h" @@ -407,80 +406,10 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, const int skip_offset, bool expand_only) { -#define EMBEDX_CASE(i, ...) \ - case i: { \ - constexpr size_t EmbedxDim = i; \ - switch (expand_embed_dim_) { \ - __VA_ARGS__ \ - default: \ - PADDLE_THROW(platform::errors::InvalidArgument( \ - "Unsupport this expand embedding size [%d]", expand_embed_dim)); \ - } \ - } break - -#define PULLSPARSE_CASE(i, ...) \ - case i: { \ - constexpr size_t ExpandDim = i; \ - if (feature_type_ == static_cast(boxps::FEATURE_SHARE_EMBEDDING)) { \ - PullSparseCase< \ - boxps::FeaturePullValueGpuShareEmbedding>( \ - place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset, expand_only); \ - } else if (feature_type_ == static_cast(boxps::FEATURE_PCOC)) { \ - PullSparseCase>( \ - place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset, expand_only); \ - } else if (feature_type_ == static_cast(boxps::FEATURE_QUANT) || \ - feature_type_ == static_cast(boxps::FEATURE_SHOWCLK)) { \ - PullSparseCase>( \ - place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset, expand_only); \ - } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ - PullSparseCase>( \ - place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset, expand_only); \ - } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ - PullSparseCase>( \ - place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset, expand_only); \ - } else if (EmbedxDim == 0 && \ - feature_type_ == static_cast(boxps::FEATURE_ADAM)) { \ - PullSparseCase>( \ - place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset, expand_only); \ - } else { \ - PullSparseCase>( \ - place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset, expand_only); \ - } \ - } break - CheckEmbedSizeIsValid(hidden_size + skip_offset - cvm_offset_, expand_embed_dim); - switch (embedx_dim_) { - EMBEDX_CASE(0, PULLSPARSE_CASE(0); PULLSPARSE_CASE(255); - PULLSPARSE_CASE(767);); - EMBEDX_CASE(2, PULLSPARSE_CASE(0);); - EMBEDX_CASE(4, PULLSPARSE_CASE(0);); - EMBEDX_CASE(8, PULLSPARSE_CASE(0); PULLSPARSE_CASE(1); PULLSPARSE_CASE(2); - PULLSPARSE_CASE(3); PULLSPARSE_CASE(4); PULLSPARSE_CASE(5); - PULLSPARSE_CASE(6); PULLSPARSE_CASE(7); PULLSPARSE_CASE(8); - PULLSPARSE_CASE(64);); - EMBEDX_CASE(16, PULLSPARSE_CASE(0); PULLSPARSE_CASE(1); PULLSPARSE_CASE(2); - PULLSPARSE_CASE(3); PULLSPARSE_CASE(4); PULLSPARSE_CASE(5); - PULLSPARSE_CASE(6); PULLSPARSE_CASE(7); PULLSPARSE_CASE(8); - PULLSPARSE_CASE(64);); - EMBEDX_CASE(32, PULLSPARSE_CASE(0);); - EMBEDX_CASE(64, PULLSPARSE_CASE(0);); - EMBEDX_CASE(256, PULLSPARSE_CASE(0);); - EMBEDX_CASE(128, PULLSPARSE_CASE(0);); - EMBEDX_CASE(280, PULLSPARSE_CASE(0);); - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupport this embedding size [%d]", hidden_size - cvm_offset_)); - } -#undef PULLSPARSE_CASE -#undef EMBEDX_CASE + PullSparseCase(place, keys, values, slot_lengths, hidden_size, + expand_embed_dim, skip_offset, expand_only); } void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place, @@ -491,77 +420,10 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place, const int expand_embed_dim, const int batch_size, const int skip_offset, bool expand_only) { -#define EMBEDX_CASE(i, ...) \ - case i: { \ - constexpr size_t EmbedxDim = i; \ - switch (expand_embed_dim_) { \ - __VA_ARGS__ \ - default: \ - PADDLE_THROW(platform::errors::InvalidArgument( \ - "Unsupport this expand embedding size [%d]", expand_embed_dim)); \ - } \ - } break - -#define PUSHSPARSE_CASE(i, ...) \ - case i: { \ - constexpr size_t ExpandDim = i; \ - if (feature_type_ == static_cast(boxps::FEATURE_SHARE_EMBEDDING)) { \ - PushSparseGradCase< \ - boxps::FeaturePushValueGpuShareEmbedding>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset, expand_only); \ - } else if (feature_type_ == static_cast(boxps::FEATURE_PCOC)) { \ - PushSparseGradCase< \ - boxps::FeaturePushValueGpuPCOC>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset, expand_only); \ - } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ - PushSparseGradCase>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset, expand_only); \ - } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ - PushSparseGradCase< \ - boxps::FeaturePushValueGpuConv>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset, expand_only); \ - } else if (EmbedxDim == 0 && \ - feature_type_ == static_cast(boxps::FEATURE_ADAM)) { \ - PushSparseGradCase>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset, expand_only); \ - } else { \ - PushSparseGradCase>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset, expand_only); \ - } \ - } break - CheckEmbedSizeIsValid(hidden_size + skip_offset - cvm_offset_, expand_embed_dim); - switch (embedx_dim_) { - EMBEDX_CASE(0, PUSHSPARSE_CASE(0); PUSHSPARSE_CASE(255); - PUSHSPARSE_CASE(767);); - EMBEDX_CASE(2, PUSHSPARSE_CASE(0);); - EMBEDX_CASE(4, PUSHSPARSE_CASE(0);); - EMBEDX_CASE(8, PUSHSPARSE_CASE(0); PUSHSPARSE_CASE(1); PUSHSPARSE_CASE(2); - PUSHSPARSE_CASE(3); PUSHSPARSE_CASE(4); PUSHSPARSE_CASE(5); - PUSHSPARSE_CASE(6); PUSHSPARSE_CASE(7); PUSHSPARSE_CASE(8); - PUSHSPARSE_CASE(64);); - EMBEDX_CASE(16, PUSHSPARSE_CASE(0); PUSHSPARSE_CASE(1); PUSHSPARSE_CASE(2); - PUSHSPARSE_CASE(3); PUSHSPARSE_CASE(4); PUSHSPARSE_CASE(5); - PUSHSPARSE_CASE(6); PUSHSPARSE_CASE(7); PUSHSPARSE_CASE(8); - PUSHSPARSE_CASE(64);); - EMBEDX_CASE(32, PUSHSPARSE_CASE(0);); - EMBEDX_CASE(64, PUSHSPARSE_CASE(0);); - EMBEDX_CASE(256, PUSHSPARSE_CASE(0);); - EMBEDX_CASE(128, PUSHSPARSE_CASE(0);); - EMBEDX_CASE(280, PUSHSPARSE_CASE(0);); - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupport this embedding size [%d]", hidden_size - cvm_offset_)); - } -#undef PUSHSPARSE_CASE -#undef EMBEDX_CASE + PushSparseGradCase(place, keys, grad_values, slot_lengths, hidden_size, + expand_embed_dim, batch_size, skip_offset, expand_only); } void BasicAucCalculator::calculate_bucket_error() { diff --git a/paddle/fluid/framework/fleet/box_wrapper.cu b/paddle/fluid/framework/fleet/box_wrapper.cu index 417b08de4feb5..34d86b4c2964c 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cu +++ b/paddle/fluid/framework/fleet/box_wrapper.cu @@ -280,8 +280,8 @@ __global__ void PullCopyExpandNNCrossWithEmb( if (dest[x + slot_num] == 0) { return; } - int offset_2 = y * (embedx_dim + cvm_offset + expand_dim) - + cvm_offset + col; + int offset_2 = + y * (embedx_dim + cvm_offset + expand_dim) + cvm_offset + col; if (total_dims[idx] & 0x02) { *(dest[x + slot_num] + offset_2) = src_val.embedx[col] * scale; } else { @@ -291,8 +291,8 @@ __global__ void PullCopyExpandNNCrossWithEmb( if (dest[x + slot_num] == 0) { return; } - int offset = y * (embedx_dim + cvm_offset + expand_dim) - + cvm_offset + col; + int offset = + y * (embedx_dim + cvm_offset + expand_dim) + cvm_offset + col; if (total_dims[idx] & 0x02) { *(dest[x + slot_num] + offset) = src_val.embed_expand[col - embedx_dim] * scale; @@ -427,8 +427,8 @@ __global__ void PullDedupCopyExpandNNCrossWithEmb( if (dest[x + slot_num] == 0) { return; } - int offset_2 = y * (embedx_dim + cvm_offset + expand_dim) - + cvm_offset + col; + int offset_2 = + y * (embedx_dim + cvm_offset + expand_dim) + cvm_offset + col; if (total_dims[idx] & 0x02) { *(dest[x + slot_num] + offset_2) = src_val.embedx[col] * scale; } else { @@ -438,8 +438,8 @@ __global__ void PullDedupCopyExpandNNCrossWithEmb( if (dest[x + slot_num] == 0) { return; } - int offset = y * (embedx_dim + cvm_offset + expand_dim) - + cvm_offset + col; + int offset = + y * (embedx_dim + cvm_offset + expand_dim) + cvm_offset + col; if (total_dims[idx] & 0x02) { *(dest[x + slot_num] + offset) = src_val.embed_expand[col - embedx_dim] * scale; @@ -594,8 +594,10 @@ __global__ void PullDedupCopyExpandVariable( } // end kernel loop } //========================== end ================================== -__global__ void FillKey2Slot(const int total_len, const int64_t* slot_lens, - const int slot_num, int* key2slots) { +__global__ void CopyKeysKernel(const int total_len, uint64_t** src_keys, + uint64_t* dest_total_keys, + const int64_t* slot_lens, const int slot_num, + int* key2slots) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; int high = slot_num - 1; @@ -608,16 +610,8 @@ __global__ void FillKey2Slot(const int total_len, const int64_t* slot_lens, } } key2slots[i] = low; - } -} - -__global__ void CopyKeysKernel(const int total_len, uint64_t** src_keys, - uint64_t* dest_total_keys, - const int64_t* slot_lens, const int* key2slot) { - CUDA_KERNEL_LOOP(i, total_len) { - int x = key2slot[i]; - int y = i - slot_lens[x]; - dest_total_keys[i] = src_keys[x][y]; + int y = i - slot_lens[low]; + dest_total_keys[i] = src_keys[low][y]; } } @@ -941,11 +935,11 @@ __global__ void PushCopyExpandNNCrossWithEmb( dest_val.embedx_g[col] = 0; } } else { // expand - int offset = y * (embedx_dim + cvm_offset + expand_dim) - + cvm_offset + col; + int offset = + y * (embedx_dim + cvm_offset + expand_dim) + cvm_offset + col; if ((total_dims[idx] & 0x02) && src[x + slot_num] != 0) { dest_val.embed_expand_g[col - embedx_dim] = - *(src[x + slot_num] + offset) * -1. * bs; + *(src[x + slot_num] + offset) * -1. * bs; } else { dest_val.embed_expand_g[col - embedx_dim] = 0; } @@ -1080,9 +1074,9 @@ __global__ void PushMergeCopyExpandNNCrossWithEmb( } } dest_val.embedx_g[col] = val * -1. * bs; - } else { // expand - int offset = y * (embedx_dim + cvm_offset + expand_dim) - + cvm_offset + col; + } else { // expand + int offset = + y * (embedx_dim + cvm_offset + expand_dim) + cvm_offset + col; double val = 0.0; for (uint32_t j = 0; j < count; ++j) { const uint32_t& pos = d_sort_idx[start + j]; @@ -1394,15 +1388,13 @@ void FeaturePullCopyNNCross(cudaStream_t stream, uint64_t** gpu_keys, } template -void FeaturePullCopyNNCrossWithEmb(cudaStream_t stream, uint64_t** gpu_keys, - float** gpu_values, void* src, - const int hidden_size, const size_t embedx_dim, - const size_t expand_dim, const int total_length, - int* total_dims, const int64_t* slot_lens, - const int slot_num, const int* key2slot, - const float scale, const int cvm_offset, - const uint32_t* gpu_restore_idx, - const int skip_offset) { +void FeaturePullCopyNNCrossWithEmb( + cudaStream_t stream, uint64_t** gpu_keys, float** gpu_values, void* src, + const int hidden_size, const size_t embedx_dim, const size_t expand_dim, + const int total_length, int* total_dims, const int64_t* slot_lens, + const int slot_num, const int* key2slot, const float scale, + const int cvm_offset, const uint32_t* gpu_restore_idx, + const int skip_offset) { FeaturePullValueType* pull_values_gpu = reinterpret_cast(src); if (gpu_restore_idx != nullptr) { @@ -1485,8 +1477,7 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, const int slot_num, const int* key2slot, const int hidden_size, const int expand_embed_dim, const int64_t total_length, int* total_dims, - const int skip_offset, - bool expand_only=true, + const int skip_offset, bool expand_only = true, const uint32_t* gpu_restore_idx) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( @@ -1545,18 +1536,18 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, feature_type_ == static_cast(boxps::FEATURE_SHOWCLK)) { \ if (expand_only) { \ FeaturePullCopyNNCross< \ - boxps::FeaturePullValueGpuQuant>( \ - stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ - EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, slot_num, \ - key2slot, pull_embedx_scale_, cvm_offset, gpu_restore_idx, \ - skip_offset); \ + boxps::FeaturePullValueGpuQuant>( \ + stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ + EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, \ + slot_num, key2slot, pull_embedx_scale_, cvm_offset, \ + gpu_restore_idx, skip_offset); \ } else { \ FeaturePullCopyNNCrossWithEmb< \ - boxps::FeaturePullValueGpuQuant>( \ - stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ - EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, slot_num, \ - key2slot, pull_embedx_scale_, cvm_offset, gpu_restore_idx, \ - skip_offset); \ + boxps::FeaturePullValueGpuQuant>( \ + stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ + EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, \ + slot_num, key2slot, pull_embedx_scale_, cvm_offset, \ + gpu_restore_idx, skip_offset); \ } \ } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ FeaturePullCopyVariable< \ @@ -1589,6 +1580,12 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, (hidden_size - cvm_offset), total_length, total_dims, slot_lens, \ slot_num, key2slot, pull_embedx_scale_, cvm_offset, gpu_restore_idx, \ skip_offset); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_TRADEW)) { \ + FeaturePullCopy>( \ + stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ + (hidden_size - cvm_offset), total_length, total_dims, slot_lens, \ + slot_num, key2slot, pull_embedx_scale_, cvm_offset, gpu_restore_idx, \ + skip_offset); \ } \ } break @@ -1641,10 +1638,8 @@ void BoxWrapper::CopyKeys(const paddle::platform::Place& place, platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) ->stream(); - FillKey2Slot<<>>(total_len, slot_lens, - slot_num, key2slot); CopyKeysKernel<<>>( - total_len, origin_keys, total_keys, slot_lens, key2slot); + total_len, origin_keys, total_keys, slot_lens, slot_num, key2slot); cudaStreamSynchronize(stream); } @@ -1911,18 +1906,18 @@ void BoxWrapper::CopyForPush( if (feature_type_ == static_cast(boxps::FEATURE_PCOC)) { \ if (expand_only) { \ FeaturePushCopyNNCross< \ - boxps::FeaturePushValueGpuPCOC>( \ - stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ - ExpandDim, total_length, batch_size, d_slot_vector, total_dims, \ - slot_lens, slot_num, key2slot, cvm_offset, gpu_sort_idx, \ - gpu_sort_offset, gpu_sort_lens, skip_offset); \ + boxps::FeaturePushValueGpuPCOC>( \ + stream, total_grad_values_gpu, grad_values, hidden_size, \ + EmbedxDim, ExpandDim, total_length, batch_size, d_slot_vector, \ + total_dims, slot_lens, slot_num, key2slot, cvm_offset, \ + gpu_sort_idx, gpu_sort_offset, gpu_sort_lens, skip_offset); \ } else { \ FeaturePushCopyNNCrossWithEmb< \ - boxps::FeaturePushValueGpuPCOC>( \ - stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ - ExpandDim, total_length, batch_size, d_slot_vector, total_dims, \ - slot_lens, slot_num, key2slot, cvm_offset, gpu_sort_idx, \ - gpu_sort_offset, gpu_sort_lens, skip_offset); \ + boxps::FeaturePushValueGpuPCOC>( \ + stream, total_grad_values_gpu, grad_values, hidden_size, \ + EmbedxDim, ExpandDim, total_length, batch_size, d_slot_vector, \ + total_dims, slot_lens, slot_num, key2slot, cvm_offset, \ + gpu_sort_idx, gpu_sort_offset, gpu_sort_lens, skip_offset); \ } \ } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ FeaturePushCopyVariable< \ @@ -1958,6 +1953,13 @@ void BoxWrapper::CopyForPush( (hidden_size - cvm_offset), total_length, batch_size, d_slot_vector, \ total_dims, slot_lens, slot_num, key2slot, cvm_offset, gpu_sort_idx, \ gpu_sort_offset, gpu_sort_lens, skip_offset); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_TRADEW)) { \ + FeaturePushCopyShareEmbedding< \ + boxps::FeatureTradeWPushValueGpu>( \ + stream, total_grad_values_gpu, grad_values, hidden_size, \ + (hidden_size - cvm_offset), total_length, batch_size, d_slot_vector, \ + total_dims, slot_lens, slot_num, key2slot, cvm_offset, gpu_sort_idx, \ + gpu_sort_offset, gpu_sort_lens, skip_offset); \ } \ } break diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 6165f1e6e20da..9acc2cd885e1b 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -425,7 +425,6 @@ class BoxWrapper { void EndPass(bool need_save_delta); void SetTestMode(bool is_test) const; - template void PullSparseCase(const paddle::platform::Place& place, const std::vector& keys, const std::vector& values, @@ -433,7 +432,6 @@ class BoxWrapper { const int hidden_size, const int expand_embed_dim, const int skip_offset, bool expand_only); - template void PullSparseCaseGPU(const paddle::platform::Place& place, const std::vector& keys, const std::vector& values, @@ -441,7 +439,6 @@ class BoxWrapper { const int hidden_size, const int expand_embed_dim, const int skip_offset, bool expand_only); - template void PullSparseCaseCPU(const paddle::platform::Place& place, const std::vector& keys, const std::vector& values, @@ -456,7 +453,6 @@ class BoxWrapper { const int hidden_size, const int expand_embed_dim, const int skip_offset, bool expand_only); - template void PushSparseGradCase(const paddle::platform::Place& place, const std::vector& keys, const std::vector& grad_values, @@ -464,7 +460,6 @@ class BoxWrapper { const int hidden_size, const int expand_embed_dim, const int batch_size, const int skip_offset, bool expand_only); - template void PushSparseGradCaseGPU(const paddle::platform::Place& place, const std::vector& keys, const std::vector& grad_values, @@ -473,7 +468,6 @@ class BoxWrapper { const int batch_size, const int skip_offset, bool expand_only); - template void PushSparseGradCaseCPU(const paddle::platform::Place& place, const std::vector& keys, const std::vector& grad_values, @@ -592,10 +586,16 @@ class BoxWrapper { } else if (s_instance_->feature_type_ == static_cast(boxps::FEATURE_CONV)) { s_instance_->cvm_offset_ = 4; + } else if (s_instance_->feature_type_ == + static_cast(boxps::FEATURE_TRADEW)) { + // embed_w + n * tradew + s_instance_->cvm_offset_ = expand_embed_dim + 3; } else { s_instance_->cvm_offset_ = 3; } s_instance_->gpu_num_ = platform::GetCUDADeviceCount(); + // get feature offset info + s_instance_->GetFeatureOffsetInfo(); if (boxps::MPICluster::Ins().size() > 1) { data_shuffle_.reset(boxps::PaddleShuffler::New()); @@ -723,6 +723,11 @@ class BoxWrapper { } return device_id; } + // get feature offset info + void GetFeatureOffsetInfo(void) { + feature_pull_size_ = boxps_ptr_->GetFeaturePullSize(pull_info_); + feature_push_size_ = boxps_ptr_->GetFeaturePushSize(push_info_); + } private: static cudaStream_t stream_list_[MAX_GPU_NUM]; @@ -741,6 +746,11 @@ class BoxWrapper { int feature_type_ = 0; float pull_embedx_scale_ = 1.0; int cvm_offset_ = 3; + // Need to refactor wrapper.cu + size_t feature_pull_size_ = 0; + size_t feature_push_size_ = 0; + boxps::FeaturePullOffset pull_info_; + boxps::FeaturePushOffset push_info_; // Metric Related int phase_ = 1; @@ -1166,5 +1176,3 @@ class BoxHelper { } // end namespace framework } // end namespace paddle - -#include "paddle/fluid/framework/fleet/box_wrapper_impl.h" diff --git a/paddle/fluid/framework/fleet/box_wrapper_impl.h b/paddle/fluid/framework/fleet/box_wrapper_impl.h index 18459eca741bb..da16e3848b17d 100644 --- a/paddle/fluid/framework/fleet/box_wrapper_impl.h +++ b/paddle/fluid/framework/fleet/box_wrapper_impl.h @@ -21,7 +21,6 @@ DECLARE_bool(enable_pullpush_dedup_keys); namespace paddle { namespace framework { -template void BoxWrapper::PullSparseCaseGPU(const paddle::platform::Place& place, const std::vector& keys, const std::vector& values, @@ -113,10 +112,9 @@ void BoxWrapper::PullSparseCaseGPU(const paddle::platform::Place& place, "dedup keys need more than zero failed in BoxPS.")); dev.dedup_key_length = dedup_size; - int64_t total_bytes = dedup_size * sizeof(FEATURE_VALUE_GPU_TYPE); - FEATURE_VALUE_GPU_TYPE* total_values_gpu = - dev.pull_push_tensor.mutable_data(total_bytes, - place); + int64_t total_bytes = dedup_size * feature_pull_size_; + void* total_values_gpu = + dev.pull_push_tensor.mutable_data(total_bytes, place); pull_boxps_timer.Resume(); @@ -137,10 +135,9 @@ void BoxWrapper::PullSparseCaseGPU(const paddle::platform::Place& place, total_length, total_dims, skip_offset, expand_only, d_restore_idx); } else { - int64_t total_bytes = total_length * sizeof(FEATURE_VALUE_GPU_TYPE); - FEATURE_VALUE_GPU_TYPE* total_values_gpu = - dev.pull_push_tensor.mutable_data(total_bytes, - place); + int64_t total_bytes = total_length * feature_pull_size_; + void* total_values_gpu = + dev.pull_push_tensor.mutable_data(total_bytes, place); pull_boxps_timer.Resume(); @@ -163,7 +160,6 @@ void BoxWrapper::PullSparseCaseGPU(const paddle::platform::Place& place, all_timer.Pause(); } -template void BoxWrapper::PullSparseCaseCPU(const paddle::platform::Place& place, const std::vector& keys, const std::vector& values, @@ -227,10 +223,9 @@ void BoxWrapper::PullSparseCaseCPU(const paddle::platform::Place& place, "dedup keys need more than zero failed in BoxPS.")); dev.dedup_key_length = dedup_size; - int64_t total_bytes = dedup_size * sizeof(FEATURE_VALUE_GPU_TYPE); - FEATURE_VALUE_GPU_TYPE* total_values_gpu = - dev.pull_push_tensor.mutable_data(total_bytes, - place); + int64_t total_bytes = dedup_size * feature_pull_size_; + void* total_values_gpu = + dev.pull_push_tensor.mutable_data(total_bytes, place); pull_boxps_timer.Resume(); @@ -249,7 +244,6 @@ void BoxWrapper::PullSparseCaseCPU(const paddle::platform::Place& place, all_timer.Pause(); } -template void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, const std::vector& keys, const std::vector& values, @@ -258,17 +252,14 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, const int expand_embed_dim, const int skip_offset, bool expand_only) { if (!platform::is_gpu_place(place)) { - PullSparseCaseCPU(place, keys, values, slot_lengths, - hidden_size, expand_embed_dim, - skip_offset, expand_only); + PullSparseCaseCPU(place, keys, values, slot_lengths, hidden_size, + expand_embed_dim, skip_offset, expand_only); } else { - PullSparseCaseGPU(place, keys, values, slot_lengths, - hidden_size, expand_embed_dim, - skip_offset, expand_only); + PullSparseCaseGPU(place, keys, values, slot_lengths, hidden_size, + expand_embed_dim, skip_offset, expand_only); } } -template void BoxWrapper::PushSparseGradCaseGPU( const paddle::platform::Place& place, const std::vector& keys, @@ -319,10 +310,9 @@ void BoxWrapper::PushSparseGradCaseGPU( uint64_t* d_merged_keys = &total_keys[total_length]; int64_t dedup_size = dev.dedup_key_length; - int64_t total_bytes = dedup_size * sizeof(FeaturePushValueGpuType); - FeaturePushValueGpuType* total_grad_values_gpu = - dev.pull_push_tensor.mutable_data(total_bytes, - place); + int64_t total_bytes = dedup_size * feature_push_size_; + void* total_grad_values_gpu = + dev.pull_push_tensor.mutable_data(total_bytes, place); this->CopyForPush(place, gpu_values, total_grad_values_gpu, d_slot_vector, slot_lens, slot_num, hidden_size, expand_embed_dim, dedup_size, batch_size, total_dims, key2slot, skip_offset, @@ -336,10 +326,9 @@ void BoxWrapper::PushSparseGradCaseGPU( "PushSparseGPU failed in BoxPS.")); push_boxps_timer.Pause(); } else { - int64_t total_bytes = total_length * sizeof(FeaturePushValueGpuType); - FeaturePushValueGpuType* total_grad_values_gpu = - dev.pull_push_tensor.mutable_data(total_bytes, - place); + int64_t total_bytes = total_length * feature_push_size_; + void* total_grad_values_gpu = + dev.pull_push_tensor.mutable_data(total_bytes, place); this->CopyForPush(place, gpu_values, total_grad_values_gpu, d_slot_vector, slot_lens, slot_num, hidden_size, expand_embed_dim, total_length, batch_size, total_dims, key2slot, @@ -356,7 +345,6 @@ void BoxWrapper::PushSparseGradCaseGPU( all_timer.Pause(); } -template void BoxWrapper::PushSparseGradCaseCPU( const paddle::platform::Place& place, const std::vector& keys, @@ -391,10 +379,9 @@ void BoxWrapper::PushSparseGradCaseCPU( uint64_t* d_merged_keys = &total_keys[total_length]; int64_t dedup_size = dev.dedup_key_length; - int64_t total_bytes = dedup_size * sizeof(FeaturePushValueGpuType); - FeaturePushValueGpuType* total_grad_values_gpu = - dev.pull_push_tensor.mutable_data(total_bytes, - place); + int64_t total_bytes = dedup_size * feature_push_size_; + void* total_grad_values_gpu = + dev.pull_push_tensor.mutable_data(total_bytes, place); this->CopyForPushCPU(place, grad_values, total_grad_values_gpu, slot_vector_.data(), slot_lens, slot_num, hidden_size, expand_embed_dim, dedup_size, batch_size, total_dims, @@ -412,7 +399,6 @@ void BoxWrapper::PushSparseGradCaseCPU( all_timer.Pause(); } -template void BoxWrapper::PushSparseGradCase( const paddle::platform::Place& place, const std::vector& keys, @@ -421,13 +407,13 @@ void BoxWrapper::PushSparseGradCase( const int expand_embed_dim, const int batch_size, const int skip_offset, bool expand_only) { if (!platform::is_gpu_place(place)) { - PushSparseGradCaseCPU( - place, keys, grad_values, slot_lengths, hidden_size, expand_embed_dim, - batch_size, skip_offset, expand_only); + PushSparseGradCaseCPU(place, keys, grad_values, slot_lengths, hidden_size, + expand_embed_dim, batch_size, skip_offset, + expand_only); } else { - PushSparseGradCaseGPU( - place, keys, grad_values, slot_lengths, hidden_size, expand_embed_dim, - batch_size, skip_offset, expand_only); + PushSparseGradCaseGPU(place, keys, grad_values, slot_lengths, hidden_size, + expand_embed_dim, batch_size, skip_offset, + expand_only); } } diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cc new file mode 100644 index 0000000000000..f8be98b5c4c03 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cc @@ -0,0 +1,222 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.h" +#include +namespace paddle { +namespace operators { + +class FusedSeqpoolCVMTradeWOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, + "Inputs(X) of FusedSeqpoolCVMOp should not be empty."); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, + "Outputs(Out) of FusedSeqpoolCVMOp should not be empty."); + + auto cvm_dims = ctx->GetInputDim("CVM"); + PADDLE_ENFORCE_EQ( + cvm_dims.size(), 2UL, + platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); + PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument( + "The 2nd dimension of " + "Input(CVM) should be 2.")); + + auto ins_dims = ctx->GetInputsDim("X"); + const int cvm_offset = ctx->Attrs().Get("cvm_offset"); + const size_t num_inputs = ins_dims.size(); + std::vector outs_dims; + outs_dims.resize(num_inputs); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + const int trade_num = ctx->Attrs().Get("trade_num"); + + PADDLE_ENFORCE_GT(num_inputs, 0UL, + platform::errors::InvalidArgument( + "Input tensors count should be greater than 0, " + "but received value is %d.", + num_inputs)); + + // The output height should be confirmed in Compute, + // since input lod is not accessible here. + PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2, + platform::errors::InvalidArgument( + "The dims size of first input should be equal to 2, " + "but received value is %d.", + ins_dims[0].size())); + + for (size_t i = 0; i < num_inputs; ++i) { + const auto dims = ins_dims[i]; + int rank = dims.size(); + if (use_cvm) { + PADDLE_ENFORCE_GT( + dims[rank - 1], 2, + "Shape error in %lu id, the last dimension(embedding) of the " + "'X' tensor must be larger than 2.", + i); + } + // input lod is not accessible here + std::vector out_dim; + if (use_cvm) { + out_dim = {-1, dims[rank - 1] - trade_num}; + } else { + out_dim = {-1, dims[rank - 1] - cvm_offset - trade_num}; + } + outs_dims[i] = framework::make_ddim(out_dim); + } + ctx->SetOutputsDim("Out", outs_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.device_context()); + } +}; + +class FusedSeqpoolCVMTradeWOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(vector) The input tensors of" + " operator.") + .AsDuplicable(); + AddInput("CVM", + "(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch " + "size, 2 is show and click."); + AddOutput("Out", + "(vector) The output of Op does not contain LoD " + "information.") + .AsDuplicable(); + AddAttr("pooltype", + "(string, default 'SUM') the pooling pooltype of " + "SequencePoolOp, only support SUM now.") + .SetDefault("SUM") + .InEnum({"SUM"}); + AddAttr("pad_value", + "(float, default 0.0) The value to pad for empty sequence.") + .SetDefault(0.0); + AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddAttr("cvm_offset", "(int, default 2)").SetDefault(2); + AddAttr("trade_id", "(int, default -1)").SetDefault(-1); + AddAttr("trade_num", "(int, default 2)").SetDefault(2); + + AddComment(R"DOC( +Fuse multiple pairs of Sequence Pool and CVM Operator. + +)DOC"); + } +}; + +class FusedSeqpoolCVMTradeWGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto og_dims = ctx->GetInputsDim(framework::GradVarName("Out")); + auto x_dims = ctx->GetInputsDim("X"); + auto cvm_dims = ctx->GetInputDim("CVM"); + const int cvm_offset = ctx->Attrs().Get("cvm_offset"); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + const int trade_num = ctx->Attrs().Get("trade_num"); + + PADDLE_ENFORCE_EQ( + cvm_dims.size(), 2, + platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); + + for (size_t i = 0; i < og_dims.size(); i++) { + PADDLE_ENFORCE_EQ( + og_dims[i].size(), x_dims[i].size(), + platform::errors::InvalidArgument( + "The rank of output grad must equal to Input(X). But " + "received: input rank %u, input shape [%s].", + og_dims[i].size(), og_dims[i])); + if (use_cvm) { + auto o_dim = og_dims[i][og_dims[i].size() - 1]; + PADDLE_ENFORCE_EQ( + o_dim + trade_num, x_dims[i][og_dims[i].size() - 1], + platform::errors::InvalidArgument( + "The dimension mismatch between Input(OUT@GRAD) and " + "Input(X). Received Input(OUT@GRAD): input rank %u, " + "input shape [%s]; received Input(X): input rank %u, " + "input shape [%s].", + og_dims[i].size(), og_dims[i], x_dims[i].size(), x_dims[i])); + } else { + PADDLE_ENFORCE_EQ( + og_dims[i][og_dims[i].size() - 1], + x_dims[i][og_dims[i].size() - 1] - cvm_offset - trade_num, + platform::errors::InvalidArgument( + "The dimension mismatch between Input(OUT@GRAD) and " + "Input(X). Received Input(OUT@GRAD): input rank %u, " + "input shape [%s]; received Input(X): input rank %u, " + "input shape [%s].", + og_dims[i].size(), og_dims[i], x_dims[i].size(), x_dims[i])); + } + } + for (size_t i = 0; i < x_dims.size(); ++i) { + ctx->ShareLoD("X", framework::GradVarName("X"), i, i); + ctx->ShareDim("X", framework::GradVarName("X"), i, i); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class FusedSeqpoolCVMTradeWGradOpMaker + : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op_desc_ptr) const override { + op_desc_ptr->SetType("fused_seqpool_cvm_tradew_grad"); + op_desc_ptr->SetInput("X", this->Input("X")); + op_desc_ptr->SetInput("CVM", this->Input("CVM")); + + op_desc_ptr->SetInput(framework::GradVarName("Out"), + this->OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), + this->InputGrad("X", false)); + op_desc_ptr->SetOutput(framework::GradVarName("CVM"), + this->InputGrad("CVM")); + op_desc_ptr->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR( + fused_seqpool_cvm_tradew, ops::FusedSeqpoolCVMTradeWOp, + ops::FusedSeqpoolCVMTradeWOpMaker, + ops::FusedSeqpoolCVMTradeWGradOpMaker, + ops::FusedSeqpoolCVMTradeWGradOpMaker); +REGISTER_OPERATOR(fused_seqpool_cvm_tradew_grad, + ops::FusedSeqpoolCVMTradeWGradOp) + +REGISTER_OP_CPU_KERNEL(fused_seqpool_cvm_tradew, + ops::FusedSeqpoolCVMTradeWOpCPUKernel) +REGISTER_OP_CPU_KERNEL(fused_seqpool_cvm_tradew_grad, + ops::FusedSeqpoolCVMTradeWGradOpCPUKernel) diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu new file mode 100644 index 0000000000000..518c191dc4eea --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu @@ -0,0 +1,484 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +#define GET_BLOCK(N) \ + ((N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS) + +#define CUDA_KERNEL_LOOP(i, n) \ + for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +// normal +template +__global__ void FusedSeqpoolTradeWKernelNormal( + const size_t N, T **input_values, T **seqpool_output_values, + size_t **lods_values, const int batch_size, const int hidden_size, + const int embedding_size, const float pad_value, const int cvm_offset, + const int trade_num) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + + double val = pad_value; + if (offset < cvm_offset) { + for (auto k = start; k < end; ++k) { + val += *(input_values[x] + k * hidden_size + offset); + } + } else { + for (auto k = start; k < end; ++k) { + val += *(input_values[x] + k * hidden_size + trade_num + offset); + } + } + *(seqpool_output_values[x] + y * embedding_size + offset) = val; + } +} + +// normal +template +__global__ void FusedSeqpoolTradeWKernelWithTradeId( + const size_t N, T **input_values, T **seqpool_output_values, + size_t **lods_values, const int batch_size, const int hidden_size, + const int embedding_size, const float pad_value, const int cvm_offset, + const int trade_id, const int trade_num) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + + double val = pad_value; + if (offset < cvm_offset) { + for (auto k = start; k < end; ++k) { + val += *(input_values[x] + k * hidden_size + offset); + } + } else { + for (auto k = start; k < end; ++k) { + val += (*(input_values[x] + k * hidden_size + trade_num + offset)) * + (*(input_values[x] + k * hidden_size + cvm_offset + trade_id)); + } + } + *(seqpool_output_values[x] + y * embedding_size + offset) = val; + } +} + +// join need show click input +template +__global__ void FusedCVMTradeWKernelWithCVM(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + if (offset == 0) { // show + *(output_values[x] + y * embedding_size) = + log(*(seqpool_output_values[x] + y * embedding_size) + 1); + } else if (offset == 1) { // click + *(output_values[x] + y * embedding_size + offset) = + log(*(seqpool_output_values[x] + y * embedding_size + 1) + 1) - + log(*(seqpool_output_values[x] + y * embedding_size) + 1); + } else { + *(output_values[x] + y * embedding_size + offset) = + *(seqpool_output_values[x] + y * embedding_size + offset); + } + } +} +// update not need show click input +template +__global__ void FusedCVMTradeWKernelNoCVM(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int no_cvm_embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / no_cvm_embedding_size; + int offset = i % no_cvm_embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + // no cvm + *(output_values[x] + y * no_cvm_embedding_size + offset) = + *(seqpool_output_values[x] + y * (no_cvm_embedding_size + cvm_offset) + + offset + cvm_offset); + } +} + +template +inline void FusedSeqpoolCVMTradeW(const paddle::platform::Place &place, + const std::vector &input_data, + const std::vector &output_data, + const std::vector &seqpool_output_data, + std::vector lods, + const int batch_size, const int slot_num, + const int embedding_size, + const float padding_value, const bool use_cvm, + const int cvm_offset, const int trade_id, + const int trade_num) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + + size_t total_ptr_len = input_data.size() + output_data.size() + + seqpool_output_data.size() + lods.size(); + auto temp_ptr = memory::AllocShared(place, total_ptr_len * sizeof(void *)); + void *ptr = temp_ptr->ptr(); + + T **gpu_input_values = reinterpret_cast(temp_ptr->ptr()); + cudaMemcpyAsync(gpu_input_values, input_data.data(), + input_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + T **gpu_output_values = + reinterpret_cast(&gpu_input_values[input_data.size()]); + cudaMemcpyAsync(gpu_output_values, output_data.data(), + output_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + T **gpu_seqpool_output_values = + reinterpret_cast(&gpu_output_values[output_data.size()]); + cudaMemcpyAsync(gpu_seqpool_output_values, seqpool_output_data.data(), + seqpool_output_data.size() * sizeof(T *), + cudaMemcpyHostToDevice, stream); + size_t **lods_values = reinterpret_cast( + &gpu_seqpool_output_values[seqpool_output_data.size()]); + cudaMemcpyAsync(lods_values, lods.data(), lods.size() * sizeof(size_t *), + cudaMemcpyHostToDevice, stream); + + size_t N = static_cast(batch_size * slot_num * embedding_size); + // + if (trade_id >= 0) { + FusedSeqpoolTradeWKernelWithTradeId<<>>( + N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, + embedding_size + trade_num, embedding_size, padding_value, cvm_offset, + trade_id, trade_num); + } else { + FusedSeqpoolTradeWKernelNormal<<>>( + N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, + embedding_size + trade_num, embedding_size, padding_value, cvm_offset, + trade_num); + } + // second log + if (use_cvm) { + FusedCVMTradeWKernelWithCVM<<>>( + N, gpu_output_values, gpu_seqpool_output_values, batch_size, + embedding_size, cvm_offset); + } else { + // not need show click input + N = static_cast(batch_size * slot_num * + (embedding_size - cvm_offset)); + FusedCVMTradeWKernelNoCVM<<>>( + N, gpu_output_values, gpu_seqpool_output_values, batch_size, + (embedding_size - cvm_offset), cvm_offset); + } +} + +template +class FusedSeqpoolCVMTradeWCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto inputs = ctx.MultiInput("X"); + auto outputs = ctx.MultiOutput("Out"); + + const auto slot_size = inputs.size(); + std::vector input_data(slot_size); + std::vector lods_data(slot_size); + std::vector output_data(slot_size); + + std::vector seqpool_outputs(slot_size); + std::vector seqpool_output_data(slot_size); + + auto padding_value = ctx.Attr("pad_value"); + auto use_cvm = ctx.Attr("use_cvm"); + const int cvm_offset = ctx.Attr("cvm_offset"); + const int trade_id = ctx.Attr("trade_id"); + const int trade_num = ctx.Attr("trade_num"); + + PADDLE_ENFORCE_GE(inputs[0]->dims()[0], 0, "batch ins zero"); + int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0] - trade_num; + PADDLE_ENFORCE_GE(embedding_size, 0, "embedx size is less trade num"); + int batch_size = -1; + for (size_t i = 0; i < slot_size; ++i) { + const auto *input = inputs[i]; + + auto lod = input->lod(); + auto lod_level = lod.size(); + + int cur_batch = lod[lod_level - 1].size() - 1; + if (batch_size == -1) { + batch_size = cur_batch; + } else { + CHECK(batch_size == cur_batch) << "batch: " << batch_size + << ", current: " << cur_batch; + } + input_data[i] = reinterpret_cast(input->data()); + auto *output = outputs[i]; + if (use_cvm) { + output->Resize({batch_size, embedding_size}); + } else { + output->Resize({batch_size, embedding_size - cvm_offset}); + } + output_data[i] = + reinterpret_cast(output->mutable_data(ctx.GetPlace())); + lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); + + seqpool_output_data[i] = + reinterpret_cast(seqpool_outputs[i].mutable_data( + {batch_size, embedding_size}, ctx.GetPlace())); + } + FusedSeqpoolCVMTradeW(ctx.GetPlace(), input_data, output_data, + seqpool_output_data, lods_data, batch_size, slot_size, + embedding_size, padding_value, use_cvm, cvm_offset, + trade_id, trade_num); + } +}; +// join grad +template +__global__ void FusedSeqpoolCVMTradeWGradKernelNoTradeId( + const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, + size_t **lods_values, const int batch_size, const int hidden_num, + const int embedding_size, const int cvm_offset, const int trade_num, + const int skip_off) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / hidden_num; + int offset = i % hidden_num; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + T val = 0.0; + if (offset < cvm_offset) { + val = *(cvm_values[x] + y * cvm_offset + offset); + } else if (offset >= cvm_offset + trade_num) { + val = *(out_grads_values[x] + y * embedding_size + offset - skip_off); + } + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * hidden_num + offset) = val; + } + } +} +// join grad +template +__global__ void FusedSeqpoolCVMTradeWGradKernel( + const size_t N, T **out_grads_values, T **input_values, T **in_grads_values, + T **cvm_values, size_t **lods_values, const int batch_size, + const int hidden_num, const int embedding_size, const int cvm_offset, + const int trade_id, const int trade_num, const int skip_off, + const int embedx_off) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / hidden_num; + int offset = i % hidden_num; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + if (offset < cvm_offset) { + T &val = *(cvm_values[x] + y * cvm_offset + offset); + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * hidden_num + offset) = val; + } + } else if (offset < cvm_offset + trade_num) { + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + if (trade_id == offset - cvm_offset) { + double sum_val = 0.0; + for (auto k = start; k < end; ++k) { + sum_val = 0.0; + T *in_ptr = input_values[x] + k * hidden_num + cvm_offset + trade_num; + T *g_ptr = out_grads_values[x] + y * embedding_size + embedx_off; + for (int j = 0; j < (embedding_size - embedx_off); ++j) { + sum_val += g_ptr[j] * in_ptr[j]; + } + *(in_grads_values[x] + k * hidden_num + offset) = sum_val; + } + } else { + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * hidden_num + offset) = 0.0; + } + } + } else { + T &val = *(out_grads_values[x] + y * embedding_size + offset - skip_off); + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * hidden_num + offset) = + val * (*(input_values[x] + k * hidden_num + cvm_offset + trade_id)); + } + } + } +} +template +inline void FusedSeqpoolCVMTradeWGrad( + const paddle::platform::Place &place, + const std::vector &out_grads_data, + const std::vector &input_data, + const std::vector &in_grads_data, + const std::vector &cvm_data, + const std::vector &lods, const int batch_size, + const int slot_num, const int embedding_size, const bool use_cvm, + const int cvm_offset, const int trade_id, const int trade_num) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + size_t total_ptr_len = input_data.size() + out_grads_data.size() + + in_grads_data.size() + cvm_data.size() + lods.size(); + auto temp_ptr = memory::AllocShared(place, total_ptr_len * sizeof(void *)); + T **gpu_out_grads_values = reinterpret_cast(temp_ptr->ptr()); + cudaMemcpyAsync(gpu_out_grads_values, out_grads_data.data(), + out_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + T **gpu_in_values = + reinterpret_cast(&gpu_out_grads_values[out_grads_data.size()]); + cudaMemcpyAsync(gpu_in_values, input_data.data(), + input_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + T **gpu_in_grads_values = + reinterpret_cast(&gpu_in_values[input_data.size()]); + cudaMemcpyAsync(gpu_in_grads_values, in_grads_data.data(), + in_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + T **gpu_cvm_values = + reinterpret_cast(&gpu_in_grads_values[in_grads_data.size()]); + cudaMemcpyAsync(gpu_cvm_values, cvm_data.data(), + cvm_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + size_t **lods_values = + reinterpret_cast(&gpu_cvm_values[cvm_data.size()]); + cudaMemcpyAsync(lods_values, lods.data(), lods.size() * sizeof(size_t *), + cudaMemcpyHostToDevice, stream); + + int hidden_num = embedding_size + trade_num; + size_t N = static_cast(batch_size * slot_num * hidden_num); + if (use_cvm) { + // join grad + if (trade_id >= 0) { + FusedSeqpoolCVMTradeWGradKernel<<>>( + N, gpu_out_grads_values, gpu_in_values, gpu_in_grads_values, + gpu_cvm_values, lods_values, batch_size, hidden_num, embedding_size, + cvm_offset, trade_id, trade_num, trade_num, cvm_offset); + } else { + FusedSeqpoolCVMTradeWGradKernelNoTradeId<<< + GET_BLOCK(N), PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, hidden_num, embedding_size, cvm_offset, + trade_num, trade_num); + } + } else { + // update grad + if (trade_id >= 0) { + FusedSeqpoolCVMTradeWGradKernel<<>>( + N, gpu_out_grads_values, gpu_in_values, gpu_in_grads_values, + gpu_cvm_values, lods_values, batch_size, hidden_num, + embedding_size - cvm_offset, cvm_offset, trade_id, trade_num, + trade_num + cvm_offset, 0); + } else { + FusedSeqpoolCVMTradeWGradKernelNoTradeId<<< + GET_BLOCK(N), PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, hidden_num, embedding_size - cvm_offset, + cvm_offset, trade_num, trade_num + cvm_offset); + } + } +} + +template +class FusedSeqpoolCVMTradeWGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto out_grads = ctx.MultiInput(framework::GradVarName("Out")); + auto in_grads = ctx.MultiOutput(framework::GradVarName("X")); + auto *cvm = ctx.Input("CVM"); + auto inputs = ctx.MultiInput("X"); + + std::string pooltype = ctx.Attr("pooltype"); + auto use_cvm = ctx.Attr("use_cvm"); + const int cvm_offset = ctx.Attr("cvm_offset"); + const int trade_id = ctx.Attr("trade_id"); + const int trade_num = ctx.Attr("trade_num"); + + const auto slot_size = in_grads.size(); + std::vector out_grads_data(slot_size); + std::vector in_grads_data(slot_size); + std::vector cvm_data(slot_size); + std::vector lods_data(slot_size); + std::vector input_data(slot_size); + + int embedding_size = + in_grads[0]->numel() / in_grads[0]->dims()[0] - trade_num; + int batch_size = -1; + for (size_t i = 0; i < slot_size; ++i) { + auto *in_grad = in_grads[i]; + + auto lod = in_grad->lod(); + auto lod_level = lod.size(); + int cur_batch = lod[lod_level - 1].size() - 1; + if (batch_size == -1) { + batch_size = cur_batch; + } else { + CHECK(batch_size == cur_batch) << "batch: " << batch_size + << ", current: " << cur_batch; + } + input_data[i] = reinterpret_cast(inputs[i]->data()); + auto *out_grad = out_grads[i]; + out_grads_data[i] = reinterpret_cast(out_grad->data()); + + in_grads_data[i] = + reinterpret_cast(in_grad->mutable_data(ctx.GetPlace())); + lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); + cvm_data[i] = reinterpret_cast(cvm->data()); + } + FusedSeqpoolCVMTradeWGrad(ctx.GetPlace(), out_grads_data, input_data, + in_grads_data, cvm_data, lods_data, batch_size, + slot_size, embedding_size, use_cvm, cvm_offset, + trade_id, trade_num); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_tradew, + ops::FusedSeqpoolCVMTradeWCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_tradew_grad, + ops::FusedSeqpoolCVMTradeWGradCUDAKernel); diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.h b/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.h new file mode 100644 index 0000000000000..4aaacfe2eb2a2 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +template +class FusedSeqpoolCVMTradeWOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "Unimplemented CPU kernel for FusedSeqpoolCVMTradeWOp, only support " + "GPU " + "now."); + } +}; + +template +class FusedSeqpoolCVMTradeWGradOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "Unimplemented CPU kernel for FusedSeqpoolCVMTradeWGradOp, only " + "support GPU " + "now."); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 0396131f28a28..6cc2fc6acaf2d 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -77,6 +77,7 @@ 'fused_seqpool_cvm', 'fused_seqpool_cvm_with_conv', 'fused_seqpool_cvm_with_diff_thres', + 'fused_seqpool_cvm_tradew', 'cross_norm_layer_hadamard', 'fused_seqpool_cvm_with_pcoc', 'scaled_fc', @@ -1860,6 +1861,64 @@ def fused_seqpool_cvm_with_pcoc(input, return outs +def fused_seqpool_cvm_tradew(input, + pool_type, + cvm, + pad_value=0.0, + use_cvm=True, + cvm_offset=2, + trade_id=-1, + trade_num=2): + """ + **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. + :attr:`input`. + Args: + input(Variable|list of Variable): Input is List of LoDTensor. + pool_type(str): pooling type, only support SUM pooling now. + cvm(Variable): cvm Variable. + pad_value(float): padding value of sequence pool. + use_cvm(bool): use cvm or not. + Returns: + Variable|list of Variable: The tensor variable storing sequence pool and cvm + of input. + """ + helper = LayerHelper('fused_seqpool_cvm_tradew', **locals()) + + if pool_type.upper() != 'SUM': + raise ValueError( + "fused_seqpool_cvm only support SUM pooling now, and your type is: " + + pool_type) + + check_type(input, 'input', list, 'fused_seqpool_cvm_tradew') + if isinstance(input, list): + for _input in input: + check_variable_and_dtype(_input, 'input', ['float32'], + 'fused_seqpool_cvm_tradew') + + dtype = helper.input_dtype() + inputs = helper.multiple_input() + outs = [ + helper.create_variable_for_type_inference(dtype) + for i in range(len(inputs)) + ] + + helper.append_op( + type="fused_seqpool_cvm_tradew", + inputs={"X": inputs, + "CVM": cvm}, + outputs={"Out": outs}, + attrs={ + "pooltype": pool_type.upper(), + "pad_value": pad_value, + "use_cvm": use_cvm, + "cvm_offset": cvm_offset, + "trade_id": trade_id, + "trade_num": trade_num + }) + + return outs + + def cross_norm_layer_hadamard(input, fields_num, embed_dim, From 4d78f3726ac98f9abb763bddc120ac4731d3daab Mon Sep 17 00:00:00 2001 From: majun16 Date: Tue, 26 Jul 2022 11:16:19 +0800 Subject: [PATCH 2/4] add credit --- paddle/fluid/framework/fleet/box_wrapper.cu | 24 + paddle/fluid/framework/fleet/box_wrapper.h | 4 + .../fused/fused_seqpool_cvm_with_credit_op.cc | 220 +++++++++ .../fused/fused_seqpool_cvm_with_credit_op.cu | 436 ++++++++++++++++++ .../fused/fused_seqpool_cvm_with_credit_op.h | 48 ++ python/paddle/fluid/contrib/layers/nn.py | 57 +++ 6 files changed, 789 insertions(+) create mode 100644 paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.cc create mode 100644 paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.cu create mode 100644 paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.h diff --git a/paddle/fluid/framework/fleet/box_wrapper.cu b/paddle/fluid/framework/fleet/box_wrapper.cu index 34d86b4c2964c..89b94a9368479 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cu +++ b/paddle/fluid/framework/fleet/box_wrapper.cu @@ -1514,6 +1514,11 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ EmbedxDim, total_length, total_dims, slot_lens, slot_num, key2slot, \ pull_embedx_scale_, cvm_offset, gpu_restore_idx, skip_offset); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CREDIT)) { \ + FeaturePullCopy>(\ + stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ + EmbedxDim, total_length, total_dims, slot_lens, slot_num, key2slot, \ + pull_embedx_scale_, cvm_offset, gpu_restore_idx, skip_offset); \ } else { \ FeaturePullCopy>( \ stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ @@ -1561,6 +1566,12 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, slot_num, \ key2slot, 1.0, cvm_offset, gpu_restore_idx, skip_offset); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CREDIT)) { \ + FeaturePullCopyNNCross< \ + boxps::FeaturePullValueGpuCredit>( \ + stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ + EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, slot_num, \ + key2slot, 1.0, cvm_offset, gpu_restore_idx, skip_offset); \ } else { \ FeaturePullCopyNNCross< \ boxps::FeaturePullValueGpu>( \ @@ -1891,6 +1902,12 @@ void BoxWrapper::CopyForPush( total_length, batch_size, d_slot_vector, total_dims, slot_lens, \ slot_num, key2slot, cvm_offset, gpu_sort_idx, gpu_sort_offset, \ gpu_sort_lens, skip_offset); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CREDIT)) { \ + FeaturePushCopy>(\ + stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ + total_length, batch_size, d_slot_vector, total_dims, slot_lens, \ + slot_num, key2slot, cvm_offset, gpu_sort_idx, gpu_sort_offset, \ + gpu_sort_lens, skip_offset); \ } else { \ FeaturePushCopy>( \ stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ @@ -1933,6 +1950,13 @@ void BoxWrapper::CopyForPush( ExpandDim, total_length, batch_size, d_slot_vector, total_dims, \ slot_lens, slot_num, key2slot, cvm_offset, gpu_sort_idx, \ gpu_sort_offset, gpu_sort_lens, skip_offset); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CREDIT)) { \ + FeaturePushCopyNNCross< \ + boxps::FeaturePushValueGpuCredit>( \ + stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ + ExpandDim, total_length, batch_size, d_slot_vector, total_dims, \ + slot_lens, slot_num, key2slot, cvm_offset, gpu_sort_idx, \ + gpu_sort_offset, gpu_sort_lens, skip_offset); \ } else { \ FeaturePushCopyNNCross< \ boxps::FeaturePushValueGpu>( \ diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 9acc2cd885e1b..12ac8a7e62640 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -590,6 +590,10 @@ class BoxWrapper { static_cast(boxps::FEATURE_TRADEW)) { // embed_w + n * tradew s_instance_->cvm_offset_ = expand_embed_dim + 3; + } else if (s_instance_->feature_type_ == + static_cast(boxps::FEATURE_CREDIT)) { + // show/clk/conv/credit/embed + s_instance_->cvm_offset_ = 5; } else { s_instance_->cvm_offset_ = 3; } diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.cc new file mode 100644 index 0000000000000..43d34742685a6 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.cc @@ -0,0 +1,220 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.h" +#include +namespace paddle { +namespace operators { + +class FusedSeqpoolCVMOpWithCredit : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, "Inputs(X) of FusedSeqpoolCVMOpWithCredit should not be empty."); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, "Outputs(Out) of FusedSeqpoolCVMOpWithCredit should not be empty."); + + auto cvm_dims = ctx->GetInputDim("CVM"); + PADDLE_ENFORCE_EQ(cvm_dims.size(), 2UL, platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); + PADDLE_ENFORCE_EQ(cvm_dims[1], 4UL, + platform::errors::InvalidArgument("The 2nd dimension of Input(CVM) should be 4.")); + + auto ins_dims = ctx->GetInputsDim("X"); + const int cvm_offset = ctx->Attrs().Get("cvm_offset"); + const size_t num_inputs = ins_dims.size(); + std::vector outs_dims; + outs_dims.resize(num_inputs); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + bool show_filter = ctx->Attrs().Get("show_filter"); + + PADDLE_ENFORCE_GT(num_inputs, 0UL, + platform::errors::InvalidArgument( + "Input tensors count should be greater than 0, " + "but received value is %d.", + num_inputs)); + + // The output height should be confirmed in Compute, + // since input lod is not accessible here. + PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2, + platform::errors::InvalidArgument( + "The dims size of first input should be equal to 2, " + "but received value is %d.", + ins_dims[0].size())); + + for (size_t i = 0; i < num_inputs; ++i) { + const auto dims = ins_dims[i]; + int rank = dims.size(); + if (use_cvm) { + PADDLE_ENFORCE_GT( + dims[rank - 1], 2, + "Shape error in %lu id, the last dimension(embedding) of the " + "'X' tensor must be larger than 2.", + i); + } + // input lod is not accessible here + std::vector out_dim; + if (use_cvm) { + if (show_filter) { + out_dim = {-1, dims[rank - 1] - 1}; + } else { + out_dim = {-1, dims[rank - 1]}; + } + } else { + out_dim = {-1, dims[rank - 1] - cvm_offset}; + } + outs_dims[i] = framework::make_ddim(out_dim); + } + ctx->SetOutputsDim("Out", outs_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.device_context()); + } +}; + +class FusedSeqpoolCVMOpWithCreditMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(vector) The input tensors of" + " operator.") + .AsDuplicable(); + AddInput("CVM", + "(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch " + "size, 2 is show and click."); + AddOutput("Out", + "(vector) The output of Op does not contain LoD " + "information.") + .AsDuplicable(); + AddAttr("pooltype", + "(string, default 'SUM') the pooling pooltype of " + "SequencePoolOp, only support SUM now.") + .SetDefault("SUM") + .InEnum({"SUM"}); + AddAttr("pad_value", + "(float, default 0.0) The value to pad for empty sequence.") + .SetDefault(0.0); + AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddAttr("cvm_offset", "(int, default 4)").SetDefault(4); + AddAttr("show_filter", "(bool, default false)").SetDefault(false); + + AddComment(R"DOC( +Fuse multiple pairs of Sequence Pool and CVM Operator. + +)DOC"); + } +}; + +class FusedSeqpoolCVMGradOpWithCredit : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto og_dims = ctx->GetInputsDim(framework::GradVarName("Out")); + auto x_dims = ctx->GetInputsDim("X"); + auto cvm_dims = ctx->GetInputDim("CVM"); + const int cvm_offset = ctx->Attrs().Get("cvm_offset"); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + bool show_filter = ctx->Attrs().Get("show_filter"); + + PADDLE_ENFORCE_EQ( + cvm_dims.size(), 2, + platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); + + for (size_t i = 0; i < og_dims.size(); i++) { + PADDLE_ENFORCE_EQ( + og_dims[i].size(), x_dims[i].size(), + platform::errors::InvalidArgument( + "The rank of output grad must equal to Input(X). But " + "received: input rank %u, input shape [%s].", + og_dims[i].size(), og_dims[i])); + if (use_cvm) { + auto o_dim = og_dims[i][og_dims[i].size() - 1]; + if (show_filter) { + o_dim += 1; + } + PADDLE_ENFORCE_EQ( + o_dim, x_dims[i][og_dims[i].size() - 1], + platform::errors::InvalidArgument( + "The dimension mismatch between Input(OUT@GRAD) and " + "Input(X). Received Input(OUT@GRAD): input rank %u, " + "input shape [%s]; received Input(X): input rank %u, " + "input shape [%s].", + og_dims[i].size(), og_dims[i], x_dims[i].size(), x_dims[i])); + } else { + PADDLE_ENFORCE_EQ( + og_dims[i][og_dims[i].size() - 1], + x_dims[i][og_dims[i].size() - 1] - cvm_offset, + platform::errors::InvalidArgument( + "The dimension mismatch between Input(OUT@GRAD) and " + "Input(X). Received Input(OUT@GRAD): input rank %u, " + "input shape [%s]; received Input(X): input rank %u, " + "input shape [%s].", + og_dims[i].size(), og_dims[i], x_dims[i].size(), x_dims[i])); + } + } + for (size_t i = 0; i < x_dims.size(); ++i) { + ctx->ShareLoD("X", framework::GradVarName("X"), i, i); + ctx->ShareDim("X", framework::GradVarName("X"), i, i); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class FusedSeqpoolCVMGradOpWithCreditMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op_desc_ptr) const override { + op_desc_ptr->SetType("fused_seqpool_cvm_with_credit_grad"); + op_desc_ptr->SetInput("X", this->Input("X")); + op_desc_ptr->SetInput("CVM", this->Input("CVM")); + + op_desc_ptr->SetInput(framework::GradVarName("Out"), + this->OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), + this->InputGrad("X", false)); + op_desc_ptr->SetOutput(framework::GradVarName("CVM"), + this->InputGrad("CVM")); + op_desc_ptr->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(fused_seqpool_cvm_with_credit, ops::FusedSeqpoolCVMOpWithCredit, + ops::FusedSeqpoolCVMOpWithCreditMaker, + ops::FusedSeqpoolCVMGradOpWithCreditMaker, + ops::FusedSeqpoolCVMGradOpWithCreditMaker); +REGISTER_OPERATOR(fused_seqpool_cvm_with_credit_grad, ops::FusedSeqpoolCVMGradOpWithCredit) + +REGISTER_OP_CPU_KERNEL(fused_seqpool_cvm_with_credit, + ops::FusedSeqpoolCVMOpWithCreditCPUKernel) +REGISTER_OP_CPU_KERNEL(fused_seqpool_cvm_with_credit_grad, + ops::FusedSeqpoolCVMGradOpWithCreditCPUKernel) diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.cu new file mode 100644 index 0000000000000..2989892730b23 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.cu @@ -0,0 +1,436 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +#define GET_BLOCK(N) \ + ((N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS) + +#define CUDA_KERNEL_LOOP(i, n) \ + for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +// normal +template +__global__ void FusedSeqpoolWithCreditKernelNormal(const size_t N, T **input_values, + T **seqpool_output_values, + size_t **lods_values, + const int batch_size, + const int embedding_size, + const float pad_value) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + + double val = pad_value; + for (auto k = start; k < end; ++k) { + val += *(input_values[x] + k * embedding_size + offset); + } + *(seqpool_output_values[x] + y * embedding_size + offset) = val; + } +} +// join need show/click/conv/credit input +template +__global__ void FusedCVMWithCreditKernelNormal(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + if (offset < cvm_offset) { // show/click/conv/credit + *(output_values[x] + y * embedding_size + offset) = + log(*(seqpool_output_values[x] + y * embedding_size + offset) + 1); + } else { // embed emebdx + *(output_values[x] + y * embedding_size + offset) = + *(seqpool_output_values[x] + y * embedding_size + offset); + } + } +} + +// join without show input +template +__global__ void FusedCVMWithCreditKernelWithOutShow(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int embedding_size, + const int noshow_embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / noshow_embedding_size; + int offset = i % noshow_embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + if (offset < cvm_offset - 1) { // filter show + *(output_values[x] + y * noshow_embedding_size + offset) = + log(*(seqpool_output_values[x] + y * embedding_size + offset + 1) + 1); + } else { + *(output_values[x] + y * noshow_embedding_size + offset) = + *(seqpool_output_values[x] + y * embedding_size + offset + 1); + } + } +} + +// update not need show click conv credit input +template +__global__ void FusedCVMWithCreditKernelNoCVM(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int no_cvm_embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / no_cvm_embedding_size; + int offset = i % no_cvm_embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + // no cvm + *(output_values[x] + y * no_cvm_embedding_size + offset) = + *(seqpool_output_values[x] + y * (no_cvm_embedding_size + cvm_offset) + + offset + cvm_offset); + } +} + +template +void FusedSeqpoolCVMWithCredit(const paddle::platform::Place &place, + const std::vector &input_data, + const std::vector &output_data, + const std::vector &seqpool_output_data, + std::vector lods, const int batch_size, + const int slot_num, const int embedding_size, + const float padding_value, const bool use_cvm, + const int cvm_offset, bool show_filter) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + + size_t total_ptr_len = input_data.size() + output_data.size() + + seqpool_output_data.size() + lods.size(); + auto temp_ptr = memory::AllocShared(place, total_ptr_len * sizeof(void *)); + void *ptr = temp_ptr->ptr(); + + T **gpu_input_values = reinterpret_cast(temp_ptr->ptr()); + cudaMemcpyAsync(gpu_input_values, input_data.data(), + input_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + T **gpu_output_values = + reinterpret_cast(&gpu_input_values[input_data.size()]); + cudaMemcpyAsync(gpu_output_values, output_data.data(), + output_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + T **gpu_seqpool_output_values = + reinterpret_cast(&gpu_output_values[output_data.size()]); + cudaMemcpyAsync(gpu_seqpool_output_values, seqpool_output_data.data(), + seqpool_output_data.size() * sizeof(T *), + cudaMemcpyHostToDevice, stream); + size_t **lods_values = reinterpret_cast( + &gpu_seqpool_output_values[seqpool_output_data.size()]); + cudaMemcpyAsync(lods_values, lods.data(), lods.size() * sizeof(size_t *), + cudaMemcpyHostToDevice, stream); + + size_t N = static_cast(batch_size * slot_num * embedding_size); + // first sum pool + FusedSeqpoolWithCreditKernelNormal<<>>( + N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, + embedding_size, padding_value); + // second log + if (use_cvm) { + if (show_filter) { + N = static_cast(batch_size * slot_num * (embedding_size - 1)); + FusedCVMWithCreditKernelWithOutShow<<>>(N, gpu_output_values, + gpu_seqpool_output_values, batch_size, + embedding_size, embedding_size - 1, cvm_offset); + } else { + FusedCVMWithCreditKernelNormal<<>>(N, gpu_output_values, + gpu_seqpool_output_values, batch_size, + embedding_size, cvm_offset); + } + } else { + // not need show click input + N = static_cast(batch_size * slot_num * + (embedding_size - cvm_offset)); + FusedCVMWithCreditKernelNoCVM<<>>( + N, gpu_output_values, gpu_seqpool_output_values, batch_size, + (embedding_size - cvm_offset), cvm_offset); + } +} + // join grad + template + __global__ void FusedSeqpoolCVMWithCreditGradKernelWithCVM( + const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, + size_t **lods_values, const int batch_size, const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + T &val = (offset < cvm_offset) + ? *(cvm_values[x] + y * cvm_offset + offset) + : *(out_grads_values[x] + y * embedding_size + offset); + + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * embedding_size + offset) = val; + } + } + } + + // join with out show + template + __global__ void FusedSeqpoolCVMWithCreditGradKernelWithOutShow( + const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, + size_t **lods_values, const int batch_size, const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + T &val = + (offset < cvm_offset) + ? *(cvm_values[x] + y * cvm_offset + offset) + : *(out_grads_values[x] + y * (embedding_size - 1) + offset - 1); + + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * embedding_size + offset) = val; + } + } + } + +// update grad +template +__global__ void FusedSeqpoolCVMWithCreditGradKernelNoCVM( + const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, + size_t **lods_values, const int batch_size, const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + T &val = (offset < cvm_offset) + ? *(cvm_values[x] + y * cvm_offset + offset) + : *(out_grads_values[x] + y * (embedding_size - cvm_offset) + + offset - cvm_offset); + + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * embedding_size + offset) = val; + } + } +} +template +void FusedSeqpoolCVMGradWithCredit(const paddle::platform::Place &place, + const std::vector &out_grads_data, + const std::vector &in_grads_data, + const std::vector &cvm_data, + const std::vector &lods, + const int batch_size, const int slot_num, + const int embedding_size, const bool use_cvm, + const int cvm_offset, bool show_filter) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + size_t total_ptr_len = out_grads_data.size() + in_grads_data.size() + + cvm_data.size() + lods.size(); + auto temp_ptr = memory::AllocShared(place, total_ptr_len * sizeof(void *)); + T **gpu_out_grads_values = reinterpret_cast(temp_ptr->ptr()); + cudaMemcpyAsync(gpu_out_grads_values, out_grads_data.data(), + out_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + T **gpu_in_grads_values = + reinterpret_cast(&gpu_out_grads_values[out_grads_data.size()]); + cudaMemcpyAsync(gpu_in_grads_values, in_grads_data.data(), + in_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + T **gpu_cvm_values = + reinterpret_cast(&gpu_in_grads_values[in_grads_data.size()]); + cudaMemcpyAsync(gpu_cvm_values, cvm_data.data(), + cvm_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + size_t **lods_values = + reinterpret_cast(&gpu_cvm_values[cvm_data.size()]); + cudaMemcpyAsync(lods_values, lods.data(), lods.size() * sizeof(size_t *), + cudaMemcpyHostToDevice, stream); + + size_t N = static_cast(batch_size * slot_num * embedding_size); + if (use_cvm) { + if (show_filter) { + FusedSeqpoolCVMWithCreditGradKernelWithOutShow<<>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, embedding_size, cvm_offset); + + } else { + FusedSeqpoolCVMWithCreditGradKernelWithCVM<<>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, embedding_size, cvm_offset); + } + } else { + // update grad + FusedSeqpoolCVMWithCreditGradKernelNoCVM<<>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, embedding_size, cvm_offset); + } +} + +template +class FusedSeqpoolCVMWithCreditCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto inputs = ctx.MultiInput("X"); + auto outputs = ctx.MultiOutput("Out"); + + const auto slot_size = inputs.size(); + std::vector input_data(slot_size); + std::vector lods_data(slot_size); + std::vector output_data(slot_size); + + std::vector seqpool_outputs(slot_size); + std::vector seqpool_output_data(slot_size); + + auto padding_value = ctx.Attr("pad_value"); + auto use_cvm = ctx.Attr("use_cvm"); + const int cvm_offset = ctx.Attr("cvm_offset"); + bool show_filter = ctx.Attr("show_filter"); + + int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0]; + int batch_size = -1; + for (size_t i = 0; i < slot_size; ++i) { + const auto *input = inputs[i]; + + auto lod = input->lod(); + auto lod_level = lod.size(); + + int cur_batch = lod[lod_level - 1].size() - 1; + if (batch_size == -1) { + batch_size = cur_batch; + } else { + CHECK(batch_size == cur_batch) << "batch: " << batch_size << ", current: " << cur_batch; + } + input_data[i] = reinterpret_cast(input->data()); + auto *output = outputs[i]; + if (use_cvm) { + if (show_filter) { + // show will filtered + output->Resize({batch_size, embedding_size - 1}); + } else { + output->Resize({batch_size, embedding_size}); + } + } else { + output->Resize({batch_size, embedding_size - cvm_offset}); + } + output_data[i] = + reinterpret_cast(output->mutable_data(ctx.GetPlace())); + lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); + + seqpool_output_data[i] = + reinterpret_cast(seqpool_outputs[i].mutable_data( + {batch_size, embedding_size}, ctx.GetPlace())); + } + FusedSeqpoolCVMWithCredit(ctx.GetPlace(), input_data, output_data, + seqpool_output_data, lods_data, batch_size, slot_size, + embedding_size, padding_value, use_cvm, cvm_offset, show_filter); + } +}; + +template +class FusedSeqpoolCVMWithCreditGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto out_grads = ctx.MultiInput(framework::GradVarName("Out")); + auto in_grads = ctx.MultiOutput(framework::GradVarName("X")); + auto *cvm = ctx.Input("CVM"); + + std::string pooltype = ctx.Attr("pooltype"); + auto use_cvm = ctx.Attr("use_cvm"); + const int cvm_offset = ctx.Attr("cvm_offset"); + bool show_filter = ctx.Attr("show_filter"); + + const auto slot_size = in_grads.size(); + std::vector out_grads_data(slot_size); + std::vector in_grads_data(slot_size); + std::vector cvm_data(slot_size); + std::vector lods_data(slot_size); + + int embedding_size = in_grads[0]->numel() / in_grads[0]->dims()[0]; + int batch_size = -1; + for (size_t i = 0; i < slot_size; ++i) { + auto *in_grad = in_grads[i]; + + auto lod = in_grad->lod(); + auto lod_level = lod.size(); + int cur_batch = lod[lod_level - 1].size() - 1; + if (batch_size == -1) { + batch_size = cur_batch; + } else { + CHECK(batch_size == cur_batch) << "batch: " << batch_size + << ", current: " << cur_batch; + } + + auto *out_grad = out_grads[i]; + out_grads_data[i] = reinterpret_cast(out_grad->data()); + + in_grads_data[i] = + reinterpret_cast(in_grad->mutable_data(ctx.GetPlace())); + lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); + cvm_data[i] = reinterpret_cast(cvm->data()); + } + FusedSeqpoolCVMGradWithCredit(ctx.GetPlace(), out_grads_data, in_grads_data, cvm_data, + lods_data, batch_size, slot_size, embedding_size, + use_cvm, cvm_offset, show_filter); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_with_credit, + ops::FusedSeqpoolCVMWithCreditCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_with_credit_grad, + ops::FusedSeqpoolCVMWithCreditGradCUDAKernel); diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.h b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.h new file mode 100644 index 0000000000000..56fd27c945039 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_credit_op.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +template +class FusedSeqpoolCVMOpWithCreditCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "Unimplemented CPU kernel for FusedSeqpoolCVMOpWithCredit only support GPU " + "now."); + } +}; + +template +class FusedSeqpoolCVMGradOpWithCreditCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "Unimplemented CPU kernel for FusedSeqpoolCVMGradOpWithCredit, only support GPU " + "now."); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 6cc2fc6acaf2d..e4e1645055825 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -78,6 +78,7 @@ 'fused_seqpool_cvm_with_conv', 'fused_seqpool_cvm_with_diff_thres', 'fused_seqpool_cvm_tradew', + 'fused_seqpool_cvm_with_credit', 'cross_norm_layer_hadamard', 'fused_seqpool_cvm_with_pcoc', 'scaled_fc', @@ -1919,6 +1920,62 @@ def fused_seqpool_cvm_tradew(input, return outs +def fused_seqpool_cvm_with_credit(input, + pool_type, + cvm, + pad_value=0.0, + use_cvm=True, + show_filter=False, + cvm_offset=4): + """ + **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. + :attr:`input`. + Args: + input(Variable|list of Variable): Input is List of LoDTensor. + pool_type(str): pooling type, only support SUM pooling now. + cvm(Variable): cvm Variable. + pad_value(float): padding value of sequence pool. + use_cvm(bool): use cvm or not. + Returns: + Variable|list of Variable: The tensor variable storing sequence pool and cvm + of input. + """ + helper = LayerHelper('fused_seqpool_cvm_with_credit', **locals()) + + if pool_type.upper() != 'SUM': + raise ValueError( + "fused_seqpool_cvm_with_credit only support SUM pooling now, and your type is: " + + pool_type) + + check_type(input, 'input', list, 'fused_seqpool_cvm_with_credit') + if isinstance(input, list): + for _input in input: + check_variable_and_dtype(_input, 'input', ['float32'], + 'fused_seqpool_cvm_with_credit') + + dtype = helper.input_dtype() + inputs = helper.multiple_input() + outs = [ + helper.create_variable_for_type_inference(dtype) + for i in range(len(inputs)) + ] + + helper.append_op( + type="fused_seqpool_cvm_with_credit", + inputs={"X": inputs, + "CVM": cvm}, + outputs={"Out": outs}, + attrs={ + "pooltype": pool_type.upper(), + "pad_value": pad_value, + "use_cvm": use_cvm, + "cvm_offset": cvm_offset, + "show_filter": show_filter + }) + + return outs + + def cross_norm_layer_hadamard(input, fields_num, embed_dim, From 2002a44d1edc451953a04ae16b3bcb97b4ad921d Mon Sep 17 00:00:00 2001 From: qshuihu Date: Tue, 26 Jul 2022 15:28:18 +0800 Subject: [PATCH 3/4] fix fuse_seqpool_cvm_tradew show click double counting --- paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu index 518c191dc4eea..09d742979d307 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_tradew_op.cu @@ -304,11 +304,12 @@ __global__ void FusedSeqpoolCVMTradeWGradKernel( int y = key % batch_size; // ins id if (offset < cvm_offset) { - T &val = *(cvm_values[x] + y * cvm_offset + offset); + // T &val = *(cvm_values[x] + y * cvm_offset + offset); auto &start = *(lods_values[x] + y); auto &end = *(lods_values[x] + y + 1); for (auto k = start; k < end; ++k) { - *(in_grads_values[x] + k * hidden_num + offset) = val; + // trade not need set show click grad + *(in_grads_values[x] + k * hidden_num + offset) = 0.0; } } else if (offset < cvm_offset + trade_num) { auto &start = *(lods_values[x] + y); From bcc2233f052f374cc601d4c0d0005cb0ad4b24ab Mon Sep 17 00:00:00 2001 From: humingqing Date: Tue, 26 Jul 2022 17:33:07 +0800 Subject: [PATCH 4/4] add new boxps --- cmake/external/box_ps.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/external/box_ps.cmake b/cmake/external/box_ps.cmake index b606cd3603a6c..0219a96ba1006 100644 --- a/cmake/external/box_ps.cmake +++ b/cmake/external/box_ps.cmake @@ -22,10 +22,10 @@ IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL)) #SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE) IF(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) #cuda10.2 - SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.32" CACHE STRING "" FORCE) + SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.36" CACHE STRING "" FORCE) ELSE() #cuda11.4 - SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.31" CACHE STRING "" FORCE) + SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.35" CACHE STRING "" FORCE) ENDIF() ENDIF() MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}")