diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 6364a3f7bf4af9..f211e15b13e285 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -22,6 +22,7 @@ limitations under the License. */ #endif DECLARE_double(gpugraph_hbm_table_load_factor); +DECLARE_bool(gpugraph_enable_gpu_direct_access); namespace paddle { namespace framework { @@ -682,7 +683,7 @@ void HeterComm::dynamic_merge_grad( uniq_len, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); heter_comm_kernel_->merge_gradient( - d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, + d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, (char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, @@ -802,7 +803,7 @@ void HeterComm::pull_sparse(int num, memory_copy(dst_place, h_right, src_place, d_right_ptr, total_device * sizeof(int), stream); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (h_left[i] == -1 || h_right[i] == -1) { @@ -818,12 +819,12 @@ void HeterComm::pull_sparse(int num, continue; } auto& node = path_[num][i].nodes_.back(); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { sync_stream(node.in_stream); } AnyDeviceGuard guard(resource_->dev_id(i)); ptr_tables_[i]->rwlock_->RDLock(); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { ptr_tables_[i]->get(reinterpret_cast(node.key_storage), node.val_storage, h_right[i] - h_left[i] + 1, resource_->remote_stream(i, num)); @@ -842,7 +843,7 @@ void HeterComm::pull_sparse(int num, } ptr_tables_[i]->rwlock_->UNLock(); } - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { walk_to_src(num, total_device, h_left, h_right, reinterpret_cast(d_shard_vals_ptr), val_type_size); for (int i = 0; i < total_device; ++i) { @@ -855,7 +856,7 @@ void HeterComm::pull_sparse(int num, val_type_size, stream); sync_stream(stream); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; @@ -946,7 +947,7 @@ void HeterComm::push_sparse(int dev_num, memory_copy(dst_place, h_right, src_place, d_right_ptr, total_device * sizeof(int), stream); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (h_left[i] == -1 || h_right[i] == -1) { @@ -965,13 +966,13 @@ void HeterComm::push_sparse(int dev_num, continue; } auto& node = path_[dev_num][i].nodes_.back(); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { sync_stream(node.in_stream); } AnyDeviceGuard guard(resource_->dev_id(i)); ptr_tables_[i]->rwlock_->WRLock(); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { ptr_tables_[i]->update(reinterpret_cast(node.key_storage), node.val_storage, h_right[i] - h_left[i] + 1, sgd, resource_->remote_stream(i, dev_num)); @@ -995,7 +996,7 @@ void HeterComm::push_sparse(int dev_num, } } - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 39a1469d9073c6..38566da3990cc0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -146,7 +146,9 @@ __global__ void dy_mf_fill_shard_grads_kernel( } } -__global__ void merge_gradients_kernel(const uint32_t* offset, +template +__global__ void merge_gradients_kernel(const KeyType* d_keys, + const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, @@ -163,10 +165,13 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, float* in = (float*)(input + size_t(ori_index) * grad_value_size); merger_.update_one(out, in, feature_value_accessor); - for (int j = 1; j < num; ++j) { - ori_index = index[start + j]; - in = (float*)(input + size_t(ori_index) * grad_value_size); - merger_.merge_one(out, in, feature_value_accessor); + KeyType key = d_keys[i]; + if (key != 0) { + for (int j = 1; j < num; ++j) { + ori_index = index[start + j]; + in = (float*)(input + size_t(ori_index) * grad_value_size); + merger_.merge_one(out, in, feature_value_accessor); + } } } } @@ -316,13 +321,15 @@ void HeterCommKernel::dy_mf_fill_shard_grads( grad_value_size, feature_value_accessor_); } -template +template void HeterCommKernel::merge_gradient( + const KeyType* d_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, size_t grad_value_size, DynamicGradMerger& merger_, const StreamType& stream) { int grid_size = (n - 1) / block_size_ + 1; merge_gradients_kernel<<>>( + d_keys, offset, fea_num, index, input, output, n, grad_value_size, merger_, feature_value_accessor_); } @@ -407,7 +414,14 @@ template void HeterCommKernel::dy_mf_fill_shard_grads< float* d_shard_grads, float* d_grads, int* idx, long long len, size_t grad_value_size, const cudaStream_t& stream); -template void HeterCommKernel::merge_gradient( +template void HeterCommKernel::merge_gradient( + const uint32_t* d_keys, + const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, + const char* input, char* output, int n, size_t grad_value_size, + DynamicGradMerger& merger_, const cudaStream_t& stream); + +template void HeterCommKernel::merge_gradient( + const uint64_t* d_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, size_t grad_value_size, DynamicGradMerger& merger_, const cudaStream_t& stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 4bf77aaac202d5..d02031f9e7e285 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -135,8 +135,8 @@ class HeterCommKernel { T* idx, long long len, size_t grad_value_size, const StreamType& stream); - template - void merge_gradient(const uint32_t* offset, const uint32_t* fea_num, + template + void merge_gradient(const KeyType* d_shard_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, size_t grad_value_size, DynamicGradMerger& merger_, const StreamType& stream); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 71daee503e49d5..b165a678f8f938 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -857,6 +857,9 @@ PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_double( gpugraph_hbm_table_load_factor, 0.75, "the load factor of hbm table, default 0.75"); +PADDLE_DEFINE_EXPORTED_bool( + gpugraph_enable_gpu_direct_access, false, + "enable hash collisions stat for hbm table, default false"); /** * ProcessGroupNCCL related FLAG