Skip to content

Commit

Permalink
Optimizing the zero key problem in the push phase (#40)
Browse files Browse the repository at this point in the history
Co-authored-by: root <[email protected]>
  • Loading branch information
lxsbupt and root authored Jun 21, 2022
1 parent 33ba59a commit 4179390
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 19 deletions.
21 changes: 11 additions & 10 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#endif

DECLARE_double(gpugraph_hbm_table_load_factor);
DECLARE_bool(gpugraph_enable_gpu_direct_access);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -682,7 +683,7 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
uniq_len, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
heter_comm_kernel_->merge_gradient(
d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
(char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr,
Expand Down Expand Up @@ -802,7 +803,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);

if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
Expand All @@ -818,12 +819,12 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
continue;
}
auto& node = path_[num][i].nodes_.back();
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
sync_stream(node.in_stream);
}
AnyDeviceGuard guard(resource_->dev_id(i));
ptr_tables_[i]->rwlock_->RDLock();
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
ptr_tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
Expand All @@ -842,7 +843,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
}
ptr_tables_[i]->rwlock_->UNLock();
}
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
walk_to_src(num, total_device, h_left, h_right,
reinterpret_cast<char*>(d_shard_vals_ptr), val_type_size);
for (int i = 0; i < total_device; ++i) {
Expand All @@ -855,7 +856,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
val_type_size, stream);

sync_stream(stream);
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
Expand Down Expand Up @@ -946,7 +947,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);

if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
Expand All @@ -965,13 +966,13 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
continue;
}
auto& node = path_[dev_num][i].nodes_.back();
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
sync_stream(node.in_stream);
}

AnyDeviceGuard guard(resource_->dev_id(i));
ptr_tables_[i]->rwlock_->WRLock();
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
ptr_tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1,
sgd, resource_->remote_stream(i, dev_num));
Expand All @@ -995,7 +996,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
}
}

if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
Expand Down
28 changes: 21 additions & 7 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ __global__ void dy_mf_fill_shard_grads_kernel(
}
}

__global__ void merge_gradients_kernel(const uint32_t* offset,
template <typename KeyType>
__global__ void merge_gradients_kernel(const KeyType* d_keys,
const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index, const char* input,
char* output, int n,
Expand All @@ -163,10 +165,13 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
float* in =
(float*)(input + size_t(ori_index) * grad_value_size);
merger_.update_one(out, in, feature_value_accessor);
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = (float*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in, feature_value_accessor);
KeyType key = d_keys[i];
if (key != 0) {
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = (float*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in, feature_value_accessor);
}
}
}
}
Expand Down Expand Up @@ -316,13 +321,15 @@ void HeterCommKernel::dy_mf_fill_shard_grads(
grad_value_size, feature_value_accessor_);
}

template <typename StreamType>
template <typename KeyType, typename StreamType>
void HeterCommKernel::merge_gradient(
const KeyType* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
d_keys,
offset, fea_num, index, input, output, n, grad_value_size, merger_, feature_value_accessor_);
}

Expand Down Expand Up @@ -407,7 +414,14 @@ template void HeterCommKernel::dy_mf_fill_shard_grads<
float* d_shard_grads, float* d_grads, int* idx, long long len,
size_t grad_value_size, const cudaStream_t& stream);

template void HeterCommKernel::merge_gradient<cudaStream_t>(
template void HeterCommKernel::merge_gradient<uint32_t, cudaStream_t>(
const uint32_t* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);

template void HeterCommKernel::merge_gradient<uint64_t, cudaStream_t>(
const uint64_t* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ class HeterCommKernel {
T* idx, long long len, size_t grad_value_size,
const StreamType& stream);

template <typename StreamType>
void merge_gradient(const uint32_t* offset, const uint32_t* fea_num,
template <typename KeyType, typename StreamType>
void merge_gradient(const KeyType* d_shard_keys, const uint32_t* offset, const uint32_t* fea_num,
const uint32_t* index, const char* input, char* output,
int n, size_t grad_value_size, DynamicGradMerger& merger_,
const StreamType& stream);
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,9 @@ PADDLE_DEFINE_EXPORTED_bool(
PADDLE_DEFINE_EXPORTED_double(
gpugraph_hbm_table_load_factor, 0.75,
"the load factor of hbm table, default 0.75");
PADDLE_DEFINE_EXPORTED_bool(
gpugraph_enable_gpu_direct_access, false,
"enable hash collisions stat for hbm table, default false");

/**
* ProcessGroupNCCL related FLAG
Expand Down

0 comments on commit 4179390

Please sign in to comment.