Skip to content

Commit

Permalink
get_features: int64_t -> uint64_t (xuewujiao#38)
Browse files Browse the repository at this point in the history
* get_features: int64_t -> uint64_t

* remove useless var

Co-authored-by: root <[email protected]>
  • Loading branch information
huwei02 and root authored Jun 21, 2022
1 parent 0e37d85 commit 359d32f
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 58 deletions.
76 changes: 38 additions & 38 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ int GraphDataGenerator::AcquireInstance(BufState *state) {
}

// TODO opt
__global__ void GraphFillFeatureKernel(int64_t *id_tensor, int *fill_ins_num,
int64_t *walk, int64_t *feature,
__global__ void GraphFillFeatureKernel(uint64_t *id_tensor, int *fill_ins_num,
uint64_t *walk, uint64_t *feature,
int *row, int central_word, int step,
int len, int col_num, int slot_num) {
__shared__ int32_t local_key[CUDA_NUM_THREADS * 16];
Expand Down Expand Up @@ -221,10 +221,10 @@ __global__ void GraphFillFeatureKernel(int64_t *id_tensor, int *fill_ins_num,
}
}

__global__ void GraphFillIdKernel(int64_t *id_tensor, int *fill_ins_num,
int64_t *walk, int *row, int central_word,
__global__ void GraphFillIdKernel(uint64_t *id_tensor, int *fill_ins_num,
uint64_t *walk, int *row, int central_word,
int step, int len, int col_num) {
__shared__ int64_t local_key[CUDA_NUM_THREADS * 2];
__shared__ uint64_t local_key[CUDA_NUM_THREADS * 2];
__shared__ int local_num;
__shared__ int global_num;

Expand Down Expand Up @@ -257,22 +257,22 @@ __global__ void GraphFillIdKernel(int64_t *id_tensor, int *fill_ins_num,
}
}

__global__ void GraphFillSlotKernel(int64_t *id_tensor, int64_t *feature_buf,
__global__ void GraphFillSlotKernel(uint64_t *id_tensor, uint64_t *feature_buf,
int len, int total_ins, int slot_num) {
CUDA_KERNEL_LOOP(idx, len) {
int slot_idx = idx / total_ins;
int ins_idx = idx % total_ins;
((int64_t *)(id_tensor[slot_idx]))[ins_idx] =
((uint64_t *)(id_tensor[slot_idx]))[ins_idx] =
feature_buf[ins_idx * slot_num + slot_idx];
}
}

__global__ void GraphFillSlotLodKernelOpt(int64_t *id_tensor, int len,
__global__ void GraphFillSlotLodKernelOpt(uint64_t *id_tensor, int len,
int total_ins) {
CUDA_KERNEL_LOOP(idx, len) {
int slot_idx = idx / total_ins;
int ins_idx = idx % total_ins;
((int64_t *)(id_tensor[slot_idx]))[ins_idx] = ins_idx;
((uint64_t *)(id_tensor[slot_idx]))[ins_idx] = ins_idx;
}
}

Expand Down Expand Up @@ -327,8 +327,8 @@ int GraphDataGenerator::FillInsBuf() {
}
}

int64_t *walk = reinterpret_cast<int64_t *>(d_walk_->ptr());
int64_t *ins_buf = reinterpret_cast<int64_t *>(d_ins_buf_->ptr());
uint64_t *walk = reinterpret_cast<uint64_t *>(d_walk_->ptr());
uint64_t *ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
int *random_row = reinterpret_cast<int *>(d_random_row_->ptr());
int *d_pair_num = reinterpret_cast<int *>(d_pair_num_->ptr());
cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream_);
Expand All @@ -342,8 +342,8 @@ int GraphDataGenerator::FillInsBuf() {
stream_);

if (!FLAGS_enable_opt_get_features && slot_num_ > 0) {
int64_t *feature_buf = reinterpret_cast<int64_t *>(d_feature_buf_->ptr());
int64_t *feature = reinterpret_cast<int64_t *>(d_feature_->ptr());
uint64_t *feature_buf = reinterpret_cast<uint64_t *>(d_feature_buf_->ptr());
uint64_t *feature = reinterpret_cast<uint64_t *>(d_feature_->ptr());
cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream_);
int len = buf_state_.len;
VLOG(2) << "feature_buf start[" << ins_buf_pair_len_ * 2 * slot_num_
Expand All @@ -358,8 +358,8 @@ int GraphDataGenerator::FillInsBuf() {
ins_buf_pair_len_ += h_pair_num;

if (debug_mode_) {
int64_t *h_ins_buf = new int64_t[ins_buf_pair_len_ * 2];
cudaMemcpy(h_ins_buf, ins_buf, 2 * ins_buf_pair_len_ * sizeof(int64_t),
uint64_t h_ins_buf[ins_buf_pair_len_ * 2];
cudaMemcpy(h_ins_buf, ins_buf, 2 * ins_buf_pair_len_ * sizeof(uint64_t),
cudaMemcpyDeviceToHost);
VLOG(2) << "h_pair_num = " << h_pair_num
<< ", ins_buf_pair_len = " << ins_buf_pair_len_;
Expand All @@ -369,10 +369,10 @@ int GraphDataGenerator::FillInsBuf() {
delete[] h_ins_buf;

if (!FLAGS_enable_opt_get_features && slot_num_ > 0) {
int64_t *feature_buf = reinterpret_cast<int64_t *>(d_feature_buf_->ptr());
int64_t h_feature_buf[(batch_size_ * 2 * 2) * slot_num_];
uint64_t *feature_buf = reinterpret_cast<uint64_t *>(d_feature_buf_->ptr());
uint64_t h_feature_buf[(batch_size_ * 2 * 2) * slot_num_];
cudaMemcpy(h_feature_buf, feature_buf,
(batch_size_ * 2 * 2) * slot_num_ * sizeof(int64_t),
(batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t),
cudaMemcpyDeviceToHost);
for (int xx = 0; xx < (batch_size_ * 2 * 2) * slot_num_; xx++) {
VLOG(2) << "h_feature_buf[" << xx << "]: " << h_feature_buf[xx];
Expand Down Expand Up @@ -454,19 +454,19 @@ int GraphDataGenerator::GenerateBatch() {
}
if (FLAGS_enable_opt_get_features) {
cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), slot_tensor_ptr_,
sizeof(int64_t *) * slot_num_, cudaMemcpyHostToDevice,
sizeof(uint64_t *) * slot_num_, cudaMemcpyHostToDevice,
stream_);
cudaMemcpyAsync(d_slot_lod_tensor_ptr_->ptr(), slot_lod_tensor_ptr_,
sizeof(int64_t *) * slot_num_, cudaMemcpyHostToDevice,
sizeof(uint64_t *) * slot_num_, cudaMemcpyHostToDevice,
stream_);
}
}

VLOG(2) << "total_instance: " << total_instance
<< ", ins_buf_pair_len = " << ins_buf_pair_len_;
int64_t *ins_buf = reinterpret_cast<int64_t *>(d_ins_buf_->ptr());
int64_t *ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance;
cudaMemcpyAsync(id_tensor_ptr_, ins_cursor, sizeof(int64_t) * total_instance,
uint64_t *ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
uint64_t *ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance;
cudaMemcpyAsync(id_tensor_ptr_, ins_cursor, sizeof(uint64_t) * total_instance,
cudaMemcpyDeviceToDevice, stream_);

GraphFillCVMKernel<<<GET_BLOCKS(total_instance), CUDA_NUM_THREADS, 0,
Expand All @@ -475,7 +475,7 @@ int GraphDataGenerator::GenerateBatch() {
stream_>>>(clk_tensor_ptr_, total_instance);

if (slot_num_ > 0) {
int64_t *feature_buf = reinterpret_cast<int64_t *>(d_feature_buf_->ptr());
uint64_t *feature_buf = reinterpret_cast<uint64_t *>(d_feature_buf_->ptr());
if (FLAGS_enable_opt_get_features) {
FillFeatureBuf(ins_cursor, feature_buf, total_instance);
if (debug_mode_) {
Expand All @@ -500,11 +500,11 @@ int GraphDataGenerator::GenerateBatch() {

GraphFillSlotKernel<<<GET_BLOCKS(total_instance * slot_num_),
CUDA_NUM_THREADS, 0, stream_>>>(
(int64_t *)d_slot_tensor_ptr_->ptr(), feature_buf,
(uint64_t *)d_slot_tensor_ptr_->ptr(), feature_buf,
total_instance * slot_num_, total_instance, slot_num_);
GraphFillSlotLodKernelOpt<<<GET_BLOCKS((total_instance + 1) * slot_num_),
CUDA_NUM_THREADS, 0, stream_>>>(
(int64_t *)d_slot_lod_tensor_ptr_->ptr(),
(uint64_t *)d_slot_lod_tensor_ptr_->ptr(),
(total_instance + 1) * slot_num_, total_instance + 1);
} else {
for (int i = 0; i < slot_num_; ++i) {
Expand All @@ -517,7 +517,7 @@ int GraphDataGenerator::GenerateBatch() {
<< feature_buf_offset + j * slot_num_ + 1 << "]";
cudaMemcpyAsync(slot_tensor_ptr_[i] + j,
&feature_buf[feature_buf_offset + j * slot_num_],
sizeof(int64_t) * 2, cudaMemcpyDeviceToDevice,
sizeof(uint64_t) * 2, cudaMemcpyDeviceToDevice,
stream_);
}
GraphFillSlotLodKernel<<<GET_BLOCKS(total_instance), CUDA_NUM_THREADS,
Expand All @@ -543,19 +543,19 @@ int GraphDataGenerator::GenerateBatch() {
cudaStreamSynchronize(stream_);

if (debug_mode_) {
int64_t h_slot_tensor[slot_num_][total_instance];
int64_t h_slot_lod_tensor[slot_num_][total_instance + 1];
uint64_t h_slot_tensor[slot_num_][total_instance];
uint64_t h_slot_lod_tensor[slot_num_][total_instance + 1];
for (int i = 0; i < slot_num_; ++i) {
cudaMemcpy(h_slot_tensor[i], slot_tensor_ptr_[i],
total_instance * sizeof(int64_t), cudaMemcpyDeviceToHost);
total_instance * sizeof(uint64_t), cudaMemcpyDeviceToHost);
int len = total_instance > 5000 ? 5000 : total_instance;
for (int j = 0; j < len; ++j) {
VLOG(2) << "gpu[" << gpuid_ << "] slot_tensor[" << i << "][" << j
<< "] = " << h_slot_tensor[i][j];
}

cudaMemcpy(h_slot_lod_tensor[i], slot_lod_tensor_ptr_[i],
(total_instance + 1) * sizeof(int64_t),
(total_instance + 1) * sizeof(uint64_t),
cudaMemcpyDeviceToHost);
len = total_instance + 1 > 5000 ? 5000 : total_instance + 1;
for (int j = 0; j < len; ++j) {
Expand Down Expand Up @@ -666,7 +666,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, uint64_t *walk,
int *h_prefix_sum = new int[len + 1];
int *h_actual_size = new int[len];
int *h_offset2idx = new int[once_max_sample_keynum];
int64_t *h_sample_keys = new int64_t[once_max_sample_keynum];
uint64_t h_sample_keys[once_max_sample_keynum];
cudaMemcpy(h_offset2idx, d_tmp_sampleidx2row,
once_max_sample_keynum * sizeof(int), cudaMemcpyDeviceToHost);

Expand All @@ -687,7 +687,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, uint64_t *walk,
cur_sampleidx2row_ = 1 - cur_sampleidx2row_;
}

int GraphDataGenerator::FillFeatureBuf(int64_t *d_walk, int64_t *d_feature,
int GraphDataGenerator::FillFeatureBuf(uint64_t *d_walk, uint64_t *d_feature,
size_t key_num) {
platform::CUDADeviceGuard guard(gpuid_);

Expand All @@ -704,7 +704,7 @@ int GraphDataGenerator::FillFeatureBuf(

auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
int ret = gpu_graph_ptr->get_feature_of_nodes(
gpuid_, (int64_t *)d_walk->ptr(), (int64_t *)d_feature->ptr(), buf_size_,
gpuid_, (uint64_t *)d_walk->ptr(), (uint64_t *)d_feature->ptr(), buf_size_,
slot_num_);
return ret;
}
Expand Down Expand Up @@ -928,17 +928,17 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place,

ins_buf_pair_len_ = 0;
d_ins_buf_ =
memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(int64_t));
memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(uint64_t));
if (slot_num_ > 0) {
d_feature_buf_ = memory::AllocShared(
place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(int64_t));
place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t));
}
d_pair_num_ = memory::AllocShared(place_, sizeof(int));
if (FLAGS_enable_opt_get_features && slot_num_ > 0) {
d_slot_tensor_ptr_ =
memory::AllocShared(place_, slot_num_ * sizeof(int64_t *));
memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *));
d_slot_lod_tensor_ptr_ =
memory::AllocShared(place_, slot_num_ * sizeof(int64_t *));
memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *));
}

cudaStreamSynchronize(stream_);
Expand Down
5 changes: 1 addition & 4 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ class GraphDataGenerator {
int AcquireInstance(BufState* state);
int GenerateBatch();
int FillWalkBuf(std::shared_ptr<phi::Allocation> d_walk);
int FillFeatureBuf(int64_t* d_walk, int64_t* d_feature, size_t key_num);
int FillFeatureBuf(uint64_t* d_walk, uint64_t* d_feature, size_t key_num);
int FillFeatureBuf(std::shared_ptr<phi::Allocation> d_walk,
std::shared_ptr<phi::Allocation> d_feature);
void FillOneStep(uint64_t* start_ids, uint64_t* walk, int len,
Expand All @@ -921,9 +921,6 @@ class GraphDataGenerator {
// point to device_keys_
size_t cursor_;
size_t jump_rows_;
int64_t* id_tensor_ptr_;
int64_t* show_tensor_ptr_;
int64_t* clk_tensor_ptr_;
cudaStream_t stream_;
paddle::platform::Place place_;
std::vector<LoDTensor*> feed_vec_;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
uint64_t *key, int sample_size,
int len, bool cpu_query_switch);

int get_feature_of_nodes(int gpu_id, int64_t* d_walk,
int64_t* d_offset, int size, int slot_num);
int get_feature_of_nodes(int gpu_id, uint64_t* d_walk,
uint64_t* d_offset, int size, int slot_num);

NodeQueryResult query_node_list(int gpu_id, int idx, int start,
int query_size);
Expand Down
20 changes: 9 additions & 11 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ __global__ void copy_buffer_ac_to_final_place(
}
}

__global__ void get_features_kernel(GpuPsCommGraphFea graph, int64_t* node_offset_array,
__global__ void get_features_kernel(GpuPsCommGraphFea graph, uint64_t* node_offset_array,
int* actual_size, uint64_t* feature, int slot_num, int n) {
int idx = blockIdx.x * blockDim.y + threadIdx.y;
if (idx < n) {
Expand Down Expand Up @@ -899,8 +899,8 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int idx, int start,
return result;
}
int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int64_t* d_nodes,
int64_t* d_feature , int node_num, int slot_num) {
int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, uint64_t* d_nodes,
uint64_t* d_feature , int node_num, int slot_num) {
if (node_num == 0) {
return -1;
}
Expand Down Expand Up @@ -928,10 +928,9 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int64_t* d_nodes,
auto d_shard_actual_size = memory::Alloc(place, node_num * sizeof(int));
int* d_shard_actual_size_ptr = reinterpret_cast<int*>(d_shard_actual_size->ptr());
uint64_t* key = (uint64_t*)d_nodes;
split_input_to_shard((uint64_t*)(key), d_idx_ptr, node_num, d_left_ptr, d_right_ptr, gpu_id);
split_input_to_shard(d_nodes, d_idx_ptr, node_num, d_left_ptr, d_right_ptr, gpu_id);
heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, node_num, stream);
heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, d_nodes, d_idx_ptr, node_num, stream);
cudaStreamSynchronize(stream);
int h_left[total_gpu]; // NOLINT
Expand All @@ -944,7 +943,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int64_t* d_nodes,
continue;
}
create_storage(gpu_id, i, shard_len * sizeof(uint64_t),
shard_len * slot_num * sizeof(uint64_t) + shard_len * sizeof(int64_t)
shard_len * slot_num * sizeof(uint64_t) + shard_len * sizeof(uint64_t)
+ sizeof(int) * (shard_len + shard_len % 2));
}
Expand All @@ -956,7 +955,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int64_t* d_nodes,
}
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
auto& node = path_[gpu_id][i].nodes_.back();
cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(int64_t), node.in_stream);
cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(uint64_t), node.in_stream);
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
// If not found, val is -1.
Expand All @@ -968,7 +967,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int64_t* d_nodes,
int offset = i * feature_table_num_;
auto graph = gpu_graph_fea_list_[offset];
int64_t* val_array = reinterpret_cast<int64_t*>(node.val_storage);
uint64_t* val_array = reinterpret_cast<uint64_t*>(node.val_storage);
int* actual_size_array = (int*)(val_array + shard_len);
uint64_t* feature_array = (uint64_t*)(actual_size_array + shard_len + shard_len % 2);
dim3 grid((shard_len - 1) / dim_y + 1);
Expand All @@ -988,8 +987,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int64_t* d_nodes,
d_shard_vals_ptr, d_shard_actual_size_ptr);
int grid_size = (node_num - 1) / block_size_ + 1;
uint64_t* result = (uint64_t*)d_feature;
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(d_shard_vals_ptr, result,
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(d_shard_vals_ptr, d_feature,
d_shard_actual_size_ptr, d_idx_ptr, slot_num, node_num);
for (int i = 0; i < total_gpu; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample_v3(
->graph_neighbor_sample_v3(q, cpu_switch);
}

int GraphGpuWrapper::get_feature_of_nodes(int gpu_id, int64_t* d_walk,
int64_t* d_offset, uint32_t size, int slot_num) {
int GraphGpuWrapper::get_feature_of_nodes(int gpu_id, uint64_t* d_walk,
uint64_t* d_offset, uint32_t size, int slot_num) {
platform::CUDADeviceGuard guard(gpu_id);
PADDLE_ENFORCE_NOT_NULL(graph_table);
return ((GpuPsGraphTable *)graph_table)
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class GraphGpuWrapper {
std::vector<uint64_t>& key,
int sample_size);
void set_feature_separator(std::string ch);
int get_feature_of_nodes(int gpu_id, int64_t* d_walk, int64_t* d_offset,
int get_feature_of_nodes(int gpu_id, uint64_t* d_walk, uint64_t* d_offset,
uint32_t size, int slot_num);

std::unordered_map<std::string, int> edge_to_id, feature_to_id;
Expand Down

0 comments on commit 359d32f

Please sign in to comment.