diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index c24f44c45e3b7..60278eb4266c1 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -1205,7 +1205,7 @@ std::pair GraphTable::parse_node_file( auto node = feature_shards[idx][index]->add_feature_node(id, false); if (node != NULL) { node->set_feature_size(feat_name[idx].size()); - for (int i = 1; i < n; ++i) { + for (int i = 1; i < num; ++i) { auto &v = vals[i]; parse_feature(idx, v.ptr, v.len, node); } diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 221915fc713a8..134fd1f2760c0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -145,6 +145,24 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { int optimizer_type_; }; + struct CommonPullValue { + /* + float show; + float click; + float embed_w; + float mf_size + std::vector embedx_w; + */ + __host__ __device__ int ShowIndex() { return 0; } + __host__ __device__ int ClickIndex() { return 1; } + __host__ __device__ int EmbedWIndex() { return 2; } + __host__ __device__ int MfSizeIndex() { return 3; } // actual mf size (ex. 0) + __host__ __device__ int EmbedxWIndex() { return 4; } + __host__ __device__ int Size(const int mf_dim) { + return (4 + mf_dim) * sizeof(float); + } + }; + struct CommonPushValue { /* float slot; @@ -251,6 +269,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { public: CommonFeatureValue common_feature_value; CommonPushValue common_push_value; + CommonPullValue common_pull_value; }; diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index c430dfa669c45..5ef37b80c0b2d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -90,37 +90,24 @@ __global__ void dy_mf_search_kernel(Table* table, // return; if (i < len) { auto it = table->find(keys[i]); - if (it != table->end()) { uint64_t offset = i * pull_feature_value_size; float* cur = (float*)(vals + offset); float* input = it->second; - int mf_dim = int(input[feature_value_accessor.common_feature_value.MfDimIndex()]); - - *(reinterpret_cast(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) = - *(reinterpret_cast(input + feature_value_accessor.common_feature_value.CpuPtrIndex())); - cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - input[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; - cur[feature_value_accessor.common_feature_value.ShowIndex()] = - input[feature_value_accessor.common_feature_value.ShowIndex()]; - cur[feature_value_accessor.common_feature_value.ClickIndex()] = - input[feature_value_accessor.common_feature_value.ClickIndex()]; - cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = - input[feature_value_accessor.common_feature_value.EmbedWIndex()]; - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x] = - input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x]; - } - cur[feature_value_accessor.common_feature_value.SlotIndex()] = - input[feature_value_accessor.common_feature_value.SlotIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = - input[feature_value_accessor.common_feature_value.MfDimIndex()]; - cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = - input[feature_value_accessor.common_feature_value.MfSizeIndex()]; - - for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); - x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - cur[x] = input[x]; + + cur[feature_value_accessor.common_pull_value.ShowIndex()] = + input[feature_value_accessor.common_feature_value.ShowIndex()]; + cur[feature_value_accessor.common_pull_value.ClickIndex()] = + input[feature_value_accessor.common_feature_value.ClickIndex()]; + cur[feature_value_accessor.common_pull_value.EmbedWIndex()] = + input[feature_value_accessor.common_feature_value.EmbedWIndex()]; + int embedx_dim = int(input[feature_value_accessor.common_feature_value.MfSizeIndex()]); + cur[feature_value_accessor.common_pull_value.MfSizeIndex()] = embedx_dim; + + int embedx_off = feature_value_accessor.common_pull_value.EmbedxWIndex(); + int value_off = feature_value_accessor.common_feature_value.EmbedxWOffsetIndex(input); + for (int i = 0; i < embedx_dim; ++i) { + cur[embedx_off + i] = input[value_off + i]; } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index f10b59edf4d77..72dfb476efe0a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include #include + #include "cub/cub.cuh" #include "cub/util_allocator.cuh" #if defined(PADDLE_WITH_CUDA) @@ -25,6 +26,7 @@ limitations under the License. */ #include "thrust/pair.h" #elif defined(PADDLE_WITH_XPU_KP) #include + #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #endif @@ -47,7 +49,7 @@ template class HeterComm { public: HeterComm(size_t capacity, std::shared_ptr resource); - HeterComm(size_t capacity, std::shared_ptr resource, + HeterComm(size_t capacity, std::shared_ptr resource, CommonFeatureValueAccessor& accessor); virtual ~HeterComm(); HeterComm(const HeterComm&) = delete; @@ -61,18 +63,19 @@ class HeterComm { uint32_t* d_restore_idx, size_t & uniq_len); void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, - int& uniq_len); // NOLINT + int& uniq_len); // NOLINT void dynamic_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads, - size_t len, int& uniq_len, size_t& segment_len, bool enable_segment_merge_grad); + size_t len, int& uniq_len, size_t& segment_len, + bool enable_segment_merge_grad); void segment_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads, - const uint32_t* d_index, size_t len, - const uint32_t* d_fea_num_info, - size_t uniq_len, size_t& segment_len); + const uint32_t* d_index, size_t len, + const uint32_t* d_fea_num_info, size_t uniq_len, + size_t& segment_len); void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len); void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len, - size_t chunk_size, int stream_num, int offset = -1); + size_t chunk_size, int stream_num, int offset = -1); void build_ps(int num, KeyType* h_keys, char* pool, size_t len, - size_t feature_value_size, size_t chunk_size, int stream_num); + size_t feature_value_size, size_t chunk_size, int stream_num); void dump(); void show_one_table(int gpu_num); void show_table_collisions(); @@ -124,7 +127,7 @@ class HeterComm { } void set_accessor(CommonFeatureValueAccessor& accessor) { - feature_value_accessor_ = accessor; + feature_value_accessor_ = accessor; } #endif @@ -137,6 +140,19 @@ class HeterComm { int get_transfer_devid(int send_id) { return (send_id + 4) % 8; } void end_pass(); +#if defined(PADDLE_WITH_CUDA) + // dedup + int dedup_keys_and_fillidx(const int gpu_id, + const int total_fea_num, + const KeyType* d_keys, // input + KeyType* d_merged_keys, // output + KeyType* d_sorted_keys, + uint32_t* d_restore_idx, + uint32_t* d_sorted_idx, + uint32_t* d_offset, + uint32_t* d_merged_cnts, + bool filter_zero); +#endif struct Node { ppStream in_stream; @@ -243,7 +259,9 @@ class HeterComm { ValType* src_val); void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right, char* src_val, size_t val_size); - + protected: + void pull_merge_sparse(int num, KeyType* d_keys, float* d_vals, size_t len); + void pull_normal_sparse(int num, KeyType* d_keys, float* d_vals, size_t len); protected: using Table = HashTable; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 68f73531928a5..e97bcb561675b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -14,10 +14,11 @@ limitations under the License. */ #pragma once #ifdef PADDLE_WITH_HETERPS #include + #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" +#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #ifdef PADDLE_WITH_XPU_KP #include "paddle/fluid/platform/device/xpu/xpu_info.h" #endif @@ -26,6 +27,7 @@ DECLARE_double(gpugraph_hbm_table_load_factor); DECLARE_bool(gpugraph_enable_gpu_direct_access); DECLARE_bool(gpugraph_enable_segment_merge_grads); DECLARE_uint64(gpugraph_merge_grads_segment_size); +DECLARE_int32(gpugraph_dedup_pull_push_mode); namespace paddle { namespace framework { @@ -79,24 +81,28 @@ HeterComm::HeterComm( } else { max_mf_dim_ = resource_->max_mf_dim(); feature_value_accessor_ = feature_value_accessor; - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); - size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); - VLOG(0) << " HeterComm init, max feature_value_size:" << val_type_size - << ", feature_value_push_size:" << grad_type_size; + size_t val_type_size = TYPEALIGN( + 8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + size_t grad_type_size = TYPEALIGN( + 8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + size_t pull_type_size = feature_value_accessor_.common_pull_value.Size(max_mf_dim_); + VLOG(0) << " HeterComm init, max feature_value_size:" << val_type_size + << ", feature_value_push_size:" << grad_type_size + << ", feature_pull_type_size:" << pull_type_size; auto ptr_table = new PtrTable(capacity / load_factor_); ptr_table->set_accessor(feature_value_accessor_); - ptr_table->set_feature_value_size(val_type_size, grad_type_size); + ptr_table->set_feature_value_size(pull_type_size, grad_type_size); ptr_tables_.push_back(ptr_table); } if (multi_node_) { storage_[i].init(feanum_, resource_->dev_id(i)); } } - heter_comm_kernel_ = std::make_unique(block_size_, feature_value_accessor_); + heter_comm_kernel_ = + std::make_unique(block_size_, feature_value_accessor_); init_path(); } - template void HeterComm::init_path() { int total_device = resource_->total_device(); @@ -307,36 +313,37 @@ void HeterComm::walk_to_dest( auto& node = path_[start_index][i].nodes_[0]; CopyTask t(&path_[start_index][i], 0); que.push(t); - CUDA_CHECK(cudaMemcpyAsync(node.key_storage, - reinterpret_cast(src_key + h_left[i]), - node.key_bytes_len, cudaMemcpyDefault, node.in_stream)); + CUDA_CHECK(cudaMemcpyAsync( + node.key_storage, reinterpret_cast(src_key + h_left[i]), + node.key_bytes_len, cudaMemcpyDefault, node.in_stream)); if (need_copy_val) { - CUDA_CHECK(cudaMemcpyAsync(node.val_storage, - src_val + uint64_t(h_left[i]) * uint64_t(val_size), - node.val_bytes_len, cudaMemcpyDefault, node.in_stream)); + CUDA_CHECK(cudaMemcpyAsync( + node.val_storage, src_val + uint64_t(h_left[i]) * uint64_t(val_size), + node.val_bytes_len, cudaMemcpyDefault, node.in_stream)); } } while (!que.empty()) { CopyTask& cur_task = que.front(); que.pop(); if (cur_task.path->nodes_[cur_task.step].sync) { - CUDA_CHECK(cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream)); + CUDA_CHECK(cudaStreamSynchronize( + cur_task.path->nodes_[cur_task.step].in_stream)); } if (cur_task.step != cur_task.path->nodes_.size() - 1) { int cur_step = cur_task.step; CopyTask c(cur_task.path, cur_step + 1); que.push(c); - CUDA_CHECK(cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage, - cur_task.path->nodes_[cur_step].key_storage, - cur_task.path->nodes_[cur_step + 1].key_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step + 1].in_stream)); + CUDA_CHECK(cudaMemcpyAsync( + cur_task.path->nodes_[cur_step + 1].key_storage, + cur_task.path->nodes_[cur_step].key_storage, + cur_task.path->nodes_[cur_step + 1].key_bytes_len, cudaMemcpyDefault, + cur_task.path->nodes_[cur_step + 1].in_stream)); if (need_copy_val) { - CUDA_CHECK(cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step + 1].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step + 1].in_stream)); + CUDA_CHECK(cudaMemcpyAsync( + cur_task.path->nodes_[cur_step + 1].val_storage, + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step + 1].val_bytes_len, + cudaMemcpyDefault, cur_task.path->nodes_[cur_step + 1].in_stream)); } } } @@ -355,16 +362,17 @@ void HeterComm::walk_to_src( auto& node = path_[start_index][i].nodes_[cur_step]; if (cur_step == 0) { CUDA_CHECK(cudaMemcpyAsync(src_val + uint64_t(h_left[i]) * val_size, - node.val_storage, node.val_bytes_len, cudaMemcpyDefault, - node.out_stream)); + node.val_storage, node.val_bytes_len, + cudaMemcpyDefault, node.out_stream)); } else { CopyTask t(&path_[start_index][i], cur_step - 1); que.push(t); - CUDA_CHECK(cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage, - node.val_storage, - path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, - cudaMemcpyDefault, - path_[start_index][i].nodes_[cur_step - 1].out_stream)); + CUDA_CHECK(cudaMemcpyAsync( + path_[start_index][i].nodes_[cur_step - 1].val_storage, + node.val_storage, + path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, + cudaMemcpyDefault, + path_[start_index][i].nodes_[cur_step - 1].out_stream)); } } while (!que.empty()) { @@ -377,18 +385,18 @@ void HeterComm::walk_to_src( if (cur_step > 0) { CopyTask c(cur_task.path, cur_step - 1); que.push(c); - CUDA_CHECK(cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step - 1].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step - 1].out_stream)); + CUDA_CHECK(cudaMemcpyAsync( + cur_task.path->nodes_[cur_step - 1].val_storage, + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step - 1].val_bytes_len, cudaMemcpyDefault, + cur_task.path->nodes_[cur_step - 1].out_stream)); } else if (cur_step == 0) { int end_index = cur_task.path->nodes_.back().dev_num; - CUDA_CHECK(cudaMemcpyAsync(src_val + uint64_t(h_left[end_index]) * val_size, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step].out_stream)); + CUDA_CHECK(cudaMemcpyAsync( + src_val + uint64_t(h_left[end_index]) * val_size, + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step].val_bytes_len, cudaMemcpyDefault, + cur_task.path->nodes_[cur_step].out_stream)); } } } @@ -514,8 +522,8 @@ void HeterComm::build_ps( if (offset == -1) offset = dev_num; tables_[offset]->insert( reinterpret_cast(d_key_bufs[cur_stream]->ptr()), - reinterpret_cast(d_val_bufs[cur_stream]->ptr()), (size_t)tmp_len, - cur_use_stream); + reinterpret_cast(d_val_bufs[cur_stream]->ptr()), + (size_t)tmp_len, cur_use_stream); cur_stream += 1; cur_len += tmp_len; @@ -622,8 +630,8 @@ void HeterComm::merge_grad( template void HeterComm::dynamic_merge_grad( - int gpu_num, KeyType* d_keys, float* d_grads, size_t len, - int& uniq_len, size_t& segment_len, bool enable_segment_merge_grad) { + int gpu_num, KeyType* d_keys, float* d_grads, size_t len, int& uniq_len, + size_t& segment_len, bool enable_segment_merge_grad) { int dev_id = resource_->dev_id(gpu_num); platform::CUDAPlace place = platform::CUDAPlace(dev_id); platform::CUDADeviceGuard guard(dev_id); @@ -631,7 +639,8 @@ void HeterComm::dynamic_merge_grad( size_t temp_storage_bytes; size_t grad_dim = max_mf_dim_; - size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + size_t grad_value_size = + TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); @@ -684,39 +693,40 @@ void HeterComm::dynamic_merge_grad( PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); if (enable_segment_merge_grad) { - segment_merge_grad( - gpu_num, - d_merge_keys_ptr, d_grads, d_index, len, - d_fea_num_info_ptr, uniq_len, - segment_len); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_keys, d_merge_keys_ptr, - sizeof(KeyType) * segment_len, - cudaMemcpyDeviceToDevice, stream)); + segment_merge_grad(gpu_num, d_merge_keys_ptr, d_grads, d_index, len, + d_fea_num_info_ptr, uniq_len, segment_len); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemcpyAsync(d_keys, d_merge_keys_ptr, sizeof(KeyType) * segment_len, + cudaMemcpyDeviceToDevice, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); } else { auto d_merge_grads = memory::Alloc(place, len * grad_value_size); float* d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); heter_comm_kernel_->merge_gradient( - d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, - (char*)d_merge_grads_ptr, uniq_len, grad_dim, grad_value_size, merger_, stream); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, - grad_value_size * uniq_len, - cudaMemcpyDeviceToDevice, stream)); + d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, + (char*)d_merge_grads_ptr, uniq_len, grad_dim, grad_value_size, merger_, + stream); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemcpyAsync(d_grads, d_merge_grads_ptr, grad_value_size * uniq_len, + cudaMemcpyDeviceToDevice, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); } } template void HeterComm::segment_merge_grad( - int gpu_num, // the device number - KeyType* d_keys, // the sorted keys list, which will be modified after merged - float* d_grads, // the raw grads list, which will be modified after merged - const uint32_t* d_index, // the storage position of d_keys, its length is len. - size_t len, // the number of raw input keys - const uint32_t* d_fea_num_info, // prefix sum array, its length is uniq_len+1 - size_t uniq_len, // the number of unique keys - size_t& segments_num) { // the number of segment merged keys + int gpu_num, // the device number + KeyType* + d_keys, // the sorted keys list, which will be modified after merged + float* d_grads, // the raw grads list, which will be modified after merged + const uint32_t* + d_index, // the storage position of d_keys, its length is len. + size_t len, // the number of raw input keys + const uint32_t* + d_fea_num_info, // prefix sum array, its length is uniq_len+1 + size_t uniq_len, // the number of unique keys + size_t& segments_num) { // the number of segment merged keys int dev_id = resource_->dev_id(gpu_num); platform::CUDAPlace place = platform::CUDAPlace(dev_id); @@ -724,7 +734,8 @@ void HeterComm::segment_merge_grad( auto stream = resource_->local_stream(gpu_num, 0); auto grad_dim = max_mf_dim_; - auto grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + auto grad_value_size = + TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); auto d_buffer1 = memory::Alloc(place, sizeof(uint32_t) * len); auto d_segments = reinterpret_cast(d_buffer1->ptr()); @@ -733,35 +744,32 @@ void HeterComm::segment_merge_grad( auto d_buffer3 = memory::Alloc(place, sizeof(uint32_t) * len); auto d_segments_fea_num_info = reinterpret_cast(d_buffer3->ptr()); auto d_buffer4 = memory::Alloc(place, sizeof(uint32_t) * len); - auto d_segments_fea_num_offset = reinterpret_cast(d_buffer4->ptr()); + auto d_segments_fea_num_offset = + reinterpret_cast(d_buffer4->ptr()); auto d_buffer5 = memory::Alloc(place, sizeof(uint32_t)); auto d_segments_num = reinterpret_cast(d_buffer5->ptr()); CUDA_CHECK(cudaMemsetAsync(d_segments_num, 0, sizeof(uint32_t), stream)); uint32_t segment_size = FLAGS_gpugraph_merge_grads_segment_size; - heter_comm_kernel_->split_segments( - d_fea_num_info, uniq_len, - d_segments, - d_segments_num, - segment_size, stream); + heter_comm_kernel_->split_segments(d_fea_num_info, uniq_len, d_segments, + d_segments_num, segment_size, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); size_t temp_storage_bytes = 0; PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum( - NULL, temp_storage_bytes, d_segments, d_segments_num, - uniq_len, stream)); + NULL, temp_storage_bytes, d_segments, d_segments_num, uniq_len, stream)); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum( - d_temp_storage->ptr(), temp_storage_bytes, d_segments, d_segments_num, - uniq_len, stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cub::DeviceReduce::Sum(d_temp_storage->ptr(), temp_storage_bytes, + d_segments, d_segments_num, uniq_len, stream)); CUDA_CHECK(cudaMemcpyAsync(&segments_num, d_segments_num, sizeof(uint32_t), - cudaMemcpyDeviceToHost, stream)); + cudaMemcpyDeviceToHost, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); temp_storage_bytes = 0; - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( - NULL, temp_storage_bytes, d_segments, d_segments_offset, - uniq_len, stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cub::DeviceScan::ExclusiveSum(NULL, temp_storage_bytes, d_segments, + d_segments_offset, uniq_len, stream)); if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; d_temp_storage = memory::Alloc(place, temp_storage_bytes); @@ -771,46 +779,43 @@ void HeterComm::segment_merge_grad( uniq_len, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - heter_comm_kernel_->expand_segments( - d_fea_num_info, - d_segments_offset, uniq_len, - d_segments_fea_num_info, segment_size, stream); + heter_comm_kernel_->expand_segments(d_fea_num_info, d_segments_offset, + uniq_len, d_segments_fea_num_info, + segment_size, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( - NULL, temp_storage_bytes, d_segments_fea_num_info, d_segments_fea_num_offset, - segments_num, stream)); + NULL, temp_storage_bytes, d_segments_fea_num_info, + d_segments_fea_num_offset, segments_num, stream)); if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; d_temp_storage = memory::Alloc(place, temp_storage_bytes); } PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( - d_temp_storage->ptr(), temp_storage_bytes, d_segments_fea_num_info, d_segments_fea_num_offset, - segments_num, stream)); + d_temp_storage->ptr(), temp_storage_bytes, d_segments_fea_num_info, + d_segments_fea_num_offset, segments_num, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); auto d_segments_keys = memory::Alloc(place, sizeof(KeyType) * segments_num); auto d_segments_keys_ptr = reinterpret_cast(d_segments_keys->ptr()); - heter_comm_kernel_->shrink_keys( - d_keys, d_segments_fea_num_offset, - d_segments_keys_ptr, segments_num, - stream); + heter_comm_kernel_->shrink_keys(d_keys, d_segments_fea_num_offset, + d_segments_keys_ptr, segments_num, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); auto d_segment_grads = memory::Alloc(place, segments_num * grad_value_size); auto d_segment_grads_ptr = reinterpret_cast(d_segment_grads->ptr()); heter_comm_kernel_->merge_gradient( - d_segments_keys_ptr, d_segments_fea_num_offset, d_segments_fea_num_info, d_index, - (char*)d_grads, (char*)d_segment_grads_ptr, segments_num, - grad_dim, grad_value_size, merger_, stream); + d_segments_keys_ptr, d_segments_fea_num_offset, d_segments_fea_num_info, + d_index, (char*)d_grads, (char*)d_segment_grads_ptr, segments_num, + grad_dim, grad_value_size, merger_, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_keys, d_segments_keys_ptr, - sizeof(KeyType) * segments_num, - cudaMemcpyDeviceToDevice, stream)); + sizeof(KeyType) * segments_num, + cudaMemcpyDeviceToDevice, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_segment_grads_ptr, - grad_value_size * segments_num, - cudaMemcpyDeviceToDevice, stream)); + grad_value_size * segments_num, + cudaMemcpyDeviceToDevice, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); } @@ -915,20 +920,14 @@ void HeterComm::merge_keys( PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); heter_comm_kernel_->fill_restore_idx( - d_index, d_offset, d_fea_num_info_ptr, d_merged_keys, uniq_len, - d_restore_idx, stream); + true, len, uniq_len, d_merged_keys, d_index, d_offset, + d_fea_num_info_ptr, d_restore_idx, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); } template -void HeterComm::pull_sparse(int num, - KeyType* d_keys, - float* d_vals, - size_t len) { - if (len == 0) { - return; - } - +void HeterComm::pull_merge_sparse( + int num, KeyType* d_keys, float* d_vals, size_t len) { int total_device = resource_->total_device(); int dev_id = resource_->dev_id(num); DevPlace place = DevPlace(dev_id); @@ -964,7 +963,7 @@ void HeterComm::pull_sparse(int num, XPUAPIErrorMsg[r2])); #endif - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + size_t val_type_size = feature_value_accessor_.common_pull_value.Size(max_mf_dim_); VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; auto d_sorted_keys = memory::Alloc(place, len * sizeof(KeyType)); auto d_sorted_keys_ptr = reinterpret_cast(d_sorted_keys->ptr()); @@ -1076,6 +1075,144 @@ void HeterComm::pull_sparse(int num, } } } +template +void HeterComm::pull_normal_sparse( + int num, KeyType* d_keys, float* d_vals, size_t len) { + int total_device = resource_->total_device(); + int dev_id = resource_->dev_id(num); + DevPlace place = DevPlace(dev_id); + AnyDeviceGuard guard(dev_id); + auto stream = resource_->local_stream(num, 0); + + int h_left[total_device]; // NOLINT + int h_right[total_device]; // NOLINT + + auto d_left = memory::Alloc(place, total_device * sizeof(int)); + auto d_right = memory::Alloc(place, total_device * sizeof(int)); + int* d_left_ptr = reinterpret_cast(d_left->ptr()); + int* d_right_ptr = reinterpret_cast(d_right->ptr()); + +#if defined(PADDLE_WITH_CUDA) + cudaMemsetAsync(d_left_ptr, -1, total_device * sizeof(int), stream); + cudaMemsetAsync(d_right_ptr, -1, total_device * sizeof(int), stream); + +#elif defined(PADDLE_WITH_XPU_KP) + // get XPUDeviceContext according to xpu place + paddle::platform::XPUDeviceContext xpu_dev_ctx(place); + auto xpu_context = xpu_dev_ctx.x_context(); + + int r = xpu::constant(xpu_context, d_left_ptr, total_device, -1); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + int r2 = xpu::constant(xpu_context, d_right_ptr, total_device, -1); + PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r2, + XPUAPIErrorMsg[r2])); +#endif + + auto d_idx = memory::Alloc(place, len * sizeof(int)); + int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); + + size_t val_type_size = feature_value_accessor_.common_pull_value.Size(max_mf_dim_); + VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; + auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); + KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_vals = memory::Alloc(place, len * val_type_size); + float* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); + + split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); + + heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, d_keys, d_idx_ptr, len, + stream); + + sync_stream(stream); + + auto dst_place = platform::CPUPlace(); + auto src_place = place; + + memory_copy(dst_place, h_left, src_place, d_left_ptr, + total_device * sizeof(int), stream); + memory_copy(dst_place, h_right, src_place, d_right_ptr, + total_device * sizeof(int), stream); + + if (!FLAGS_gpugraph_enable_gpu_direct_access) { + for (int i = 0; i < total_device; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + create_storage(num, i, shard_len * sizeof(KeyType), + shard_len * val_type_size); + } + walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL); + } + for (int i = 0; i < total_device; ++i) { + if (h_left[i] == -1) { + continue; + } + auto& node = path_[num][i].nodes_.back(); + if (!FLAGS_gpugraph_enable_gpu_direct_access) { + sync_stream(node.in_stream); + } + AnyDeviceGuard guard(resource_->dev_id(i)); + ptr_tables_[i]->rwlock_->RDLock(); + if (!FLAGS_gpugraph_enable_gpu_direct_access) { + ptr_tables_[i]->get(reinterpret_cast(node.key_storage), + node.val_storage, h_right[i] - h_left[i] + 1, + resource_->remote_stream(i, num)); + } else { + ptr_tables_[i]->get( + d_shard_keys_ptr + h_left[i], + reinterpret_cast(d_shard_vals_ptr) + h_left[i] * val_type_size, + h_right[i] - h_left[i] + 1, resource_->remote_stream(i, num)); + } + } + + for (int i = 0; i < total_device; ++i) { + sync_stream(resource_->remote_stream(i, num)); + if (h_left[i] == -1) { + continue; + } + ptr_tables_[i]->rwlock_->UNLock(); + } + if (!FLAGS_gpugraph_enable_gpu_direct_access) { + walk_to_src(num, total_device, h_left, h_right, + reinterpret_cast(d_shard_vals_ptr), val_type_size); + for (int i = 0; i < total_device; ++i) { + auto& node = path_[num][i].nodes_.front(); + sync_stream(node.out_stream); + } + } + + heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, + val_type_size, stream); + + sync_stream(stream); + if (!FLAGS_gpugraph_enable_gpu_direct_access) { + for (int i = 0; i < total_device; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + destroy_storage(num, i); + } + } +} + +template +void HeterComm::pull_sparse( + int num, KeyType* d_keys, float* d_vals, size_t len) { + if (len == 0) { + return; + } + if (!FLAGS_gpugraph_dedup_pull_push_mode) { + pull_merge_sparse(num, d_keys, d_vals, len); + } else { + pull_normal_sparse(num, d_keys, d_vals, len); + } +} #if defined(PADDLE_WITH_CUDA) template @@ -1093,7 +1230,7 @@ void HeterComm::push_sparse(int dev_num, int dev_id = resource_->dev_id(dev_num); size_t grad_value_size = - TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(dev_num, 0); @@ -1138,27 +1275,32 @@ void HeterComm::push_sparse(int dev_num, d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); int uniq_len = len; - size_t segment_len = 0; - if (FLAGS_gpugraph_enable_segment_merge_grads) { - // do two gradient merge - // 1st. do segmented gradient merge - // 2nd. do global gradient merge - dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, true); - len = segment_len; - uniq_len = 0; - segment_len = 0; - dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, false); - } else { - // Perform gradient merge only once - dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, false); + if (!FLAGS_gpugraph_dedup_pull_push_mode) { + size_t segment_len = 0; + if (FLAGS_gpugraph_enable_segment_merge_grads) { + // do two gradient merge + // 1st. do segmented gradient merge + // 2nd. do global gradient merge + dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, + true); + len = segment_len; + uniq_len = 0; + segment_len = 0; + dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, + false); + } else { + // Perform gradient merge only once + dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, + false); + } } split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, dev_num); - + heter_comm_kernel_->dy_mf_fill_shard_grads( - d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, - uniq_len, grad_value_size, stream); + d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, uniq_len, + grad_value_size, stream); sync_stream(stream); @@ -1176,11 +1318,11 @@ void HeterComm::push_sparse(int dev_num, continue; } create_storage(dev_num, i, shard_len * sizeof(KeyType), - shard_len * grad_value_size); + shard_len * grad_value_size); } walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr, - reinterpret_cast(d_shard_grads_ptr), grad_value_size); + reinterpret_cast(d_shard_grads_ptr), grad_value_size); } for (int i = 0; i < total_device; ++i) { @@ -1196,14 +1338,14 @@ void HeterComm::push_sparse(int dev_num, ptr_tables_[i]->rwlock_->WRLock(); if (!FLAGS_gpugraph_enable_gpu_direct_access) { ptr_tables_[i]->update(reinterpret_cast(node.key_storage), - node.val_storage, h_right[i] - h_left[i] + 1, - sgd, resource_->remote_stream(i, dev_num)); + node.val_storage, h_right[i] - h_left[i] + 1, sgd, + resource_->remote_stream(i, dev_num)); } else { ptr_tables_[i]->update(d_shard_keys_ptr + h_left[i], - reinterpret_cast(d_shard_grads_ptr) + - grad_value_size * h_left[i], - h_right[i] - h_left[i] + 1, sgd, - resource_->remote_stream(i, dev_num)); + reinterpret_cast(d_shard_grads_ptr) + + grad_value_size * h_left[i], + h_right[i] - h_left[i] + 1, sgd, + resource_->remote_stream(i, dev_num)); } } @@ -1217,7 +1359,7 @@ void HeterComm::push_sparse(int dev_num, } } } - + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { @@ -1555,6 +1697,84 @@ void HeterComm::end_pass() { } } +#if defined(PADDLE_WITH_CUDA) +template +int HeterComm::dedup_keys_and_fillidx( + const int gpu_id, + const int total_fea_num, + const KeyType* d_keys, // input + KeyType* d_merged_keys, // output + KeyType* d_sorted_keys, + uint32_t* d_restore_idx, + uint32_t* d_sorted_idx, + uint32_t* d_offset, + uint32_t* d_merged_cnts, + bool filter_zero) { + int dev_id = resource_->dev_id(gpu_id); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->local_stream(gpu_id, 0); + + assert(total_fea_num > 0); + int merged_size = 0; + size_t byte_size = sizeof(uint32_t) * (total_fea_num + 1); + + auto d_index_ptr = memory::Alloc(place, byte_size); + uint32_t* d_index_in = reinterpret_cast(d_index_ptr->ptr()); + int* d_merged_size = reinterpret_cast(&d_index_in[total_fea_num]); + + heter_comm_kernel_->fill_idx(d_index_in, total_fea_num, stream); + + void* d_buf = NULL; + size_t temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + NULL, temp_storage_bytes, d_keys, d_sorted_keys, d_index_in, d_sorted_idx, + total_fea_num, 0, 8 * sizeof(KeyType), stream, false)); + auto d_cache_ptr = memory::Alloc(place, temp_storage_bytes); + d_buf = reinterpret_cast(d_cache_ptr->ptr()); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + d_buf, temp_storage_bytes, d_keys, d_sorted_keys, d_index_in, + d_sorted_idx, total_fea_num, 0, 8 * sizeof(KeyType), stream, false)); + + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( + NULL, temp_storage_bytes, d_sorted_keys, d_merged_keys, d_merged_cnts, + d_merged_size, total_fea_num, stream)); + if (d_cache_ptr->size() < temp_storage_bytes) { + d_cache_ptr = NULL; + d_cache_ptr = memory::Alloc(place, temp_storage_bytes); + } + d_buf = reinterpret_cast(d_cache_ptr->ptr()); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( + d_buf, temp_storage_bytes, d_sorted_keys, d_merged_keys, d_merged_cnts, + d_merged_size, total_fea_num, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync((void*)&merged_size, + (void*)d_merged_size, sizeof(int), + cudaMemcpyDeviceToHost, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + NULL, temp_storage_bytes, d_merged_cnts, d_offset, merged_size, stream)); + if (d_cache_ptr->size() < temp_storage_bytes) { + d_cache_ptr = NULL; + d_cache_ptr = memory::Alloc(place, temp_storage_bytes); + } + d_buf = reinterpret_cast(d_cache_ptr->ptr()); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + d_buf, temp_storage_bytes, d_merged_cnts, d_offset, merged_size, stream)); + + if (filter_zero) { + cudaMemsetAsync(d_restore_idx, 0, total_fea_num * sizeof(uint32_t), stream); + } + // fill restore idx [1,3,5,2,4,6] = [1,2,1,3,2,1] + heter_comm_kernel_->fill_restore_idx(filter_zero, + total_fea_num, merged_size, d_merged_keys, d_sorted_idx, + d_offset, d_merged_cnts, d_restore_idx, stream); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + return merged_size; +} +#endif // template // void HeterComm::dump_to_cpu(int index) { // auto stream = resource_->local_stream(index, 0); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index abb5cff60f5f8..8d4f2625be70d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -117,44 +117,12 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, } } -template -__global__ void dy_mf_fill_shard_grads_kernel( - KeyType* d_shard_keys, KeyType* d_keys, float* d_shard_grads, - float* d_grads, T* idx, size_t len, size_t grad_value_size, - CommonFeatureValueAccessor feature_value_accessor) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len) { - d_shard_keys[i] = d_keys[idx[i]]; - float* cur = (float*)((char*)d_shard_grads + i * grad_value_size); - float* shard_val = (float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); - - cur[feature_value_accessor.common_push_value.SlotIndex()] = - shard_val[feature_value_accessor.common_push_value.SlotIndex()]; - cur[feature_value_accessor.common_push_value.ShowIndex()] = - shard_val[feature_value_accessor.common_push_value.ShowIndex()]; - cur[feature_value_accessor.common_push_value.ClickIndex()] = - shard_val[feature_value_accessor.common_push_value.ClickIndex()]; - cur[feature_value_accessor.common_push_value.MfDimIndex()] = - shard_val[feature_value_accessor.common_push_value.MfDimIndex()]; - cur[feature_value_accessor.common_push_value.EmbedGIndex()] = - shard_val[feature_value_accessor.common_push_value.EmbedGIndex()]; - - for (int x = 0; x < int(shard_val[feature_value_accessor.common_push_value.MfDimIndex()]); x++) { - cur[feature_value_accessor.common_push_value.EmbedxGIndex() + x] = - shard_val[feature_value_accessor.common_push_value.EmbedxGIndex() + x]; - } - } -} - template -__global__ void merge_gradients_basic_kernel(const KeyType* d_keys, - const uint32_t* offset, - const uint32_t* fea_num, - const uint32_t* index, const char* input, - char* output, int n, - size_t grad_value_size, - DynamicGradMerger& merger, - CommonFeatureValueAccessor& feature_value_accessor) { +__global__ void merge_gradients_basic_kernel( + const KeyType* d_keys, const uint32_t* offset, const uint32_t* fea_num, + const uint32_t* index, const char* input, char* output, int n, + size_t grad_value_size, DynamicGradMerger& merger, + CommonFeatureValueAccessor& feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { @@ -162,8 +130,7 @@ __global__ void merge_gradients_basic_kernel(const KeyType* d_keys, uint32_t num = fea_num[i]; int ori_index = index[start]; float* out = (float*)(output + i * grad_value_size); - float* in = - (float*)(input + size_t(ori_index) * grad_value_size); + float* in = (float*)(input + size_t(ori_index) * grad_value_size); merger.update_basic(out, in, feature_value_accessor); KeyType key = d_keys[i]; if (key != 0) { @@ -177,15 +144,11 @@ __global__ void merge_gradients_basic_kernel(const KeyType* d_keys, } template -__global__ void merge_gradients_embedx_kernel(const KeyType* d_keys, - const uint32_t* offset, - const uint32_t* fea_num, - const uint32_t* index, const char* input, - char* output, int n, - size_t grad_dim, - size_t grad_value_size, - DynamicGradMerger& merger, - CommonFeatureValueAccessor& feature_value_accessor) { +__global__ void merge_gradients_embedx_kernel( + const KeyType* d_keys, const uint32_t* offset, const uint32_t* fea_num, + const uint32_t* index, const char* input, char* output, int n, + size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger, + CommonFeatureValueAccessor& feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { @@ -208,10 +171,10 @@ __global__ void merge_gradients_embedx_kernel(const KeyType* d_keys, } } -__global__ void split_segments_kernel( - const uint32_t* d_fea_num_info, size_t n, - uint32_t* d_segments, uint32_t* d_segments_num, - uint32_t segment_size) { +__global__ void split_segments_kernel(const uint32_t* d_fea_num_info, size_t n, + uint32_t* d_segments, + uint32_t* d_segments_num, + uint32_t segment_size) { const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; if (tx >= n) { return; @@ -222,10 +185,11 @@ __global__ void split_segments_kernel( d_segments[tx] = seg_num; } -__global__ void expand_segments_kernel( - const uint32_t* d_fea_num_info, - const uint32_t* d_segments_offset, size_t n, - uint32_t* d_segments_fea_num_info, uint32_t segment_size) { +__global__ void expand_segments_kernel(const uint32_t* d_fea_num_info, + const uint32_t* d_segments_offset, + size_t n, + uint32_t* d_segments_fea_num_info, + uint32_t segment_size) { const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; if (tx >= n) { return; @@ -248,9 +212,9 @@ __global__ void expand_segments_kernel( } template -__global__ void shrink_keys_kernel( - const KeyType* d_keys, const uint32_t* d_segments_offset, - KeyType* d_segments_keys, size_t n) { +__global__ void shrink_keys_kernel(const KeyType* d_keys, + const uint32_t* d_segments_offset, + KeyType* d_segments_keys, size_t n) { const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; if (tx >= n) { return; @@ -259,38 +223,12 @@ __global__ void shrink_keys_kernel( d_segments_keys[tx] = d_keys[d_segments_offset[tx]]; } -template -__global__ void fill_restore_idx_kernel( - const T *d_sorted_idx, - const T *d_offset, - const T *d_merged_cnts, - const KeyType *d_merged_keys, - T *d_restore_idx, - size_t n) { - const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; - if (tx >= n) { - return; - } - - const KeyType & key = d_merged_keys[tx]; - if (key == 0) { - return; - } - - const T &off = d_offset[tx]; - const T &num = d_merged_cnts[tx]; - for (size_t k = 0; k < num; ++k) { - d_restore_idx[d_sorted_idx[off + k]] = tx; - } -} - template __global__ void unpack_merged_vals_kernel( const KeyType* d_keys, const float* d_merged_vals, const uint32_t* d_restored_idx, - float* d_out, size_t val_size, const size_t n, - CommonFeatureValueAccessor feature_value_accessor) { + float* d_out, size_t val_size, const size_t n) { const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; if (tx >= n) { return; @@ -305,69 +243,33 @@ __global__ void unpack_merged_vals_kernel( uint64_t dst_offset = uint64_t(tx) * val_size; float* dst = (float*)((char*)d_out + dst_offset); float* src_val = (float*)((char*)d_merged_vals + uint64_t(src_val_idx) * val_size); - int mf_dim = int(src_val[feature_value_accessor.common_feature_value.MfDimIndex()]); - - *(reinterpret_cast(dst + feature_value_accessor.common_feature_value.CpuPtrIndex())) = - *(reinterpret_cast(src_val + feature_value_accessor.common_feature_value.CpuPtrIndex())); - dst[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - src_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; - dst[feature_value_accessor.common_feature_value.ShowIndex()] = - src_val[feature_value_accessor.common_feature_value.ShowIndex()]; - dst[feature_value_accessor.common_feature_value.ClickIndex()] = - src_val[feature_value_accessor.common_feature_value.ClickIndex()]; - dst[feature_value_accessor.common_feature_value.EmbedWIndex()] = - src_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; - for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { - dst[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = - src_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; + + size_t n_float = val_size / sizeof(float); + for (size_t k = 0; k < n_float; ++k) { + dst[k] = src_val[k]; } - dst[feature_value_accessor.common_feature_value.SlotIndex()] = - src_val[feature_value_accessor.common_feature_value.SlotIndex()]; - dst[feature_value_accessor.common_feature_value.MfDimIndex()] = mf_dim; - dst[feature_value_accessor.common_feature_value.MfSizeIndex()] = - src_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; - - for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); - x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - dst[x] = src_val[x]; +} + +template +__global__ void scatter_dvals_by_unit_kernel(TUnit* d_dest_vals, + const TUnit* d_src_vals, T* idx, + size_t len, size_t val_size_unit) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + size_t pos = idx[i / val_size_unit] * val_size_unit + (i % val_size_unit); + d_dest_vals[i] = d_src_vals[pos]; } } -template -__global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals, - T* idx, size_t len, size_t val_size, - CommonFeatureValueAccessor feature_value_accessor) { +template +__global__ void gather_dvals_by_unit_kernel(TUnit* d_dest_vals, + const TUnit* d_src_vals, T* idx, + size_t len, + const size_t val_size_unit) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { - uint64_t new_offset = uint64_t(idx[i]) * val_size; - float* cur = (float*)((char*)d_vals + new_offset); - float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); - int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); - - *(reinterpret_cast(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) = - *(reinterpret_cast(shard_val + feature_value_accessor.common_feature_value.CpuPtrIndex())); - cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - shard_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; - cur[feature_value_accessor.common_feature_value.ShowIndex()] = - shard_val[feature_value_accessor.common_feature_value.ShowIndex()]; - cur[feature_value_accessor.common_feature_value.ClickIndex()] = - shard_val[feature_value_accessor.common_feature_value.ClickIndex()]; - cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = - shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; - for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { - cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = - shard_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; - } - cur[feature_value_accessor.common_feature_value.SlotIndex()] = - shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = mf_dim; - cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = - shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; - - for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); - x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - cur[x] = shard_val[x]; - } + size_t pos = idx[i / val_size_unit] * val_size_unit + (i % val_size_unit); + d_dest_vals[pos] = d_src_vals[i]; } } @@ -468,26 +370,34 @@ void HeterCommKernel::dy_mf_fill_shard_grads( const StreamType& stream) { int grid_size = (len - 1) / block_size_ + 1; size_t c_len = (size_t)len; - dy_mf_fill_shard_grads_kernel<<>>( - d_shard_keys, d_keys, d_shard_grads, d_grads, idx, c_len, - grad_value_size, feature_value_accessor_); + + const size_t grad_value_size_float = grad_value_size / sizeof(float); + // d_keys to d_shard_keys + fill_shard_key_kernel<<>>( + d_shard_keys, d_keys, idx, c_len); + + CHECK((grad_value_size % sizeof(float)) == 0); + size_t N = len * grad_value_size_float; + grid_size = (N - 1) / block_size_ + 1; + scatter_dvals_by_unit_kernel<<>>( + d_shard_grads, d_grads, idx, N, grad_value_size_float); } template void HeterCommKernel::merge_gradient( - const KeyType* d_keys, - const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, - const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size, - DynamicGradMerger& merger, const StreamType& stream) { + const KeyType* d_keys, const uint32_t* offset, const uint32_t* fea_num, + const uint32_t* index, const char* input, char* output, int n, + size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger, + const StreamType& stream) { int grid_size1 = (n - 1) / block_size_ + 1; merge_gradients_basic_kernel<<>>( - d_keys, - offset, fea_num, index, input, output, n, grad_value_size, merger, feature_value_accessor_); + d_keys, offset, fea_num, index, input, output, n, grad_value_size, merger, + feature_value_accessor_); if (grad_dim > 0) { int grid_size2 = (n * grad_dim - 1) / block_size_ + 1; merge_gradients_embedx_kernel<<>>( - d_keys, - offset, fea_num, index, input, output, n * grad_dim, grad_dim, grad_value_size, merger, feature_value_accessor_); + d_keys, offset, fea_num, index, input, output, n * grad_dim, grad_dim, + grad_value_size, merger, feature_value_accessor_); } } @@ -495,30 +405,37 @@ template void HeterCommKernel::dy_mf_fill_dvals(float* d_shard_vals, float* d_vals, T* idx, long long len, size_t val_size, const StreamType& stream) { - int grid_size = (len - 1) / block_size_ + 1; - size_t c_len = (size_t)len; - dy_mf_fill_dvals_kernel<<>>( - d_shard_vals, d_vals, idx, c_len, val_size, feature_value_accessor_); + const size_t val_size_float = val_size / sizeof(float); + CHECK((val_size % sizeof(float)) == 0); + size_t N = len * val_size_float; + const int grid_size = (N - 1) / block_size_ + 1; + // fill by float, d_shard_vals to d_vals + gather_dvals_by_unit_kernel<<>>( + d_vals, d_shard_vals, idx, N, val_size_float); } template void HeterCommKernel::split_segments(const uint32_t* d_fea_num_info, size_t n, - uint32_t* d_segments, uint32_t* d_segments_num, size_t segment_size, const StreamType& stream) { + uint32_t* d_segments, + uint32_t* d_segments_num, + size_t segment_size, + const StreamType& stream) { int grid_size = (n - 1) / block_size_ + 1; split_segments_kernel<<>>( - d_fea_num_info, n, d_segments, d_segments_num, segment_size); + d_fea_num_info, n, d_segments, d_segments_num, segment_size); } template void HeterCommKernel::expand_segments(const uint32_t* d_fea_num_info, - const uint32_t* d_segments_offset, size_t n, - uint32_t* d_segments_fea_num_info, uint32_t segment_size, - const StreamType& stream) { + const uint32_t* d_segments_offset, + size_t n, + uint32_t* d_segments_fea_num_info, + uint32_t segment_size, + const StreamType& stream) { int grid_size = (n - 1) / block_size_ + 1; expand_segments_kernel<<>>( - d_fea_num_info, - d_segments_offset, n, - d_segments_fea_num_info, segment_size); + d_fea_num_info, d_segments_offset, n, d_segments_fea_num_info, + segment_size); } template @@ -526,19 +443,85 @@ void HeterCommKernel::shrink_keys(const KeyType* d_keys, const uint32_t* d_segme KeyType* d_segments_keys, size_t n, const StreamType& stream) { int grid_size = (n - 1) / block_size_ + 1; shrink_keys_kernel<<>>( - d_keys, d_segments_offset, d_segments_keys, n); + d_keys, d_segments_offset, d_segments_keys, n); +} +template +__global__ void kernel_fill_restore_idx( + const size_t N, const T* d_sorted_idx, + const T* d_offset, const T* d_merged_cnts, T* d_restore_idx) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < N) { + const T& off = d_offset[i]; + const T& num = d_merged_cnts[i]; + for (size_t k = 0; k < num; ++k) { + d_restore_idx[d_sorted_idx[off + k]] = i; + } + } +} +template +__global__ void kernel_fill_restore_idx_filter_zero( + const size_t N, const KeyType *d_keys, const T* d_sorted_idx, + const T* d_offset, const T* d_merged_cnts, T* d_restore_idx) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < N) { + if (d_keys[i] == 0) { + return; + } + const T& off = d_offset[i]; + const T& num = d_merged_cnts[i]; + for (size_t k = 0; k < num; ++k) { + d_restore_idx[d_sorted_idx[off + k]] = i; + } + } +} +template +__global__ void kernel_fill_restore_idx_by_search( + const size_t N, const T* d_sorted_idx, + const size_t merge_num, const T* d_offset, T* d_restore_idx) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < N) { + if (i < d_offset[1]) { + d_restore_idx[d_sorted_idx[i]] = 0; + return; + } + int high = merge_num - 1; + int low = 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < d_offset[mid + 1]) { + high = mid; + } else { + low = mid + 1; + } + } + d_restore_idx[d_sorted_idx[i]] = low; + } } - template void HeterCommKernel::fill_restore_idx( - const uint32_t* d_sorted_idx, const uint32_t* d_offset, - const uint32_t* d_merged_cnts, const KeyType* d_merged_keys, - const size_t n, uint32_t *d_restore_idx, const StreamType& stream) { - int grid_size = (n - 1) / block_size_ + 1; - fill_restore_idx_kernel<<>>( - d_sorted_idx, d_offset, d_merged_cnts, d_merged_keys, d_restore_idx, n); + bool filter_zero, const size_t total_num, + const size_t merge_size, const KeyType *d_keys, + const uint32_t* d_sorted_idx, + const uint32_t* d_offset, const uint32_t* d_merged_cnts, + uint32_t* d_restore_idx, const StreamType& stream) { + // fill restore idx [1,3,5,2,4,6] = [1,2,1,3,2,1] + if (merge_size * 3 > total_num) { + // repetition rate is not very high + size_t grid_size = (merge_size - 1) / block_size_ + 1; + if (filter_zero) { + kernel_fill_restore_idx_filter_zero<<>>( + merge_size, d_keys, d_sorted_idx, d_offset, d_merged_cnts, d_restore_idx); + } else { + kernel_fill_restore_idx<<>>( + merge_size, d_sorted_idx, d_offset, d_merged_cnts, d_restore_idx); + } + } else { + size_t grid_size = (total_num - 1) / block_size_ + 1; + // mid search + kernel_fill_restore_idx_by_search<<>>( + total_num, d_sorted_idx, merge_size, d_offset, d_restore_idx); + } } - template void HeterCommKernel::unpack_merged_vals(size_t n, const KeyType* d_keys, const void* d_merged_vals, const uint32_t* d_restore_idx, @@ -546,7 +529,7 @@ void HeterCommKernel::unpack_merged_vals(size_t n, const KeyType* d_keys, int grid_size = (n - 1) / block_size_ + 1; unpack_merged_vals_kernel<<>>( d_keys, (const float *)d_merged_vals, d_restore_idx, - (float *)d_vals, val_size, n, feature_value_accessor_); + (float *)d_vals, val_size, n); } template void HeterCommKernel::fill_idx( @@ -557,10 +540,10 @@ template void HeterCommKernel::fill_idx( template void HeterCommKernel::calc_shard_offset( int* idx, int* left, int* right, long long len, int total_devs, const cudaStream_t& stream); -template void HeterCommKernel::calc_shard_index< - unsigned long, int, cudaStream_t>(unsigned long* d_keys, long long len, - int* shard_index, int total_devs, - const cudaStream_t& stream); +template void +HeterCommKernel::calc_shard_index( + unsigned long* d_keys, long long len, int* shard_index, int total_devs, + const cudaStream_t& stream); template void HeterCommKernel::calc_shard_index( long* d_keys, long long len, int* shard_index, int total_devs, @@ -574,12 +557,10 @@ template void HeterCommKernel::fill_shard_key( unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len, const cudaStream_t& stream); -template void HeterCommKernel::fill_shard_grads< - unsigned long, float, int, cudaStream_t>( - unsigned long* d_shard_keys, unsigned long* d_keys, - float* d_shard_grads, - float* d_grads, int* idx, long long len, - const cudaStream_t& stream); +template void +HeterCommKernel::fill_shard_grads( + unsigned long* d_shard_keys, unsigned long* d_keys, float* d_shard_grads, + float* d_grads, int* idx, long long len, const cudaStream_t& stream); template void HeterCommKernel::fill_dvals( @@ -614,57 +595,54 @@ template void HeterCommKernel::reduce_by_key< paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out, int num_items, cudaStream_t stream, bool debug_synchronous); -template void HeterCommKernel::dy_mf_fill_shard_grads< - unsigned long, int, cudaStream_t>( - unsigned long* d_shard_keys, unsigned long* d_keys, - float* d_shard_grads, float* d_grads, int* idx, long long len, - size_t grad_value_size, const cudaStream_t& stream); +template void +HeterCommKernel::dy_mf_fill_shard_grads( + unsigned long* d_shard_keys, unsigned long* d_keys, float* d_shard_grads, + float* d_grads, int* idx, long long len, size_t grad_value_size, + const cudaStream_t& stream); template void HeterCommKernel::merge_gradient( - const uint32_t* d_keys, - const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, - const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size, - DynamicGradMerger& merger_, const cudaStream_t& stream); + const uint32_t* d_keys, const uint32_t* offset, const uint32_t* fea_num, + const uint32_t* index, const char* input, char* output, int n, + size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger_, + const cudaStream_t& stream); template void HeterCommKernel::merge_gradient( - const uint64_t* d_keys, - const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, - const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size, - DynamicGradMerger& merger_, const cudaStream_t& stream); + const uint64_t* d_keys, const uint32_t* offset, const uint32_t* fea_num, + const uint32_t* index, const char* input, char* output, int n, + size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger_, + const cudaStream_t& stream); template void HeterCommKernel::dy_mf_fill_dvals( - float* d_shard_vals, - float* d_vals, int* idx, long long len, + float* d_shard_vals, float* d_vals, int* idx, long long len, size_t val_size, const cudaStream_t& stream); template void HeterCommKernel::split_segments( - const uint32_t* d_fea_num_info, size_t n, - uint32_t* d_segment, uint32_t* d_segments_num, size_t segment_size, - const cudaStream_t& stream); + const uint32_t* d_fea_num_info, size_t n, uint32_t* d_segment, + uint32_t* d_segments_num, size_t segment_size, const cudaStream_t& stream); template void HeterCommKernel::expand_segments( - const uint32_t* d_fea_num_info, - const uint32_t* d_segments_offset, size_t n, + const uint32_t* d_fea_num_info, const uint32_t* d_segments_offset, size_t n, uint32_t* d_segments_fea_num_info, uint32_t segment_size, const cudaStream_t& stream); template void HeterCommKernel::shrink_keys( - const uint32_t* d_keys, const uint32_t* d_segments_offset, - uint32_t* d_segments_keys, size_t segment_num, const cudaStream_t& stream); + const uint32_t* d_keys, const uint32_t* d_segments_offset, + uint32_t* d_segments_keys, size_t segment_num, const cudaStream_t& stream); template void HeterCommKernel::shrink_keys( - const uint64_t* d_keys, const uint32_t* d_segments, - uint64_t* d_segments_keys, size_t total_segment_num, const cudaStream_t& stream); + const uint64_t* d_keys, const uint32_t* d_segments, + uint64_t* d_segments_keys, size_t total_segment_num, const cudaStream_t& stream); template void HeterCommKernel::fill_restore_idx( - const uint32_t* d_sorted_idx, const uint32_t* d_offset, - const uint32_t* d_merged_cnts, const uint64_t* d_merged_keys, - const size_t n, uint32_t* d_restore_idx, const cudaStream_t& stream); + bool filter_zero, const size_t total_num, const size_t merge_size, const uint64_t *d_keys, + const uint32_t* d_sorted_idx, const uint32_t* d_offset, const uint32_t* d_merged_cnts, + uint32_t* d_restore_idx, const cudaStream_t& stream); template void HeterCommKernel::fill_restore_idx( - const uint32_t* d_sorted_idx, const uint32_t* d_offset, - const uint32_t* d_merged_cnts, const uint32_t* d_merged_keys, - const size_t n, uint32_t* d_restore_idx, const cudaStream_t& stream); + bool filter_zero, const size_t total_num, const size_t merge_size, const uint32_t *d_keys, + const uint32_t* d_sorted_idx, const uint32_t* d_offset, const uint32_t* d_merged_cnts, + uint32_t* d_restore_idx, const cudaStream_t& stream); template void HeterCommKernel::unpack_merged_vals( size_t n, const uint64_t* d_keys, const void* d_merged_vals, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 1cde86e64b6bc..05d93b1d8bcc0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -41,40 +41,47 @@ struct DynamicGradMerger { return out; } - __device__ __forceinline__ void update_one(float* output, const float* input, - CommonFeatureValueAccessor& feature_value_accessor) { - output[feature_value_accessor.common_push_value.SlotIndex()] = + __device__ __forceinline__ void update_one( + float* output, const float* input, + CommonFeatureValueAccessor& feature_value_accessor) { + output[feature_value_accessor.common_push_value.SlotIndex()] = input[feature_value_accessor.common_push_value.SlotIndex()]; - output[feature_value_accessor.common_push_value.ShowIndex()] = + output[feature_value_accessor.common_push_value.ShowIndex()] = input[feature_value_accessor.common_push_value.ShowIndex()]; - output[feature_value_accessor.common_push_value.ClickIndex()] = + output[feature_value_accessor.common_push_value.ClickIndex()] = input[feature_value_accessor.common_push_value.ClickIndex()]; - output[feature_value_accessor.common_push_value.MfDimIndex()] = + output[feature_value_accessor.common_push_value.MfDimIndex()] = input[feature_value_accessor.common_push_value.MfDimIndex()]; - output[feature_value_accessor.common_push_value.EmbedGIndex()] = + output[feature_value_accessor.common_push_value.EmbedGIndex()] = input[feature_value_accessor.common_push_value.EmbedGIndex()]; - for (int j = 0; j < int(output[feature_value_accessor.common_push_value.MfDimIndex()]); j++) { - output[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = + for (int j = 0; + j < int(output[feature_value_accessor.common_push_value.MfDimIndex()]); + j++) { + output[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = input[feature_value_accessor.common_push_value.EmbedxGIndex() + j]; } } - __device__ __forceinline__ void merge_one(float* output, const float* input, - CommonFeatureValueAccessor& feature_value_accessor) { - output[feature_value_accessor.common_push_value.ShowIndex()] += + __device__ __forceinline__ void merge_one( + float* output, const float* input, + CommonFeatureValueAccessor& feature_value_accessor) { + output[feature_value_accessor.common_push_value.ShowIndex()] += input[feature_value_accessor.common_push_value.ShowIndex()]; - output[feature_value_accessor.common_push_value.ClickIndex()] += + output[feature_value_accessor.common_push_value.ClickIndex()] += input[feature_value_accessor.common_push_value.ClickIndex()]; - output[feature_value_accessor.common_push_value.EmbedGIndex()] += + output[feature_value_accessor.common_push_value.EmbedGIndex()] += input[feature_value_accessor.common_push_value.EmbedGIndex()]; - for (int j = 0; j < int(output[feature_value_accessor.common_push_value.MfDimIndex()]); j++) { - output[feature_value_accessor.common_push_value.EmbedxGIndex() + j] += + for (int j = 0; + j < int(output[feature_value_accessor.common_push_value.MfDimIndex()]); + j++) { + output[feature_value_accessor.common_push_value.EmbedxGIndex() + j] += input[feature_value_accessor.common_push_value.EmbedxGIndex() + j]; } } - __device__ __forceinline__ void update_basic(float* output, const float* input, - CommonFeatureValueAccessor& fv_accessor) { + __device__ __forceinline__ void update_basic( + float* output, const float* input, + CommonFeatureValueAccessor& fv_accessor) { output[fv_accessor.common_push_value.SlotIndex()] = input[fv_accessor.common_push_value.SlotIndex()]; output[fv_accessor.common_push_value.ShowIndex()] = @@ -87,8 +94,9 @@ struct DynamicGradMerger { input[fv_accessor.common_push_value.EmbedGIndex()]; } - __device__ __forceinline__ void merge_basic(float* output, const float* input, - CommonFeatureValueAccessor& fv_accessor) { + __device__ __forceinline__ void merge_basic( + float* output, const float* input, + CommonFeatureValueAccessor& fv_accessor) { output[fv_accessor.common_push_value.ShowIndex()] += input[fv_accessor.common_push_value.ShowIndex()]; output[fv_accessor.common_push_value.ClickIndex()] += @@ -97,16 +105,18 @@ struct DynamicGradMerger { input[fv_accessor.common_push_value.EmbedGIndex()]; } - __device__ __forceinline__ void update_embedx(float* output, const float* input, size_t embedx_idx, - CommonFeatureValueAccessor& fv_accessor) { + __device__ __forceinline__ void update_embedx( + float* output, const float* input, size_t embedx_idx, + CommonFeatureValueAccessor& fv_accessor) { if (embedx_idx < output[fv_accessor.common_push_value.MfDimIndex()]) { output[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx] = input[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx]; } } - __device__ __forceinline__ void merge_embedx(float* output, const float* input, size_t embedx_idx, - CommonFeatureValueAccessor& fv_accessor) { + __device__ __forceinline__ void merge_embedx( + float* output, const float* input, size_t embedx_idx, + CommonFeatureValueAccessor& fv_accessor) { if (embedx_idx < output[fv_accessor.common_push_value.MfDimIndex()]) { output[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx] += input[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx]; @@ -119,7 +129,10 @@ class HeterCommKernel { HeterCommKernel() {} explicit HeterCommKernel(const int block_size) : block_size_(block_size) {} - explicit HeterCommKernel(const int block_size, CommonFeatureValueAccessor& feature_value_accessor) : block_size_(block_size), feature_value_accessor_(feature_value_accessor) {} + explicit HeterCommKernel(const int block_size, + CommonFeatureValueAccessor& feature_value_accessor) + : block_size_(block_size), + feature_value_accessor_(feature_value_accessor) {} template void fill_idx(T* idx, long long len, const StreamType& stream); @@ -171,14 +184,15 @@ class HeterCommKernel { template void dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, - float* d_shard_grads, float* d_grads, - T* idx, long long len, size_t grad_value_size, + float* d_shard_grads, float* d_grads, T* idx, + long long len, size_t grad_value_size, const StreamType& stream); template - void merge_gradient(const KeyType* d_shard_keys, const uint32_t* offset, const uint32_t* fea_num, - const uint32_t* index, const char* input, char* output, - int n, size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger, + void merge_gradient(const KeyType* d_shard_keys, const uint32_t* offset, + const uint32_t* fea_num, const uint32_t* index, + const char* input, char* output, int n, size_t grad_dim, + size_t grad_value_size, DynamicGradMerger& merger, const StreamType& stream); template @@ -187,24 +201,27 @@ class HeterCommKernel { const StreamType& stream); template - void split_segments(const uint32_t* d_fea_num_info, - size_t len, uint32_t* d_segments, uint32_t* d_segments_num, - size_t segment_size, const StreamType& stream); + void split_segments(const uint32_t* d_fea_num_info, size_t len, + uint32_t* d_segments, uint32_t* d_segments_num, + size_t segment_size, const StreamType& stream); template void expand_segments(const uint32_t* d_fea_num_info, - const uint32_t* d_segments_offset, size_t segments_num, - uint32_t* d_segments_fea_num_info, uint32_t segment_size, - const StreamType& stream); + const uint32_t* d_segments_offset, size_t segments_num, + uint32_t* d_segments_fea_num_info, uint32_t segment_size, + const StreamType& stream); template void shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset, - KeyType* d_segments_keys, size_t segments_num, const StreamType& stream); + KeyType* d_segments_keys, size_t segments_num, + const StreamType& stream); template - void fill_restore_idx(const uint32_t* d_sorted_idx, const uint32_t* d_offset, - const uint32_t* d_merged_cnts, const KeyType* d_merged_keys, - const size_t len, uint32_t* d_restore_idx, const StreamType& stream); + void fill_restore_idx(bool filter_zero, const size_t total_num, + const size_t merge_size, const KeyType *d_keys, + const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, uint32_t* d_restore_idx, + const StreamType& stream); template void unpack_merged_vals(size_t n, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index 037ec1415c7bd..b812250f012a0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include + #include "paddle/fluid/framework/fleet/heter_ps/heter_ps.h" #ifdef PADDLE_WITH_HETERPS @@ -22,17 +23,16 @@ namespace framework { HeterPsBase* HeterPsBase::get_instance( size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, - int optimizer_type) { - return new HeterPs(capacity, resource, feature_value_accessor, optimizer_type); + CommonFeatureValueAccessor feature_value_accessor, int optimizer_type) { + return new HeterPs(capacity, resource, feature_value_accessor, + optimizer_type); } -HeterPs::HeterPs(size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, - int optimizer_type) { - comm_ = - std::make_shared>( - capacity, resource, feature_value_accessor); +HeterPs::HeterPs(size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor feature_value_accessor, + int optimizer_type) { + comm_ = std::make_shared>( + capacity, resource, feature_value_accessor); feature_value_accessor_ = feature_value_accessor; optimizer_type_ = optimizer_type; } @@ -67,19 +67,22 @@ void HeterPs::end_pass() { comm_->end_pass(); } void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } -void HeterPs::push_sparse(int num, FeatureKey* d_keys, - float* d_grads, size_t len) { - if (optimizer_type_ == 3) { //adam +void HeterPs::push_sparse(int num, FeatureKey* d_keys, float* d_grads, + size_t len) { + if (optimizer_type_ == 3) { // adam auto optimizer = SparseAdamOptimizer(feature_value_accessor_); - VLOG(5) << "INTO push_sparse SparseAdamOptimizer, EmbedDim():" << optimizer.EmbedDim(); + VLOG(5) << "INTO push_sparse SparseAdamOptimizer, EmbedDim():" + << optimizer.EmbedDim(); comm_->push_sparse(num, d_keys, d_grads, len, optimizer); - } else if (optimizer_type_ == 4) { //shared_adam + } else if (optimizer_type_ == 4) { // shared_adam auto optimizer = SparseAdamSharedOptimizer(feature_value_accessor_); - VLOG(5) << "INTO push_sparse SparseAdamSharedOptimizer, EmbedDim():" << optimizer.EmbedDim(); + VLOG(5) << "INTO push_sparse SparseAdamSharedOptimizer, EmbedDim():" + << optimizer.EmbedDim(); comm_->push_sparse(num, d_keys, d_grads, len, optimizer); } else { auto optimizer = SparseAdagradOptimizer(feature_value_accessor_); - VLOG(5) << "INTO push_sparse SparseAdagradOptimizer, EmbedDim():" << optimizer.EmbedDim(); + VLOG(5) << "INTO push_sparse SparseAdagradOptimizer, EmbedDim():" + << optimizer.EmbedDim(); comm_->push_sparse(num, d_keys, d_grads, len, optimizer); } } @@ -98,8 +101,28 @@ void HeterPs::set_accessor(CommonFeatureValueAccessor& accessor) { comm_->set_accessor(accessor); } -void HeterPs::show_table_collisions() { - comm_->show_table_collisions(); +void HeterPs::show_table_collisions() { comm_->show_table_collisions(); } + +int HeterPs::dedup_keys_and_fillidx(const int gpu_id, + const int total_fea_num, + const FeatureKey* d_keys, // input + FeatureKey* d_merged_keys, // output + FeatureKey* d_sorted_keys, + uint32_t* d_restore_idx, + uint32_t* d_sorted_idx, + uint32_t* d_offset, + uint32_t* d_merged_cnts, + bool filter_zero) { + return comm_->dedup_keys_and_fillidx(gpu_id, + total_fea_num, + d_keys, // input + d_merged_keys, // output + d_sorted_keys, + d_restore_idx, + d_sorted_idx, + d_offset, + d_merged_cnts, + filter_zero); } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 89ec93f63db1c..caa069fb30613 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include + #include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" #if defined(PADDLE_WITH_CUDA) @@ -29,8 +30,8 @@ class HeterPs : public HeterPsBase { public: HeterPs() {} HeterPs(size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, - int optimizer_type); + CommonFeatureValueAccessor feature_value_accessor, + int optimizer_type); virtual ~HeterPs(); HeterPs(const HeterPs&) = delete; HeterPs& operator=(const HeterPs&) = delete; @@ -56,10 +57,21 @@ class HeterPs : public HeterPsBase { void end_pass() override; int get_index_by_devid(int devid) override; void show_one_table(int gpu_num) override; - void push_sparse(int num, FeatureKey* d_keys, float* d_grads, - size_t len); + void push_sparse(int num, FeatureKey* d_keys, float* d_grads, size_t len); void show_table_collisions() override; - +#if defined(PADDLE_WITH_CUDA) + // dedup + int dedup_keys_and_fillidx(const int gpu_id, + const int total_fea_num, + const FeatureKey* d_keys, // input + FeatureKey* d_merged_keys, // output + FeatureKey* d_sorted_keys, + uint32_t* d_restore_idx, + uint32_t* d_sorted_idx, + uint32_t* d_offset, + uint32_t* d_merged_cnts, + bool filter_zero); +#endif private: std::shared_ptr> comm_; #if defined(PADDLE_WITH_CUDA) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index aa74335b1a5e4..e5fe095f9b011 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include + #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" @@ -48,16 +49,28 @@ class HeterPsBase { virtual void end_pass() = 0; virtual void show_one_table(int gpu_num) = 0; virtual void show_table_collisions() = 0; - virtual void push_sparse(int num, FeatureKey* d_keys, - float* d_grads, size_t len) = 0; + virtual void push_sparse(int num, FeatureKey* d_keys, float* d_grads, + size_t len) = 0; virtual void set_sparse_sgd(const OptimizerConfig& optimizer_config) = 0; virtual void set_embedx_sgd(const OptimizerConfig& optimizer_config) = 0; - static HeterPsBase* get_instance(size_t capacity, - std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, - int optimizer_type); + static HeterPsBase* get_instance( + size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor feature_value_accessor, int optimizer_type); +#if defined(PADDLE_WITH_CUDA) + // dedup + virtual int dedup_keys_and_fillidx(const int gpu_id, + const int total_fea_num, + const FeatureKey* d_keys, // input + FeatureKey* d_merged_keys, // output + FeatureKey* d_sorted_keys, + uint32_t* d_restore_idx, + uint32_t* d_sorted_idx, + uint32_t* d_offset, + uint32_t* d_merged_cnts, + bool filter_zero) = 0; +#endif }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index db1e817b3f6a1..d09bb5f0badab 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -25,7 +25,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" @@ -33,12 +32,14 @@ limitations under the License. */ #include #include -#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/platform/timer.h" #if defined(PADDLE_WITH_PSCORE) #include "paddle/fluid/distributed/ps/table/depends/feature_value.h" #endif +DECLARE_int32(gpugraph_dedup_pull_push_mode); + namespace paddle { namespace framework { @@ -145,8 +146,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { remain = total_len % thread_keys_thread_num_; VLOG(0) << "total len: " << total_len; auto gen_dynamic_mf_func = [this]( - const std::deque& total_data, int begin_index, - int end_index, int i) { + const std::deque& total_data, + int begin_index, int end_index, int i) { for (auto iter = total_data.begin() + begin_index; iter != total_data.begin() + end_index; iter++) { const auto& ins = *iter; @@ -233,17 +234,17 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { this->thread_keys_[i][shard_id].insert(cur_key); } }; - auto gen_graph_dynamic_mf_func = [this]( - const std::vector& total_data, int begin_index, int end_index, - int i) { - for (auto iter = total_data.begin() + begin_index; - iter != total_data.begin() + end_index; iter++) { - uint64_t cur_key = *iter; - int shard_id = cur_key % thread_keys_shard_num_; - // TODO: feasign <-> slot <-> multi_dim - this->thread_dim_keys_[i][shard_id][0].insert(cur_key); - } - }; + auto gen_graph_dynamic_mf_func = + [this](const std::vector& total_data, int begin_index, + int end_index, int i) { + for (auto iter = total_data.begin() + begin_index; + iter != total_data.begin() + end_index; iter++) { + uint64_t cur_key = *iter; + int shard_id = cur_key % thread_keys_shard_num_; + // TODO: feasign <-> slot <-> multi_dim + this->thread_dim_keys_[i][shard_id][0].insert(cur_key); + } + }; for (int i = 0; i < thread_keys_thread_num_; i++) { if (!multi_mf_dim_) { VLOG(1) << "psgpu graph wrapper genfunc"; @@ -582,136 +583,163 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { return; } std::vector threads(device_num); - HeterPs_ = HeterPsBase::get_instance(size_max, resource_, feature_value_accessor_, optimizer_type_); + HeterPs_ = HeterPsBase::get_instance( + size_max, resource_, feature_value_accessor_, optimizer_type_); #ifdef PADDLE_WITH_CUDA HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_); HeterPs_->set_sparse_sgd(optimizer_config_); HeterPs_->set_embedx_sgd(optimizer_config_); #endif - auto build_dynamic_mf_func = [this, &gpu_task](int i, int j) { - this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); - // this->HeterPs_->set_accessor(feature_value_accessor_); - int mf_dim = this->index_dim_vec_[j]; - VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim - << " feature_value_dim:" << feature_value_accessor_.common_feature_value.Dim(mf_dim) - << " feature_value_size:" << feature_value_accessor_.common_feature_value.Size(mf_dim); - size_t feature_value_size = - TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(mf_dim)); - auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; - auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; - size_t len = device_dim_keys.size(); - CHECK(len == device_dim_ptrs.size()); - this->mem_pools_[i * this->multi_mf_dim_ + j] = - new MemoryPool(len, feature_value_size); - auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; - for (size_t k = 0; k < len; k++) { - float* val = (float*)(mem_pool->mem_address(k)); - float* ptr_val = device_dim_ptrs[k]->data(); - size_t dim = device_dim_ptrs[k]->size(); + auto build_dynamic_mf_func = + [this, &gpu_task](int i, int j) { + this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); + // this->HeterPs_->set_accessor(feature_value_accessor_); + int mf_dim = this->index_dim_vec_[j]; + VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim + << " feature_value_dim:" + << feature_value_accessor_.common_feature_value.Dim(mf_dim) + << " feature_value_size:" + << feature_value_accessor_.common_feature_value.Size(mf_dim); + size_t feature_value_size = TYPEALIGN( + 8, feature_value_accessor_.common_feature_value.Size(mf_dim)); + auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; + auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; + size_t len = device_dim_keys.size(); + CHECK(len == device_dim_ptrs.size()); + this->mem_pools_[i * this->multi_mf_dim_ + j] = + new MemoryPool(len, feature_value_size); + auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; + for (size_t k = 0; k < len; k++) { + float* val = (float*)(mem_pool->mem_address(k)); + float* ptr_val = device_dim_ptrs[k]->data(); + size_t dim = device_dim_ptrs[k]->size(); #ifdef PADDLE_WITH_PSLIB - val->delta_score = - ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::delta_score_index()]; - val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::show_index()]; - val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::click_index()]; - val->slot = int(ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::slot_index()]); - val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::embed_w_index()]; - val->lr_g2sum = + val->delta_score = + ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::delta_score_index()]; + val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::show_index()]; + val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::click_index()]; + val->slot = + int(ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::slot_index()]); + val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_w_index()]; + val->lr_g2sum = + ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_g2sum_index()]; + // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::embed_g2sum_index()]; - // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor - ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - mf_dim_index()] = float(mf_dim); - val->mf_dim = mf_dim; - if (dim > 8) { // CpuPS alreay expand as mf_dim - val->mf_size = mf_dim + 1; - for (int x = 0; x < val->mf_dim + 1; x++) { - val->mf[x] = ptr_val[x + 8]; - } - } else { - val->mf_size = 0; - for (int x = 0; x < val->mf_dim + 1; x++) { - val->mf[x] = 0; + DownpourCtrDymfFeatureValue::mf_dim_index()] = + float(mf_dim); + val->mf_dim = mf_dim; + if (dim > 8) { // CpuPS alreay expand as mf_dim + val->mf_size = mf_dim + 1; + for (int x = 0; x < val->mf_dim + 1; x++) { + val->mf[x] = ptr_val[x + 8]; + } + } else { + val->mf_size = 0; + for (int x = 0; x < val->mf_dim + 1; x++) { + val->mf[x] = 0; + } + } } - } - } #endif #ifdef PADDLE_WITH_PSCORE - VLOG(5) << "cpu build "<< k << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) - << " |: "<< cpu_table_accessor_->ParseToString(ptr_val, dim); - val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()]; - val[feature_value_accessor_.common_feature_value.ShowIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.ShowIndex()]; - val[feature_value_accessor_.common_feature_value.ClickIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.ClickIndex()]; - val[feature_value_accessor_.common_feature_value.SlotIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.SlotIndex()]; - val[feature_value_accessor_.common_feature_value.EmbedWIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()]; - for (int i = 0; i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { - val[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + i] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i]; - } - - *(reinterpret_cast(val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = (uint64_t)(device_dim_ptrs[k]); - - ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = float(mf_dim); - val[feature_value_accessor_.common_feature_value.MfDimIndex()] = mf_dim; - if (dim > cpu_table_accessor_->GetAccessorInfo().dim - - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)) { - val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); - - for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float)); - x++) { - val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x]; + VLOG(5) << "cpu build " << k + << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) + << " |: " << cpu_table_accessor_->ParseToString(ptr_val, dim); + val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value + .DeltaScoreIndex()]; + val[feature_value_accessor_.common_feature_value.ShowIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.ShowIndex()]; + val[feature_value_accessor_.common_feature_value.ClickIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.ClickIndex()]; + val[feature_value_accessor_.common_feature_value.SlotIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.SlotIndex()]; + val[feature_value_accessor_.common_feature_value.EmbedWIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()]; + for (int i = 0; + i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { + val[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + + i] = ptr_val + [cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i]; } - } else { - val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = 0; - for (int x = feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(); - x < int(feature_value_accessor_.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - val[x] = 0; + + *(reinterpret_cast( + val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = + (uint64_t)(device_dim_ptrs[k]); + + ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = + float(mf_dim); + val[feature_value_accessor_.common_feature_value.MfDimIndex()] = mf_dim; + if (dim > cpu_table_accessor_->GetAccessorInfo().dim - + cpu_table_accessor_->GetAccessorInfo().mf_size / + sizeof(float)) { + val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / + sizeof(float); + + for (int x = 0; + x < + int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / + sizeof(float)); + x++) { + val[feature_value_accessor_.common_feature_value + .EmbedxG2SumIndex() + + x] = ptr_val[cpu_table_accessor_->common_feature_value + .EmbedxG2SumIndex() + + x]; + } + } else { + val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = 0; + for (int x = feature_value_accessor_.common_feature_value + .EmbedxG2SumIndex(); + x < + int(feature_value_accessor_.common_feature_value.Size(mf_dim) / + sizeof(float)); + x++) { + val[x] = 0; + } } + VLOG(5) << "build " << k << " : " + << feature_value_accessor_.ParseToString( + val, feature_value_accessor_.common_feature_value.Dim( + mf_dim)); } - VLOG(5) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.common_feature_value.Dim(mf_dim)); - } #endif - platform::CUDADeviceGuard guard(resource_->dev_id(i)); + platform::CUDADeviceGuard guard(resource_->dev_id(i)); - this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool); - auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; + this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool); + auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; - this->HeterPs_->build_ps(i, device_dim_keys.data(), cur_pool->mem(), len, - feature_value_size, 500000, 2); + this->HeterPs_->build_ps(i, device_dim_keys.data(), cur_pool->mem(), len, + feature_value_size, 500000, 2); - if (device_dim_keys.size() > 0) { - VLOG(0) << "show ptr table: " << i - << " table kv size: " << device_dim_keys.size() - << "dim: " << mf_dim << " len: " << len; - this->HeterPs_->show_one_table(i); - } - delete mem_pool; - }; - threads.resize(device_num * multi_mf_dim_); - for (int i = 0; i < device_num; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - threads[i + j * device_num] = std::thread(build_dynamic_mf_func, i, j); - } + if (device_dim_keys.size() > 0) { + VLOG(0) << "show ptr table: " << i + << " table kv size: " << device_dim_keys.size() << "dim: " << mf_dim + << " len: " << len; + this->HeterPs_->show_one_table(i); } - - for (std::thread& t : threads) { - t.join(); + delete mem_pool; +}; +threads.resize(device_num* multi_mf_dim_); +for (int i = 0; i < device_num; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + threads[i + j * device_num] = std::thread(build_dynamic_mf_func, i, j); } - timeline.Pause(); - VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec() - << " s."; +} + +for (std::thread& t : threads) { + t.join(); +} +timeline.Pause(); +VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec() << " s."; } void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { @@ -800,8 +828,12 @@ void PSGPUWrapper::BeginPass() { PADDLE_THROW(platform::errors::Fatal( "[BeginPass] after build_task, current task is not null.")); } - - VLOG(0) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s"; + if (FLAGS_gpugraph_dedup_pull_push_mode) { + VLOG(0) << "BeginPass end, cost time: " << timer.ElapsedSec() + << "s, enable pull push dedup mode=" << FLAGS_gpugraph_dedup_pull_push_mode; + } else { + VLOG(0) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s"; + } } void PSGPUWrapper::EndPass() { @@ -821,124 +853,152 @@ void PSGPUWrapper::EndPass() { } } - auto dump_pool_to_cpu_func = [this](int i, int j) { - PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); - auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; - auto& device_keys = this->current_task_->device_dim_keys_[i][j]; - size_t len = device_keys.size(); - int mf_dim = this->index_dim_vec_[j]; - size_t feature_value_size = - TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(mf_dim)); - VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim - << " key_len :" << len << " feature_value_size:" << feature_value_size; - - char* test_build_values = (char*)malloc(feature_value_size * len); - cudaMemcpy(test_build_values, hbm_pool->mem(), feature_value_size * len, - cudaMemcpyDeviceToHost); - - CHECK(len == hbm_pool->capacity()); - uint64_t unuse_key = std::numeric_limits::max(); - for (size_t index = 0; index < len; ++index) { - if (device_keys[index] == unuse_key) { - continue; - } - size_t offset = index * feature_value_size; - float* gpu_val = (float*)(test_build_values + offset); + auto dump_pool_to_cpu_func = + [this](int i, int j) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); + auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; + auto& device_keys = this->current_task_->device_dim_keys_[i][j]; + size_t len = device_keys.size(); + int mf_dim = this->index_dim_vec_[j]; + size_t feature_value_size = TYPEALIGN( + 8, feature_value_accessor_.common_feature_value.Size(mf_dim)); + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim + << " key_len :" << len + << " feature_value_size:" << feature_value_size; + + char* test_build_values = (char*)malloc(feature_value_size * len); + cudaMemcpy(test_build_values, hbm_pool->mem(), feature_value_size * len, + cudaMemcpyDeviceToHost); + + CHECK(len == hbm_pool->capacity()); + uint64_t unuse_key = std::numeric_limits::max(); + for (size_t index = 0; index < len; ++index) { + if (device_keys[index] == unuse_key) { + continue; + } + size_t offset = index * feature_value_size; + float* gpu_val = (float*)(test_build_values + offset); #ifdef PADDLE_WITH_PSLIB - auto* downpour_value = - (paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr); - int downpour_value_size = downpour_value->size(); - if (gpu_val->mf_size > 0 && downpour_value_size == 8) { - downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size); - } - float* cpu_val = downpour_value->data(); - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - delta_score_index()] = gpu_val->delta_score; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - show_index()] = gpu_val->show; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - click_index()] = gpu_val->clk; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - embed_w_index()] = gpu_val->lr; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - embed_g2sum_index()] = gpu_val->lr_g2sum; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - slot_index()] = gpu_val->slot; - - if (gpu_val->mf_size > 0) { - for (int x = 0; x < gpu_val->mf_dim + 1; x++) { - cpu_val[x + 8] = gpu_val->mf[x]; + auto* downpour_value = + (paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr); + int downpour_value_size = downpour_value->size(); + if (gpu_val->mf_size > 0 && downpour_value_size == 8) { + downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size); + } + float* cpu_val = downpour_value->data(); + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::delta_score_index()] = + gpu_val->delta_score; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::show_index()] = + gpu_val->show; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::click_index()] = + gpu_val->clk; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_w_index()] = + gpu_val->lr; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_g2sum_index()] = + gpu_val->lr_g2sum; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::slot_index()] = + gpu_val->slot; + + if (gpu_val->mf_size > 0) { + for (int x = 0; x < gpu_val->mf_dim + 1; x++) { + cpu_val[x + 8] = gpu_val->mf[x]; + } + } } - } - } #endif #ifdef PADDLE_WITH_PSCORE - auto* downpour_value = - (paddle::distributed::FixedFeatureValue*)(*(reinterpret_cast(gpu_val+ feature_value_accessor_.common_feature_value.CpuPtrIndex()))); - size_t downpour_value_size = downpour_value->size(); - if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0 && - downpour_value_size == (cpu_table_accessor_->GetAccessorInfo().dim - - int(cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)))) { // cpu_accessor - downpour_value->resize(cpu_table_accessor_->common_feature_value.Dim(mf_dim)); - } - float* cpu_val = downpour_value->data(); - - cpu_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()]; - cpu_val[cpu_table_accessor_->common_feature_value.ShowIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.ShowIndex()]; - cpu_val[cpu_table_accessor_->common_feature_value.ClickIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.ClickIndex()]; - cpu_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.EmbedWIndex()]; - cpu_val[cpu_table_accessor_->common_feature_value.SlotIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.SlotIndex()]; - - for (int i = 0; i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { - cpu_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i] = - gpu_val[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + i]; - } + auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(*( + reinterpret_cast( + gpu_val + + feature_value_accessor_.common_feature_value.CpuPtrIndex()))); + size_t downpour_value_size = downpour_value->size(); + if (gpu_val[feature_value_accessor_.common_feature_value + .MfSizeIndex()] > 0 && + downpour_value_size == + (cpu_table_accessor_->GetAccessorInfo().dim - + int(cpu_table_accessor_->GetAccessorInfo().mf_size / + sizeof(float)))) { // cpu_accessor + downpour_value->resize( + cpu_table_accessor_->common_feature_value.Dim(mf_dim)); + } + float* cpu_val = downpour_value->data(); + + cpu_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()] = + gpu_val[feature_value_accessor_.common_feature_value + .DeltaScoreIndex()]; + cpu_val[cpu_table_accessor_->common_feature_value.ShowIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.ShowIndex()]; + cpu_val[cpu_table_accessor_->common_feature_value.ClickIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.ClickIndex()]; + cpu_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.EmbedWIndex()]; + cpu_val[cpu_table_accessor_->common_feature_value.SlotIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.SlotIndex()]; + + for (int i = 0; + i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { + cpu_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + + i] = gpu_val[feature_value_accessor_.common_feature_value + .EmbedG2SumIndex() + + i]; + } - if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0) { - - for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float)); - x++) { - cpu_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x] = - gpu_val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x]; + if (gpu_val[feature_value_accessor_.common_feature_value + .MfSizeIndex()] > 0) { + for (int x = 0; + x < + int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / + sizeof(float)); + x++) { + cpu_val[cpu_table_accessor_->common_feature_value + .EmbedxG2SumIndex() + + x] = gpu_val[feature_value_accessor_.common_feature_value + .EmbedxG2SumIndex() + + x]; + } } + VLOG(5) << "dump to cpu " << index << " : " + << feature_value_accessor_.ParseToString( + gpu_val, + feature_value_accessor_.common_feature_value.Dim(mf_dim)) + << " ===== CPU:" + << cpu_table_accessor_->ParseToString(cpu_val, + downpour_value->size()); } - VLOG(5) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.common_feature_value.Dim(mf_dim)) - << " ===== CPU:" << cpu_table_accessor_->ParseToString(cpu_val, downpour_value->size()); - - } #endif - free(test_build_values); - }; - if (multi_mf_dim_) { - VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; - size_t device_num = heter_devices_.size(); - std::vector threads(device_num * multi_mf_dim_); - for (size_t i = 0; i < device_num; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - threads[i + j * device_num] = std::thread(dump_pool_to_cpu_func, i, j); - } - } - for (std::thread& t : threads) { - t.join(); + free(test_build_values); +}; +if (multi_mf_dim_) { + VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; + size_t device_num = heter_devices_.size(); + std::vector threads(device_num * multi_mf_dim_); + for (size_t i = 0; i < device_num; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + threads[i + j * device_num] = std::thread(dump_pool_to_cpu_func, i, j); } } - if (keysize_max != 0) { - HeterPs_->end_pass(); - } - VLOG(0) << "HeterPs_->end_pass end"; - for (size_t i = 0; i < hbm_pools_.size(); i++) { - delete hbm_pools_[i]; + for (std::thread& t : threads) { + t.join(); } - gpu_task_pool_.Push(current_task_); - current_task_ = nullptr; - gpu_free_channel_->Put(current_task_); - timer.Pause(); - VLOG(0) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; +} +if (keysize_max != 0) { + HeterPs_->end_pass(); +} +VLOG(0) << "HeterPs_->end_pass end"; +for (size_t i = 0; i < hbm_pools_.size(); i++) { + delete hbm_pools_[i]; +} +gpu_task_pool_.Push(current_task_); +current_task_ = nullptr; +gpu_free_channel_->Put(current_task_); +timer.Pause(); +VLOG(0) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; } void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, @@ -947,7 +1007,8 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, const std::vector& values, const std::vector& slot_lengths, const int hidden_size) { - VLOG(0) << "Warning:: recommand use pull_gpups_sparse op instead. This PullSparse is not used."; + VLOG(0) << "Warning:: recommand use pull_gpups_sparse op instead. This " + "PullSparse is not used."; } void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, @@ -961,84 +1022,176 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, platform::Timer all_timer; platform::Timer pull_gpups_timer; all_timer.Start(); - size_t total_length = - std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); - size_t feature_value_size = 0; - feature_value_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); - VLOG(3) << "PullSparse max_dim:" << max_mf_dim_ << " feature_value_size:" << feature_value_size; - -#ifdef PADDLE_WITH_CUDA - VLOG(3) << "Begine Gpu Ps PullSparse"; - auto buf = memory::Alloc(place, total_length * feature_value_size); - float* total_values_gpu = reinterpret_cast(buf->ptr()); -#endif -#ifdef PADDLE_WITH_XPU_KP - VLOG(3) << "Begine Xpu Ps PullSparse"; - FeatureValue* total_values_gpu = nullptr; - xpu_malloc(reinterpret_cast(&total_values_gpu), - total_length * feature_value_size); -#endif + size_t feature_value_size = feature_value_accessor_.common_pull_value.Size(max_mf_dim_); + VLOG(3) << "PullSparse max_dim:" << max_mf_dim_ + << " pull_feature_value_size:" << pull_type_size_; + if (platform::is_cpu_place(place)) { PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in GpuPs now.")); } else if (platform::is_gpu_place(place)) { - VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; +#ifdef PADDLE_WITH_CUDA int device_id = place.GetDeviceId(); int devid_2_index = HeterPs_->get_index_by_devid(device_id); - LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; - uint64_t* total_keys = - reinterpret_cast(total_keys_tensor.mutable_data( - {int64_t(total_length), 1}, place)); - - // construct slot_level lod info - auto slot_lengths_lod = slot_lengths; - for (size_t i = 1; i < slot_lengths_lod.size(); i++) { - slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + if (FLAGS_gpugraph_dedup_pull_push_mode > 0) { + auto& dev = device_caches_[devid_2_index]; + int slot_num = static_cast(slot_lengths.size()); + std::vector slot_lengths_lod; + slot_lengths_lod.reserve(slot_num + 1); + slot_lengths_lod.push_back(0); + + int64_t total_length = 0; + for (int i = 0; i < slot_num; ++i) { + total_length += slot_lengths[i]; + slot_lengths_lod.push_back(total_length); + } + dev.total_key_length = total_length; + VLOG(3) << "[" << device_id << "]Begin copy keys, key_num[" + << total_length << "] dedup mode"; + + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + + uint64_t* total_keys = dev.keys_tensor.mutable_data( + (total_length * 3) * sizeof(uint64_t), place); + + int* gpu_slot_dims = dev.dims_tensor.mutable_data( + slot_dim.size() * sizeof(int), place); + uint64_t** gpu_keys = dev.keys_ptr_tensor.mutable_data( + keys.size() * sizeof(uint64_t*), place); + + int64_t* slot_lens = dev.slot_lens.mutable_data( + (slot_num + 1) * sizeof(int64_t), place); + cudaMemcpyAsync(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(slot_lens, slot_lengths_lod.data(), + slot_lengths_lod.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream); + + cudaMemcpyAsync(gpu_slot_dims, slot_dim.data(), + slot_dim.size() * sizeof(int), cudaMemcpyHostToDevice, + stream); + float** gpu_values = dev.values_ptr_tensor.mutable_data( + values.size() * sizeof(float*), place); + cudaMemcpyAsync(gpu_values, values.data(), values.size() * sizeof(float*), + cudaMemcpyHostToDevice, stream); + + int* key2slot = dev.keys2slot.mutable_data( + (total_length * 5) * sizeof(int), place); + + this->CopyKeys(place, gpu_keys, total_keys, slot_lens, slot_num, + static_cast(total_length), key2slot); + + uint32_t* d_restore_idx = + reinterpret_cast(&key2slot[total_length]); + uint32_t* d_sorted_idx = + reinterpret_cast(&d_restore_idx[total_length]); + uint32_t* d_offset = + reinterpret_cast(&d_sorted_idx[total_length]); + uint32_t* d_merged_cnts = + reinterpret_cast(&d_offset[total_length]); + uint64_t* d_merged_keys = + reinterpret_cast(&total_keys[total_length]); + uint64_t* d_sorted_keys = + reinterpret_cast(&d_merged_keys[total_length]); + + int dedup_size = HeterPs_->dedup_keys_and_fillidx( + devid_2_index, static_cast(total_length), + total_keys, // input + d_merged_keys, // output + d_sorted_keys, // sort keys + d_restore_idx, // pull fill idx + d_sorted_idx, // sort old idx + d_offset, // offset + d_merged_cnts, + FLAGS_gpugraph_dedup_pull_push_mode & 0x02); +// printf("device %d, end dedup_keys_and_fillidx total %d, " +// "dedup_size %d, slot num: %d, value size: %d\n", +// device_id, int(total_length), dedup_size, slot_num, int(feature_value_size)); + + PADDLE_ENFORCE_GT(dedup_size, 0, + platform::errors::PreconditionNotMet( + "dedup keys need more than zero failed in BoxPS.")); + dev.dedup_key_length = dedup_size; + + int64_t total_bytes = dedup_size * feature_value_size; + float* total_values_gpu = + dev.pull_push_tensor.mutable_data(total_bytes, place); + pull_gpups_timer.Start(); + HeterPs_->pull_sparse(devid_2_index, d_merged_keys, total_values_gpu, + dedup_size); + + // values.size() not sure equal slot_num + this->CopyForPull(place, total_keys, gpu_values, total_values_gpu, + slot_lens, key2slot, max_mf_dim_ + 3, total_length, + gpu_slot_dims, d_restore_idx); + } else { + size_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + auto buf = memory::Alloc(place, total_length * feature_value_size); + float* total_values_gpu = reinterpret_cast(buf->ptr()); + VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; + LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; + uint64_t* total_keys = + reinterpret_cast(total_keys_tensor.mutable_data( + {int64_t(total_length), 1}, place)); + // construct slot_level lod info + auto slot_lengths_lod = slot_lengths; + for (size_t i = 1; i < slot_lengths_lod.size(); i++) { + slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + } + auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*)); + auto buf_length = + memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); + uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); + int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); + cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), + cudaMemcpyHostToDevice); + cudaMemcpy(gpu_len, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); + + auto buf_dim = memory::Alloc(place, slot_dim.size() * sizeof(int)); + int* gpu_dim = reinterpret_cast(buf_dim->ptr()); + cudaMemcpy(gpu_dim, slot_dim.data(), slot_dim.size() * sizeof(int), + cudaMemcpyHostToDevice); + + this->CopyKeys(place, gpu_keys, total_keys, gpu_len, + static_cast(slot_lengths.size()), + static_cast(total_length)); + VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index + << " len: " << total_length; + + pull_gpups_timer.Start(); + HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu, + total_length); + + VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length + << "]"; + + this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len, + static_cast(slot_lengths.size()), hidden_size, + total_length, gpu_dim); } - auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*)); - auto buf_length = - memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); - uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); - int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); - cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), - cudaMemcpyHostToDevice); - cudaMemcpy(gpu_len, slot_lengths_lod.data(), - slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); - - auto buf_dim = memory::Alloc(place, slot_dim.size() * sizeof(int)); - int* gpu_dim = reinterpret_cast(buf_dim->ptr()); - cudaMemcpy(gpu_dim, slot_dim.data(), slot_dim.size() * sizeof(int), - cudaMemcpyHostToDevice); - - this->CopyKeys(place, gpu_keys, total_keys, gpu_len, - static_cast(slot_lengths.size()), - static_cast(total_length)); - VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index - << " len: " << total_length; - - pull_gpups_timer.Start(); - HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu, - total_length); - - VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length - << "]"; - - this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len, - static_cast(slot_lengths.size()), hidden_size, - total_length, gpu_dim); - pull_gpups_timer.Pause(); #endif } else if (platform::is_xpu_place(place)) { #ifdef PADDLE_WITH_XPU_KP + VLOG(3) << "Begine Xpu Ps PullSparse"; + size_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + FeatureValue* total_values_gpu = nullptr; + xpu_malloc(reinterpret_cast(&total_values_gpu), + total_length * feature_value_size); VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; int device_id = place.GetDeviceId(); int devid_2_index = HeterPs_->get_index_by_devid(device_id); LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; - uint64_t* total_keys = reinterpret_cast( - total_keys_tensor.mutable_data({total_length, 1}, place)); + uint64_t* total_keys = + reinterpret_cast(total_keys_tensor.mutable_data( + {int64_t(total_length), 1}, place)); // construct slot_level lod info auto slot_lengths_lod = slot_lengths; @@ -1094,16 +1247,9 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, platform::Timer all_timer; platform::Timer push_gpups_timer; all_timer.Start(); - int64_t total_length = - std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); - // #ifdef PADDLE_WITH_CUDA - VLOG(3) << "Begin GPUPS PushSparseGrad"; - size_t grad_value_size = - TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); - auto buf = memory::Alloc(place, total_length * grad_value_size); - VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_ << "grad_value_size:" << grad_value_size; - float* total_grad_values_gpu = - reinterpret_cast(buf->ptr()); + size_t grad_value_size = + TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + if (platform::is_cpu_place(place)) { PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in GPUPS now.")); @@ -1111,28 +1257,104 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, #ifdef PADDLE_WITH_CUDA int device_id = place.GetDeviceId(); int devid_2_index = HeterPs_->get_index_by_devid(device_id); - LoDTensor& cached_total_keys_tensor = keys_tensor[devid_2_index]; - uint64_t* total_keys = - reinterpret_cast(cached_total_keys_tensor.data()); - VLOG(3) << "Begin copy grad tensor to gpups struct"; - - this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, - total_length, batch_size, grad_value_size); - - VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index - << " len: " << total_length; - push_gpups_timer.Start(); - HeterPs_->push_sparse(devid_2_index, total_keys, total_grad_values_gpu, - static_cast(total_length)); + if (FLAGS_gpugraph_dedup_pull_push_mode > 0) { + auto& dev = device_caches_[devid_2_index]; + int64_t total_length = dev.total_key_length; + VLOG(3) << "Begin push sparse, key_num[" << total_length + << "] dedup mode, device:" << device_id << ", index" + << devid_2_index; + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + uint64_t* total_keys = dev.keys_tensor.data(); + int* slot_dims = dev.dims_tensor.data(); + int slot_num = static_cast(slot_lengths.size()); + if (!dev.d_slot_vector.IsInitialized()) { + int* buf_slot_vector = + dev.d_slot_vector.mutable_data(slot_num * sizeof(int), place); + cudaMemcpyAsync(buf_slot_vector, slot_vector_.data(), + slot_num * sizeof(int), cudaMemcpyHostToDevice, stream); + } + + const int64_t* slot_lens = dev.slot_lens.data(); + const int* d_slot_vector = dev.d_slot_vector.data(); + const int* key2slot = dev.keys2slot.data(); + float** gpu_values = dev.values_ptr_tensor.data(); + cudaMemcpyAsync(gpu_values, grad_values.data(), + grad_values.size() * sizeof(float*), + cudaMemcpyHostToDevice, stream); + + uint64_t* d_merged_keys = &total_keys[total_length]; + + int64_t dedup_size = dev.dedup_key_length; + int64_t total_bytes = dedup_size * grad_value_size; + float* total_grad_values_gpu = + dev.pull_push_tensor.mutable_data(total_bytes, place); + // dedup rate more than 3 + if (total_length > dedup_size * 3) { + const uint32_t* d_restore_idx = + reinterpret_cast(&key2slot[total_length]); + this->CopyForPush(place, total_keys, gpu_values, total_grad_values_gpu, + d_slot_vector, slot_lens, max_mf_dim_ + 3, total_length, + dedup_size, batch_size, slot_dims, key2slot, + d_restore_idx, grad_value_size); + } else { + const uint32_t* d_sorted_idx = + reinterpret_cast(&key2slot[total_length * 2]); + const uint32_t* d_offset = + reinterpret_cast(&d_sorted_idx[total_length]); + const uint32_t* d_merged_cnts = + reinterpret_cast(&d_offset[total_length]); + this->CopyForPush(place, d_merged_keys, gpu_values, total_grad_values_gpu, + d_slot_vector, slot_lens, max_mf_dim_ + 3, total_length, + dedup_size, batch_size, slot_dims, key2slot, + d_sorted_idx, d_offset, d_merged_cnts, grad_value_size); + } + + push_gpups_timer.Start(); + HeterPs_->push_sparse(devid_2_index, d_merged_keys, total_grad_values_gpu, + static_cast(dedup_size)); + } else { + int64_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + VLOG(3) << "Begin GPUPS PushSparseGrad"; + + auto buf = memory::Alloc(place, total_length * grad_value_size); + VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_ + << "grad_value_size:" << grad_value_size; + float* total_grad_values_gpu = reinterpret_cast(buf->ptr()); + + LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; + uint64_t* total_keys = + reinterpret_cast(total_keys_tensor.data()); + VLOG(3) << "Begin copy grad tensor to gpups struct"; + + this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, + total_length, batch_size, grad_value_size); + + VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index + << " len: " << total_length; + push_gpups_timer.Start(); + HeterPs_->push_sparse(devid_2_index, total_keys, total_grad_values_gpu, + static_cast(total_length)); + } push_gpups_timer.Pause(); #endif } else if (platform::is_xpu_place(place)) { #ifdef PADDLE_WITH_XPU_KP int device_id = place.GetDeviceId(); int devid_2_index = HeterPs_->get_index_by_devid(device_id); - LoDTensor& cached_total_keys_tensor = keys_tensor[devid_2_index]; + int64_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + VLOG(3) << "Begin GPUPS PushSparseGrad"; + + auto buf = memory::Alloc(place, total_length * grad_value_size); + VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_ + << "grad_value_size:" << grad_value_size; + float* total_grad_values_gpu = reinterpret_cast(buf->ptr()); + LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; uint64_t* total_keys = - reinterpret_cast(cached_total_keys_tensor.data()); + reinterpret_cast(total_keys_tensor.data()); VLOG(3) << "Begin copy grad tensor to xpups struct"; this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, hidden_size, total_length, batch_size); @@ -1159,4 +1381,4 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, } // end namespace framework } // end namespace paddle -// #endif +#endif diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 15d22ab57428d..3ff0c89c3d67d 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -17,14 +17,20 @@ limitations under the License. */ #include #include #include + #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { namespace framework { +const int CUDA_NUM_THREADS = platform::PADDLE_CUDA_NUM_THREADS; +#define GET_BLOCK(N) ((N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS) +#define CUDA_BLOCK(N) GET_BLOCK(N), CUDA_NUM_THREADS, 0 + __global__ void PullCopy(float** dest, const FeatureValue* src, const int64_t* len, int hidden, int slot_num, int total_len, uint64_t** keys) { @@ -61,10 +67,11 @@ __global__ void PullCopy(float** dest, const FeatureValue* src, } } -__global__ void PullCopy(float** dest, const float* src, - const int64_t* len, int slot_num, int total_len, - uint64_t** keys, uint64_t max_val_size, int* gpu_dim, - CommonFeatureValueAccessor feature_value_accessor) { +template +__global__ void PullCopy(float** dest, const float* src, const int64_t* len, + int slot_num, int total_len, uint64_t** keys, + uint64_t max_val_size, int* gpu_dim, + TAccess accessor) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; int high = slot_num - 1; @@ -85,19 +92,23 @@ __global__ void PullCopy(float** dest, const float* src, *(dest[x] + y * (mf_dim + 3) + 1) = 0; *(dest[x] + y * (mf_dim + 3) + 2) = 0; } else { - *(dest[x] + y * (mf_dim + 3)) = feature_value_ptr[feature_value_accessor.common_feature_value.ShowIndex()]; - *(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr[feature_value_accessor.common_feature_value.ClickIndex()]; - *(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr[feature_value_accessor.common_feature_value.EmbedWIndex()]; + *(dest[x] + y * (mf_dim + 3)) = + feature_value_ptr[accessor.ShowIndex()]; + *(dest[x] + y * (mf_dim + 3) + 1) = + feature_value_ptr[accessor.ClickIndex()]; + *(dest[x] + y * (mf_dim + 3) + 2) = + feature_value_ptr[accessor.EmbedWIndex()]; } - if (feature_value_ptr[feature_value_accessor.common_feature_value.MfSizeIndex()] == 0 || *(keys[x] + y) == 0) { + if (feature_value_ptr[accessor.MfSizeIndex()] == 0 || + *(keys[x] + y) == 0) { for (int j = 0; j < mf_dim; j++) { *(dest[x] + y * (mf_dim + 3) + 3 + j) = 0; } } else { for (int j = 0; j < mf_dim; j++) { - *(dest[x] + y * (mf_dim + 3) + 3 + j) = - feature_value_ptr[feature_value_accessor.common_feature_value.EmbedxWOffsetIndex(feature_value_ptr) + j]; + *(dest[x] + y * (mf_dim + 3) + 3 + j) = + feature_value_ptr[accessor.EmbedxWIndex() + j]; } } } @@ -147,11 +158,10 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len, } } -__global__ void PushCopyWithPool(float* dest, float** src, - int64_t* len, int slot_num, uint64_t total_len, - int bs, int* slot_vector, int* mf_dim_vector, - size_t grad_value_size, - CommonFeatureValueAccessor feature_value_accessor) { +__global__ void PushCopyWithPool( + float* dest, float** src, int64_t* len, int slot_num, uint64_t total_len, + int bs, int* slot_vector, int* mf_dim_vector, size_t grad_value_size, + CommonFeatureValueAccessor feature_value_accessor) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; int high = slot_num - 1; @@ -164,22 +174,22 @@ __global__ void PushCopyWithPool(float* dest, float** src, } int x = low; int y = i - (x ? len[low - 1] : 0); - float* cur = - (float*)((char*)dest + i * grad_value_size); + float* cur = (float*)((char*)dest + i * grad_value_size); - cur[feature_value_accessor.common_push_value.SlotIndex()] = + cur[feature_value_accessor.common_push_value.SlotIndex()] = (float)slot_vector[x]; int mf_dim = mf_dim_vector[x]; cur[feature_value_accessor.common_push_value.MfDimIndex()] = mf_dim; - cur[feature_value_accessor.common_push_value.ShowIndex()] = - *(src[x] + y * (mf_dim + 3)); - cur[feature_value_accessor.common_push_value.ClickIndex()] = - *(src[x] + y * (mf_dim + 3) + 1); - cur[feature_value_accessor.common_push_value.EmbedGIndex()] = - *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; + cur[feature_value_accessor.common_push_value.ShowIndex()] = + *(src[x] + y * (mf_dim + 3)); + cur[feature_value_accessor.common_push_value.ClickIndex()] = + *(src[x] + y * (mf_dim + 3) + 1); + cur[feature_value_accessor.common_push_value.EmbedGIndex()] = + *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; for (int j = 0; j < mf_dim; j++) { - cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; + cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = + *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; } } } @@ -222,7 +232,7 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, cudaMemcpyHostToDevice); PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( gpu_values, total_values_gpu, gpu_len, slot_num, total_length, gpu_keys, - val_type_size_, gpu_dim, feature_value_accessor_); + pull_type_size_, gpu_dim, feature_value_accessor_.common_pull_value); cudaStreamSynchronize(stream); } @@ -238,6 +248,103 @@ void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, cudaStreamSynchronize(stream); } +__global__ void CopyKeysKernel2( + const int total_len, + uint64_t** src_keys, + uint64_t* dest_total_keys, + const int slot_num, + const int64_t* slot_lens, + int* key2slots) { + CUDA_KERNEL_LOOP(i, total_len) { + int low = 0; + int high = slot_num - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < slot_lens[mid + 1]) { + high = mid; + } else { + low = mid + 1; + } + } + key2slots[i] = low; + int y = i - slot_lens[low]; + dest_total_keys[i] = src_keys[low][y]; + } +} +void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, + uint64_t** origin_keys, uint64_t* total_keys, + const int64_t* slot_lens, int slot_num, + int total_len, int* key2slot) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + CopyKeysKernel2<<>>( + total_len, origin_keys, total_keys, slot_num, slot_lens, key2slot); + cudaStreamSynchronize(stream); +} +template +__global__ void PullDedupCopy( + const size_t N, const uint64_t* total_keys, float** dest, const float* src, + const int64_t* slot_lens, uint64_t max_val_size, const int* slot_dims, + const int hidden, const int* key2slot, const uint32_t* restore_idx, + TAccess accessor) { + CUDA_KERNEL_LOOP(idx, N) { + int i = idx / hidden; + int off = idx % hidden; + + int x = key2slot[i]; + int y = i - slot_lens[x]; + + assert(slot_dims[x] == hidden); + float* dest_ptr = dest[x] + y * hidden; + // 0 key fill zero + if (total_keys[i] == 0) { + *(dest_ptr + off) = 0; + return; + } + + float* src_ptr = + (float*)((char*)src + + uint64_t(restore_idx[i]) * uint64_t(max_val_size)); + switch (off) { + case 0: + *(dest_ptr + off) = src_ptr[accessor.ShowIndex()]; + break; + case 1: + *(dest_ptr + off) = src_ptr[accessor.ClickIndex()]; + break; + case 2: + *(dest_ptr + off) = src_ptr[accessor.EmbedWIndex()]; + break; + default: + if (src_ptr[accessor.MfSizeIndex()] == 0) { + *(dest_ptr + off) = 0; + } else { + *(dest_ptr + off) = + src_ptr[accessor.EmbedxWIndex() + off - 3]; + } + break; + } + } +} +void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, + const uint64_t* total_keys, float** gpu_values, + const float* total_values_gpu, + const int64_t* slot_lens, const int* key2slot, + const int hidden_size, + const int64_t total_length, + const int* slot_dims, + const uint32_t* gpu_restore_idx) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + size_t N = total_length * hidden_size; + PullDedupCopy<<>>( + N, total_keys, gpu_values, total_values_gpu, slot_lens, pull_type_size_, + slot_dims, hidden_size, key2slot, gpu_restore_idx, + feature_value_accessor_.common_pull_value); + cudaStreamSynchronize(stream); +} void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, const std::vector& grad_values, FeaturePushValue* total_grad_values_gpu, @@ -309,8 +416,184 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( total_grad_values_gpu, gpu_values, gpu_len, slot_lengths.size(), - total_length, batch_size, d_slot_vector, d_mf_dim_vector, - grad_value_size, feature_value_accessor_); + total_length, batch_size, d_slot_vector, d_mf_dim_vector, grad_value_size, + feature_value_accessor_); + cudaStreamSynchronize(stream); +} + +template +__global__ void PushMergeCopyAtomic( + const size_t N, const uint64_t* total_keys, float* dest, float** src, + const int hidden, const int bs, const int* slot_vector, + const int* slot_dims, const int64_t* slot_lens, const int* key2slot, + const uint32_t* d_restore_idx, size_t grad_value_size, + TAccess accessor) { + CUDA_KERNEL_LOOP(idx, N) { + int i = idx / hidden; + int off = idx % hidden; + // filter 0 keys + if (total_keys[i] == 0) { + return; + } + + int x = key2slot[i]; + int y = i - slot_lens[x]; + + const float* ptr = src[x] + y * hidden; + float* cur = (float*)((char*)dest + d_restore_idx[i] * grad_value_size); + int mf_dim = slot_dims[x] - 3; + switch (off) { + case 0: + cur[accessor.SlotIndex()] = (float)slot_vector[x]; + cur[accessor.MfDimIndex()] = mf_dim; + paddle::platform::CudaAtomicAdd( + &cur[accessor.ShowIndex()], *(ptr + off)); + break; + case 1: + paddle::platform::CudaAtomicAdd( + &cur[accessor.ClickIndex()], *(ptr + off)); + break; + case 2: + paddle::platform::CudaAtomicAdd( + &cur[accessor.EmbedGIndex()], *(ptr + off) * -1. * bs); + break; + default: + int embedx_idx = off - 3; + if (mf_dim < embedx_idx) { + return; + } + paddle::platform::CudaAtomicAdd( + &cur[accessor.EmbedxGIndex() + embedx_idx], *(ptr + off) * -1. * bs); + break; + } + } +} + +void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, + const uint64_t* total_keys, float** grad_values, + float* total_grad_values_gpu, const int* slots, + const int64_t* slot_lens, const int hidden_size, + const int64_t total_length, + const int64_t dedup_length, const int batch_size, + const int* slot_dims, const int* key2slot, + const uint32_t* d_restore_idx, + const size_t grad_value_size) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + cudaMemsetAsync(total_grad_values_gpu, 0, dedup_length * grad_value_size, + stream); + size_t N = total_length * hidden_size; + PushMergeCopyAtomic<<>>( + N, total_keys, total_grad_values_gpu, grad_values, hidden_size, + batch_size, slots, slot_dims, slot_lens, key2slot, d_restore_idx, + grad_value_size, feature_value_accessor_.common_push_value); + + cudaStreamSynchronize(stream); +} + +#define SUM_GRAD_VALUE \ + for (uint32_t j = 0; j < count; ++j) { \ + const uint32_t& pos = d_sort_idx[start + j]; \ + const int& x = key2slot[pos]; \ + y = pos - slot_lens[x]; \ + val += *(reinterpret_cast(src[x] + y * hidden + off)); \ + } + +template +__global__ void PushMergeCopy( + const size_t N, const uint64_t* total_keys, float* dest, float** src, + const int hidden, const int bs, const int* slot_vector, + const int* slot_dims, const int64_t* slot_lens, const int* key2slot, + const uint32_t* d_sort_idx, + const uint32_t* d_sort_offset, + const uint32_t* d_sort_cnt, size_t grad_value_size, + TAccess accessor) { + CUDA_KERNEL_LOOP(idx, N) { + int i = idx / hidden; + int off = idx % hidden; + // filter 0 keys + float* cur = (float*)((char*)dest + i * grad_value_size); + + if (total_keys[i] == 0) { + switch (off) { + case 0: + cur[accessor.SlotIndex()] = 0; + cur[accessor.MfDimIndex()] = 0; + cur[accessor.ShowIndex()] = 0.0; + break; + case 1: + cur[accessor.ClickIndex()] = 0.0; + break; + case 2: + cur[accessor.EmbedGIndex()] = 0.0; + break; + default: + cur[accessor.EmbedxGIndex() + off - 3] = 0.0; + break; + } + return; + } + + const uint32_t& start = d_sort_offset[i]; + const uint32_t& count = d_sort_cnt[i]; + const uint32_t& pos = d_sort_idx[start]; + + const int& x = key2slot[pos]; + int y = pos - slot_lens[x]; + int mf_dim = slot_dims[x] - 3; + + double val = 0.0; + + switch (off) { + case 0: + cur[accessor.SlotIndex()] = (float)slot_vector[x]; + cur[accessor.MfDimIndex()] = mf_dim; + SUM_GRAD_VALUE + cur[accessor.ShowIndex()] = val; + break; + case 1: + SUM_GRAD_VALUE + cur[accessor.ClickIndex()] = val; + break; + case 2: + SUM_GRAD_VALUE + cur[accessor.EmbedGIndex()] = val * -1. * bs; + break; + default: + int embedx_idx = off - 3; + if (mf_dim < embedx_idx) { + cur[accessor.EmbedxGIndex() + embedx_idx] = 0.0; + return; + } + SUM_GRAD_VALUE + cur[accessor.EmbedxGIndex() + embedx_idx] = val * -1. * bs; + break; + } + } +} + +void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, + const uint64_t* total_keys, float** grad_values, + float* total_grad_values_gpu, const int* slots, + const int64_t* slot_lens, const int hidden_size, + const int64_t total_length, const int64_t dedup_length, + const int batch_size, const int* slot_dims, + const int* key2slot, + const uint32_t* gpu_sort_idx, + const uint32_t* gpu_sort_offset, + const uint32_t* gpu_sort_lens, + const size_t grad_value_size) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + // merge all grad to one + size_t N = dedup_length * hidden_size; + PushMergeCopy<<>>( + N, total_keys, total_grad_values_gpu, grad_values, hidden_size, + batch_size, slots, slot_dims, slot_lens, key2slot, + gpu_sort_idx, gpu_sort_offset, gpu_sort_lens, + grad_value_size, feature_value_accessor_.common_push_value); cudaStreamSynchronize(stream); } @@ -319,20 +602,22 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, float learning_rate, float initial_g2sum, float initial_range, float beta1_decay_rate, float beta2_decay_rate, float ada_epsilon) { - optimizer_config_.set_sparse_sgd(nonclk_coeff, clk_coeff, min_bound, max_bound, - learning_rate, initial_g2sum, initial_range, - beta1_decay_rate, beta2_decay_rate, ada_epsilon); + optimizer_config_.set_sparse_sgd(nonclk_coeff, clk_coeff, min_bound, + max_bound, learning_rate, initial_g2sum, + initial_range, beta1_decay_rate, + beta2_decay_rate, ada_epsilon); } void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, float mf_learning_rate, float mf_initial_g2sum, float mf_initial_range, float mf_min_bound, float mf_max_bound, float mf_beta1_decay_rate, - float mf_beta2_decay_rate, float mf_ada_epsilon) { - optimizer_config_.set_embedx_sgd(mf_create_thresholds, mf_learning_rate, - mf_initial_g2sum, mf_initial_range, - mf_min_bound, mf_max_bound, mf_beta1_decay_rate, - mf_beta2_decay_rate, mf_ada_epsilon); + float mf_beta2_decay_rate, + float mf_ada_epsilon) { + optimizer_config_.set_embedx_sgd( + mf_create_thresholds, mf_learning_rate, mf_initial_g2sum, + mf_initial_range, mf_min_bound, mf_max_bound, mf_beta1_decay_rate, + mf_beta2_decay_rate, mf_ada_epsilon); } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 6da628db72455..a8fab77238e4e 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - #ifdef PADDLE_WITH_HETERPS #include @@ -27,6 +26,7 @@ limitations under the License. */ #include #ifdef PADDLE_WITH_GLOO #include + #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #endif @@ -50,10 +50,10 @@ limitations under the License. */ #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_PSCORE -#include "paddle/fluid/distributed/ps/wrapper/fleet.h" +#include "paddle/fluid/distributed/ps.pb.h" #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" -#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/ps/wrapper/fleet.h" #endif #ifdef PADDLE_WITH_PSLIB #include "afs_api.h" @@ -96,6 +96,55 @@ class AfsWrapper { #endif class PSGPUWrapper { + class DCacheBuffer { + public: + DCacheBuffer() : buf_(nullptr) {} + ~DCacheBuffer() {} + /** + * @Brief get data + */ + template + T* mutable_data(const size_t total_bytes, + const paddle::platform::Place& place) { + if (buf_ == nullptr) { + buf_ = memory::AllocShared(place, total_bytes); + } else if (buf_->size() < total_bytes) { + buf_.reset(); + buf_ = memory::AllocShared(place, total_bytes); + } + return reinterpret_cast(buf_->ptr()); + } + template + T* data() { + return reinterpret_cast(buf_->ptr()); + } + size_t memory_size() { + if (buf_ == nullptr) { + return 0; + } + return buf_->size(); + } + bool IsInitialized(void) { return (buf_ != nullptr); } + + private: + std::shared_ptr buf_ = nullptr; + }; + struct PSDeviceData { + DCacheBuffer keys_tensor; + DCacheBuffer dims_tensor; + DCacheBuffer keys_ptr_tensor; + DCacheBuffer values_ptr_tensor; + DCacheBuffer pull_push_tensor; + + DCacheBuffer slot_lens; + DCacheBuffer d_slot_vector; + DCacheBuffer keys2slot; + + int64_t total_key_length = 0; + int64_t dedup_key_length = 0; + }; + PSDeviceData* device_caches_ = nullptr; + public: ~PSGPUWrapper(); @@ -130,6 +179,9 @@ class PSGPUWrapper { void CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys, uint64_t* total_keys, const int64_t* gpu_len, int slot_num, int total_len); + void CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys, + uint64_t* total_keys, const int64_t* gpu_len, int slot_num, + int total_len, int* key2slot); void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, const std::vector& values, const FeatureValue* total_values_gpu, const int64_t* gpu_len, @@ -140,6 +192,12 @@ class PSGPUWrapper { const float* total_values_gpu, const int64_t* gpu_len, const int slot_num, const int hidden_size, const int64_t total_length, int* gpu_dim); + void CopyForPull(const paddle::platform::Place& place, + const uint64_t* total_keys, float** gpu_values, + const float* total_values_gpu, const int64_t* slot_lens, + const int* key2slot, const int hidden_size, + const int64_t total_length, const int* slot_dims, + const uint32_t* gpu_restore_idx); void CopyForPush(const paddle::platform::Place& place, const std::vector& grad_values, FeaturePushValue* total_grad_values_gpu, @@ -152,6 +210,25 @@ class PSGPUWrapper { const std::vector& slot_lengths, const uint64_t total_length, const int batch_size, size_t grad_value_size); + void CopyForPush(const paddle::platform::Place& place, + const uint64_t* total_keys, float** grad_values, + float* total_grad_values_gpu, const int* slots, + const int64_t* slot_lens, const int hidden_size, + const int64_t total_length, const int64_t dedup_length, + const int batch_size, const int* slot_dims, + const int* key2slot, const uint32_t* d_restore_idx, + const size_t grad_value_size); + void CopyForPush(const paddle::platform::Place& place, + const uint64_t* total_keys, float** grad_values, + float* total_grad_values_gpu, const int* slots, + const int64_t* slot_lens, const int hidden_size, + const int64_t total_length, const int64_t dedup_length, + const int batch_size, const int* slot_dims, + const int* key2slot, + const uint32_t* gpu_sort_idx, + const uint32_t* gpu_sort_offset, + const uint32_t* gpu_sort_lens, + const size_t grad_value_size); void BuildGPUTask(std::shared_ptr gpu_task); void PreBuildTask(std::shared_ptr gpu_task); @@ -177,6 +254,10 @@ class PSGPUWrapper { s_instance_ = nullptr; VLOG(3) << "PSGPUWrapper Finalize Finished."; HeterPs_->show_table_collisions(); + if (device_caches_ != nullptr) { + delete[] device_caches_; + device_caches_ = nullptr; + } } void InitializeGPU(const std::vector& dev_ids) { @@ -186,6 +267,7 @@ class PSGPUWrapper { resource_ = std::make_shared(dev_ids); resource_->enable_p2p(); keys_tensor.resize(resource_->total_device()); + device_caches_ = new PSDeviceData[resource_->total_device()]; #ifdef PADDLE_WITH_GLOO auto gloo = paddle::framework::GlooWrapper::GetInstance(); if (gloo->Size() > 1) { @@ -259,7 +341,7 @@ class PSGPUWrapper { float mf_min_bound, float mf_max_bound, float mf_beta1_decay_rate, float mf_beta2_decay_rate, float mf_ada_epsilon); - + #ifdef PADDLE_WITH_PSCORE void add_sparse_optimizer( std::unordered_map& config, // NOLINT @@ -309,7 +391,7 @@ class PSGPUWrapper { void InitializeGPUServer(paddle::distributed::PSParameter ps_param) { auto sparse_table = - ps_param.server_param().downpour_server_param().downpour_table_param(0); + ps_param.server_param().downpour_server_param().downpour_table_param(0); auto sparse_table_accessor = sparse_table.accessor(); auto sparse_table_accessor_parameter = sparse_table_accessor.ctr_accessor_param(); @@ -325,13 +407,13 @@ class PSGPUWrapper { // optimizer config for embed_w and embedx add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param()); add_sparse_optimizer(config, sparse_table_accessor.embedx_sgd_param(), - "mf_"); + "mf_"); } feature_value_accessor_.Configure(config); InitializeGPUServer(config); } - #endif +#endif void InitializeGPUServer(std::unordered_map config) { float nonclk_coeff = (config.find("nonclk_coeff") == config.end()) @@ -342,9 +424,8 @@ class PSGPUWrapper { float min_bound = (config.find("min_bound") == config.end()) ? -10.0 : config["min_bound"]; - float max_bound = (config.find("max_bound") == config.end()) - ? 10.0 - : config["max_bound"]; + float max_bound = + (config.find("max_bound") == config.end()) ? 10.0 : config["max_bound"]; float learning_rate = (config.find("learning_rate") == config.end()) ? 0.05 : config["learning_rate"]; @@ -361,8 +442,8 @@ class PSGPUWrapper { ? 0.999 : config["beta2_decay_rate"]; float ada_epsilon = (config.find("ada_epsilon") == config.end()) - ? 1e-8 - : config["ada_epsilon"]; + ? 1e-8 + : config["ada_epsilon"]; // mf config settings float mf_create_thresholds = (config.find("mf_create_thresholds") == config.end()) @@ -383,35 +464,37 @@ class PSGPUWrapper { float mf_max_bound = (config.find("mf_max_bound") == config.end()) ? 10.0 : config["mf_max_bound"]; - float mf_beta1_decay_rate = (config.find("mf_beta1_decay_rate") == config.end()) - ? 0.9 - : config["mf_beta1_decay_rate"]; - float mf_beta2_decay_rate = (config.find("mf_beta2_decay_rate") == config.end()) - ? 0.999 - : config["mf_beta2_decay_rate"]; + float mf_beta1_decay_rate = + (config.find("mf_beta1_decay_rate") == config.end()) + ? 0.9 + : config["mf_beta1_decay_rate"]; + float mf_beta2_decay_rate = + (config.find("mf_beta2_decay_rate") == config.end()) + ? 0.999 + : config["mf_beta2_decay_rate"]; float mf_ada_epsilon = (config.find("mf_ada_epsilon") == config.end()) - ? 1e-8 - : config["mf_ada_epsilon"]; + ? 1e-8 + : config["mf_ada_epsilon"]; this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, max_bound, - learning_rate, initial_g2sum, initial_range, - beta1_decay_rate, beta2_decay_rate, ada_epsilon); - this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate, - mf_initial_g2sum, mf_initial_range, mf_min_bound, - mf_max_bound, mf_beta1_decay_rate, mf_beta2_decay_rate, - mf_ada_epsilon); + learning_rate, initial_g2sum, initial_range, + beta1_decay_rate, beta2_decay_rate, ada_epsilon); + this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate, mf_initial_g2sum, + mf_initial_range, mf_min_bound, mf_max_bound, + mf_beta1_decay_rate, mf_beta2_decay_rate, + mf_ada_epsilon); // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam) optimizer_type_ = (config.find("optimizer_type") == config.end()) - ? 1 - : int(config["optimizer_type"]); + ? 1 + : int(config["optimizer_type"]); embedx_dim_ = (config.find("embedx_dim") == config.end()) - ? 8 - : int(config["embedx_dim"]); - if (optimizer_type_ == 3) { //adam + ? 8 + : int(config["embedx_dim"]); + if (optimizer_type_ == 3) { // adam embed_sgd_dim_ = 4; embedx_sgd_dim_ = embedx_dim_ * 2 + 2; - } else if (optimizer_type_ == 4) { //shared_adam + } else if (optimizer_type_ == 4) { // shared_adam embed_sgd_dim_ = 4; embedx_sgd_dim_ = 4; } else { @@ -419,8 +502,9 @@ class PSGPUWrapper { embedx_sgd_dim_ = 1; } - VLOG(0) << "InitializeGPUServer embed_sgd_dim_:" << embed_sgd_dim_ << " embedx_sgd_dim_:" - << embedx_sgd_dim_ << " embedx_dim_:" << embedx_dim_ + VLOG(0) << "InitializeGPUServer embed_sgd_dim_:" << embed_sgd_dim_ + << " embedx_sgd_dim_:" << embedx_sgd_dim_ + << " embedx_dim_:" << embedx_dim_ << " optimizer_type_:" << optimizer_type_; } @@ -507,9 +591,14 @@ class PSGPUWrapper { for (size_t i = 0; i < slot_index_vec_.size(); i++) { slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]]; } - val_type_size_ = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); - grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); - VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_ << " grad_type_size_:" << grad_type_size_; + val_type_size_ = TYPEALIGN( + 8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + grad_type_size_ = TYPEALIGN( + 8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + pull_type_size_ = feature_value_accessor_.common_pull_value.Size(max_mf_dim_); + VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_ + << " grad_type_size_:" << grad_type_size_ + << " pull_type_size_:" << pull_type_size_; slot_info_initialized_ = true; } #endif @@ -530,11 +619,13 @@ class PSGPUWrapper { #ifdef PADDLE_WITH_PSCORE void SetTableAccessor(paddle::distributed::ValueAccessor* accessor) { - cpu_table_accessor_ = dynamic_cast(accessor); + cpu_table_accessor_ = + dynamic_cast(accessor); } #endif CommonFeatureValueAccessor feature_value_accessor_; + private: static std::shared_ptr s_instance_; Dataset* dataset_; @@ -558,6 +649,7 @@ class PSGPUWrapper { int max_mf_dim_{0}; size_t val_type_size_{0}; size_t grad_type_size_{0}; + size_t pull_type_size_{0}; double time_1 = 0.0; double time_2 = 0.0; @@ -619,6 +711,7 @@ class PSGPUWrapper { std::vector> pull_thread_pool_; std::vector> hbm_thread_pool_; OptimizerConfig optimizer_config_; + protected: static bool is_initialized_; }; diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 33198c11cc2af..c544a426d11f3 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -866,6 +866,9 @@ PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_uint64( gpugraph_merge_grads_segment_size, 128, "segment size with segment gradient merge, default 128"); +PADDLE_DEFINE_EXPORTED_int32( + gpugraph_dedup_pull_push_mode, 0, + "enable dedup keys while pull push sparse, default 0"); /** * ProcessGroupNCCL related FLAG