From 7006b829d808b2af04061e1c6a212031f538813c Mon Sep 17 00:00:00 2001 From: huwei02 <53012141+huwei02@users.noreply.github.com> Date: Sat, 10 Dec 2022 10:51:59 +0800 Subject: [PATCH] add excluded_train_pair and infer_node_type (#187) Co-authored-by: root --- paddle/fluid/framework/data_feed.cu | 207 +++++++++++++----- paddle/fluid/framework/data_feed.h | 7 + paddle/fluid/framework/data_feed.proto | 2 + .../fleet/heter_ps/graph_gpu_ps_table_inl.cu | 1 + .../fleet/heter_ps/graph_gpu_wrapper.cu | 57 +++-- .../fleet/heter_ps/graph_gpu_wrapper.h | 10 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 2 +- python/paddle/fluid/dataset.py | 4 + 8 files changed, 210 insertions(+), 80 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 1c819d29ca4d5..b393d1a344c70 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -449,66 +449,21 @@ int GraphDataGenerator::AcquireInstance(BufState *state) { return 0; } -// TODO opt -__global__ void GraphFillFeatureKernel(uint64_t *id_tensor, - int *fill_ins_num, - uint64_t *walk, - uint64_t *feature, - int *row, - int central_word, - int step, - int len, - int col_num, - int slot_num) { - __shared__ int32_t local_key[CUDA_NUM_THREADS * 16]; - __shared__ int local_num; - __shared__ int global_num; - - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (threadIdx.x == 0) { - local_num = 0; - } - __syncthreads(); - if (idx < len) { - int src = row[idx] * col_num + central_word; - if (walk[src] != 0 && walk[src + step] != 0) { - size_t dst = atomicAdd(&local_num, 1); - for (int i = 0; i < slot_num; ++i) { - local_key[dst * 2 * slot_num + i * 2] = feature[src * slot_num + i]; - local_key[dst * 2 * slot_num + i * 2 + 1] = - feature[(src + step) * slot_num + i]; - } - } - } - - __syncthreads(); - - if (threadIdx.x == 0) { - global_num = atomicAdd(fill_ins_num, local_num); - } - __syncthreads(); - - if (threadIdx.x < local_num) { - for (int i = 0; i < slot_num; ++i) { - id_tensor[(global_num * 2 + 2 * threadIdx.x) * slot_num + i] = - local_key[(2 * threadIdx.x) * slot_num + i]; - id_tensor[(global_num * 2 + 2 * threadIdx.x + 1) * slot_num + i] = - local_key[(2 * threadIdx.x + 1) * slot_num + i]; - } - } -} - __global__ void GraphFillIdKernel(uint64_t *id_tensor, int *fill_ins_num, uint64_t *walk, + uint8_t *walk_ntype, int *row, int central_word, int step, int len, - int col_num) { + int col_num, + uint8_t* excluded_train_pair, + int excluded_train_pair_len) { __shared__ uint64_t local_key[CUDA_NUM_THREADS * 2]; __shared__ int local_num; __shared__ int global_num; + bool need_filter = false; size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (threadIdx.x == 0) { @@ -521,9 +476,19 @@ __global__ void GraphFillIdKernel(uint64_t *id_tensor, if (idx < len) { int src = row[idx] * col_num + central_word; if (walk[src] != 0 && walk[src + step] != 0) { - size_t dst = atomicAdd(&local_num, 1); - local_key[dst * 2] = walk[src]; - local_key[dst * 2 + 1] = walk[src + step]; + for (int i = 0; i < excluded_train_pair_len; i += 2) { + if (walk_ntype[src] == excluded_train_pair[i] + && walk_ntype[src + step] == excluded_train_pair[i + 1]) { + // filter this pair + need_filter = true; + break; + } + } + if (!need_filter) { + size_t dst = atomicAdd(&local_num, 1); + local_key[dst * 2] = walk[src]; + local_key[dst * 2 + 1] = walk[src + step]; + } } } @@ -753,6 +718,12 @@ int GraphDataGenerator::FillGraphSlotFeature( int GraphDataGenerator::MakeInsPair(cudaStream_t stream) { uint64_t *walk = reinterpret_cast(d_walk_->ptr()); + uint8_t *walk_ntype = NULL; + uint8_t *excluded_train_pair = NULL; + if (excluded_train_pair_len_ > 0) { + walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); + excluded_train_pair = reinterpret_cast(d_excluded_train_pair_->ptr()); + } uint64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); int *random_row = reinterpret_cast(d_random_row_->ptr()); int *d_pair_num = reinterpret_cast(d_pair_num_->ptr()); @@ -763,11 +734,14 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) { ins_buf + ins_buf_pair_len_ * 2, d_pair_num, walk, + walk_ntype, random_row + buf_state_.cursor, buf_state_.central_word, window_step_[buf_state_.step], len, - walk_len_); + walk_len_, + excluded_train_pair, + excluded_train_pair_len_); int h_pair_num; cudaMemcpyAsync( &h_pair_num, d_pair_num, sizeof(int), cudaMemcpyDeviceToHost, stream); @@ -782,8 +756,8 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) { cudaMemcpyDeviceToHost); VLOG(2) << "h_pair_num = " << h_pair_num << ", ins_buf_pair_len = " << ins_buf_pair_len_; - for (int xx = 0; xx < 2 * ins_buf_pair_len_; xx++) { - VLOG(2) << "h_ins_buf[" << xx << "]: " << h_ins_buf[xx]; + for (int xx = 0; xx < ins_buf_pair_len_; xx++) { + VLOG(2) << "h_ins_buf: " << h_ins_buf[xx * 2] << ", " << h_ins_buf[xx * 2 + 1]; } } return ins_buf_pair_len_; @@ -809,6 +783,7 @@ int GraphDataGenerator::GenerateBatch() { platform::CUDADeviceGuard guard(gpuid_); int res = 0; if (!gpu_graph_training_) { + // infer if (!sage_mode_) { total_instance = (infer_node_start_ + batch_size_ <= infer_node_end_) ? batch_size_ @@ -829,6 +804,7 @@ int GraphDataGenerator::GenerateBatch() { sage_batch_count_); } } else { + // train if (!sage_mode_) { while (ins_buf_pair_len_ < batch_size_) { res = FillInsBuf(train_stream_); @@ -908,6 +884,7 @@ __global__ void GraphFillSampleKeysKernel(uint64_t *neighbors, __global__ void GraphDoWalkKernel(uint64_t *neighbors, uint64_t *walk, + uint8_t *walk_ntype, int *d_prefix_sum, int *actual_sample_size, int cur_degree, @@ -915,7 +892,8 @@ __global__ void GraphDoWalkKernel(uint64_t *neighbors, int len, int *id_cnt, int *sampleidx2row, - int col_size) { + int col_size, + uint8_t edge_dst_id) { CUDA_KERNEL_LOOP(i, len) { for (int k = 0; k < actual_sample_size[i]; k++) { // int idx = sampleidx2row[i]; @@ -924,6 +902,9 @@ __global__ void GraphDoWalkKernel(uint64_t *neighbors, size_t col = step; size_t offset = (row * col_size + col); walk[offset] = neighbors[i * cur_degree + k]; + if (walk_ntype != NULL) { + walk_ntype[offset] = edge_dst_id; + } } } } @@ -932,7 +913,10 @@ __global__ void GraphDoWalkKernel(uint64_t *neighbors, __global__ void GraphFillFirstStepKernel(int *prefix_sum, int *sampleidx2row, uint64_t *walk, + uint8_t *walk_ntype, uint64_t *keys, + uint8_t edge_src_id, + uint8_t edge_dst_id, int len, int walk_degree, int col_size, @@ -948,6 +932,10 @@ __global__ void GraphFillFirstStepKernel(int *prefix_sum, size_t offset = col_size * row; walk[offset] = keys[idx]; walk[offset + 1] = neighbors[idx * walk_degree + k]; + if (walk_ntype != NULL) { + walk_ntype[offset] = edge_src_id; + walk_ntype[offset + 1] = edge_dst_id; + } } } } @@ -1071,12 +1059,18 @@ __global__ void UniqueFeature(uint64_t *d_in, } // Fill sample_res to the stepth column of walk void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, + int etype_id, uint64_t *walk, + uint8_t *walk_ntype, int len, NeighborSampleResult &sample_res, int cur_degree, int step, int *len_per_row) { + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + uint64_t node_id = gpu_graph_ptr->edge_to_node_map_[etype_id]; + uint8_t edge_src_id = node_id >> 32; + uint8_t edge_dst_id = node_id; size_t temp_storage_bytes = 0; int *d_actual_sample_size = sample_res.actual_sample_size; uint64_t *d_neighbors = sample_res.val; @@ -1114,7 +1108,10 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, sample_stream_>>>(d_prefix_sum, d_tmp_sampleidx2row, walk, + walk_ntype, d_start_ids, + edge_src_id, + edge_dst_id, len, walk_degree_, walk_len_, @@ -1138,6 +1135,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, GraphDoWalkKernel<<>>( d_neighbors, walk, + walk_ntype, d_prefix_sum, d_actual_sample_size, cur_degree, @@ -1145,7 +1143,8 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, len, len_per_row, d_tmp_sampleidx2row, - walk_len_); + walk_len_, + edge_dst_id); } if (debug_mode_) { size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; @@ -1892,6 +1891,7 @@ void GraphDataGenerator::DoWalkandSage() { debug_gpu_memory_info(device_id, "DoWalkandSage start"); platform::CUDADeviceGuard guard(gpuid_); if (gpu_graph_training_) { + // train bool train_flag; if (FLAGS_graph_metapath_split_opt) { train_flag = FillWalkBufMultiPath(); @@ -1950,6 +1950,7 @@ void GraphDataGenerator::DoWalkandSage() { } } } else { + // infer bool infer_flag = FillInferBuf(); if (sage_mode_) { sage_batch_num_ = 0; @@ -2042,6 +2043,22 @@ int GraphDataGenerator::FillInferBuf() { return 0; } } + if (!infer_node_type_index_set_.empty()) { + while (infer_cursor < h_device_keys_len_.size()) { + if (infer_node_type_index_set_.find(infer_cursor) == infer_node_type_index_set_.end()) { + VLOG(2) << "Skip cursor[" << infer_cursor << "]"; + infer_cursor++; + continue; + } else { + VLOG(2) << "Not skip cursor[" << infer_cursor << "]"; + break; + } + } + if (infer_cursor >= h_device_keys_len_.size()) { + return 0; + } + } + size_t device_key_size = h_device_keys_len_[infer_cursor]; total_row_ = (global_infer_node_type_start[infer_cursor] + infer_table_cap_ <= @@ -2104,6 +2121,11 @@ int GraphDataGenerator::FillWalkBuf() { int *len_per_row = reinterpret_cast(d_len_per_row_->ptr()); uint64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), sample_stream_); + uint8_t *walk_ntype = NULL; + if (excluded_train_pair_len_ > 0) { + walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); + cudaMemsetAsync(walk_ntype, 0, buf_size_ * sizeof(uint8_t), sample_stream_); + } // cudaMemsetAsync( // len_per_row, 0, once_max_sample_keynum * sizeof(int), sample_stream_); int sample_times = 0; @@ -2157,6 +2179,10 @@ int GraphDataGenerator::FillWalkBuf() { VLOG(2) << "gpuid = " << gpuid_ << " path[0] = " << path[0]; uint64_t *cur_walk = walk + i; + uint8_t *cur_walk_ntype = NULL; + if (excluded_train_pair_len_ > 0) { + cur_walk_ntype = walk_ntype + i; + } NeighborSampleQuery q; q.initialize(gpuid_, @@ -2197,7 +2223,9 @@ int GraphDataGenerator::FillWalkBuf() { } } FillOneStep(d_type_keys + start, + path[0], cur_walk, + cur_walk_ntype, tmp_len, sample_res, walk_degree_, @@ -2248,7 +2276,9 @@ int GraphDataGenerator::FillWalkBuf() { } } FillOneStep(d_type_keys + start, + edge_type_id, cur_walk, + cur_walk_ntype, sample_key_len, sample_res, 1, @@ -2341,6 +2371,10 @@ int GraphDataGenerator::FillWalkBufMultiPath() { /////// auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); uint64_t *walk = reinterpret_cast(d_walk_->ptr()); + uint8_t *walk_ntype = NULL; + if (excluded_train_pair_len_ > 0) { + walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); + } int *len_per_row = reinterpret_cast(d_len_per_row_->ptr()); uint64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), sample_stream_); @@ -2359,7 +2393,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() { size_t node_type_len = first_node_type.size(); std::string first_node = paddle::string::split_string(cur_metapath, "2")[0]; - auto it = gpu_graph_ptr->feature_to_id.find(first_node); + auto it = gpu_graph_ptr->node_to_id.find(first_node); auto node_type = it->second; int remain_size = @@ -2383,6 +2417,10 @@ int GraphDataGenerator::FillWalkBufMultiPath() { VLOG(2) << "gpuid = " << gpuid_ << " path[0] = " << path[0]; uint64_t *cur_walk = walk + i; + uint8_t *cur_walk_ntype = NULL; + if (excluded_train_pair_len_ > 0) { + cur_walk_ntype = walk_ntype + i; + } NeighborSampleQuery q; q.initialize(gpuid_, @@ -2421,7 +2459,9 @@ int GraphDataGenerator::FillWalkBufMultiPath() { } FillOneStep(d_type_keys + start, + path[0], cur_walk, + cur_walk_ntype, tmp_len, sample_res, walk_degree_, @@ -2472,7 +2512,9 @@ int GraphDataGenerator::FillWalkBufMultiPath() { } } FillOneStep(d_type_keys + start, + edge_type_id, cur_walk, + cur_walk_ntype, sample_key_len, sample_res, 1, @@ -2638,6 +2680,27 @@ void GraphDataGenerator::AllocResource(int thread_id, phi::Stream(reinterpret_cast(sample_stream_))); cudaMemsetAsync( d_walk_->ptr(), 0, buf_size_ * sizeof(uint64_t), sample_stream_); + + excluded_train_pair_len_ = gpu_graph_ptr->excluded_train_pair_.size(); + if (excluded_train_pair_len_ > 0) { + d_excluded_train_pair_ = memory::AllocShared( + place_, + excluded_train_pair_len_ * sizeof(uint8_t), + phi::Stream(reinterpret_cast(sample_stream_))); + CUDA_CHECK(cudaMemcpyAsync( + d_excluded_train_pair_->ptr(), gpu_graph_ptr->excluded_train_pair_.data(), + excluded_train_pair_len_ * sizeof(uint8_t), + cudaMemcpyHostToDevice, + sample_stream_)); + + d_walk_ntype_ = memory::AllocShared( + place_, + buf_size_ * sizeof(uint8_t), + phi::Stream(reinterpret_cast(sample_stream_))); + cudaMemsetAsync( + d_walk_ntype_->ptr(), 0, buf_size_ * sizeof(uint8_t), sample_stream_); + } + d_sample_keys_ = memory::AllocShared( place_, once_max_sample_keynum * sizeof(uint64_t), @@ -2733,6 +2796,26 @@ void GraphDataGenerator::AllocResource(int thread_id, (batch_size_ * 2 * 2) * sizeof(uint32_t), phi::Stream(reinterpret_cast(sample_stream_))); } + + // parse infer_node_type + auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index(); + if (!gpu_graph_training_) { + auto node_types = paddle::string::split_string(infer_node_type_, ";"); + auto node_to_id = gpu_graph_ptr->node_to_id; + for (auto &type : node_types) { + auto iter = node_to_id.find(type); + PADDLE_ENFORCE_NE( + iter, + node_to_id.end(), + platform::errors::NotFound("(%s) is not found in node_to_id.", type)); + int node_type = iter->second; + int type_index = type_to_index[node_type]; + VLOG(2) << "add node[" << type << "] into infer_node_type, type_index(cursor)[" + << type_index << "]"; + infer_node_type_index_set_.insert(type_index); + } + VLOG(2) << "infer_node_type_index_set_num: " << infer_node_type_index_set_.size(); + } cudaStreamSynchronize(sample_stream_); @@ -2787,7 +2870,7 @@ void GraphDataGenerator::SetConfig( std::string str_samples = graph_config.samples(); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); debug_gpu_memory_info("init_conf start"); - gpu_graph_ptr->init_conf(first_node_type, meta_path); + gpu_graph_ptr->init_conf(first_node_type, meta_path, graph_config.excluded_train_pair()); debug_gpu_memory_info("init_conf end"); auto edge_to_id = gpu_graph_ptr->edge_to_id; @@ -2799,6 +2882,10 @@ void GraphDataGenerator::SetConfig( samples_.emplace_back(sample_size); } copy_unique_len_ = 0; + + if (!gpu_graph_training_) { + infer_node_type_ = graph_config.infer_node_type(); + } }; } // namespace framework diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 29d13f89c46bd..20685e7ea90fc 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -914,7 +914,9 @@ class GraphDataGenerator { int FillFeatureBuf(std::shared_ptr d_walk, std::shared_ptr d_feature); void FillOneStep(uint64_t* start_ids, + int etype_id, uint64_t* walk, + uint8_t *walk_ntype, int len, NeighborSampleResult& sample_res, int cur_degree, @@ -999,6 +1001,8 @@ class GraphDataGenerator { std::shared_ptr d_train_metapath_keys_; std::shared_ptr d_walk_; + std::shared_ptr d_walk_ntype_; + std::shared_ptr d_excluded_train_pair_; std::shared_ptr d_feature_list_; std::shared_ptr d_feature_; std::shared_ptr d_len_per_row_; @@ -1038,6 +1042,7 @@ class GraphDataGenerator { std::vector>> graph_edges_vec_; std::vector>> edges_split_num_vec_; + int excluded_train_pair_len_; int64_t reindex_table_size_; int sage_batch_count_; int sage_batch_num_; @@ -1066,6 +1071,8 @@ class GraphDataGenerator { int total_row_; size_t infer_node_start_; size_t infer_node_end_; + std::set infer_node_type_index_set_; + std::string infer_node_type_; }; class DataFeed { diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index 25610eea23781..904bd1698ee80 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -42,6 +42,8 @@ message GraphConfig { optional string samples = 12; optional int64 train_table_cap = 13 [ default = 80000 ]; optional int64 infer_table_cap = 14 [ default = 80000 ]; + optional string excluded_train_pair = 15; + optional string infer_node_type = 16; } message DataFeedDesc { diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 81fe3a50c6556..0f9e92a726f14 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -1329,6 +1329,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( return result; } +// only for graphsage NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( int gpu_id, int edge_type_len, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 6c074edeec12c..41a0d6d06be43 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -32,7 +32,8 @@ void GraphGpuWrapper::set_device(std::vector ids) { } void GraphGpuWrapper::init_conf(const std::string &first_node_type, - const std::string &meta_path) { + const std::string &meta_path, + const std::string &excluded_train_pair) { static std::mutex mutex; { std::lock_guard lock(mutex); @@ -45,12 +46,12 @@ void GraphGpuWrapper::init_conf(const std::string &first_node_type, paddle::string::split_string(first_node_type, ";"); VLOG(2) << "node_types: " << first_node_type; for (auto &type : node_types) { - auto iter = feature_to_id.find(type); + auto iter = node_to_id.find(type); PADDLE_ENFORCE_NE(iter, - feature_to_id.end(), + node_to_id.end(), platform::errors::NotFound( - "(%s) is not found in feature_to_id.", type)); - VLOG(2) << "feature_to_id[" << type << "] = " << iter->second; + "(%s) is not found in node_to_id.", type)); + VLOG(2) << "node_to_id[" << type << "] = " << iter->second; first_node_type_.push_back(iter->second); } meta_path_.resize(first_node_type_.size()); @@ -58,17 +59,39 @@ void GraphGpuWrapper::init_conf(const std::string &first_node_type, for (size_t i = 0; i < meta_paths.size(); i++) { auto path = meta_paths[i]; - auto nodes = paddle::string::split_string(path, "-"); + auto edges = paddle::string::split_string(path, "-"); + for (auto &edge : edges) { + auto iter = edge_to_id.find(edge); + PADDLE_ENFORCE_NE(iter, + edge_to_id.end(), + platform::errors::NotFound( + "(%s) is not found in edge_to_id.", edge)); + VLOG(2) << "edge_to_id[" << edge << "] = " << iter->second; + meta_path_[i].push_back(iter->second); + if (edge_to_node_map_.find(iter->second) == edge_to_node_map_.end()) { + auto nodes = paddle::string::split_string(edge, "2"); + uint64_t src_node_id = node_to_id.find(nodes[0])->second; + uint64_t dst_node_id = node_to_id.find(nodes[1])->second; + edge_to_node_map_[iter->second] = src_node_id << 32 | dst_node_id; + } + } + } + + auto paths = paddle::string::split_string(excluded_train_pair, ";"); + VLOG(2) << "excluded_train_pair[" << excluded_train_pair << "]"; + for (auto &path: paths) { + auto nodes = paddle::string::split_string(path, "2"); for (auto &node : nodes) { - auto iter = edge_to_id.find(node); + auto iter = node_to_id.find(node); PADDLE_ENFORCE_NE(iter, edge_to_id.end(), platform::errors::NotFound( "(%s) is not found in edge_to_id.", node)); VLOG(2) << "edge_to_id[" << node << "] = " << iter->second; - meta_path_[i].push_back(iter->second); + excluded_train_pair_.push_back(iter->second); } } + int max_dev_id = 0; for (size_t i = 0; i < device_id_mapping.size(); i++) { if (device_id_mapping[i] > max_dev_id) { @@ -85,11 +108,11 @@ void GraphGpuWrapper::init_conf(const std::string &first_node_type, auto &finish_node_type = finish_node_type_[i]; finish_node_type.clear(); - for (size_t idx = 0; idx < feature_to_id.size(); idx++) { + for (size_t idx = 0; idx < node_to_id.size(); idx++) { infer_node_type_start[idx] = 0; } for (auto &type : node_types) { - auto iter = feature_to_id.find(type); + auto iter = node_to_id.find(type); node_type_start[iter->second] = 0; infer_node_type_start[iter->second] = 0; } @@ -188,7 +211,7 @@ void GraphGpuWrapper::init_metapath(std::string cur_metapath, int first_node_idx; std::string first_node = paddle::string::split_string(cur_metapath_, "2")[0]; - auto it = feature_to_id.find(first_node); + auto it = node_to_id.find(first_node); first_node_idx = it->second; d_graph_train_total_keys_.resize(thread_num); h_graph_train_keys_len_.resize(thread_num); @@ -315,8 +338,8 @@ void GraphGpuWrapper::set_up_types(std::vector &edge_types, } id_to_feature = node_types; for (size_t table_id = 0; table_id < node_types.size(); table_id++) { - int res = feature_to_id.size(); - feature_to_id[node_types[table_id]] = res; + int res = node_to_id.size(); + node_to_id[node_types[table_id]] = res; } table_feat_mapping.resize(node_types.size()); this->table_feat_conf_feat_name.resize(node_types.size()); @@ -399,7 +422,7 @@ void GraphGpuWrapper::load_node_file(std::string name, std::string filepath) { std::string params = "n" + name; - if (feature_to_id.find(name) != feature_to_id.end()) { + if (node_to_id.find(name) != node_to_id.end()) { ((GpuPsGraphTable *)graph_table) ->cpu_graph_table_->Load(std::string(filepath), params); } @@ -427,8 +450,8 @@ void GraphGpuWrapper::add_table_feat_conf(std::string table_name, std::string feat_name, std::string feat_dtype, int feat_shape) { - if (feature_to_id.find(table_name) != feature_to_id.end()) { - int idx = feature_to_id[table_name]; + if (node_to_id.find(table_name) != node_to_id.end()) { + int idx = node_to_id[table_name]; if (table_feat_mapping[idx].find(feat_name) == table_feat_mapping[idx].end()) { int res = (int)table_feat_mapping[idx].size(); @@ -776,7 +799,7 @@ std::string &GraphGpuWrapper::get_node_type_size(std::string first_node_type) { auto &type_to_index = get_graph_type_to_index(); std::vector node_type_size; for (auto node : uniq_first_node_) { - auto it = feature_to_id.find(node); + auto it = node_to_id.find(node); auto first_node_idx = it->second; size_t f_idx = type_to_index[first_node_idx]; int type_total_key_size = graph_all_type_total_keys[f_idx].size(); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index 900d3486e68cd..84ddef9fa2e62 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -41,7 +41,8 @@ class GraphGpuWrapper { } static std::shared_ptr s_instance_; void init_conf(const std::string& first_node_type, - const std::string& meta_path); + const std::string& meta_path, + const std::string& excluded_train_pair); void initialize(); void finalize(); void set_device(std::vector ids); @@ -160,7 +161,7 @@ class GraphGpuWrapper { std::string& get_node_type_size(std::string first_node_type); std::string& get_edge_type_size(); - std::unordered_map edge_to_id, feature_to_id; + std::unordered_map edge_to_id, node_to_id; std::vector id_to_feature, id_to_edge; std::vector> table_feat_mapping; std::vector> table_feat_conf_feat_name; @@ -175,6 +176,7 @@ class GraphGpuWrapper { std::string feature_separator_ = std::string(" "); bool conf_initialized_ = false; std::vector first_node_type_; + std::vector excluded_train_pair_; std::vector> meta_path_; std::vector> finish_node_type_; @@ -187,6 +189,10 @@ class GraphGpuWrapper { std::vector h_graph_train_keys_len_; std::vector>> d_graph_all_type_total_keys_; + std::map edge_to_node_map_; + std::vector> h_graph_all_type_keys_len_; std::string slot_feature_separator_ = std::string(" "); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 5c8b9d1998c61..e42c8255287c9 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -1446,7 +1446,7 @@ void PSGPUWrapper::SparseTableToHbm() { gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_); gpu_task->pass_id_ = (uint16_t)(dataset_->GetPassID()); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - auto node_to_id = gpu_graph_ptr->feature_to_id; + auto node_to_id = gpu_graph_ptr->node_to_id; auto edge_to_id = gpu_graph_ptr->edge_to_id; std::vector vec_data = gpu_graph_ptr->get_graph_total_keys(); diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index b5f6fe4fbcb0c..6d9d042cf343e 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -1090,6 +1090,10 @@ def set_graph_config(self, config): "train_table_cap", 800000) self.proto_desc.graph_config.infer_table_cap = config.get( "infer_table_cap", 800000) + self.proto_desc.graph_config.excluded_train_pair = config.get( + "excluded_train_pair", "") + self.proto_desc.graph_config.infer_node_type = config.get( + "infer_node_type", "") self.dataset.set_gpu_graph_mode(True) def set_pass_id(self, pass_id):