diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index e3b2fb1def809a..7c1c665aa2a8e5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -21,6 +21,11 @@ #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" +#define ALIGN_INT64(LEN) (uint64_t((LEN) + 7) & uint64_t(~7)) +#define HBMPS_MAX_BUFF 1024 * 1024 + +DECLARE_bool(enable_neighbor_list_use_uva); + namespace paddle { namespace framework { /* @@ -882,8 +887,14 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g, gpu_graph_list_[offset].node_size = 0; } if (g.neighbor_size) { - cudaError_t cudaStatus = cudaMalloc(&gpu_graph_list_[offset].neighbor_list, + cudaError_t cudaStatus; + if (!FLAGS_enable_neighbor_list_use_uva) { + cudaStatus = cudaMalloc(&gpu_graph_list_[offset].neighbor_list, + g.neighbor_size * sizeof(uint64_t)); + } else { + cudaStatus = cudaMallocManaged(&gpu_graph_list_[offset].neighbor_list, g.neighbor_size * sizeof(uint64_t)); + } PADDLE_ENFORCE_EQ(cudaStatus, cudaSuccess, platform::errors::InvalidArgument( @@ -945,9 +956,13 @@ void GpuPsGraphTable::build_graph_from_cpu( gpu_graph_list_[offset].node_size = 0; } if (cpu_graph_list[i].neighbor_size) { - CUDA_CHECK( - cudaMalloc(&gpu_graph_list_[offset].neighbor_list, - cpu_graph_list[i].neighbor_size * sizeof(uint64_t))); + if (!FLAGS_enable_neighbor_list_use_uva) { + CUDA_CHECK(cudaMalloc(&gpu_graph_list_[offset].neighbor_list, + cpu_graph_list[i].neighbor_size * sizeof(uint64_t))); + } else { + CUDA_CHECK(cudaMallocManaged(&gpu_graph_list_[offset].neighbor_list, + cpu_graph_list[i].neighbor_size * sizeof(uint64_t))); + } CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].neighbor_list, cpu_graph_list[i].neighbor_list, diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index b2a6b58e1bd225..0cdcec8ba69dd3 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -852,6 +852,18 @@ PHI_DEFINE_EXPORTED_bool(graph_load_in_parallel, false, "It controls whether load graph node and edge with " "mutli threads parallely."); + +/** + * Distributed related FLAG + * Name: FLAGS_enable_neighbor_list_use_uva + * Since Version: 2.2.0 + * Value Range: bool, default=false + * Example: + * Note: Control whether store neighbor_list with UVA + */ +PHI_DEFINE_EXPORTED_bool(enable_neighbor_list_use_uva, + false, + "It controls whether store neighbor_list with UVA"); /** * Distributed related FLAG