Skip to content

Commit

Permalink
Gpugraph.0621 (#45)
Browse files Browse the repository at this point in the history
* Optimizing the zero key problem in the push phase

* Optimize CUDA thread parallelism in MergeGrad phase

* Optimize CUDA thread parallelism in MergeGrad phase

Co-authored-by: root <[email protected]>
  • Loading branch information
lxsbupt and root authored Jun 24, 2022
1 parent fdf59b6 commit 79675fd
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -841,12 +841,14 @@ x.second );

__host__ void print_collision(int id) {
if (m_enable_collision_stat) {
printf("collision stat for hbm table %d, insert(%lu:%lu), query(%lu:%lu)\n",
printf("collision stat for hbm table %d, insert(%lu:%lu:%.2f), query(%lu:%lu:%.2f)\n",
id,
m_insert_times,
m_insert_collisions,
m_insert_collisions / (double)m_insert_times,
m_query_times,
m_query_collisions);
m_query_collisions,
m_query_collisions / (double)m_query_times);
}
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct GpuAccessorInfo {
size_t dim;
// value各个维度的size
size_t size;
// embedx维度
size_t embedx_dim;
// push value维度
size_t update_dim;
// push value各个维度的size
Expand Down Expand Up @@ -192,6 +194,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor {
? 8
: int(_config["embedx_dim"]);
// VLOG(0) << "feature value InitAccessorInfo embedx_dim:" << embedx_dim;
_accessor_info.embedx_dim = embedx_dim;
_accessor_info.update_dim = 5 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ class HeterComm {
std::vector<std::vector<Path>> path_;
float load_factor_{0.75};
int block_size_{256};
int direct_access_ = 1;
std::unique_ptr<HeterCommKernel> heter_comm_kernel_;

private:
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(

size_t temp_storage_bytes;

size_t grad_dim = feature_value_accessor_.GetAccessorInfo().embedx_dim;
size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size);

auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
Expand Down Expand Up @@ -687,7 +688,7 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
heter_comm_kernel_->merge_gradient(
d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
(char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream);
(char*)d_merge_grads_ptr, uniq_len, grad_dim, grad_value_size, merger_, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr,
grad_value_size * uniq_len,
Expand Down
60 changes: 49 additions & 11 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ __global__ void dy_mf_fill_shard_grads_kernel(
}

template <typename KeyType>
__global__ void merge_gradients_kernel(const KeyType* d_keys,
__global__ void merge_gradients_basic_kernel(const KeyType* d_keys,
const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index, const char* input,
char* output, int n,
size_t grad_value_size,
DynamicGradMerger& merger_,
DynamicGradMerger& merger,
CommonFeatureValueAccessor& feature_value_accessor) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

Expand All @@ -164,13 +164,45 @@ __global__ void merge_gradients_kernel(const KeyType* d_keys,
float* out = (float*)(output + i * grad_value_size);
float* in =
(float*)(input + size_t(ori_index) * grad_value_size);
merger_.update_one(out, in, feature_value_accessor);
merger.update_basic(out, in, feature_value_accessor);
KeyType key = d_keys[i];
if (key != 0) {
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = (float*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in, feature_value_accessor);
merger.merge_basic(out, in, feature_value_accessor);
}
}
}
}

template <typename KeyType>
__global__ void merge_gradients_embedx_kernel(const KeyType* d_keys,
const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index, const char* input,
char* output, int n,
size_t grad_dim,
size_t grad_value_size,
DynamicGradMerger& merger,
CommonFeatureValueAccessor& feature_value_accessor) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

if (i < n) {
size_t value_idx = i / grad_dim;
size_t field_idx = i % grad_dim;
uint32_t start = offset[value_idx];
uint32_t num = fea_num[value_idx];
int ori_index = index[start];
float* in = (float*)(input + size_t(ori_index) * grad_value_size);
float* out = (float*)(output + value_idx * grad_value_size);
merger.update_embedx(out, in, field_idx, feature_value_accessor);
KeyType key = d_keys[value_idx];
if (key != 0) {
for (int j = 1; j < num; ++j) {
int ori_index = index[start + j];
float* in = (float*)(input + size_t(ori_index) * grad_value_size);
merger.merge_embedx(out, in, field_idx, feature_value_accessor);
}
}
}
Expand Down Expand Up @@ -325,12 +357,18 @@ template <typename KeyType, typename StreamType>
void HeterCommKernel::merge_gradient(
const KeyType* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size,
DynamicGradMerger& merger, const StreamType& stream) {
int grid_size1 = (n - 1) / block_size_ + 1;
merge_gradients_basic_kernel<<<grid_size1, block_size_, 0, stream>>>(
d_keys,
offset, fea_num, index, input, output, n, grad_value_size, merger_, feature_value_accessor_);
offset, fea_num, index, input, output, n, grad_value_size, merger, feature_value_accessor_);
if (grad_dim > 0) {
int grid_size2 = (n * grad_dim - 1) / block_size_ + 1;
merge_gradients_embedx_kernel<<<grid_size2, block_size_, 0, stream>>>(
d_keys,
offset, fea_num, index, input, output, n * grad_dim, grad_dim, grad_value_size, merger, feature_value_accessor_);
}
}

template <typename T, typename StreamType>
Expand Down Expand Up @@ -417,13 +455,13 @@ template void HeterCommKernel::dy_mf_fill_shard_grads<
template void HeterCommKernel::merge_gradient<uint32_t, cudaStream_t>(
const uint32_t* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);

template void HeterCommKernel::merge_gradient<uint64_t, cudaStream_t>(
const uint64_t* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);

template void HeterCommKernel::dy_mf_fill_dvals<int, cudaStream_t>(
Expand Down
42 changes: 41 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,46 @@ struct DynamicGradMerger {
input[feature_value_accessor.common_push_value.EmbedxGIndex() + j];
}
}

__device__ __forceinline__ void update_basic(float* output, const float* input,
CommonFeatureValueAccessor& fv_accessor) {
output[fv_accessor.common_push_value.SlotIndex()] =
input[fv_accessor.common_push_value.SlotIndex()];
output[fv_accessor.common_push_value.ShowIndex()] =
input[fv_accessor.common_push_value.ShowIndex()];
output[fv_accessor.common_push_value.ClickIndex()] =
input[fv_accessor.common_push_value.ClickIndex()];
output[fv_accessor.common_push_value.MfDimIndex()] =
input[fv_accessor.common_push_value.MfDimIndex()];
output[fv_accessor.common_push_value.EmbedGIndex()] =
input[fv_accessor.common_push_value.EmbedGIndex()];
}

__device__ __forceinline__ void merge_basic(float* output, const float* input,
CommonFeatureValueAccessor& fv_accessor) {
output[fv_accessor.common_push_value.ShowIndex()] +=
input[fv_accessor.common_push_value.ShowIndex()];
output[fv_accessor.common_push_value.ClickIndex()] +=
input[fv_accessor.common_push_value.ClickIndex()];
output[fv_accessor.common_push_value.EmbedGIndex()] +=
input[fv_accessor.common_push_value.EmbedGIndex()];
}

__device__ __forceinline__ void update_embedx(float* output, const float* input, size_t embedx_idx,
CommonFeatureValueAccessor& fv_accessor) {
if (embedx_idx < output[fv_accessor.common_push_value.MfDimIndex()]) {
output[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx] =
input[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx];
}
}

__device__ __forceinline__ void merge_embedx(float* output, const float* input, size_t embedx_idx,
CommonFeatureValueAccessor& fv_accessor) {
if (embedx_idx < output[fv_accessor.common_push_value.MfDimIndex()]) {
output[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx] +=
input[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx];
}
}
};

class HeterCommKernel {
Expand Down Expand Up @@ -138,7 +178,7 @@ class HeterCommKernel {
template <typename KeyType, typename StreamType>
void merge_gradient(const KeyType* d_shard_keys, const uint32_t* offset, const uint32_t* fea_num,
const uint32_t* index, const char* input, char* output,
int n, size_t grad_value_size, DynamicGradMerger& merger_,
int n, size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger,
const StreamType& stream);

template <typename T, typename StreamType>
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ PADDLE_DEFINE_EXPORTED_double(
"the load factor of hbm table, default 0.75");
PADDLE_DEFINE_EXPORTED_bool(
gpugraph_enable_gpu_direct_access, false,
"enable hash collisions stat for hbm table, default false");
"enable direct access bwtween multi gpu cards, default false");

/**
* ProcessGroupNCCL related FLAG
Expand Down

0 comments on commit 79675fd

Please sign in to comment.