diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu old mode 100644 new mode 100755 index 698e9455c04cf..7b7bc92862f4b --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -2055,6 +2055,7 @@ void GraphDataGenerator::DoWalkandSage() { } void GraphDataGenerator::clear_gpu_mem() { + platform::CUDADeviceGuard guard(gpuid_); d_len_per_row_.reset(); d_sample_keys_.reset(); d_prefix_sum_.reset(); @@ -2356,8 +2357,9 @@ int GraphDataGenerator::FillWalkBuf() { buf_state_.Reset(total_row_); int *d_random_row = reinterpret_cast(d_random_row_->ptr()); + paddle::memory::ThrustAllocator allocator(place_, sample_stream_); thrust::random::default_random_engine engine(shuffle_seed_); - const auto &exec_policy = thrust::cuda::par.on(sample_stream_); + const auto &exec_policy = thrust::cuda::par(allocator).on(sample_stream_); thrust::counting_iterator cnt_iter(0); thrust::shuffle_copy(exec_policy, cnt_iter, @@ -2591,8 +2593,9 @@ int GraphDataGenerator::FillWalkBufMultiPath() { buf_state_.Reset(total_row_); int *d_random_row = reinterpret_cast(d_random_row_->ptr()); + paddle::memory::ThrustAllocator allocator(place_, sample_stream_); thrust::random::default_random_engine engine(shuffle_seed_); - const auto &exec_policy = thrust::cuda::par.on(sample_stream_); + const auto &exec_policy = thrust::cuda::par(allocator).on(sample_stream_); thrust::counting_iterator cnt_iter(0); thrust::shuffle_copy(exec_policy, cnt_iter, @@ -2647,22 +2650,22 @@ void GraphDataGenerator::AllocResource(int thread_id, debug_gpu_memory_info(gpuid_, "AllocResource start"); platform::CUDADeviceGuard guard(gpuid_); + sample_stream_ = gpu_graph_ptr->get_local_stream(gpuid_); + train_stream_ = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->stream(); if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { if (gpu_graph_training_) { table_ = new HashTable( - train_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor); + train_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor, sample_stream_); } else { table_ = new HashTable( - infer_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor); + infer_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor, sample_stream_); } } VLOG(1) << "AllocResource gpuid " << gpuid_ << " feed_vec.size: " << feed_vec.size() << " table cap: " << train_table_cap_; - sample_stream_ = gpu_graph_ptr->get_local_stream(gpuid_); - train_stream_ = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place_)) - ->stream(); // feed_vec_ = feed_vec; if (!sage_mode_) { slot_num_ = (feed_vec.size() - 3) / 2; @@ -2782,8 +2785,10 @@ void GraphDataGenerator::AllocResource(int thread_id, ins_buf_pair_len_ = 0; if (!sage_mode_) { d_ins_buf_ = - memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(uint64_t)); - d_pair_num_ = memory::AllocShared(place_, sizeof(int)); + memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(uint64_t), + phi::Stream(reinterpret_cast(sample_stream_))); + d_pair_num_ = memory::AllocShared(place_, sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); } else { d_ins_buf_ = memory::AllocShared( place_, @@ -2796,9 +2801,11 @@ void GraphDataGenerator::AllocResource(int thread_id, } d_slot_tensor_ptr_ = - memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *)); + memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *), + phi::Stream(reinterpret_cast(sample_stream_))); d_slot_lod_tensor_ptr_ = - memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *)); + memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *), + phi::Stream(reinterpret_cast(sample_stream_))); if (sage_mode_) { reindex_table_size_ = batch_size_ * 2; diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index afe686ea48dd6..0a5bc6c94c1d0 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -24,7 +24,7 @@ #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/timer.h" - +#include "paddle/fluid/framework/threadpool.h" #ifdef PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/ps/wrapper/fleet.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" @@ -447,7 +447,19 @@ void MultiSlotDataset::PrepareTrain() { #endif return; } - +inline std::vector>& +GetReadThreadPool(int thread_num) { + static std::vector> + thread_pools; + if (!thread_pools.empty()) { + return thread_pools; + } + thread_pools.resize(thread_num); + for (int i = 0; i < thread_num; ++i) { + thread_pools[i].reset(new paddle::framework::ThreadPool(1)); + } + return thread_pools; +} // load data into memory, Dataset hold this memory, // which will later be fed into readers' channel template @@ -455,10 +467,11 @@ void DatasetImpl::LoadIntoMemory() { VLOG(3) << "DatasetImpl::LoadIntoMemory() begin"; platform::Timer timeline; timeline.Start(); - std::vector load_threads; if (gpu_graph_mode_) { VLOG(1) << "in gpu_graph_mode"; #ifdef PADDLE_WITH_HETERPS + std::vector> wait_futures; + auto pool = GetReadThreadPool(thread_num_); for (size_t i = 0; i < readers_.size(); i++) { readers_[i]->SetGpuGraphMode(gpu_graph_mode_); } @@ -473,24 +486,41 @@ void DatasetImpl::LoadIntoMemory() { } for (int64_t i = 0; i < thread_num_; ++i) { - load_threads.push_back(std::thread( - &paddle::framework::DataFeed::DoWalkandSage, readers_[i].get())); + wait_futures.emplace_back( + pool[i]->Run([this, i]() { readers_[i]->DoWalkandSage(); })); } - for (std::thread& t : load_threads) { - t.join(); + for (auto& th : wait_futures) { + th.get(); } + wait_futures.clear(); + uint64_t node_num = 0; + std::vector offsets; + offsets.resize(thread_num_); + for (int i = 0; i < thread_num_; i++) { auto& host_vec = readers_[i]->GetHostVec(); + offsets[i] = node_num; node_num += host_vec.size(); } - gpu_graph_total_keys_.reserve(node_num); + gpu_graph_total_keys_.resize(node_num); for (int i = 0; i < thread_num_; i++) { - auto& host_vec = readers_[i]->GetHostVec(); - for (size_t j = 0; j < host_vec.size(); j++) { - gpu_graph_total_keys_.push_back(host_vec[j]); - } + uint64_t off = offsets[i]; + wait_futures.emplace_back( + pool[i]->Run([this, i, off]() { + auto& host_vec = readers_[i]->GetHostVec(); + for (size_t j = 0; j < host_vec.size(); j++) { + gpu_graph_total_keys_[off + j] = host_vec[j]; + } + if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { + readers_[i]->clear_gpu_mem(); + } + })); } + for (auto& th : wait_futures) { + th.get(); + } + wait_futures.clear(); if (GetEpochFinish() == true) { VLOG(0) << "epoch finish, set stat and clear sample stat!"; @@ -499,16 +529,17 @@ void DatasetImpl::LoadIntoMemory() { readers_[i]->ClearSampleState(); } } - if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { - for (size_t i = 0; i < readers_.size(); i++) { - readers_[i]->clear_gpu_mem(); - } - } +// if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { +// for (size_t i = 0; i < readers_.size(); i++) { +// readers_[i]->clear_gpu_mem(); +// } +// } VLOG(2) << "end add edge into gpu_graph_total_keys_ size[" << gpu_graph_total_keys_.size() << "]"; #endif } else { + std::vector load_threads; for (int64_t i = 0; i < thread_num_; ++i) { load_threads.push_back(std::thread( &paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get())); diff --git a/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h b/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h index b4590548d70fb..56567837af9fc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h @@ -354,7 +354,8 @@ class concurrent_unordered_map : public managed { public: concurrent_unordered_map(const concurrent_unordered_map&) = delete; concurrent_unordered_map& operator=(const concurrent_unordered_map&) = delete; - explicit concurrent_unordered_map(size_type n, + explicit concurrent_unordered_map(cudaStream_t stream, + size_type n, const mapped_type unused_element, const Hasher& hf = hasher(), const Equality& eql = key_equal(), @@ -388,13 +389,13 @@ class concurrent_unordered_map : public managed { int dev_id = 0; CUDA_RT_CALL(cudaGetDevice(&dev_id)); CUDA_RT_CALL(cudaMemPrefetchAsync( - m_hashtbl_values, m_hashtbl_size * sizeof(value_type), dev_id, 0)); + m_hashtbl_values, m_hashtbl_size * sizeof(value_type), dev_id, stream)); } } // Initialize kernel, set all entry to unused - init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size>>>( + init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0, stream>>>( m_hashtbl_values, m_hashtbl_size, unused_key, m_unused_element); - CUDA_RT_CALL(cudaStreamSynchronize(0)); + CUDA_RT_CALL(cudaStreamSynchronize(stream)); CUDA_RT_CALL(cudaGetLastError()); m_enable_collision_stat = FLAGS_gpugraph_enable_hbm_table_collision_stat; } diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index e5445dc9b0592..cf90e75c153d7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -349,13 +349,20 @@ struct NeighborSampleResultV2 { struct NodeQueryResult { uint64_t *val; int actual_sample_size; + cudaStream_t stream = 0; + void set_stream(cudaStream_t stream_t) { stream = stream_t; } uint64_t get_val() { return (uint64_t)val; } int get_len() { return actual_sample_size; } std::shared_ptr val_mem; void initialize(int query_size, int dev_id) { platform::CUDADeviceGuard guard(dev_id); platform::CUDAPlace place = platform::CUDAPlace(dev_id); - val_mem = memory::AllocShared(place, query_size * sizeof(uint64_t)); + if (stream != 0) { + val_mem = memory::AllocShared(place, query_size * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); + } else { + val_mem = memory::AllocShared(place, query_size * sizeof(uint64_t)); + } val = (uint64_t *)val_mem->ptr(); actual_sample_size = 0; } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 1db827f7d082b..898bda5b2ada5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -21,6 +21,8 @@ #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" +#define ALIGN_INT64(LEN) (uint64_t((LEN) + 7) & uint64_t(~7)) +#define HBMPS_MAX_BUFF 1024 * 1024 namespace paddle { namespace framework { /* @@ -370,39 +372,35 @@ void GpuPsGraphTable::move_result_to_source_gpu(int start_index, shard_len[i] = h_right[i] - h_left[i] + 1; int cur_step = (int)path_[start_index][i].nodes_.size() - 1; for (int j = cur_step; j > 0; j--) { - CUDA_CHECK( - cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, - path_[start_index][i].nodes_[j].val_storage, - path_[start_index][i].nodes_[j - 1].val_bytes_len, - cudaMemcpyDefault, - path_[start_index][i].nodes_[j - 1].out_stream)); + auto& dst_node = path_[start_index][i].nodes_[j - 1]; + auto& src_node = path_[start_index][i].nodes_[j]; + MemcpyPeerAsync(dst_node.val_storage, + src_node.val_storage, + dst_node.val_bytes_len, + src_node.out_stream); + if (src_node.sync) { + CUDA_CHECK(cudaStreamSynchronize(src_node.out_stream)); + } } auto& node = path_[start_index][i].nodes_.front(); - if (fea_num_list[i] > 0) { - CUDA_CHECK(cudaMemcpyAsync( - reinterpret_cast(feature_list + fea_left[i]), - node.val_storage + - sizeof(uint32_t) * (shard_len[i] + shard_len[i] % 2), - sizeof(uint64_t) * fea_num_list[i], - cudaMemcpyDefault, - node.out_stream)); - CUDA_CHECK(cudaMemcpyAsync( - reinterpret_cast(slot_list + fea_left[i]), - node.val_storage + - sizeof(uint32_t) * (shard_len[i] + shard_len[i] % 2) + - sizeof(uint64_t) * fea_num_list[i], - sizeof(uint8_t) * fea_num_list[i], - cudaMemcpyDefault, - node.out_stream)); + MemcpyPeerAsync(reinterpret_cast(feature_list + fea_left[i]), + node.val_storage + + sizeof(uint32_t) * (shard_len[i] + shard_len[i] % 2), + sizeof(uint64_t) * fea_num_list[i], + node.out_stream); + MemcpyPeerAsync(reinterpret_cast(slot_list + fea_left[i]), + node.val_storage + + sizeof(uint32_t) * (shard_len[i] + shard_len[i] % 2) + + sizeof(uint64_t) * fea_num_list[i], + sizeof(uint8_t) * fea_num_list[i], + node.out_stream); } if (shard_len[i] > 0) { - CUDA_CHECK(cudaMemcpyAsync( - reinterpret_cast(actual_feature_size + h_left[i]), - node.val_storage, - sizeof(uint32_t) * shard_len[i], - cudaMemcpyDefault, - node.out_stream)); + MemcpyPeerAsync(reinterpret_cast(actual_feature_size + h_left[i]), + node.val_storage, + sizeof(uint32_t) * shard_len[i], + node.out_stream); } } for (int i = 0; i < gpu_num; ++i) { @@ -429,27 +427,27 @@ void GpuPsGraphTable::move_result_to_source_gpu(int start_index, shard_len[i] = h_right[i] - h_left[i] + 1; int cur_step = (int)path_[start_index][i].nodes_.size() - 1; for (int j = cur_step; j > 0; j--) { - CUDA_CHECK( - cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, - path_[start_index][i].nodes_[j].val_storage, - path_[start_index][i].nodes_[j - 1].val_bytes_len, - cudaMemcpyDefault, - path_[start_index][i].nodes_[j - 1].out_stream)); + auto& dst_node = path_[start_index][i].nodes_[j - 1]; + auto& src_node = path_[start_index][i].nodes_[j]; + MemcpyPeerAsync(dst_node.val_storage, + src_node.val_storage, + dst_node.val_bytes_len, + src_node.out_stream); + if (src_node.sync) { + CUDA_CHECK(cudaStreamSynchronize(src_node.out_stream)); + } } auto& node = path_[start_index][i].nodes_.front(); - CUDA_CHECK(cudaMemcpyAsync( + MemcpyPeerAsync( reinterpret_cast(src_sample_res + h_left[i] * sample_size), node.val_storage + sizeof(int64_t) * shard_len[i] + sizeof(int) * (shard_len[i] + shard_len[i] % 2), sizeof(uint64_t) * shard_len[i] * sample_size, - cudaMemcpyDefault, - node.out_stream)); - CUDA_CHECK( - cudaMemcpyAsync(reinterpret_cast(actual_sample_size + h_left[i]), - node.val_storage + sizeof(int64_t) * shard_len[i], - sizeof(int) * shard_len[i], - cudaMemcpyDefault, - node.out_stream)); + node.out_stream); + MemcpyPeerAsync(reinterpret_cast(actual_sample_size + h_left[i]), + node.val_storage + sizeof(int64_t) * shard_len[i], + sizeof(int) * shard_len[i], + node.out_stream); } for (int i = 0; i < gpu_num; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { @@ -461,11 +459,8 @@ void GpuPsGraphTable::move_result_to_source_gpu(int start_index, } } -void GpuPsGraphTable::move_degree_to_source_gpu(int start_index, - int gpu_num, - int* h_left, - int* h_right, - int* node_degree) { +void GpuPsGraphTable::move_degree_to_source_gpu( + int start_index, int gpu_num, int* h_left, int* h_right, int* node_degree) { int shard_len[gpu_num]; for (int i = 0; i < gpu_num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { @@ -474,20 +469,21 @@ void GpuPsGraphTable::move_degree_to_source_gpu(int start_index, shard_len[i] = h_right[i] - h_left[i] + 1; int cur_step = (int)path_[start_index][i].nodes_.size() - 1; for (int j = cur_step; j > 0; j--) { - CUDA_CHECK( - cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, - path_[start_index][i].nodes_[j].val_storage, - path_[start_index][i].nodes_[j - 1].val_bytes_len, - cudaMemcpyDefault, - path_[start_index][i].nodes_[j - 1].out_stream)); + auto& dst_node = path_[start_index][i].nodes_[j - 1]; + auto& src_node = path_[start_index][i].nodes_[j]; + MemcpyPeerAsync(dst_node.val_storage, + src_node.val_storage, + dst_node.val_bytes_len, + src_node.out_stream); + if (src_node.sync) { + CUDA_CHECK(cudaStreamSynchronize(src_node.out_stream)); + } } auto& node = path_[start_index][i].nodes_.front(); - CUDA_CHECK(cudaMemcpyAsync( - reinterpret_cast(node_degree + h_left[i]), - node.val_storage + sizeof(int64_t) * shard_len[i], - sizeof(int) * shard_len[i], - cudaMemcpyDefault, - node.out_stream)); + MemcpyPeerAsync(reinterpret_cast(node_degree + h_left[i]), + node.val_storage + sizeof(int64_t) * shard_len[i], + sizeof(int) * shard_len[i], + node.out_stream); } for (int i = 0; i < gpu_num; ++i) { @@ -510,7 +506,6 @@ void GpuPsGraphTable::move_result_to_source_gpu_all_edge_type( int edge_type_len, int len) { int shard_len[gpu_num]; - for (int i = 0; i < gpu_num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { continue; @@ -518,12 +513,15 @@ void GpuPsGraphTable::move_result_to_source_gpu_all_edge_type( shard_len[i] = h_right[i] - h_left[i] + 1; int cur_step = (int)path_[start_index][i].nodes_.size() - 1; for (int j = cur_step; j > 0; j--) { - CUDA_CHECK( - cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, - path_[start_index][i].nodes_[j].val_storage, - path_[start_index][i].nodes_[j - 1].val_bytes_len, - cudaMemcpyDefault, - path_[start_index][i].nodes_[j - 1].out_stream)); + auto& dst_node = path_[start_index][i].nodes_[j - 1]; + auto& src_node = path_[start_index][i].nodes_[j]; + MemcpyPeerAsync(dst_node.val_storage, + src_node.val_storage, + dst_node.val_bytes_len, + src_node.out_stream); + if (src_node.sync) { + CUDA_CHECK(cudaStreamSynchronize(src_node.out_stream)); + } } } @@ -533,7 +531,7 @@ void GpuPsGraphTable::move_result_to_source_gpu_all_edge_type( continue; } auto& node = path_[start_index][j].nodes_.front(); - CUDA_CHECK(cudaMemcpyAsync( + MemcpyPeerAsync( reinterpret_cast(src_sample_res + i * len * sample_size + h_left[j] * sample_size), node.val_storage + sizeof(int64_t) * shard_len[j] * edge_type_len + @@ -541,15 +539,13 @@ void GpuPsGraphTable::move_result_to_source_gpu_all_edge_type( (shard_len[j] * edge_type_len) % 2) + sizeof(uint64_t) * i * shard_len[j] * sample_size, sizeof(uint64_t) * shard_len[j] * sample_size, - cudaMemcpyDefault, - node.out_stream)); - CUDA_CHECK(cudaMemcpyAsync( + node.out_stream); + MemcpyPeerAsync( reinterpret_cast(actual_sample_size + i * len + h_left[j]), node.val_storage + sizeof(int64_t) * shard_len[j] * edge_type_len + sizeof(int) * i * shard_len[j], sizeof(int) * shard_len[j], - cudaMemcpyDefault, - node.out_stream)); + node.out_stream); } } @@ -710,10 +706,11 @@ void GpuPsGraphTable::reset_feature_info(int gpu_id, size_t capacity, size_t feature_size) { int idx = 0; + auto stream = get_local_stream(gpu_id); int offset = get_table_offset(gpu_id, GraphTableType::FEATURE_TABLE, idx); if (offset < tables_.size()) { delete tables_[offset]; - tables_[offset] = new Table(capacity); + tables_[offset] = new Table(capacity, stream); } int graph_fea_idx = get_graph_fea_list_offset(gpu_id); auto& graph = gpu_graph_fea_list_[graph_fea_idx]; @@ -722,7 +719,7 @@ void GpuPsGraphTable::reset_feature_info(int gpu_id, CUDA_CHECK(cudaMalloc((void**)&graph.feature_list, feature_size * sizeof(uint64_t))); CUDA_CHECK(cudaMalloc((void**)&graph.slot_id_list, - feature_size * sizeof(uint8_t))); + ALIGN_INT64(feature_size * sizeof(uint8_t)))); graph.feature_capacity = feature_size; } else if (graph.feature_capacity < feature_size) { cudaFree(graph.feature_list); @@ -730,13 +727,14 @@ void GpuPsGraphTable::reset_feature_info(int gpu_id, CUDA_CHECK(cudaMalloc((void**)&graph.feature_list, feature_size * sizeof(uint64_t))); CUDA_CHECK(cudaMalloc((void**)&graph.slot_id_list, - feature_size * sizeof(uint8_t))); + ALIGN_INT64(feature_size * sizeof(uint8_t)))); graph.feature_capacity = feature_size; } else { - CUDA_CHECK( - cudaMemset(graph.feature_list, 0, feature_size * sizeof(uint64_t))); - CUDA_CHECK( - cudaMemset(graph.slot_id_list, 0, feature_size * sizeof(uint8_t))); + CUDA_CHECK(cudaMemsetAsync( + graph.feature_list, 0, feature_size * sizeof(uint64_t), stream)); + CUDA_CHECK(cudaMemsetAsync( + graph.slot_id_list, 0, feature_size * sizeof(uint8_t), stream)); + cudaStreamSynchronize(stream); } } @@ -782,23 +780,27 @@ void GpuPsGraphTable::build_graph_fea_on_single_gpu(const GpuPsCommGraphFea& g, g.node_list, (uint64_t*)g.fea_info_list, g.node_size, - 1024, + HBMPS_MAX_BUFF, 8, table_offset); gpu_graph_fea_list_[offset].node_size = g.node_size; } else { - build_ps(gpu_id, NULL, NULL, 0, 1024, 8, table_offset); + build_ps(gpu_id, NULL, NULL, 0, HBMPS_MAX_BUFF, 8, table_offset); gpu_graph_fea_list_[offset].node_size = 0; } if (g.feature_size) { - CUDA_CHECK(cudaMemcpy(gpu_graph_fea_list_[offset].feature_list, - g.feature_list, - g.feature_size * sizeof(uint64_t), - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(gpu_graph_fea_list_[offset].slot_id_list, - g.slot_id_list, - g.feature_size * sizeof(uint8_t), - cudaMemcpyHostToDevice)); + auto stream = get_local_stream(gpu_id); + CUDA_CHECK(cudaMemcpyAsync(gpu_graph_fea_list_[offset].feature_list, + g.feature_list, + g.feature_size * sizeof(uint64_t), + cudaMemcpyHostToDevice, + stream)); + CUDA_CHECK(cudaMemcpyAsync(gpu_graph_fea_list_[offset].slot_id_list, + g.slot_id_list, + g.feature_size * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream)); + cudaStreamSynchronize(stream); gpu_graph_fea_list_[offset].feature_size = g.feature_size; } else { @@ -831,10 +833,12 @@ GpuPsGraphTable::get_edge_type_graph(int gpu_id, int edge_type_len) { phi::Stream(reinterpret_cast(stream))); GpuPsCommGraph* d_commgraph_ptr = reinterpret_cast(d_commgraph_mem->ptr()); - CUDA_CHECK(cudaMemcpy(d_commgraph_ptr, - graphs, - sizeof(GpuPsCommGraph) * edge_type_len, - cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpyAsync(d_commgraph_ptr, + graphs, + sizeof(GpuPsCommGraph) * edge_type_len, + cudaMemcpyHostToDevice, + stream)); + cudaStreamSynchronize(stream); graphs_vec.emplace_back(d_commgraph_mem); } @@ -858,27 +862,30 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g, gpu_graph_list_[offset] = GpuPsCommGraph(); int table_offset = get_table_offset(gpu_id, GraphTableType::EDGE_TABLE, edge_idx); size_t capacity = std::max((uint64_t)1, (uint64_t)g.node_size) / load_factor_; - tables_[table_offset] = new Table(capacity); + auto stream = get_local_stream(gpu_id); + tables_[table_offset] = new Table(capacity, stream); if (g.node_size > 0) { if (FLAGS_gpugraph_load_node_list_into_hbm) { CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].node_list, g.node_size * sizeof(uint64_t))); - CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].node_list, - g.node_list, - g.node_size * sizeof(uint64_t), - cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpyAsync(gpu_graph_list_[offset].node_list, + g.node_list, + g.node_size * sizeof(uint64_t), + cudaMemcpyHostToDevice, + stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); } build_ps(gpu_id, g.node_list, (uint64_t*)(g.node_info_list), g.node_size, - 1024, + HBMPS_MAX_BUFF, 8, table_offset); gpu_graph_list_[offset].node_size = g.node_size; } else { - build_ps(gpu_id, NULL, NULL, 0, 1024, 8, table_offset); + build_ps(gpu_id, NULL, NULL, 0, HBMPS_MAX_BUFF, 8, table_offset); gpu_graph_list_[offset].node_list = NULL; gpu_graph_list_[offset].node_size = 0; } @@ -893,15 +900,17 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g, VLOG(0) << "sucessfully allocate " << g.neighbor_size * sizeof(uint64_t) << " bytes of memory for graph-edges on gpu " << resource_->dev_id(gpu_id); - CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].neighbor_list, - g.neighbor_list, - g.neighbor_size * sizeof(uint64_t), - cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpyAsync(gpu_graph_list_[offset].neighbor_list, + g.neighbor_list, + g.neighbor_size * sizeof(uint64_t), + cudaMemcpyHostToDevice, + stream)); gpu_graph_list_[offset].neighbor_size = g.neighbor_size; } else { gpu_graph_list_[offset].neighbor_list = NULL; gpu_graph_list_[offset].neighbor_size = 0; } + cudaStreamSynchronize(stream); VLOG(0) << " gpu node_neighbor info card: " << gpu_id << " ,node_size is " << gpu_graph_list_[offset].node_size << ", neighbor_size is " << gpu_graph_list_[offset].neighbor_size; @@ -922,26 +931,30 @@ void GpuPsGraphTable::build_graph_from_cpu( int offset = get_graph_list_offset(i, edge_idx); platform::CUDADeviceGuard guard(resource_->dev_id(i)); gpu_graph_list_[offset] = GpuPsCommGraph(); + auto stream = get_local_stream(i); tables_[table_offset] = new Table(std::max((uint64_t)1, (uint64_t)cpu_graph_list[i].node_size) / - load_factor_); + load_factor_, + stream); if (cpu_graph_list[i].node_size > 0) { CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].node_list, cpu_graph_list[i].node_size * sizeof(uint64_t))); - CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].node_list, - cpu_graph_list[i].node_list, - cpu_graph_list[i].node_size * sizeof(uint64_t), - cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpyAsync(gpu_graph_list_[offset].node_list, + cpu_graph_list[i].node_list, + cpu_graph_list[i].node_size * sizeof(uint64_t), + cudaMemcpyHostToDevice, + stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); build_ps(i, cpu_graph_list[i].node_list, (uint64_t*)(cpu_graph_list[i].node_info_list), cpu_graph_list[i].node_size, - 1024, + HBMPS_MAX_BUFF, 8, table_offset); gpu_graph_list_[offset].node_size = cpu_graph_list[i].node_size; } else { - build_ps(i, NULL, NULL, 0, 1024, 8, table_offset); + build_ps(i, NULL, NULL, 0, HBMPS_MAX_BUFF, 8, table_offset); gpu_graph_list_[offset].node_list = NULL; gpu_graph_list_[offset].node_size = 0; } @@ -950,17 +963,19 @@ void GpuPsGraphTable::build_graph_from_cpu( cudaMalloc((void**)&gpu_graph_list_[offset].neighbor_list, cpu_graph_list[i].neighbor_size * sizeof(uint64_t))); - CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].neighbor_list, - cpu_graph_list[i].neighbor_list, - cpu_graph_list[i].neighbor_size * sizeof(uint64_t), - cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpyAsync(gpu_graph_list_[offset].neighbor_list, + cpu_graph_list[i].neighbor_list, + cpu_graph_list[i].neighbor_size * sizeof(uint64_t), + cudaMemcpyHostToDevice, + stream)); gpu_graph_list_[offset].neighbor_size = cpu_graph_list[i].neighbor_size; } else { gpu_graph_list_[offset].neighbor_list = NULL; gpu_graph_list_[offset].neighbor_size = 0; } + cudaStreamSynchronize(stream); } - CUDA_CHECK(cudaDeviceSynchronize()); } NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3( @@ -1001,7 +1016,6 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); - int* actual_sample_size = result.actual_sample_size; uint64_t* val = result.val; int total_gpu = resource_->total_device(); @@ -1052,14 +1066,17 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( int* d_shard_actual_sample_size_ptr = reinterpret_cast(d_shard_actual_sample_size->ptr()); - split_input_to_shard( - (uint64_t*)(key), d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); + split_idx_to_shard((uint64_t*)(key), + d_idx_ptr, + len, + d_left_ptr, + d_right_ptr, + gpu_id, + stream); heter_comm_kernel_->fill_shard_key( d_shard_keys_ptr, key, d_idx_ptr, len, stream); - CUDA_CHECK(cudaStreamSynchronize(stream)); - CUDA_CHECK(cudaMemcpyAsync(h_left, d_left_ptr, total_gpu * sizeof(int), @@ -1071,6 +1088,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); + device_mutex_[gpu_id]->lock(); for (int i = 0; i < total_gpu; ++i) { int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; @@ -1092,19 +1110,21 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( continue; } int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + CHECK(shard_len > 0); auto& node = path_[gpu_id][i].nodes_.back(); - - CUDA_CHECK(cudaMemsetAsync( - node.val_storage, 0, shard_len * sizeof(uint64_t), node.in_stream)); - CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); + // CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); platform::CUDADeviceGuard guard(resource_->dev_id(i)); + auto cur_stream = resource_->remote_stream(i, gpu_id); + CUDA_CHECK(cudaMemsetAsync( + node.val_storage, 0, shard_len * sizeof(uint64_t), cur_stream)); // If not found, val is -1. int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx); + CHECK(table_offset >= 0); int offset = get_graph_list_offset(i, idx); tables_[table_offset]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), (size_t)(h_right[i] - h_left[i] + 1), - resource_->remote_stream(i, gpu_id)); + cur_stream); auto graph = gpu_graph_list_[offset]; GpuPsNodeInfo* node_info_list = @@ -1119,14 +1139,13 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 grid((shard_len + TILE_SIZE - 1) / TILE_SIZE); neighbor_sample_kernel_walking - <<remote_stream(i, gpu_id)>>>( - graph, - node_info_list, - actual_size_array, - sample_array, - sample_size, - shard_len, - default_value); + <<>>(graph, + node_info_list, + actual_size_array, + sample_array, + sample_size, + shard_len, + default_value); } for (int i = 0; i < total_gpu; ++i) { @@ -1375,14 +1394,17 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( int* d_shard_actual_sample_size_ptr = reinterpret_cast(d_shard_actual_sample_size->ptr()); - split_input_to_shard( - (uint64_t*)(key), d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); + split_idx_to_shard((uint64_t*)(key), + d_idx_ptr, + len, + d_left_ptr, + d_right_ptr, + gpu_id, + stream); heter_comm_kernel_->fill_shard_key( d_shard_keys_ptr, key, d_idx_ptr, len, stream); - CUDA_CHECK(cudaStreamSynchronize(stream)); - CUDA_CHECK(cudaMemcpyAsync(h_left, d_left_ptr, total_gpu * sizeof(int), @@ -1420,13 +1442,13 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( } int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; auto& node = path_[gpu_id][i].nodes_.back(); + // CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + auto cur_stream = resource_->remote_stream(i, gpu_id); CUDA_CHECK(cudaMemsetAsync(node.val_storage, 0, shard_len * edge_type_len * sizeof(uint64_t), - node.in_stream)); - CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); - platform::CUDADeviceGuard guard(resource_->dev_id(i)); - + cur_stream)); GpuPsNodeInfo* node_info_base = reinterpret_cast(node.val_storage); for (int idx = 0; idx < edge_type_len; idx++) { @@ -1439,7 +1461,7 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( reinterpret_cast(node.key_storage), reinterpret_cast(node_info_base + idx * shard_len), (size_t)(shard_len), - resource_->remote_stream(i, gpu_id)); + cur_stream); } auto d_commgraph_mem = edge_type_graphs[i]; @@ -1453,8 +1475,7 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( neighbor_sample_kernel_all_edge_type<<remote_stream(i, - gpu_id)>>>( + cur_stream>>>( d_commgraph_ptr, node_info_base, actual_size_base, @@ -1506,10 +1527,13 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type( } void GpuPsGraphTable::get_node_degree( - int gpu_id, int edge_idx, uint64_t* key, int len, + int gpu_id, + int edge_idx, + uint64_t* key, + int len, std::shared_ptr node_degree) { int* node_degree_ptr = - reinterpret_cast(node_degree->ptr()) + edge_idx * len; + reinterpret_cast(node_degree->ptr()) + edge_idx * len; int total_gpu = resource_->total_device(); platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); @@ -1543,12 +1567,16 @@ void GpuPsGraphTable::get_node_degree( memory::Alloc(place, len * sizeof(int), phi::Stream(reinterpret_cast(stream))); - int* d_shard_degree_ptr = reinterpret_cast(d_shard_degree->ptr()); - split_input_to_shard( - (uint64_t*)(key), d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); + int* d_shard_degree_ptr = reinterpret_cast(d_shard_degree->ptr()); + split_idx_to_shard((uint64_t*)(key), + d_idx_ptr, + len, + d_left_ptr, + d_right_ptr, + gpu_id, + stream); heter_comm_kernel_->fill_shard_key( d_shard_keys_ptr, key, d_idx_ptr, len, stream); - CUDA_CHECK(cudaStreamSynchronize(stream)); CUDA_CHECK(cudaMemcpyAsync(h_left, d_left_ptr, total_gpu * sizeof(int), @@ -1560,17 +1588,18 @@ void GpuPsGraphTable::get_node_degree( cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); + device_mutex_[gpu_id]->lock(); for (int i = 0; i < total_gpu; ++i) { - int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; if (shard_len == 0) { continue; } - create_storage(gpu_id, - i, - shard_len * sizeof(uint64_t), - shard_len * sizeof(uint64_t) + - sizeof(int) * shard_len + shard_len % 2); + create_storage( + gpu_id, + i, + shard_len * sizeof(uint64_t), + shard_len * sizeof(uint64_t) + sizeof(int) * shard_len + shard_len % 2); } walk_to_dest( gpu_id, total_gpu, h_left, h_right, (uint64_t*)(d_shard_keys_ptr), NULL); @@ -1580,26 +1609,23 @@ void GpuPsGraphTable::get_node_degree( } int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; auto& node = path_[gpu_id][i].nodes_.back(); - CUDA_CHECK(cudaMemsetAsync(node.val_storage, - 0, - shard_len * sizeof(uint64_t), - node.in_stream)); - CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); + // CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); platform::CUDADeviceGuard guard(resource_->dev_id(i)); - int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, edge_idx); + auto cur_stream = resource_->remote_stream(i, gpu_id); + CUDA_CHECK(cudaMemsetAsync( + node.val_storage, 0, shard_len * sizeof(uint64_t), cur_stream)); + int table_offset = + get_table_offset(i, GraphTableType::EDGE_TABLE, edge_idx); tables_[table_offset]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), (size_t)(h_right[i] - h_left[i] + 1), - resource_->remote_stream(i, gpu_id)); + cur_stream); GpuPsNodeInfo* node_info_list = reinterpret_cast(node.val_storage); int* node_degree_array = (int*)(node_info_list + shard_len); int grid_size_ = (shard_len - 1) / block_size_ + 1; - get_node_degree_kernel<<< - grid_size_, block_size_, 0, resource_->remote_stream(i, gpu_id)>>>( - node_info_list, - node_degree_array, - shard_len); + get_node_degree_kernel<<>>( + node_info_list, node_degree_array, shard_len); } for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1) { @@ -1607,19 +1633,13 @@ void GpuPsGraphTable::get_node_degree( } CUDA_CHECK(cudaStreamSynchronize(resource_->remote_stream(i, gpu_id))); } - move_degree_to_source_gpu(gpu_id, - total_gpu, - h_left, - h_right, - d_shard_degree_ptr); + move_degree_to_source_gpu( + gpu_id, total_gpu, h_left, h_right, d_shard_degree_ptr); fill_dvalues<<>>( - d_shard_degree_ptr, - node_degree_ptr, - d_idx_ptr, - len); + d_shard_degree_ptr, node_degree_ptr, d_idx_ptr, len); CUDA_CHECK(cudaStreamSynchronize(stream)); for (int i = 0; i < total_gpu; i++) { - int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; if (shard_len == 0) { continue; } @@ -1661,6 +1681,7 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, return result; } + result.set_stream(resource_->local_stream(gpu_id, 0)); result.initialize(len, resource_->dev_id(gpu_id)); result.actual_sample_size = len; uint64_t* val = result.val; @@ -1723,12 +1744,11 @@ int GpuPsGraphTable::get_feature_info_of_nodes( phi::Stream(reinterpret_cast(stream))); int* d_shard_actual_size_ptr = reinterpret_cast(d_shard_actual_size->ptr()); - split_input_to_shard( - d_nodes, d_idx_ptr, node_num, d_left_ptr, d_right_ptr, gpu_id); + split_idx_to_shard( + d_nodes, d_idx_ptr, node_num, d_left_ptr, d_right_ptr, gpu_id, stream); heter_comm_kernel_->fill_shard_key( d_shard_keys_ptr, d_nodes, d_idx_ptr, node_num, stream); - CUDA_CHECK(cudaStreamSynchronize(stream)); std::vector d_fea_info(total_gpu, NULL); std::vector d_fea_size(total_gpu, NULL); @@ -1737,11 +1757,19 @@ int GpuPsGraphTable::get_feature_info_of_nodes( std::vector fea_left(total_gpu, -1); int h_left[total_gpu]; // NOLINT - CUDA_CHECK(cudaMemcpy( - h_left, d_left_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpyAsync(h_left, + d_left_ptr, + total_gpu * sizeof(int), + cudaMemcpyDeviceToHost, + stream)); int h_right[total_gpu]; // NOLINT - CUDA_CHECK(cudaMemcpy( - h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpyAsync(h_right, + d_right_ptr, + total_gpu * sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + device_mutex_[gpu_id]->lock(); int shard_len[total_gpu]; void* d_temp_storage[total_gpu]; @@ -1753,6 +1781,7 @@ int GpuPsGraphTable::get_feature_info_of_nodes( if (h_left[i] == -1) { continue; } + // create keys storage create_storage(gpu_id, i, shard_len[i] * sizeof(uint64_t), 0); platform::CUDADeviceGuard guard(resource_->dev_id(i)); auto& node = path_[gpu_id][i].nodes_.back(); @@ -1796,8 +1825,8 @@ int GpuPsGraphTable::get_feature_info_of_nodes( auto& node = path_[gpu_id][i].nodes_.back(); // If not found, val is -1. int table_offset = get_table_offset(i, GraphTableType::FEATURE_TABLE, 0); - CUDA_CHECK(cudaStreamSynchronize( - node.in_stream)); // wait for walk_to_dest and memset + // CUDA_CHECK(cudaStreamSynchronize( + // node.in_stream)); // wait for walk_to_dest and memset tables_[table_offset]->get(reinterpret_cast(node.key_storage), (uint64_t*)d_fea_info[i], (size_t)(h_right[i] - h_left[i] + 1), @@ -1838,7 +1867,7 @@ int GpuPsGraphTable::get_feature_info_of_nodes( CUDA_CHECK(cudaStreamSynchronize( resource_->remote_stream(i, gpu_id))); // wait for fea_num_list - + // create vals storage create_storage(gpu_id, i, 0, @@ -2051,19 +2080,26 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int* d_shard_actual_size_ptr = reinterpret_cast(d_shard_actual_size->ptr()); - split_input_to_shard( - d_nodes, d_idx_ptr, node_num, d_left_ptr, d_right_ptr, gpu_id); + split_idx_to_shard( + d_nodes, d_idx_ptr, node_num, d_left_ptr, d_right_ptr, gpu_id, stream); heter_comm_kernel_->fill_shard_key( d_shard_keys_ptr, d_nodes, d_idx_ptr, node_num, stream); - CUDA_CHECK(cudaStreamSynchronize(stream)); int h_left[total_gpu]; // NOLINT - CUDA_CHECK(cudaMemcpy( - h_left, d_left_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpyAsync(h_left, + d_left_ptr, + total_gpu * sizeof(int), + cudaMemcpyDeviceToHost, + stream)); int h_right[total_gpu]; // NOLINT - CUDA_CHECK(cudaMemcpy( - h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpyAsync(h_right, + d_right_ptr, + total_gpu * sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + device_mutex_[gpu_id]->lock(); for (int i = 0; i < total_gpu; ++i) { int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; @@ -2088,16 +2124,17 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; auto& node = path_[gpu_id][i].nodes_.back(); - CUDA_CHECK(cudaMemsetAsync( - node.val_storage, 0, shard_len * sizeof(uint64_t), node.in_stream)); - CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); + // CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); platform::CUDADeviceGuard guard(resource_->dev_id(i)); + auto cur_stream = resource_->remote_stream(i, gpu_id); + CUDA_CHECK(cudaMemsetAsync( + node.val_storage, 0, shard_len * sizeof(uint64_t), cur_stream)); // If not found, val is -1. int table_offset = get_table_offset(i, GraphTableType::FEATURE_TABLE, 0); tables_[table_offset]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), (size_t)(h_right[i] - h_left[i] + 1), - resource_->remote_stream(i, gpu_id)); + cur_stream); int offset = get_graph_fea_list_offset(i); auto graph = gpu_graph_fea_list_[offset]; @@ -2108,18 +2145,14 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, (uint64_t*)(actual_size_array + shard_len + shard_len % 2); dim3 grid((shard_len - 1) / dim_y + 1); dim3 block(1, dim_y); - get_features_kernel<<remote_stream(i, gpu_id)>>>( - graph, - val_array, - actual_size_array, - feature_array, - d_slot_feature_num_map, - slot_num, - shard_len, - fea_num_per_node); + get_features_kernel<<>>(graph, + val_array, + actual_size_array, + feature_array, + d_slot_feature_num_map, + slot_num, + shard_len, + fea_num_per_node); } for (int i = 0; i < total_gpu; ++i) { diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 3d888b5e9ac2a..63cb77ed1c235 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -157,7 +157,8 @@ void GraphGpuWrapper::init_type_keys() { auto place = platform::CUDAPlace(gpuid); platform::CUDADeviceGuard guard(gpuid); d_graph_all_type_total_keys_[f_idx][j] = - memory::AllocShared(place, tmp_keys[j].size() * sizeof(uint64_t)); + memory::AllocShared(place, tmp_keys[j].size() * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); cudaMemcpyAsync(d_graph_all_type_total_keys_[f_idx][j]->ptr(), tmp_keys[j].data(), sizeof(uint64_t) * tmp_keys[j].size(), @@ -243,7 +244,9 @@ void GraphGpuWrapper::init_metapath(std::string cur_metapath, auto place = platform::CUDAPlace(gpuid); platform::CUDADeviceGuard guard(gpuid); d_graph_train_total_keys_[j] = - memory::AllocShared(place, tmp_keys[j].size() * sizeof(uint64_t)); + memory::AllocShared(place, + tmp_keys[j].size() * sizeof(uint64_t), + phi::Stream(reinterpret_cast(stream))); cudaMemcpyAsync(d_graph_train_total_keys_[j]->ptr(), tmp_keys[j].data(), sizeof(uint64_t) * tmp_keys[j].size(), diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 35ef6b363c22c..b473d8b521d24 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -55,11 +55,11 @@ class TableContainer ValType, std::numeric_limits::max()> { public: - TableContainer(size_t capacity) + TableContainer(size_t capacity, cudaStream_t stream) : concurrent_unordered_map::max()>( - capacity, ValType()) {} + stream, capacity, ValType()) {} }; #elif defined(PADDLE_WITH_XPU_KP) template @@ -113,7 +113,11 @@ class XPUCacheArray { template class HashTable { public: +#if defined(PADDLE_WITH_CUDA) + explicit HashTable(size_t capacity, cudaStream_t stream = 0); +#else explicit HashTable(size_t capacity); +#endif virtual ~HashTable(); HashTable(const HashTable&) = delete; HashTable& operator=(const HashTable&) = delete; @@ -218,6 +222,7 @@ class HashTable { private: #if defined(PADDLE_WITH_CUDA) TableContainer* container_; + cudaStream_t stream_ = 0; #elif defined(PADDLE_WITH_XPU_KP) XPUCacheArray* container_; #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 9b8906daf5bf5..c599f2a77821e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -227,14 +227,16 @@ __global__ void get_keys_kernel(Table* table, } template -HashTable::HashTable(size_t capacity) { - container_ = new TableContainer(capacity); +HashTable::HashTable(size_t capacity, cudaStream_t stream) { + stream_ = stream; + container_ = new TableContainer(capacity, stream); CUDA_RT_CALL( cudaMalloc((void**)&device_optimizer_config_, sizeof(OptimizerConfig))); - CUDA_RT_CALL(cudaMemcpy((void*)device_optimizer_config_, + CUDA_RT_CALL(cudaMemcpyAsync((void*)device_optimizer_config_, &host_optimizer_config_, sizeof(OptimizerConfig), - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, stream)); + cudaStreamSynchronize(stream); rwlock_.reset(new phi::RWLock); } @@ -248,20 +250,24 @@ template void HashTable::set_sparse_sgd( const OptimizerConfig& optimizer_config) { host_optimizer_config_.set_sparse_sgd(optimizer_config); - cudaMemcpy((void*)device_optimizer_config_, + cudaMemcpyAsync((void*)device_optimizer_config_, &host_optimizer_config_, sizeof(OptimizerConfig), - cudaMemcpyHostToDevice); + cudaMemcpyHostToDevice, + stream_); + cudaStreamSynchronize(stream_); } template void HashTable::set_embedx_sgd( const OptimizerConfig& optimizer_config) { host_optimizer_config_.set_embedx_sgd(optimizer_config); - cudaMemcpy((void*)device_optimizer_config_, + cudaMemcpyAsync((void*)device_optimizer_config_, &host_optimizer_config_, sizeof(OptimizerConfig), - cudaMemcpyHostToDevice); + cudaMemcpyHostToDevice, + stream_); + cudaStreamSynchronize(stream_); } template diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 3256895f23427..888ddfce396d5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -59,7 +59,7 @@ class HeterComm { HeterComm(size_t capacity, std::shared_ptr resource); HeterComm(size_t capacity, std::shared_ptr resource, - GPUAccessor& gpu_accessor); // NOLINT + GPUAccessor& gpu_accessor); // NOLINT virtual ~HeterComm(); HeterComm(const HeterComm&) = delete; HeterComm& operator=(const HeterComm&) = delete; @@ -82,8 +82,8 @@ class HeterComm { KeyType* d_keys, float* d_grads, size_t len, - int& uniq_len, // NOLINT - size_t& segment_len, // NOLINT + int& uniq_len, // NOLINT + size_t& segment_len, // NOLINT bool enable_segment_merge_grad); void segment_merge_grad(int gpu_num, KeyType* d_keys, @@ -92,7 +92,7 @@ class HeterComm { size_t len, const uint32_t* d_fea_num_info, size_t uniq_len, - size_t& segment_len); // NOLINT + size_t& segment_len); // NOLINT void build_ps(int num, KeyType* h_keys, ValType* h_vals, @@ -111,8 +111,11 @@ class HeterComm { GradType* d_grads, size_t len, int& uniq_len); // NOLINT - void dynamic_merge_grad( - int gpu_num, KeyType* d_keys, float* d_grads, size_t len, int& uniq_len); // NOLINT + void dynamic_merge_grad(int gpu_num, + KeyType* d_keys, + float* d_grads, + size_t len, + int& uniq_len); // NOLINT void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len); void build_ps(int num, KeyType* h_keys, @@ -149,6 +152,11 @@ class HeterComm { const void* src, size_t count, StreamType stream = 0); + template + void MemcpyPeerAsync(void *dst, + const void *src, + size_t count, + StreamType stream); #if defined(PADDLE_WITH_CUDA) template @@ -311,7 +319,7 @@ class HeterComm { } else if (need_mem > alloc->size()) { if (need_copy) { std::shared_ptr tmp = - memory::Alloc(place_, need_mem, stream_); + memory::Alloc(place_, need_mem); #if defined(PADDLE_WITH_CUDA) PADDLE_ENFORCE_GPU_SUCCESS( cudaMemcpyAsync(tmp->ptr(), // output @@ -319,6 +327,8 @@ class HeterComm { alloc->size(), cudaMemcpyDeviceToDevice, reinterpret_cast(stream_.id()))); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize( + reinterpret_cast(stream_.id()))); #else memory::Copy(place_, tmp->ptr(), @@ -470,11 +480,11 @@ class HeterComm { int end_index, size_t keylen, size_t vallen); - void create_tmp_storage(void*& dest, // NOLINT + void create_tmp_storage(void*& dest, // NOLINT int start_index, int end_index, size_t vallen); - void destroy_tmp_storage(void*& p, int start_index, int end_index); // NOLINT + void destroy_tmp_storage(void*& p, int start_index, int end_index); // NOLINT void destroy_storage(int start_index, int end_index); void walk_to_dest(int start_index, int gpu_num, @@ -537,7 +547,7 @@ class HeterComm { const cudaStream_t& stream); void gather_inner_keys_p2p(const size_t& total_fea_num, const KeyType* d_keys, - InnerResource& res, // NOLINT + InnerResource& res, // NOLINT const int& gpu_id, const int& gpu_num, const int& trans_id, @@ -578,7 +588,7 @@ class HeterComm { const cudaStream_t& stream); void scatter_inner_vals_p2p(const size_t& total_fea_num, void* d_out_vals, - InnerResource& res, // NOLINT + InnerResource& res, // NOLINT const int& gpu_id, const int& gpu_num, const int& trans_id, @@ -593,7 +603,7 @@ class HeterComm { void gather_inner_data_p2p(const size_t& total_fea_num, const KeyType* d_keys, const void* d_vals, - InnerResource& res, // NOLINT + InnerResource& res, // NOLINT const int& gpu_id, const int& gpu_num, const int& trans_id, @@ -655,18 +665,33 @@ class HeterComm { // debug time void print_debug_time(const int& gpu_id, bool force = false); // alloc temp memory - template + template T* AllocCache(std::shared_ptr* alloc, const TPlace& place, - const size_t& byte_len, - const StreamType& stream) { + const size_t& byte_len) { if (alloc->get() == nullptr || byte_len > (*alloc)->size()) { alloc->reset(); - auto id = phi::Stream(reinterpret_cast(stream)); - *alloc = memory::Alloc(place, byte_len, id); + if (resource_->multi_mf()) { + *alloc = memory::Alloc(place, byte_len); + } else { + auto stream = resource_->local_stream(place.GetDeviceId(), 0); + auto id = phi::Stream(reinterpret_cast(stream)); + *alloc = memory::Alloc(place, byte_len, id); + } } return reinterpret_cast((*alloc)->ptr()); } + template + std::shared_ptr MemoryAlloc(const TPlace& place, + const size_t& byte_len) { + if (resource_->multi_mf()) { + return memory::Alloc(place, byte_len); + } else { + auto stream = resource_->local_stream(place.GetDeviceId(), 0); + auto id = phi::Stream(reinterpret_cast(stream)); + return memory::Alloc(place, byte_len, id); + } + } using Table = HashTable; using PtrTable = HashTable; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 9c91cfe3e61c6..fa69db27c983e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -85,6 +85,7 @@ HeterComm::HeterComm( max_type_size_ = std::max(pull_type_size_, grad_type_size_); for (int i = 0; i < device_num_; ++i) { + auto stream = resource_->local_stream(i, 0); #if defined(PADDLE_WITH_CUDA) platform::CUDADeviceGuard guard(resource_->dev_id(i)); allocators_.push_back(std::make_shared( @@ -92,19 +93,26 @@ HeterComm::HeterComm( #endif if (!multi_mf_dim_) { if (capacity > 0) { +#if defined(PADDLE_WITH_CUDA) + auto table = new Table(capacity / load_factor_, stream); +#else auto table = new Table(capacity / load_factor_); +#endif tables_.push_back(table); } } else { +#if defined(PADDLE_WITH_CUDA) + auto ptr_table = new PtrTable(capacity / load_factor_, stream); +#else auto ptr_table = new PtrTable(capacity / load_factor_); +#endif ptr_table->set_feature_value_size(pull_type_size_, grad_type_size_); ptr_tables_.push_back(ptr_table); } if (multi_node_) { storage_[i].init(device_num_, resource_->dev_id(i), - phi::Stream(reinterpret_cast( - resource_->comm_stream(i, 0)))); + phi::Stream(reinterpret_cast(stream))); } } barrier_.reset(device_num_); @@ -159,24 +167,32 @@ HeterComm::HeterComm( max_type_size_ = std::max(pull_type_size_, grad_type_size_); for (int i = 0; i < device_num_; ++i) { + auto stream = resource_->local_stream(i, 0); #if defined(PADDLE_WITH_CUDA) platform::CUDADeviceGuard guard(resource_->dev_id(i)); allocators_.push_back(std::make_shared( 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT #endif if (!multi_mf_dim_) { +#if defined(PADDLE_WITH_CUDA) + auto table = new Table(capacity / load_factor_, stream); +#else auto table = new Table(capacity / load_factor_); +#endif tables_.push_back(table); } else { +#if defined(PADDLE_WITH_CUDA) + auto ptr_table = new PtrTable(capacity / load_factor_, stream); +#else auto ptr_table = new PtrTable(capacity / load_factor_); +#endif ptr_table->set_feature_value_size(pull_type_size_, grad_type_size_); ptr_tables_.push_back(ptr_table); } if (multi_node_) { storage_[i].init(device_num_, resource_->dev_id(i), - phi::Stream(reinterpret_cast( - resource_->comm_stream(i, 0)))); + phi::Stream(reinterpret_cast(stream))); } } barrier_.reset(device_num_); @@ -198,8 +214,8 @@ void HeterComm::init_path() { for (int j = 0; j < total_device; ++j) { auto &nodes = path_[i][j].nodes_; nodes.resize(1); - nodes[0].in_stream = resource_->remote_stream(i, j); - nodes[0].out_stream = resource_->remote_stream(i, j); + nodes[0].in_stream = resource_->remote_stream(i, j); // i->j + nodes[0].out_stream = resource_->remote_stream(j, i); // j->i nodes[0].key_storage = NULL; nodes[0].val_storage = NULL; nodes[0].sync = 0; @@ -229,8 +245,8 @@ void HeterComm::init_path() { } nodes.push_back(Node()); Node &node = nodes.back(); - node.in_stream = resource_->remote_stream(i, transfer_id); - node.out_stream = resource_->remote_stream(transfer_id, i); + node.in_stream = resource_->remote_stream(transfer_id, j); + node.out_stream = resource_->remote_stream(j, transfer_id); node.key_storage = NULL; node.val_storage = NULL; node.sync = 0; @@ -256,31 +272,40 @@ void HeterComm::reset_table( device_num_); #if defined(PADDLE_WITH_CUDA) platform::CUDADeviceGuard guard(resource_->dev_id(dev_id)); + auto stream = resource_->local_stream(dev_id, 0); #endif size_t need_capacity = capacity / load_factor_; if (!multi_mf_dim_) { auto table = tables_[dev_id]; if (static_cast(table->size()) < need_capacity) { delete table; +#if defined(PADDLE_WITH_CUDA) + table = new Table(need_capacity, stream); +#else table = new Table(need_capacity); +#endif table->set_sparse_sgd(sgd_config); table->set_embedx_sgd(sgd_config); tables_[dev_id] = table; } else { - table->clear(); + table->clear(stream); } table->set_mode(infer_mode); } else { auto table = ptr_tables_[dev_id]; if (static_cast(table->size()) < need_capacity) { delete table; +#if defined(PADDLE_WITH_CUDA) + table = new PtrTable(need_capacity, stream); +#else table = new PtrTable(need_capacity); +#endif table->set_feature_value_size(pull_type_size_, grad_type_size_); table->set_sparse_sgd(sgd_config); table->set_embedx_sgd(sgd_config); ptr_tables_[dev_id] = table; } else { - table->clear(); + table->clear(stream); } table->set_mode(infer_mode); } @@ -352,6 +377,44 @@ void HeterComm::memory_copy( #endif } +#if defined(PADDLE_WITH_CUDA) +inline int get_dev_by_ptr(const void *ptr) { + cudaPointerAttributes attr; + CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); + int dev = -1; +#if CUDART_VERSION >= 10000 + if (attr.type == cudaMemoryTypeDevice) +#else + if (attr.memoryType == cudaMemoryTypeDevice) +#endif + { + dev = attr.device; + } + return dev; +} +#endif +template +template +void HeterComm::MemcpyPeerAsync( + void *dst, const void *src, size_t count, StreamType stream) { +#if defined(PADDLE_WITH_CUDA) + int src_device = get_dev_by_ptr(src); + int dst_device = get_dev_by_ptr(dst); + if (dst_device == -1 || src_device == -1) { + CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream)); + } else if (dst_device == src_device) { + CUDA_CHECK( + cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream)); + } else { + CUDA_CHECK( + cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream)); + } +#endif +} + template ::create_storage( platform::XPUDeviceGuard guard(resource_->dev_id(nodes[i].dev_num)); auto place = DevPlace(resource_->dev_id(nodes[i].dev_num)); if (keylen > 0) { - auto node_keys_mem = memory::Alloc(place, keylen); + auto node_keys_mem = MemoryAlloc(place, keylen); nodes[i].key_storage = reinterpret_cast(node_keys_mem->ptr()); nodes[i].key_bytes_len = keylen; } if (vallen > 0) { - auto node_vals_mem = memory::Alloc(place, vallen); + auto node_vals_mem = MemoryAlloc(place, vallen); nodes[i].val_storage = reinterpret_cast(node_vals_mem->ptr()); nodes[i].val_bytes_len = vallen; } @@ -417,7 +480,7 @@ void HeterComm::create_tmp_storage( #elif defined(PADDLE_WITH_XPU_KP) platform::XPUDeviceGuard guard(resource_->dev_id(end_index)); auto place = DevPlace(resource_->dev_id(end_index)); - auto node_vals_mem = memory::Alloc(place, vallen); + auto node_vals_mem = MemoryAlloc(place, vallen); dest = reinterpret_cast(node_vals_mem->ptr()); #endif } @@ -467,80 +530,56 @@ void HeterComm::walk_to_dest( int *h_right, KeyType *src_key, GradType *src_val) { - int need_copy_val = 0; - if (src_val) { - need_copy_val = 1; - } - std::queue que; for (int i = 0; i < num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { continue; } - // int size = path_[start_index][i].nodes_.size(); - auto &node = path_[start_index][i].nodes_[0]; - - CopyTask t(&path_[start_index][i], 0); - que.push(t); - auto src_dev_id = resource_->dev_id(start_index); - auto dst_dev_id = resource_->dev_id(i); - auto src_place = DevPlace(src_dev_id); - auto dst_place = DevPlace(dst_dev_id); - - memory_copy(dst_place, - node.key_storage, - src_place, - reinterpret_cast(src_key + h_left[i]), - node.key_bytes_len, - node.in_stream); - // #if defined(PADDLE_WITH_CUDA) // adapt for gpu-graph - // cudaMemsetAsync(node.val_storage, -1, node.val_bytes_len, - // node.in_stream); - // #endif - - if (need_copy_val) { - memory_copy(dst_place, - node.val_storage, - src_place, - reinterpret_cast(src_val + h_left[i]), - node.val_bytes_len, - node.in_stream); + auto &nodes = path_[start_index][i].nodes_; + auto &node = nodes[0]; + MemcpyPeerAsync(node.key_storage, + reinterpret_cast(src_key + h_left[i]), + node.key_bytes_len, + node.in_stream); + if (src_val) { + MemcpyPeerAsync(node.val_storage, + reinterpret_cast(src_val + h_left[i]), + node.val_bytes_len, + node.in_stream); } - } - while (!que.empty()) { - CopyTask &cur_task = que.front(); - que.pop(); - if (cur_task.path->nodes_[cur_task.step].sync) { - sync_stream(cur_task.path->nodes_[cur_task.step].in_stream); + // transfer + int step_num = static_cast(nodes.size()) - 1; + if (step_num == 0) { + continue; + } + if (node.sync) { + sync_stream(node.in_stream); } - if (static_cast(cur_task.step) != - cur_task.path->nodes_.size() - 1) { - int cur_step = cur_task.step; - CopyTask c(cur_task.path, cur_step + 1); - que.push(c); - - auto src_dev_id = - resource_->dev_id(cur_task.path->nodes_[cur_step].dev_num); - auto dst_dev_id = - resource_->dev_id(cur_task.path->nodes_[cur_step + 1].dev_num); - auto src_place = DevPlace(src_dev_id); - auto dst_place = DevPlace(dst_dev_id); - - memory_copy(dst_place, - cur_task.path->nodes_[cur_step + 1].key_storage, - src_place, - cur_task.path->nodes_[cur_step].key_storage, - cur_task.path->nodes_[cur_step + 1].key_bytes_len, - cur_task.path->nodes_[cur_step + 1].in_stream); - if (need_copy_val) { - memory_copy(dst_place, - cur_task.path->nodes_[cur_step + 1].val_storage, - src_place, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step + 1].val_bytes_len, - cur_task.path->nodes_[cur_step + 1].in_stream); + for (int cur_step = 0; cur_step < step_num; ++cur_step) { + auto &src_node = nodes[cur_step]; + auto &dst_node = nodes[cur_step + 1]; + MemcpyPeerAsync(dst_node.key_storage, + src_node.key_storage, + dst_node.key_bytes_len, + src_node.in_stream); + if (src_val) { + MemcpyPeerAsync(dst_node.val_storage, + src_node.val_storage, + dst_node.val_bytes_len, + src_node.in_stream); + } + if (src_node.sync) { + sync_stream(src_node.in_stream); } } } + // wait stream to finish + for (int i = 0; i < num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + auto &node = path_[start_index][i].nodes_.back(); + sync_stream(node.in_stream); + } } template ::walk_to_dest( KeyType *src_key, char *src_val, size_t val_size) { - int need_copy_val = 0; - if (src_val) { - need_copy_val = 1; - } - std::queue que; for (int i = 0; i < gpu_num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { continue; } - int size = path_[start_index][i].nodes_.size(); - auto &node = path_[start_index][i].nodes_[0]; - CopyTask t(&path_[start_index][i], 0); - que.push(t); - CUDA_CHECK(cudaMemcpyAsync(node.key_storage, - reinterpret_cast(src_key + h_left[i]), - node.key_bytes_len, - cudaMemcpyDefault, - node.in_stream)); - if (need_copy_val) { - CUDA_CHECK( - cudaMemcpyAsync(node.val_storage, - src_val + uint64_t(h_left[i]) * uint64_t(val_size), - node.val_bytes_len, - cudaMemcpyDefault, - node.in_stream)); + auto &nodes = path_[start_index][i].nodes_; + auto &node = nodes[0]; + MemcpyPeerAsync(node.key_storage, + reinterpret_cast(src_key + h_left[i]), + node.key_bytes_len, + node.in_stream); + if (src_val) { + MemcpyPeerAsync(node.val_storage, + src_val + uint64_t(h_left[i]) * uint64_t(val_size), + node.val_bytes_len, + node.in_stream); } - } - while (!que.empty()) { - CopyTask &cur_task = que.front(); - que.pop(); - if (cur_task.path->nodes_[cur_task.step].sync) { - CUDA_CHECK(cudaStreamSynchronize( - cur_task.path->nodes_[cur_task.step].in_stream)); + int step_num = static_cast(nodes.size()) - 1; + if (step_num == 0) { + continue; + } + if (node.sync) { + sync_stream(node.in_stream); } - if (cur_task.step != cur_task.path->nodes_.size() - 1) { - int cur_step = cur_task.step; - CopyTask c(cur_task.path, cur_step + 1); - que.push(c); - CUDA_CHECK( - cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage, - cur_task.path->nodes_[cur_step].key_storage, - cur_task.path->nodes_[cur_step + 1].key_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step + 1].in_stream)); - if (need_copy_val) { - CUDA_CHECK( - cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step + 1].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step + 1].in_stream)); + // transfer + for (int cur_step = 0; cur_step < step_num; ++cur_step) { + auto &src_node = nodes[cur_step]; + auto &dest_node = nodes[cur_step + 1]; + MemcpyPeerAsync(dest_node.key_storage, + src_node.key_storage, + src_node.key_bytes_len, + src_node.in_stream); + if (src_val) { + MemcpyPeerAsync(dest_node.val_storage, + src_node.val_storage, + src_node.val_bytes_len, + src_node.in_stream); + } + if (src_node.sync) { + sync_stream(src_node.in_stream); } } } + // wait stream to finish + for (int i = 0; i < gpu_num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + auto &node = path_[start_index][i].nodes_.back(); + sync_stream(node.in_stream); + } } template ::walk_to_src( int *h_right, char *src_val, size_t val_size) { - std::queue que; for (int i = 0; i < gpu_num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { continue; } - int cur_step = path_[start_index][i].nodes_.size() - 1; - auto &node = path_[start_index][i].nodes_[cur_step]; - if (cur_step == 0) { - CUDA_CHECK(cudaMemcpyAsync(src_val + uint64_t(h_left[i]) * val_size, - node.val_storage, - node.val_bytes_len, - cudaMemcpyDefault, - node.out_stream)); - } else { - CopyTask t(&path_[start_index][i], cur_step - 1); - que.push(t); - CUDA_CHECK(cudaMemcpyAsync( - path_[start_index][i].nodes_[cur_step - 1].val_storage, - node.val_storage, - path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, - cudaMemcpyDefault, - path_[start_index][i].nodes_[cur_step - 1].out_stream)); + auto &nodes = path_[start_index][i].nodes_; + int step_num = static_cast(nodes.size() - 1); + if (step_num > 0) { + // transfer + for (int cur_step = step_num; cur_step > 0; --cur_step) { + auto &src_node = nodes[cur_step]; + auto &dst_node = nodes[cur_step - 1]; + MemcpyPeerAsync(dst_node.val_storage, + src_node.val_storage, + dst_node.val_bytes_len, + src_node.out_stream); + if (src_node.sync) { + sync_stream(src_node.out_stream); + } + } } + auto &node = nodes[0]; + MemcpyPeerAsync(src_val + uint64_t(h_left[i]) * val_size, + node.val_storage, + node.val_bytes_len, + node.out_stream); } - while (!que.empty()) { - CopyTask &cur_task = que.front(); - que.pop(); - int cur_step = cur_task.step; - if (cur_task.path->nodes_[cur_step].sync) { - cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream); - } - if (cur_step > 0) { - CopyTask c(cur_task.path, cur_step - 1); - que.push(c); - CUDA_CHECK( - cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step - 1].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step - 1].out_stream)); - } else if (cur_step == 0) { - int end_index = cur_task.path->nodes_.back().dev_num; - CUDA_CHECK( - cudaMemcpyAsync(src_val + uint64_t(h_left[end_index]) * val_size, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step].out_stream)); + // wait stream to finish + for (int i = 0; i < gpu_num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; } + auto &node = path_[start_index][i].nodes_.front(); + sync_stream(node.out_stream); } } @@ -800,18 +819,28 @@ void HeterComm::build_ps( } int dev_id = resource_->dev_id(dev_num); - std::vector d_key_bufs; - std::vector d_val_bufs; + std::vector> d_key_bufs; + std::vector> d_val_bufs; + + // auto adjust stream num by data length + int max_stream = (len + chunk_size - 1) / chunk_size; + if (max_stream < stream_num) { + stream_num = max_stream; + } + if (stream_num > device_num_) { + stream_num = device_num_; + } DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); ppStream streams[stream_num]; // NOLINT + + d_key_bufs.resize(stream_num); + d_val_bufs.resize(stream_num); for (int i = 0; i < stream_num; ++i) { - create_stream(&(streams[i])); - auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType)); - auto d_v_buf = memory::Alloc(place, chunk_size * sizeof(ValType)); - d_key_bufs.push_back(std::move(d_k_buf)); - d_val_bufs.push_back(std::move(d_v_buf)); + streams[i] = resource_->local_stream(dev_num, i); + d_key_bufs[i] = MemoryAlloc(place, chunk_size * sizeof(KeyType)); + d_val_bufs[i] = MemoryAlloc(place, chunk_size * sizeof(ValType)); } int cur_len = 0; @@ -853,7 +882,6 @@ void HeterComm::build_ps( } for (int i = 0; i < stream_num; ++i) { sync_stream(streams[i]); - destroy_stream(streams[i]); } } @@ -874,17 +902,26 @@ void HeterComm::build_ps( } int dev_id = resource_->dev_id(num); + // auto adjust stream num by data length + int max_stream = (len + chunk_size - 1) / chunk_size; + if (max_stream < stream_num) { + stream_num = max_stream; + } + if (stream_num > device_num_) { + stream_num = device_num_; + } + DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); // use hbm pool - std::vector d_key_bufs; + std::vector> d_key_bufs; ppStream streams[stream_num]; // NOLINT + d_key_bufs.resize(stream_num); for (int i = 0; i < stream_num; ++i) { - create_stream(&(streams[i])); - auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType)); - d_key_bufs.push_back(std::move(d_k_buf)); + streams[i] = resource_->local_stream(num, i); + d_key_bufs[i] = MemoryAlloc(place, chunk_size * sizeof(KeyType)); } int cur_len = 0; @@ -919,7 +956,6 @@ void HeterComm::build_ps( } for (int i = 0; i < stream_num; ++i) { sync_stream(streams[i]); - destroy_stream(streams[i]); } } @@ -938,9 +974,9 @@ void HeterComm::merge_grad( AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(dev_num, 0); size_t temp_storage_bytes; - auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_merge_keys = MemoryAlloc(place, len * sizeof(KeyType)); KeyType *d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); - auto d_merge_grads = memory::Alloc(place, len * sizeof(GradType)); + auto d_merge_grads = MemoryAlloc(place, len * sizeof(GradType)); GradType *d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); heter_comm_kernel_->sort_pairs(NULL, @@ -954,7 +990,7 @@ void HeterComm::merge_grad( 8 * sizeof(KeyType), stream, false); - auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + auto d_temp_storage = MemoryAlloc(place, temp_storage_bytes); heter_comm_kernel_->sort_pairs(d_temp_storage->ptr(), temp_storage_bytes, d_keys, @@ -967,7 +1003,7 @@ void HeterComm::merge_grad( stream, false); temp_storage_bytes = 0; - auto d_num_runs_out_mem = memory::Alloc(place, sizeof(int)); + auto d_num_runs_out_mem = MemoryAlloc(place, sizeof(int)); int *d_num_runs_out = reinterpret_cast(d_num_runs_out_mem->ptr()); heter_comm_kernel_->reduce_by_key(NULL, temp_storage_bytes, @@ -981,7 +1017,7 @@ void HeterComm::merge_grad( false); if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; - d_temp_storage = memory::Alloc(place, temp_storage_bytes); + d_temp_storage = MemoryAlloc(place, temp_storage_bytes); } heter_comm_kernel_->reduce_by_key(d_temp_storage->ptr(), temp_storage_bytes, @@ -1023,9 +1059,9 @@ void HeterComm::dynamic_merge_grad( GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); size_t grad_value_size = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); - auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_merge_keys = MemoryAlloc(place, len * sizeof(KeyType)); KeyType *d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); - auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); + auto d_fea_num_info = MemoryAlloc(place, sizeof(uint32_t) * (len * 3 + 1)); uint32_t *d_fea_num_info_ptr = reinterpret_cast(d_fea_num_info->ptr()); uint32_t *d_index = &d_fea_num_info_ptr[len]; @@ -1044,7 +1080,7 @@ void HeterComm::dynamic_merge_grad( 0, 8 * sizeof(KeyType), stream)); - auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + auto d_temp_storage = MemoryAlloc(place, temp_storage_bytes); PADDLE_ENFORCE_GPU_SUCCESS( cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), temp_storage_bytes, @@ -1070,7 +1106,7 @@ void HeterComm::dynamic_merge_grad( stream)); if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; - d_temp_storage = memory::Alloc(place, temp_storage_bytes); + d_temp_storage = MemoryAlloc(place, temp_storage_bytes); } PADDLE_ENFORCE_GPU_SUCCESS( cub::DeviceRunLengthEncode::Encode(d_temp_storage->ptr(), @@ -1100,7 +1136,7 @@ void HeterComm::dynamic_merge_grad( stream)); if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; - d_temp_storage = memory::Alloc(place, temp_storage_bytes); + d_temp_storage = MemoryAlloc(place, temp_storage_bytes); } PADDLE_ENFORCE_GPU_SUCCESS( cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(), @@ -1127,7 +1163,7 @@ void HeterComm::dynamic_merge_grad( stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); } else { - auto d_merge_grads = memory::Alloc(place, len * grad_value_size); + auto d_merge_grads = MemoryAlloc(place, len * grad_value_size); float *d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); // copy merge keys to d_keys PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_keys, @@ -1184,16 +1220,16 @@ void HeterComm::segment_merge_grad( GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); size_t grad_value_size = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); - auto d_buffer1 = memory::Alloc(place, sizeof(uint32_t) * len); + auto d_buffer1 = MemoryAlloc(place, sizeof(uint32_t) * len); auto d_segments = reinterpret_cast(d_buffer1->ptr()); - auto d_buffer2 = memory::Alloc(place, sizeof(uint32_t) * len); + auto d_buffer2 = MemoryAlloc(place, sizeof(uint32_t) * len); auto d_segments_offset = reinterpret_cast(d_buffer2->ptr()); - auto d_buffer3 = memory::Alloc(place, sizeof(uint32_t) * len); + auto d_buffer3 = MemoryAlloc(place, sizeof(uint32_t) * len); auto d_segments_fea_num_info = reinterpret_cast(d_buffer3->ptr()); - auto d_buffer4 = memory::Alloc(place, sizeof(uint32_t) * len); + auto d_buffer4 = MemoryAlloc(place, sizeof(uint32_t) * len); auto d_segments_fea_num_offset = reinterpret_cast(d_buffer4->ptr()); - auto d_buffer5 = memory::Alloc(place, sizeof(uint32_t)); + auto d_buffer5 = MemoryAlloc(place, sizeof(uint32_t)); auto d_segments_num = reinterpret_cast(d_buffer5->ptr()); CUDA_CHECK(cudaMemsetAsync(d_segments_num, 0, sizeof(uint32_t), stream)); @@ -1209,7 +1245,7 @@ void HeterComm::segment_merge_grad( size_t temp_storage_bytes = 0; PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum( NULL, temp_storage_bytes, d_segments, d_segments_num, uniq_len, stream)); - auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + auto d_temp_storage = MemoryAlloc(place, temp_storage_bytes); PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum(d_temp_storage->ptr(), temp_storage_bytes, d_segments, @@ -1232,7 +1268,7 @@ void HeterComm::segment_merge_grad( stream)); if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; - d_temp_storage = memory::Alloc(place, temp_storage_bytes); + d_temp_storage = MemoryAlloc(place, temp_storage_bytes); } PADDLE_ENFORCE_GPU_SUCCESS( cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(), @@ -1260,7 +1296,7 @@ void HeterComm::segment_merge_grad( stream)); if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; - d_temp_storage = memory::Alloc(place, temp_storage_bytes); + d_temp_storage = MemoryAlloc(place, temp_storage_bytes); } PADDLE_ENFORCE_GPU_SUCCESS( cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(), @@ -1271,7 +1307,7 @@ void HeterComm::segment_merge_grad( stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - auto d_segments_keys = memory::Alloc(place, sizeof(KeyType) * segments_num); + auto d_segments_keys = MemoryAlloc(place, sizeof(KeyType) * segments_num); auto d_segments_keys_ptr = reinterpret_cast(d_segments_keys->ptr()); heter_comm_kernel_->shrink_keys(d_keys, @@ -1281,7 +1317,7 @@ void HeterComm::segment_merge_grad( stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - auto d_segment_grads = memory::Alloc(place, segments_num * grad_value_size); + auto d_segment_grads = MemoryAlloc(place, segments_num * grad_value_size); auto d_segment_grads_ptr = reinterpret_cast(d_segment_grads->ptr()); heter_comm_kernel_->merge_gradient( d_segments_keys_ptr, @@ -1344,8 +1380,7 @@ void HeterComm::split_idx_to_shard( AnyDeviceGuard guard(dev_id); thread_local std::shared_ptr d_idx_tmp = nullptr; - T *d_idx_tmp_ptr = - AllocCache(&d_idx_tmp, place, 3 * len * sizeof(T), stream); + T *d_idx_tmp_ptr = AllocCache(&d_idx_tmp, place, 3 * len * sizeof(T)); T *d_shard_index_ptr = reinterpret_cast(&d_idx_tmp_ptr[len]); T *d_shard_index_tmp_ptr = reinterpret_cast(&d_shard_index_ptr[len]); @@ -1367,8 +1402,7 @@ void HeterComm::split_idx_to_shard( stream); thread_local std::shared_ptr d_temp_storage = nullptr; - void *d_buf = - AllocCache(&d_temp_storage, place, temp_storage_bytes, stream); + void *d_buf = AllocCache(&d_temp_storage, place, temp_storage_bytes); heter_comm_kernel_->sort_pairs(d_buf, temp_storage_bytes, d_shard_index_tmp_ptr, @@ -1403,7 +1437,7 @@ size_t HeterComm::merge_keys( thread_local std::shared_ptr d_fea_num_info = nullptr; uint32_t *d_offset = AllocCache( - &d_fea_num_info, place, sizeof(uint32_t) * (len * 3), stream); + &d_fea_num_info, place, sizeof(uint32_t) * (len * 3)); uint32_t *d_merged_cnts = &d_offset[len]; uint32_t *d_sorted_idx = &d_merged_cnts[len]; @@ -1438,8 +1472,8 @@ void HeterComm::pull_merge_sparse( int h_left[total_device]; // NOLINT int h_right[total_device]; // NOLINT - auto d_left = memory::Alloc(place, total_device * sizeof(int)); - auto d_right = memory::Alloc(place, total_device * sizeof(int)); + auto d_left = MemoryAlloc(place, total_device * sizeof(int)); + auto d_right = MemoryAlloc(place, total_device * sizeof(int)); int *d_left_ptr = reinterpret_cast(d_left->ptr()); int *d_right_ptr = reinterpret_cast(d_right->ptr()); @@ -1472,15 +1506,15 @@ void HeterComm::pull_merge_sparse( GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); size_t val_type_size = accessor_wrapper_ptr->GetPullValueSize(max_mf_dim_); VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; - auto d_sorted_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_sorted_keys = MemoryAlloc(place, len * sizeof(KeyType)); auto d_sorted_keys_ptr = reinterpret_cast(d_sorted_keys->ptr()); - auto d_merged_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_merged_keys = MemoryAlloc(place, len * sizeof(KeyType)); auto d_merged_keys_ptr = reinterpret_cast(d_merged_keys->ptr()); - auto d_restore_idx = memory::Alloc(place, len * sizeof(uint32_t)); + auto d_restore_idx = MemoryAlloc(place, len * sizeof(uint32_t)); auto d_restore_idx_ptr = reinterpret_cast(d_restore_idx->ptr()); - auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_shard_keys = MemoryAlloc(place, len * sizeof(KeyType)); auto d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); - auto d_shard_vals = memory::Alloc(place, len * val_type_size); + auto d_shard_vals = MemoryAlloc(place, len * val_type_size); auto d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); size_t uniq_len = merge_keys(num, @@ -1492,7 +1526,7 @@ void HeterComm::pull_merge_sparse( stream); sync_stream(stream); - auto d_idx = memory::Alloc(place, uniq_len * sizeof(int)); + auto d_idx = MemoryAlloc(place, uniq_len * sizeof(int)); auto d_idx_ptr = reinterpret_cast(d_idx->ptr()); split_idx_to_shard(d_merged_keys_ptr, d_idx_ptr, @@ -1503,7 +1537,6 @@ void HeterComm::pull_merge_sparse( stream); heter_comm_kernel_->fill_shard_key( d_shard_keys_ptr, d_merged_keys_ptr, d_idx_ptr, uniq_len, stream); - sync_stream(stream); auto dst_place = platform::CPUPlace(); auto src_place = place; @@ -1520,6 +1553,7 @@ void HeterComm::pull_merge_sparse( d_right_ptr, total_device * sizeof(int), stream); + sync_stream(stream); if (!enable_gpu_direct_access_) { for (int i = 0; i < total_device; ++i) { @@ -1538,9 +1572,6 @@ void HeterComm::pull_merge_sparse( continue; } auto &node = path_[num][i].nodes_.back(); - if (!enable_gpu_direct_access_) { - sync_stream(node.in_stream); - } AnyDeviceGuard guard(resource_->dev_id(i)); ptr_tables_[i]->rwlock_->RDLock(); if (!enable_gpu_direct_access_) { @@ -1574,13 +1605,9 @@ void HeterComm::pull_merge_sparse( h_right, reinterpret_cast(d_shard_vals_ptr), val_type_size); - for (int i = 0; i < total_device; ++i) { - auto &node = path_[num][i].nodes_.front(); - sync_stream(node.out_stream); - } } - auto d_merged_vals = memory::Alloc(place, uniq_len * val_type_size); + auto d_merged_vals = MemoryAlloc(place, uniq_len * val_type_size); auto d_merged_vals_ptr = reinterpret_cast(d_merged_vals->ptr()); heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_merged_vals_ptr, @@ -1623,8 +1650,8 @@ void HeterComm::pull_normal_sparse( int h_left[total_device]; // NOLINT int h_right[total_device]; // NOLINT - auto d_left = memory::Alloc(place, total_device * sizeof(int)); - auto d_right = memory::Alloc(place, total_device * sizeof(int)); + auto d_left = MemoryAlloc(place, total_device * sizeof(int)); + auto d_right = MemoryAlloc(place, total_device * sizeof(int)); int *d_left_ptr = reinterpret_cast(d_left->ptr()); int *d_right_ptr = reinterpret_cast(d_right->ptr()); @@ -1653,16 +1680,16 @@ void HeterComm::pull_normal_sparse( XPUAPIErrorMsg[r2])); #endif - auto d_idx = memory::Alloc(place, len * sizeof(int)); + auto d_idx = MemoryAlloc(place, len * sizeof(int)); int *d_idx_ptr = reinterpret_cast(d_idx->ptr()); auto accessor_wrapper_ptr = GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); size_t val_type_size = accessor_wrapper_ptr->GetPullValueSize(max_mf_dim_); VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; - auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_shard_keys = MemoryAlloc(place, len * sizeof(KeyType)); KeyType *d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); - auto d_shard_vals = memory::Alloc(place, len * val_type_size); + auto d_shard_vals = MemoryAlloc(place, len * val_type_size); float *d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); split_idx_to_shard( @@ -1671,8 +1698,6 @@ void HeterComm::pull_normal_sparse( heter_comm_kernel_->fill_shard_key( d_shard_keys_ptr, d_keys, d_idx_ptr, len, stream); - sync_stream(stream); - auto dst_place = platform::CPUPlace(); auto src_place = place; @@ -1688,6 +1713,7 @@ void HeterComm::pull_normal_sparse( d_right_ptr, total_device * sizeof(int), stream); + sync_stream(stream); if (!enable_gpu_direct_access_) { for (int i = 0; i < total_device; ++i) { @@ -1705,9 +1731,6 @@ void HeterComm::pull_normal_sparse( continue; } auto &node = path_[num][i].nodes_.back(); - if (!enable_gpu_direct_access_) { - sync_stream(node.in_stream); - } AnyDeviceGuard guard(resource_->dev_id(i)); ptr_tables_[i]->rwlock_->RDLock(); if (!enable_gpu_direct_access_) { @@ -1740,10 +1763,6 @@ void HeterComm::pull_normal_sparse( h_right, reinterpret_cast(d_shard_vals_ptr), val_type_size); - for (int i = 0; i < total_device; ++i) { - auto &node = path_[num][i].nodes_.front(); - sync_stream(node.out_stream); - } } heter_comm_kernel_->dy_mf_fill_dvals( d_shard_vals_ptr, d_vals, d_idx_ptr, len, val_type_size, stream); @@ -1826,8 +1845,8 @@ void HeterComm::push_normal_sparse( int h_left[total_device]; // NOLINT int h_right[total_device]; // NOLINT - auto d_left = memory::Alloc(place, total_device * sizeof(int)); - auto d_right = memory::Alloc(place, total_device * sizeof(int)); + auto d_left = MemoryAlloc(place, total_device * sizeof(int)); + auto d_right = MemoryAlloc(place, total_device * sizeof(int)); int *d_left_ptr = reinterpret_cast(d_left->ptr()); int *d_right_ptr = reinterpret_cast(d_right->ptr()); @@ -1856,14 +1875,14 @@ void HeterComm::push_normal_sparse( XPUAPIErrorMsg[r2])); #endif - auto d_idx = memory::Alloc(place, len * sizeof(int)); + auto d_idx = MemoryAlloc(place, len * sizeof(int)); int *d_idx_ptr = reinterpret_cast(d_idx->ptr()); - auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_shard_keys = MemoryAlloc(place, len * sizeof(KeyType)); KeyType *d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); float *d_shard_grads_ptr; - auto d_shard_grads = memory::Alloc(place, len * grad_value_size); + auto d_shard_grads = MemoryAlloc(place, len * grad_value_size); d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); int uniq_len = len; @@ -1900,8 +1919,6 @@ void HeterComm::push_normal_sparse( stream, gpu_accessor_); - sync_stream(stream); - auto dst_place = platform::CPUPlace(); auto src_place = place; memory_copy(dst_place, @@ -1916,6 +1933,7 @@ void HeterComm::push_normal_sparse( d_right_ptr, total_device * sizeof(int), stream); + sync_stream(stream); if (!enable_gpu_direct_access_) { for (int i = 0; i < total_device; ++i) { @@ -1941,9 +1959,9 @@ void HeterComm::push_normal_sparse( continue; } auto &node = path_[dev_num][i].nodes_.back(); - if (!enable_gpu_direct_access_) { - sync_stream(node.in_stream); - } + // if (!enable_gpu_direct_access_) { + // sync_stream(node.in_stream); + // } AnyDeviceGuard guard(resource_->dev_id(i)); ptr_tables_[i]->rwlock_->WRLock(); @@ -2001,8 +2019,8 @@ void HeterComm::push_sparse( int h_left[total_device]; // NOLINT int h_right[total_device]; // NOLINT - auto d_left = memory::Alloc(place, total_device * sizeof(int)); - auto d_right = memory::Alloc(place, total_device * sizeof(int)); + auto d_left = MemoryAlloc(place, total_device * sizeof(int)); + auto d_right = MemoryAlloc(place, total_device * sizeof(int)); int *d_left_ptr = reinterpret_cast(d_left->ptr()); int *d_right_ptr = reinterpret_cast(d_right->ptr()); @@ -2031,12 +2049,12 @@ void HeterComm::push_sparse( XPUAPIErrorMsg[r2])); #endif - auto d_idx = memory::Alloc(place, len * sizeof(int)); + auto d_idx = MemoryAlloc(place, len * sizeof(int)); int *d_idx_ptr = reinterpret_cast(d_idx->ptr()); - auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_shard_keys = MemoryAlloc(place, len * sizeof(KeyType)); KeyType *d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); - auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType)); + auto d_shard_grads = MemoryAlloc(place, len * sizeof(GradType)); GradType *d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); @@ -2054,8 +2072,6 @@ void HeterComm::push_sparse( uniq_len, stream); - sync_stream(stream); - auto dst_place = platform::CPUPlace(); auto src_place = place; memory_copy(dst_place, @@ -2070,6 +2086,7 @@ void HeterComm::push_sparse( d_right_ptr, total_device * sizeof(int), stream); + sync_stream(stream); for (int i = 0; i < total_device; ++i) { int shard_len = h_right[i] - h_left[i] + 1; @@ -2092,8 +2109,6 @@ void HeterComm::push_sparse( continue; } auto &node = path_[dev_num][i].nodes_.back(); - sync_stream(node.in_stream); - AnyDeviceGuard guard(resource_->dev_id(i)); tables_[i]->rwlock_->WRLock(); tables_[i]->update(reinterpret_cast(node.key_storage), @@ -2205,7 +2220,7 @@ int HeterComm::gather_one_node_grad( ncclComm_t nccl_inner_comm = nccl_inner_comms_[gpu_num]; // alloc for size int h_node_len[total_gpu]; // NOLINT - auto d_node_len_mem = memory::Alloc(place, total_gpu * sizeof(int)); + auto d_node_len_mem = MemoryAlloc(place, total_gpu * sizeof(int)); int *d_node_len = reinterpret_cast(d_node_len_mem->ptr()); h_node_len[gpu_num] = len; @@ -2252,15 +2267,15 @@ int HeterComm::gather_one_node_grad( int h_left[total_gpu]; // NOLINT int h_right[total_gpu]; // NOLINT - auto d_left = memory::Alloc(place, total_gpu * sizeof(int)); - auto d_right = memory::Alloc(place, total_gpu * sizeof(int)); + auto d_left = MemoryAlloc(place, total_gpu * sizeof(int)); + auto d_right = MemoryAlloc(place, total_gpu * sizeof(int)); int *d_left_ptr = reinterpret_cast(d_left->ptr()); int *d_right_ptr = reinterpret_cast(d_right->ptr()); int merge_num = 0; for (int i = 0; i < total_gpu; ++i) { int index = i * max_size; - auto d_idx = memory::Alloc(place, h_node_len[i] * sizeof(int)); + auto d_idx = MemoryAlloc(place, h_node_len[i] * sizeof(int)); int *d_idx_ptr = reinterpret_cast(d_idx->ptr()); cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int)); @@ -2313,7 +2328,7 @@ int HeterComm::gather_multi_node_grad( ncclComm_t nccl_inter_comm = nccl_inter_comms_[gpu_num]; // alloc for size int h_node_len[node_size_]; // NOLINT - auto d_node_len_mem = memory::Alloc(place, node_size_ * sizeof(int)); + auto d_node_len_mem = MemoryAlloc(place, node_size_ * sizeof(int)); int *d_node_len = reinterpret_cast(d_node_len_mem->ptr()); h_node_len[0] = len; @@ -2428,8 +2443,7 @@ int HeterComm::dedup_keys_and_fillidx( size_t byte_size = sizeof(uint32_t) * (total_fea_num + 1); thread_local std::shared_ptr d_index_ptr = nullptr; - uint32_t *d_index_in = - AllocCache(&d_index_ptr, place, byte_size, stream); + uint32_t *d_index_in = AllocCache(&d_index_ptr, place, byte_size); int *d_merged_size = reinterpret_cast(&d_index_in[total_fea_num]); heter_comm_kernel_->fill_idx(d_index_in, total_fea_num, stream); @@ -2449,7 +2463,7 @@ int HeterComm::dedup_keys_and_fillidx( stream, false)); thread_local std::shared_ptr d_cache_ptr = nullptr; - d_buf = AllocCache(&d_cache_ptr, place, temp_storage_bytes, stream); + d_buf = AllocCache(&d_cache_ptr, place, temp_storage_bytes); PADDLE_ENFORCE_GPU_SUCCESS( cub::DeviceRadixSort::SortPairs(d_buf, temp_storage_bytes, @@ -2472,7 +2486,7 @@ int HeterComm::dedup_keys_and_fillidx( d_merged_size, total_fea_num, stream)); - d_buf = AllocCache(&d_cache_ptr, place, temp_storage_bytes, stream); + d_buf = AllocCache(&d_cache_ptr, place, temp_storage_bytes); PADDLE_ENFORCE_GPU_SUCCESS( cub::DeviceRunLengthEncode::Encode(d_buf, temp_storage_bytes, @@ -2492,7 +2506,7 @@ int HeterComm::dedup_keys_and_fillidx( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( NULL, temp_storage_bytes, d_merged_cnts, d_offset, merged_size, stream)); - d_buf = AllocCache(&d_cache_ptr, place, temp_storage_bytes, stream); + d_buf = AllocCache(&d_cache_ptr, place, temp_storage_bytes); PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( d_buf, temp_storage_bytes, d_merged_cnts, d_offset, merged_size, stream)); @@ -2556,7 +2570,7 @@ void HeterComm::pull_sparse_all2all( AnyDeviceGuard guard(gpu_id); auto &loc = storage_[gpu_id]; // get from local table - auto stream = resource_->comm_stream(gpu_id, 0); + auto stream = resource_->local_stream(gpu_id, 0); size_t gather_inner_size = 0; size_t pull_size = 0; @@ -2927,7 +2941,7 @@ void HeterComm::partition_shard_keys( thread_local std::shared_ptr d_offset_tmp = nullptr; uint32_t *d_left = AllocCache( - &d_offset_tmp, place, (len * 3 + shard_num * 2) * sizeof(int), stream); + &d_offset_tmp, place, (len * 3 + shard_num * 2) * sizeof(int)); uint32_t *d_right = &d_left[shard_num]; // init cudaMemsetAsync(d_left, -1, shard_num * 2 * sizeof(int), stream); @@ -2954,8 +2968,7 @@ void HeterComm::partition_shard_keys( stream); thread_local std::shared_ptr d_temp_storage = nullptr; - void *d_buf = - AllocCache(&d_temp_storage, place, temp_storage_bytes, stream); + void *d_buf = AllocCache(&d_temp_storage, place, temp_storage_bytes); heter_comm_kernel_->sort_pairs(d_buf, temp_storage_bytes, d_shard_index_tmp_ptr, @@ -3017,8 +3030,10 @@ size_t HeterComm::send_data_by_all2all( send_size * value_bytes, cudaMemcpyDeviceToDevice, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); CHECK_EQ(send_size, h_recv_part_sizes[nccl_rank_id]); + auto nccl_stream = resource_->comm_stream(gpu_id, 0); size_t total_fea_num = 0; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); for (int i = 0; i < nccl_node_size; i++) { @@ -3034,7 +3049,7 @@ size_t HeterComm::send_data_by_all2all( ncclInt8, i, comm, - stream)); + nccl_stream)); total_fea_num += send_size; } const size_t &recv_size = h_recv_part_sizes[i]; @@ -3046,11 +3061,12 @@ size_t HeterComm::send_data_by_all2all( ncclInt8, i, comm, - stream)); + nccl_stream)); total_fea_num += recv_size; } } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(nccl_stream)); return total_fea_num; } @@ -3080,7 +3096,7 @@ size_t HeterComm:: node_size_, stream); // barrier - barrier_.wait(); + // barrier_.wait(); int all_shard_part_size = node_size_ * node_size_; int rank_offset = rank_id_ * node_size_; @@ -3097,22 +3113,27 @@ size_t HeterComm:: node_size_ * sizeof(int), cudaMemcpyHostToDevice, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + cache.node_barrier_.Resume(); auto &comm = nccl_inter_comms_[gpu_id]; + auto nccl_stream = resource_->comm_stream(gpu_id, 0); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( &res.d_node_size_ptr[rank_offset], reinterpret_cast(res.d_node_size_ptr), node_size_, ncclInt, comm, - stream)); + nccl_stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(nccl_stream)); + cache.node_barrier_.Pause(); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&h_push_fea_sizes[0], res.d_node_size_ptr, all_shard_part_size * sizeof(int), cudaMemcpyDeviceToHost, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - cache.node_barrier_.Pause(); size_t *h_remote_part_sizes = res.h_remote_part_sizes.data(); size_t *h_remote_part_offsets = res.h_remote_part_offsets.data(); @@ -3126,7 +3147,7 @@ size_t HeterComm:: size_t &remote_size = h_remote_part_offsets[node_size_]; cache.alloc(remote_size, max_type_size_, HeterCommType::COPY_KEY); // barrier - barrier_.wait(); + // barrier_.wait(); size_t total_fea_num = 0; if (rdma_checker_->need_rdma_trans()) { @@ -3296,9 +3317,9 @@ void HeterComm:: const cudaStream_t &stream) { auto &my_cache = storage_[gpu_id]; // restore vals - heter_comm_kernel_->scatter_vals( - reinterpret_cast(d_in_vals), // in + heter_comm_kernel_->gather_vals( reinterpret_cast(my_cache.d_merged_push_vals), // out + reinterpret_cast(d_in_vals), // in my_cache.pull_res.d_restore_keys_idx, my_cache.pull_res.h_recv_fea_num, value_bytes, @@ -3458,7 +3479,7 @@ void HeterComm::push_sparse_all2all( } auto &my_cache = storage_[gpu_id]; my_cache.all2all_span_.Resume(); - auto stream = resource_->comm_stream(gpu_id, 0); + auto stream = resource_->local_stream(gpu_id, 0); // tracker if (FLAGS_enable_tracker_all2all) { // check push grads @@ -3648,15 +3669,15 @@ size_t HeterComm::merge_grad( platform::CUDADeviceGuard guard(gpu_id); auto place = platform::CUDAPlace(gpu_id); thread_local std::shared_ptr d_fea_num_info = nullptr; - uint32_t *d_offset = AllocCache( - &d_fea_num_info, place, sizeof(uint32_t) * len * 4, stream); + uint32_t *d_offset = + AllocCache(&d_fea_num_info, place, sizeof(uint32_t) * len * 4); uint32_t *d_sorted_idx = &d_offset[len]; uint32_t *d_restore_idx = &d_sorted_idx[len]; uint32_t *d_merged_cnts = &d_restore_idx[len]; thread_local std::shared_ptr d_sort_keys_ptr = nullptr; - KeyType *d_sorted_keys = AllocCache( - &d_sort_keys_ptr, place, sizeof(KeyType) * len, stream); + KeyType *d_sorted_keys = + AllocCache(&d_sort_keys_ptr, place, sizeof(KeyType) * len); size_t merge_size = dedup_keys_and_fillidx(gpu_id, len, @@ -3828,23 +3849,27 @@ size_t HeterComm:: cudaMemcpyHostToDevice, stream)); // barrier - barrier_.wait(); + // barrier_.wait(); my_cache.node_barrier_.Resume(); auto &comm = nccl_inter_comms_[gpu_id]; + auto nccl_stream = resource_->comm_stream(gpu_id, 0); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( &res.d_node_size_ptr[rank_id_ * node_size_], reinterpret_cast(res.d_node_size_ptr), node_size_, ncclInt, comm, - stream)); + nccl_stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(nccl_stream)); + my_cache.node_barrier_.Pause(); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&h_push_fea_sizes[0], res.d_node_size_ptr, all_shard_part_size * sizeof(int), cudaMemcpyDeviceToHost, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - my_cache.node_barrier_.Pause(); + size_t *h_remote_part_sizes = res.h_remote_part_sizes.data(); size_t *h_remote_part_offsets = res.h_remote_part_offsets.data(); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 45a87278d603b..25828e9160cab 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -458,15 +458,18 @@ void PSGPUWrapper::add_slot_feature(std::shared_ptr gpu_task) { &node_ids, &feature_ids](int i) { platform::CUDADeviceGuard guard(resource_->dev_id(i)); + auto stream = resource_->local_stream(i, 0); int* d_slot_feature_num_map; uint64_t* d_node_list_ptr; uint64_t* d_feature_list_ptr; CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_slot_feature_num_map), slot_num * sizeof(int))); - CUDA_CHECK(cudaMemcpy(d_slot_feature_num_map, + CUDA_CHECK(cudaMemcpyAsync(d_slot_feature_num_map, h_slot_feature_num_map.data(), sizeof(int) * slot_num, - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_node_list_ptr), batch * sizeof(uint64_t))); CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_feature_list_ptr), @@ -479,10 +482,11 @@ void PSGPUWrapper::add_slot_feature(std::shared_ptr gpu_task) { real_batch = (pos + batch) <= node_ids[i].size() ? batch : node_ids[i].size() - pos; - CUDA_CHECK(cudaMemcpy(d_node_list_ptr, + CUDA_CHECK(cudaMemcpyAsync(d_node_list_ptr, node_ids[i].data() + pos, real_batch * sizeof(uint64_t), - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, + stream)); int ret = gpu_graph_ptr->get_feature_of_nodes(i, d_node_list_ptr, d_feature_list_ptr, @@ -495,12 +499,14 @@ void PSGPUWrapper::add_slot_feature(std::shared_ptr gpu_task) { 0, platform::errors::PreconditionNotMet("get_feature_of_nodes error")); - CUDA_CHECK(cudaMemcpy(feature_ids[i].data() + pos * fea_num_per_node, + CUDA_CHECK(cudaMemcpyAsync(feature_ids[i].data() + pos * fea_num_per_node, d_feature_list_ptr, real_batch * fea_num_per_node * sizeof(uint64_t), - cudaMemcpyDeviceToHost)); + cudaMemcpyDeviceToHost, + stream)); pos += real_batch; } + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); cudaFree(d_slot_feature_num_map); cudaFree(d_node_list_ptr); cudaFree(d_feature_list_ptr); @@ -825,10 +831,12 @@ void PSGPUWrapper::FilterPull(std::shared_ptr gpu_task, continue; } if (dedup_size == pos) { + CHECK(shard_values[dedup_size] != 0); ++dedup_size; continue; } shard_keys[dedup_size] = shard_keys[pos]; + CHECK(shard_values[dedup_size] != 0); ++dedup_size; } shard_keys.resize(dedup_size); @@ -856,12 +864,14 @@ void PSGPUWrapper::MergePull(std::shared_ptr gpu_task) { timeline.Start(); auto fleet_ptr = paddle::distributed::FleetWrapper::GetInstance(); std::vector> task_futures; + std::vector> dim_pass_values(multi_mf_dim_, nullptr); for (int dim_id = 0; dim_id < multi_mf_dim_; ++dim_id) { auto pass_values = fleet_ptr->worker_ptr_->TakePassSparseReferedValues( table_id_, gpu_task->pass_id_, dim_id); if (pass_values == nullptr) { continue; } + dim_pass_values[dim_id] = pass_values; for (int shard_id = 0; shard_id < thread_keys_shard_num_; ++shard_id) { auto& merge_values = pass_values->at(shard_id); task_futures.emplace_back(pull_thread_pool_[shard_id]->enqueue( @@ -896,6 +906,8 @@ void PSGPUWrapper::MergePull(std::shared_ptr gpu_task) { } last_key = merge_key; shard_keys[dedup_index] = merge_key; + CHECK(merge_values.values[k] != 0) + << "num=" << merge_num << ", pos=" << k << ", key=" << merge_key << " is nullptr"; shard_values[dedup_index] = CONV2FEATURE_PTR(merge_values.values[k]); ++k; @@ -910,6 +922,8 @@ void PSGPUWrapper::MergePull(std::shared_ptr gpu_task) { } last_key = merge_key; shard_keys[dedup_index] = merge_key; + CHECK(merge_values.values[k] != 0) + << "num=" << merge_num << ", pos=" << k << ", key=" << merge_key << " is nullptr"; shard_values[dedup_index] = CONV2FEATURE_PTR(merge_values.values[k]); ++k; @@ -1435,7 +1449,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { this->hbm_pools_[i * this->multi_mf_dim_ + j]->mem(), len, feature_value_size, - 500000, + 4 * 1024 * 1024, 2); if (device_dim_keys.size() > 0) { VLOG(3) << "show table: " << i @@ -1466,6 +1480,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { size_t total_len = 0; stagetime.Start(); struct task_info task; + auto stream = resource_->local_stream(i, 0); while (cpu_reday_channels_[i]->Get(task)) { auto hbm = this->hbm_pools_[task.device_id * this->multi_mf_dim_ + task.multi_mf_dim] @@ -1475,12 +1490,14 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { accessor_wrapper_ptr->GetFeatureValueSize(mf_dim); auto hbm_start = hbm + task.offset * feature_value_size; CUDA_CHECK( - cudaMemcpy(hbm_start, + cudaMemcpyAsync(hbm_start, task.build_values.get() + task.start * feature_value_size, (task.end - task.start) * feature_value_size, - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, + stream)); total_len += (task.end - task.start); } + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); stagetime.Pause(); timer.Pause(); @@ -1760,7 +1777,7 @@ void PSGPUWrapper::HbmToSparseTable() { platform::Timer tm; tm.Start(); PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); - + auto stream = resource_->local_stream(i, 0); size_t total_len = 0; // multi mf dim for (int j = 0; j < this->multi_mf_dim_; ++j) { @@ -1783,10 +1800,11 @@ void PSGPUWrapper::HbmToSparseTable() { uint64_t offset = start * feature_value_size; char* test_build_values = build_values.get(); - cudaMemcpy(test_build_values, + cudaMemcpyAsync(test_build_values, hbm_pool->mem() + offset, feature_value_size * real_len, - cudaMemcpyDeviceToHost); + cudaMemcpyDeviceToHost, + stream); for (size_t k = 0; k < real_len; k = k + once_cpu_num) { struct task_info task; task.build_values = build_values; @@ -1802,6 +1820,7 @@ void PSGPUWrapper::HbmToSparseTable() { } total_len += len; } + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); tm.Pause(); VLOG(1) << "dump_pool_to_cpu_func i=" << i << ", total len=" << total_len << ", span=" << tm.ElapsedSec(); diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc old mode 100644 new mode 100755 index d86fcfa0ae660..abbb553a6ba09 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -30,6 +30,9 @@ limitations under the License. */ #endif DECLARE_bool(enable_exit_when_partial_worker); +PADDLE_DEFINE_EXPORTED_bool(gpugraph_force_device_batch_num_equal, + false, + "enable force_device_batch_num_equal, default false"); namespace paddle { namespace framework { @@ -145,11 +148,13 @@ bool HogwildWorker::CheckBatchNum(int flag) { } else if (flag < 0) { flag = 0; } - g_barrier.wait(); +// g_barrier.wait(); float *stat_ptr = sync_stat_.data(); auto comm = platform::NCCLCommContext::Instance().Get(0, place_.GetDeviceId()); - auto stream = static_cast(dev_ctx_)->stream(); +// auto stream = static_cast(dev_ctx_)->stream(); +// PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + auto stream = comm->stream(); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&stat_ptr[flag], &stat_ptr[2], 1, @@ -157,13 +162,14 @@ bool HogwildWorker::CheckBatchNum(int flag) { ncclProd, comm->comm(), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret, // output &stat_ptr[2], sizeof(float), cudaMemcpyDeviceToHost, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - g_barrier.wait(); +// g_barrier.wait(); #endif return (ret > 0.0); } @@ -210,7 +216,7 @@ void HogwildWorker::TrainFilesWithProfiler() { while (1) { cur_batch = device_reader_->Next(); #if defined(PADDLE_WITH_GPU_GRAPH) - if (is_multi_node) { + if (FLAGS_gpugraph_force_device_batch_num_equal || is_multi_node) { if (!CheckBatchNum(cur_batch)) { break; } @@ -342,7 +348,7 @@ void HogwildWorker::TrainFiles() { while (1) { cur_batch = device_reader_->Next(); #if defined(PADDLE_WITH_GPU_GRAPH) - if (is_multi_node) { + if (FLAGS_gpugraph_force_device_batch_num_equal || is_multi_node) { if (!CheckBatchNum(cur_batch)) { break; } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index f7f05da634013..fa191ea749fe4 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -451,7 +451,6 @@ void TensorCopy(const Tensor& src, Tensor* dst) { TensorCopyImpl(src, dst_place, ctx, dst); } - void TensorCopySync(const Tensor& src, const platform::Place& dst_place, Tensor* dst) { @@ -571,32 +570,52 @@ void TensorCopySync(const Tensor& src, } else if (platform::is_gpu_place(src_place) && // NOLINT platform::is_cuda_pinned_place(dst_place)) { - memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr); + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(src_place)) + ->stream(); + memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream); + platform::GpuStreamSync(stream); } else if (platform::is_gpu_place(src_place) && // NOLINT platform::is_cpu_place(dst_place)) { auto src_gpu_place = src_place; auto dst_cpu_place = dst_place; - memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr); + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(src_place)) + ->stream(); + memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); + platform::GpuStreamSync(stream); } else if (platform::is_cpu_place(src_place) && // NOLINT platform::is_gpu_place(dst_place)) { auto src_cpu_place = src_place; auto dst_gpu_place = dst_place; - memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, nullptr); + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(dst_place)) + ->stream(); + memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); + platform::GpuStreamSync(stream); } else if (platform::is_gpu_place(src_place) && // NOLINT platform::is_gpu_place(dst_place)) { auto src_gpu_place = src_place; auto dst_gpu_place = dst_place; - memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr); + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(src_place)) + ->stream(); + memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); + platform::GpuStreamSync(stream); } else if (platform::is_cuda_pinned_place(src_place) && // NOLINT platform::is_gpu_place(dst_place)) { auto src_pinned_place = src_place; auto dst_gpu_place = dst_place; + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(dst_place)) + ->stream(); memory::Copy( - dst_gpu_place, dst_ptr, src_pinned_place, src_ptr, size, nullptr); + dst_gpu_place, dst_ptr, src_pinned_place, src_ptr, size, stream); + platform::GpuStreamSync(stream); } else { // NOLINT PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc index dc1091d47a2d8..6c62328dcab1e 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc @@ -37,6 +37,11 @@ PADDLE_DEFINE_EXPORTED_READONLY_bool( "chunk would be freed when out of memory occurs. This flag " "only works when FLAGS_allocator_strategy=auto_growth."); +PADDLE_DEFINE_EXPORTED_READONLY_bool( + print_allocator_trace_info, + false, + "print trace memory info"); + namespace paddle { namespace memory { namespace allocation { @@ -186,8 +191,9 @@ uint64_t AutoGrowthBestFitAllocator::FreeIdleChunks() { ++chunk_it; } } - - Trace(); + if (FLAGS_print_allocator_trace_info) { + Trace(); + } return bytes; } diff --git a/paddle/fluid/memory/malloc.h b/paddle/fluid/memory/malloc.h index 49ced76c3371c..1b76e3a863350 100644 --- a/paddle/fluid/memory/malloc.h +++ b/paddle/fluid/memory/malloc.h @@ -56,6 +56,37 @@ extern uint64_t Release(const platform::CUDAPlace& place, gpuStream_t stream); void RecordStream(std::shared_ptr allocation, gpuStream_t stream); gpuStream_t GetStream(const std::shared_ptr& allocation); + +template +struct ThrustAllocator { + typedef char value_type; + ThrustAllocator(platform::Place place, StreamType stream) { + VLOG(2) << "construct allocator"; + place_ = place; + stream_ = stream; + } + ~ThrustAllocator() { VLOG(2) << "destory allocator"; } + char *allocate(std::ptrdiff_t num_bytes) { + VLOG(2) << "allocate " << num_bytes << " bytes"; + auto storage = memory::AllocShared(place_, num_bytes, + phi::Stream(reinterpret_cast(stream_))); + char *ptr = reinterpret_cast(storage->ptr()); + busy_allocation_.emplace(std::make_pair(ptr, storage)); + return ptr; + } + void deallocate(char *ptr, size_t) { + VLOG(2) << "deallocate "; + allocation_map_type::iterator iter = busy_allocation_.find(ptr); + CHECK(iter != busy_allocation_.end()); + busy_allocation_.erase(iter); + } + private: + typedef std::unordered_map> + allocation_map_type; + allocation_map_type busy_allocation_; + platform::Place place_; + StreamType stream_; +}; #endif } // namespace memory } // namespace paddle diff --git a/paddle/fluid/operators/collective/c_allreduce_x_op.cc b/paddle/fluid/operators/collective/c_allreduce_x_op.cc new file mode 100644 index 0000000000000..8d42c7c94053c --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_x_op.cc @@ -0,0 +1,194 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#elif defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_XPU) +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" +#endif +#include "paddle/fluid/platform/collective_helper.h" +namespace paddle { +namespace operators { + +class CAllReduceXOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override {} + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } +}; +template +class CAllReduceXOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto in_tensors = ctx.MultiInput("X"); + auto out_tensors = ctx.MultiOutput("Out"); + + PADDLE_ENFORCE_EQ(in_tensors.size(), + out_tensors.size(), + platform::errors::InvalidArgument( + "The number of CReduceX operator's input and " + "output is not match, " + "input number is %u, output number is %u.", + in_tensors.size(), + out_tensors.size())); + + int ring_id = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); + +#if defined(PADDLE_WITH_NCCL) + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + stream = dynamic_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } +#elif defined(PADDLE_WITH_XPU_BKCL) + auto comm = platform::BKCLCommContext::Instance().Get(ring_id, place); + XPUStream stream = static_cast(dev_ctx) + ->x_context() + ->xpu_stream; +#else + PADDLE_THROW("PaddlePaddle should compile with NCCL OR XPU."); +#endif + + // Init the output as input + for (size_t i = 0; i < in_tensors.size(); ++i) { + auto &out_tensor = out_tensors[i]; + if (out_tensor->IsInitialized()) { + PADDLE_ENFORCE_EQ(out_tensor->numel(), + in_tensors[i]->numel(), + platform::errors::InvalidArgument( + "The number of CReduceX operator's X[%u] and " + "Out[%u] is not match, " + "input numel is %u, output numel is %u.", + i, + i, + out_tensor->numel(), + in_tensors[i]->numel())); + } else { + out_tensor->Resize(in_tensors[i]->dims()); + out_tensor->mutable_data(place); + } + } +#if defined(PADDLE_WITH_NCCL) + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); +#endif + // allreduce sum data + for (size_t i = 0; i < in_tensors.size(); ++i) { + auto &in_tensor = in_tensors[i]; + auto &out_tensor = out_tensors[i]; + int64_t numel = in_tensor->numel(); + const T *sendbuff = in_tensor->data(); + T *recvbuff = out_tensor->mutable_data(place); +#if defined(PADDLE_WITH_NCCL) + ncclDataType_t nccl_dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(in_tensor->dtype())); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclAllReduce( + sendbuff, + recvbuff, + numel, + nccl_dtype, + ncclSum, + comm->comm(), + stream)); +#elif defined(PADDLE_WITH_XPU_BKCL) + BKCLDataType bkcl_dtype = + platform::ToBKCLDataType(framework::TransToProtoVarType(in_tensor->dtype())); + PADDLE_ENFORCE_EQ( + bkcl_all_reduce(comm->comm(), + sendbuff, + recvbuff, + numel, + bkcl_dtype, + BKCL_ADD, + stream), + BKCL_SUCCESS, + platform::errors::PreconditionNotMet("BKCL all reduce failed")); +#endif + } +#if defined(PADDLE_WITH_NCCL) + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); +#endif + } +}; +class CAllReduceXOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", + "(vector) The input tensors of callreduce_x_tensor " + "operator.") + .AsDuplicable(); + AddOutput("Out", + "(LoDTensor) The output tensor ") + .AsDuplicable(); + AddAttr("ring_id", "(int default -1) nccl ring id num.") + .SetDefault(-1); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(string::Sprintf(R"DOC( +CAllReduceX %s Operator + +Call collective ReduceX with reduce type %s. If input and output are +the same variable, in-place allreduce will be used. +Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allreduce +)DOC", + GetName(), GetName())); + } + protected: + virtual std::string GetName() { return "ReduceX"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(c_allreduce_xsum, ops::CAllReduceXOp, + ops::CAllReduceXOpMaker); +REGISTER_OP_CPU_KERNEL(c_allreduce_xsum, ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel); +#if defined(PADDLE_WITH_NCCL) +REGISTER_OP_CUDA_KERNEL(c_allreduce_xsum, ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel); +#endif +#if defined(PADDLE_WITH_XPU_BKCL) +REGISTER_OP_XPU_KERNEL(c_allreduce_xsum, ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel, + ops::CAllReduceXOpKernel); +#endif diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu index 19dc830cc430f..3a89255dc7528 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -27,37 +27,6 @@ namespace paddle { namespace operators { -struct CacheAllocator { - typedef char value_type; - CacheAllocator(platform::Place place) { - VLOG(2) << "construct allocator"; - place_ = place; - } - - ~CacheAllocator() { VLOG(2) << "destory allocator"; } - - char *allocate(std::ptrdiff_t num_bytes) { - VLOG(2) << "allocate " << num_bytes << " bytes"; - auto storage = memory::AllocShared(place_, num_bytes); - char *ptr = reinterpret_cast(storage->ptr()); - busy_allocation_.emplace(std::make_pair(ptr, storage)); - return ptr; - } - - void deallocate(char *ptr, size_t) { - VLOG(2) << "deallocate "; - allocation_map_type::iterator iter = busy_allocation_.find(ptr); - CHECK(iter != busy_allocation_.end()); - busy_allocation_.erase(iter); - } - - private: - typedef std::unordered_map> - allocation_map_type; - allocation_map_type busy_allocation_; - platform::Place place_; -}; - template struct ReorderFunctor { ReorderFunctor(const T *x, const int64_t *shuffle_idx, T *y, int64_t stride) @@ -121,7 +90,7 @@ class ShuffleBatchCUDAKernel : public framework::OpKernel { auto &dev_ctx = ctx.template device_context(); #ifdef PADDLE_WITH_CUDA - CacheAllocator allocator(ctx.GetPlace()); + paddle::memory::ThrustAllocator allocator(ctx.GetPlace(), dev_ctx.stream()); const auto &exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream()); #else const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); diff --git a/paddle/fluid/platform/device/gpu/gpu_info.cc b/paddle/fluid/platform/device/gpu/gpu_info.cc index 7cceb8ccec3e1..c9258cd1ca5dd 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.cc +++ b/paddle/fluid/platform/device/gpu/gpu_info.cc @@ -63,7 +63,6 @@ PADDLE_DEFINE_EXPORTED_bool(enable_gpu_memory_usage_log_mb, constexpr static float fraction_reserve_gpu_memory = 0.05f; -USE_GPU_MEM_STAT; namespace paddle { namespace platform { @@ -221,7 +220,7 @@ class RecordedGpuMallocHelper { gpuError_t Malloc(void **ptr, size_t size, bool malloc_managed_memory = false) { - LockGuardPtr lock(mtx_); +// LockGuardPtr lock(mtx_); if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) { return gpuErrorOutOfMemory; } @@ -246,7 +245,6 @@ class RecordedGpuMallocHelper { #endif if (result == gpuSuccess) { cur_size_.fetch_add(size); - STAT_INT_ADD("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size); DEVICE_MEMORY_STAT_UPDATE(Reserved, dev_id_, size); platform::RecordMemEvent(ptr, GPUPlace(dev_id_), @@ -288,7 +286,6 @@ class RecordedGpuMallocHelper { #endif PADDLE_ENFORCE_GPU_SUCCESS(err); cur_size_.fetch_sub(size); - STAT_INT_SUB("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size); DEVICE_MEMORY_STAT_UPDATE(Reserved, dev_id_, -size); platform::RecordMemEvent(ptr, GPUPlace(dev_id_), diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index 15a708f02f497..8390e63e0dc9e 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -33,6 +33,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function_impl.h" #include "unsupported/Eigen/CXX11/Tensor" +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#endif namespace phi { namespace funcs { @@ -138,23 +141,116 @@ DEFINE_CPU_TRANS_NORMAL(phi::dtype::complex); DEFINE_CPU_TRANS_NORMAL(phi::dtype::complex); struct TensorSetConstantCPU { - TensorSetConstantCPU(paddle::framework::Tensor* tensor, float value) + TensorSetConstantCPU(paddle::framework::Tensor* tensor, const void* value) : tensor_(tensor), value_(value) {} template void apply() const { auto cpu = phi::CPUPlace(); auto* begin = tensor_->mutable_data(cpu); - std::fill(begin, begin + tensor_->numel(), static_cast(value_)); + const T* num = reinterpret_cast(value_); + std::fill(begin, begin + tensor_->numel(), static_cast(*num)); } paddle::framework::Tensor* tensor_; - float value_; + const void* value_; +}; +struct TensorSetConstantEx { + TensorSetConstantEx( + paddle::framework::Tensor* tensor, + const void* value, + paddle::platform::Place place) + : tensor_(tensor), value_(value), place_(place) {} + template + void apply() const { + auto* data = tensor_->mutable_data(place_); + int numel = tensor_->numel(); + const T* num = reinterpret_cast(value_); + if (paddle::platform::is_cpu_place(place_)) { + std::fill(data, data + numel, static_cast(*num)); + } else { + std::unique_ptr data_cpu(new T[numel]); + std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast(*num)); + paddle::memory::Copy(place_, + data, + phi::CPUPlace(), + static_cast(data_cpu.get()), + numel * sizeof(T)); + } + } + paddle::framework::Tensor* tensor_; + const void* value_; + paddle::platform::Place place_; +}; +#ifdef PADDLE_WITH_XPU +template +class XPUTensorTrait { + public: + using Type = T; +}; +template <> +class XPUTensorTrait { + public: + using Type = ::float16; +}; +template <> +class XPUTensorTrait { + public: + using Type = ::float16; +}; +template<> +class XPUTensorTrait> { +public: + using Type = int64_t; +}; +template<> +class XPUTensorTrait> { +public: + using Type = float; }; - +template <> +class XPUTensorTrait { + public: + using Type = bool; +}; +template <> +class XPUTensorTrait { + public: + using Type = int64_t; +}; +struct TensorSetConstantXPU { + TensorSetConstantXPU(const paddle::platform::DeviceContext& context, + paddle::framework::Tensor* tensor, + const void* value, + paddle::platform::Place place) + : context_(context), tensor_(tensor), value_(value), place_(place) {} + template + void apply() const { + auto* data = tensor_->mutable_data(place_); + int numel = tensor_->numel(); + using XPUInTDType = typename XPUTensorTrait::Type; + float num = static_cast(*reinterpret_cast(value_)); + auto dev_ctx = reinterpret_cast(&context_); + int ret = xpu::constant(dev_ctx->x_context(), + reinterpret_cast(data), + numel, + static_cast(num)); + PADDLE_ENFORCE_EQ( + ret, + XPU_SUCCESS, + phi::errors::External("XPU CONSTANT API return wrong value[%d %s].", + ret, + XPUAPIErrorMsg[ret])); + } + const paddle::platform::DeviceContext& context_; + paddle::framework::Tensor* tensor_; + const void* value_; + paddle::platform::Place place_; +}; +#endif template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { PADDLE_THROW(phi::errors::Unimplemented("XPUPlace is not supported")); } @@ -162,7 +258,7 @@ template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { PADDLE_THROW(phi::errors::Unimplemented("NPUPlace is not supported")); } @@ -170,7 +266,7 @@ template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { PADDLE_THROW(phi::errors::Unimplemented("NPUPinnedPlace is not supported")); } @@ -178,7 +274,7 @@ template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { PADDLE_THROW(phi::errors::Unimplemented("IPUPlace is not supported")); } @@ -186,7 +282,7 @@ template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported")); } @@ -194,7 +290,7 @@ template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { phi::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value)); } @@ -202,7 +298,7 @@ template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { PADDLE_THROW(phi::errors::Unimplemented("MLUPlace is not supported")); } @@ -210,7 +306,7 @@ template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { phi::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value)); } @@ -218,7 +314,7 @@ struct TensorSetConstantWithPlace : public std::unary_function { TensorSetConstantWithPlace(const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) + const void* value) : context_(context), tensor_(tensor), value_(value) {} template @@ -228,25 +324,27 @@ struct TensorSetConstantWithPlace const paddle::platform::DeviceContext& context_; paddle::framework::Tensor* tensor_; - float value_; + const void* value_; }; void set_constant(const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { - TensorSetConstantWithPlace func(context, tensor, value); -#ifdef PADDLE_WITH_CUSTOM_DEVICE - if (paddle::platform::is_custom_place(context.GetPlace())) { - func(phi::CPUPlace()); - return; - } -#endif + const void* value) { + auto place = context.GetPlace(); + if (paddle::platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - // tensor->place().apply_visitor(func); - paddle::platform::VisitPlace(tensor->place(), func); -#else - func(phi::CPUPlace()); + TensorSetConstantWithPlace func(context, tensor, value); + paddle::platform::VisitPlace(tensor->place(), func); +#endif + } else if (paddle::platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU + phi::VisitDataType(tensor->dtype(), + TensorSetConstantXPU(context, tensor, value, place)); #endif + } else { + phi::VisitDataType(tensor->dtype(), + TensorSetConstantEx(tensor, value, place)); + } } template struct ColwiseSum; diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index 9f0c20ccf14dc..47749101c8901 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -80,6 +80,14 @@ DEFINE_GPU_TRANS(4); DEFINE_GPU_TRANS(5); DEFINE_GPU_TRANS(6); +template +__global__ void FillConstantKernel(const int N, T* a, const T val) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + a[i] = val; + } +} + #define REINTERPRET(T, DST_PTR, SRC_PTR) \ T* DST_PTR = reinterpret_cast(SRC_PTR) @@ -217,27 +225,35 @@ DEFINE_GPU_TRANS_NORMAL(phi::dtype::complex); struct TensorSetConstantGPU { TensorSetConstantGPU(const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) + const void* value) : context_(context), tensor_(tensor), value_(value) {} template void apply() const { - SetConstant functor; - functor(reinterpret_cast(context_), - tensor_, - static_cast(value_)); +// SetConstant functor; +// functor(reinterpret_cast(context_), +// tensor_, +// static_cast(value_)); + int N = static_cast(tensor_->numel()); + if (N <= 0) { + return; + } + auto& ctx = reinterpret_cast(context_); + const T* num = reinterpret_cast(value_); + FillConstantKernel<<<(N + 512 - 1) / 512, 512, 0, ctx.stream()>>>( + N, tensor_->mutable_data(ctx.GetPlace()), static_cast(*num)); } const paddle::platform::DeviceContext& context_; paddle::framework::Tensor* tensor_; - float value_; + const void* value_; }; template <> void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value) { + const void* value) { phi::VisitDataType(tensor->dtype(), TensorSetConstantGPU(context, tensor, value)); } diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index b735587d3d53d..a939ab51206fd 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -57,11 +57,17 @@ struct SetConstant { template void set_constant_with_place(const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value); + const void* value); void set_constant(const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, - float value); + const void* value); +template +void set_constant(const paddle::platform::DeviceContext& context, + paddle::framework::Tensor* tensor, const T value) { + set_constant(context, tensor, reinterpret_cast(&value)); +} + template struct RowwiseAdd { @@ -100,31 +106,6 @@ struct RowwiseMean { paddle::framework::Tensor* vec); }; -#ifdef PADDLE_WITH_XPU -template -struct TensorSetConstantXPU { - TensorSetConstantXPU(paddle::framework::Tensor* tensor, - U value, - paddle::platform::Place place) - : tensor_(tensor), value_(value), place_(place) {} - template - void apply() const { - auto* begin = tensor_->mutable_data(place_); - int numel = tensor_->numel(); - std::unique_ptr data_cpu(new T[numel]); - std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast(value_)); - paddle::memory::Copy(place_, - begin, - phi::CPUPlace(), - static_cast(data_cpu.get()), - numel * sizeof(T)); - } - paddle::framework::Tensor* tensor_; - U value_; - paddle::platform::Place place_; -}; -#endif - template inline void TransCompute(const int dim, const Context& dev_ctx, diff --git a/paddle/phi/kernels/funcs/math_function_impl.h b/paddle/phi/kernels/funcs/math_function_impl.h index f9055fb56c913..b7eb987c60714 100644 --- a/paddle/phi/kernels/funcs/math_function_impl.h +++ b/paddle/phi/kernels/funcs/math_function_impl.h @@ -27,19 +27,11 @@ using paddle::framework::To32BitIndex; template void SetConstant::operator()( const DeviceContext& context, paddle::framework::Tensor* tensor, T num) { - bool xpu_place = false; -#ifdef PADDLE_WITH_XPU - if (paddle::platform::is_xpu_place(context.GetPlace())) { - xpu_place = true; - phi::VisitDataType( - tensor->dtype(), - TensorSetConstantXPU(tensor, num, context.GetPlace())); - } -#endif - if (!xpu_place) { - auto t = paddle::framework::EigenVector::Flatten(*tensor); - t.device(*context.eigen_device()) = t.constant(static_cast(num)); - } +// if (!paddle::platform::is_xpu_place(context.GetPlace())) { +// auto t = paddle::framework::EigenVector::Flatten(*tensor); +// t.device(*context.eigen_device()) = t.constant(static_cast(num)); +// } + set_constant(context, tensor, reinterpret_cast(&num)); } template diff --git a/paddle/phi/kernels/gpu/auc_kernel.cu b/paddle/phi/kernels/gpu/auc_kernel.cu index 5a1bb9874fe19..54a841aa7044a 100644 --- a/paddle/phi/kernels/gpu/auc_kernel.cu +++ b/paddle/phi/kernels/gpu/auc_kernel.cu @@ -204,20 +204,20 @@ void AucKernel(const Context &dev_ctx, auto *neg_in_data = stat_neg.data(); #ifdef PADDLE_WITH_CUDA if (stat_pos_in_tensor != stat_pos_out) { - cudaMemcpy( + cudaMemcpyAsync( origin_stat_pos, pos_in_data, ((1 + slide_steps) * (num_thresholds + 1) + (slide_steps > 0 ? 1 : 0)) * sizeof(int64_t), - cudaMemcpyDeviceToDevice); + cudaMemcpyDeviceToDevice, dev_ctx.stream()); } if (stat_neg_in_tensor != stat_neg_out) { - cudaMemcpy( + cudaMemcpyAsync( origin_stat_neg, neg_in_data, ((1 + slide_steps) * (num_thresholds + 1) + (slide_steps > 0 ? 1 : 0)) * sizeof(int64_t), - cudaMemcpyDeviceToDevice); + cudaMemcpyDeviceToDevice, dev_ctx.stream()); } #else if (stat_pos_in_tensor != stat_pos_out) { diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 40d7f74379fa7..084070ab76400 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -29,6 +29,7 @@ namespace cub = hipcub; #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/fluid/memory/memory.h" namespace phi { @@ -245,7 +246,8 @@ void ScanKernel(const Context& dev_ctx, #ifdef __HIPCC__ const auto& policy = thrust::hip::par.on(dev_ctx.stream()); #else - const auto& policy = thrust::cuda::par.on(dev_ctx.stream()); + paddle::memory::ThrustAllocator allocator(dev_ctx.GetPlace(), dev_ctx.stream()); + const auto &policy = thrust::cuda::par(allocator).on(dev_ctx.stream()); #endif if (reverse) { thrust::reverse_iterator> reversed_in( diff --git a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu index e78fb7892ed7d..c00411bc2d50b 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu @@ -46,12 +46,8 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( memset_size *= src_dims[i]; } const size_t& memset_bytes = memset_size * sizeof(T); - -#ifdef PADDLE_WITH_HIP - hipMemset(p_output, 0, memset_bytes); -#else - cudaMemset(p_output, 0, memset_bytes); -#endif + + cudaMemsetAsync(p_output, 0, memset_bytes, ctx.stream()); if (index_size == 0) return; diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index 4dc2794d9c949..a46588edcd836 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -24,6 +24,7 @@ #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -58,25 +59,14 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, } ctx.template Alloc(out); T* p_output = out->data(); - const size_t& memset_bytes = memset_size * sizeof(T); + + funcs::SetConstant constant_functor; if (pool_type == "SUM" || pool_type == "MEAN") { -#ifdef PADDLE_WITH_HIP - hipMemset(p_output, 0, memset_bytes); -#else - cudaMemset(p_output, 0, memset_bytes); -#endif + constant_functor(ctx, out, static_cast(0)); } else if (pool_type == "MAX") { - thrust::device_ptr p_output_ptr(p_output); - thrust::fill(thrust::device, - p_output_ptr, - p_output_ptr + memset_size, - std::numeric_limits::lowest()); + constant_functor(ctx, out, std::numeric_limits::lowest()); } else if (pool_type == "MIN") { - thrust::device_ptr p_output_ptr(p_output); - thrust::fill(thrust::device, - p_output_ptr, - p_output_ptr + memset_size, - std::numeric_limits::max()); + constant_functor(ctx, out, std::numeric_limits::lowest()); } if (index_size == 0) return; @@ -135,11 +125,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ctx.template Alloc(dst_count); int* p_dst_count = dst_count->data(); -#ifdef PADDLE_WITH_HIP - hipMemset(p_dst_count, 0, input_size * sizeof(int)); -#else - cudaMemset(p_dst_count, 0, input_size * sizeof(int)); -#endif + cudaMemsetAsync(p_dst_count, 0, input_size * sizeof(int), ctx.stream()); int64_t grid_count = (index_size + block - 1) / block; ComputeCountCUDAKernel<<>>( diff --git a/paddle/phi/kernels/gpu/unique_kernel.cu b/paddle/phi/kernels/gpu/unique_kernel.cu index 74cbae1b3c36a..8e7beab34860f 100644 --- a/paddle/phi/kernels/gpu/unique_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_kernel.cu @@ -195,23 +195,26 @@ static void UniqueFlattendCUDATensor(const Context& context, indices->Resize(phi::make_ddim({num_input})); auto* indices_data = context.template Alloc(indices); + + paddle::memory::ThrustAllocator allocator(context.GetPlace(), context.stream()); + const auto &exec_policy = thrust::cuda::par(allocator).on(context.stream()); - thrust::sequence(thrust::device, indices_data, indices_data + num_input); + thrust::sequence(exec_policy, indices_data, indices_data + num_input); thrust::sort_by_key( - thrust::device, in_data_hat, in_data_hat + num_input, indices_data); + exec_policy, in_data_hat, in_data_hat + num_input, indices_data); // 1. Calculate op result: 'out' DenseTensor range; range.Resize(phi::make_ddim({num_input + 1})); auto* range_data_ptr = context.template Alloc(&range); thrust::sequence( - thrust::device, range_data_ptr, range_data_ptr + num_input + 1); + exec_policy, range_data_ptr, range_data_ptr + num_input + 1); phi::Copy(context, in_hat, context.GetPlace(), false, out); int num_out; auto out_data = context.template Alloc(out); num_out = thrust::unique_by_key( - thrust::device, out_data, out_data + num_input, range_data_ptr, equal) + exec_policy, out_data, out_data + num_input, range_data_ptr, equal) .first - out_data; out->Resize(phi::make_ddim({num_out})); @@ -224,25 +227,27 @@ static void UniqueFlattendCUDATensor(const Context& context, DenseTensor inv_loc; inv_loc.Resize(phi::make_ddim({num_input})); auto inv_loc_data_ptr = context.template Alloc(&inv_loc); - thrust::adjacent_difference(thrust::device, + thrust::adjacent_difference(exec_policy, in_data_hat, in_data_hat + num_input, inv_loc_data_ptr, not_equal); - cudaMemset(inv_loc_data_ptr, 0, sizeof(IndexT)); + cudaMemsetAsync(inv_loc_data_ptr, 0, sizeof(IndexT), context.stream()); size_t temp_storage_bytes = 0; cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, inv_loc_data_ptr, inv_loc_data_ptr, - num_input); + num_input, + context.stream()); auto d_temp_storage = paddle::memory::Alloc(place, temp_storage_bytes); cub::DeviceScan::InclusiveSum(d_temp_storage->ptr(), temp_storage_bytes, inv_loc_data_ptr, inv_loc_data_ptr, - num_input); - thrust::scatter(thrust::device, + num_input, + context.stream()); + thrust::scatter(exec_policy, inv_loc_data_ptr, inv_loc_data_ptr + num_input, indices_data, @@ -254,11 +259,11 @@ static void UniqueFlattendCUDATensor(const Context& context, DenseTensor tmp_indices; tmp_indices.Resize(phi::make_ddim({num_input})); auto* tmp_indices_data_ptr = context.template Alloc(&tmp_indices); - thrust::copy(thrust::device, + thrust::copy(exec_policy, in_data_hat, in_data_hat + num_input, tmp_indices_data_ptr); - thrust::unique_by_key(thrust::device, + thrust::unique_by_key(exec_policy, tmp_indices_data_ptr, tmp_indices_data_ptr + num_input, indices_data, @@ -271,10 +276,10 @@ static void UniqueFlattendCUDATensor(const Context& context, counts->Resize(phi::make_ddim({num_out})); auto count_data = context.template Alloc(counts); // init 'count_data' as 0 - thrust::fill(thrust::device, count_data, count_data + num_out, 0); + thrust::fill(exec_policy, count_data, count_data + num_out, 0); thrust::device_ptr range_data_ptr_dev(range_data_ptr); range_data_ptr_dev[num_out] = num_input; - thrust::adjacent_difference(thrust::device, + thrust::adjacent_difference(exec_policy, range_data_ptr + 1, range_data_ptr + num_out + 1, count_data); @@ -300,24 +305,26 @@ static void ComputeUniqueDims(const Context& context, equal_T equal, not_equal_T not_equal, int64_t row) { + paddle::memory::ThrustAllocator allocator(context.GetPlace(), context.stream()); + const auto &exec_policy = thrust::cuda::par(allocator).on(context.stream()); // 1. inverse indices: 'inverse' inverse->Resize(phi::make_ddim({row})); auto* inverse_data = context.template Alloc(inverse); DenseTensor inv_loc; inv_loc.Resize(phi::make_ddim({row})); auto inv_loc_data_ptr = context.template Alloc(&inv_loc); - thrust::adjacent_difference(thrust::device, + thrust::adjacent_difference(exec_policy, sorted_indices_data, sorted_indices_data + row, inv_loc_data_ptr, not_equal); thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); inv_loc_data_dev[0] = 0; - thrust::inclusive_scan(thrust::device, + thrust::inclusive_scan(exec_policy, inv_loc_data_ptr, inv_loc_data_ptr + row, inv_loc_data_ptr); - thrust::scatter(thrust::device, + thrust::scatter(exec_policy, inv_loc_data_ptr, inv_loc_data_ptr + row, sorted_indices_data, @@ -327,9 +334,9 @@ static void ComputeUniqueDims(const Context& context, DenseTensor range; range.Resize(phi::make_ddim({row + 1})); auto range_data_ptr = context.template Alloc(&range); - thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1); + thrust::sequence(exec_policy, range_data_ptr, range_data_ptr + row + 1); int num_out; - num_out = thrust::unique_by_key(thrust::device, + num_out = thrust::unique_by_key(exec_policy, sorted_indices_data, sorted_indices_data + row, range_data_ptr, @@ -343,9 +350,9 @@ static void ComputeUniqueDims(const Context& context, // 3. counts: 'counts' counts->Resize(phi::make_ddim({num_out})); auto* count_data = context.template Alloc(counts); - thrust::fill(thrust::device, count_data, count_data + row, 0); + thrust::fill(exec_policy, count_data, count_data + row, 0); thrust::adjacent_difference( - thrust::device, range_data_ptr + 1, range_data_ptr + row + 1, count_data); + exec_policy, range_data_ptr + 1, range_data_ptr + row + 1, count_data); } // Calculate unique when 'axis' is set @@ -393,10 +400,12 @@ static void UniqueDimsCUDATensor(const Context& context, auto* sorted_indices_data = context.template Alloc(indices); // 2. Calculate 'indices', 'inverse', 'counts' - // Init index and sort + // Init index and sort + paddle::memory::ThrustAllocator allocator(context.GetPlace(), context.stream()); + const auto &exec_policy = thrust::cuda::par(allocator).on(context.stream()); thrust::sequence( - thrust::device, sorted_indices_data, sorted_indices_data + row); - thrust::sort(thrust::device, + exec_policy, sorted_indices_data, sorted_indices_data + row); + thrust::sort(exec_policy, sorted_indices_data, sorted_indices_data + row, LessThan(col, in_trans_data)); diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index 5983af9ecaa09..ac68b89ab02f0 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -544,57 +544,86 @@ def _insert_fuse_allreduce_ops(self): if grad is None: return - # init output_grads - output_grads = input_grads - # init fused_output with temp shape, it will calculate real shape depend on inputs - fused_output = block.create_var(name="fused_output", - shape=[1], - persistable=False, - dtype=core.VarDesc.VarType.FP32, - stop_gradient=True) - # fuse all grad tensors - coalesce_tensor_attrs = { - "copy_data": True, - "set_constant": False, - "dtype": core.VarDesc.VarType.FP32 - } - block._insert_op(global_offset, - type='coalesce_tensor', - inputs={'Input': input_grads}, - outputs={ - 'Output': output_grads, - 'FusedOutput': fused_output - }, - attrs=coalesce_tensor_attrs) - global_offset += 1 - # grads aggregation of multi-gpus - block._insert_op(global_offset, - type='c_sync_calc_stream', - inputs={'X': fused_output}, - outputs={'Out': fused_output}, - attrs={self.op_role_key: OpRole.Backward}) - global_offset += 1 - ring_id = (ring_id + 1) % self.nrings - block._insert_op(global_offset, - type='c_allreduce_sum', - inputs={'X': fused_output}, - outputs={'Out': fused_output}, - attrs={ - 'ring_id': ring_id, - self.op_role_key: OpRole.Backward - }) - global_offset += 1 - - # sync before adam - block._insert_op(global_offset, - type='c_sync_comm_stream', - inputs={'X': fused_output}, - outputs={'Out': fused_output}, - attrs={ - 'ring_id': ring_id, - self.op_role_key: OpRole.Backward - }) - global_offset += 1 + if self.fuse_allreduce is 2: + # grads aggregation of multi-gpus + block._insert_op(global_offset, + type='c_sync_calc_stream', + inputs={'X': input_grads[0]}, + outputs={'Out': input_grads[0]}, + attrs={self.op_role_key: OpRole.Backward}) + global_offset += 1 + ring_id = (ring_id + 1) % self.nrings + block._insert_op(global_offset, + type='c_allreduce_xsum', + inputs={'X': input_grads}, + outputs={'Out': input_grads}, + attrs={ + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward + }) + global_offset += 1 + # sync before adam + block._insert_op(global_offset, + type='c_sync_comm_stream', + inputs={'X': input_grads[0]}, + outputs={'Out': input_grads[0]}, + attrs={ + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward + }) + global_offset += 1 + else: + # init output_grads + output_grads = input_grads + # init fused_output with temp shape, it will calculate real shape depend on inputs + fused_output = block.create_var(name="fused_output", + shape=[1], + persistable=False, + dtype=core.VarDesc.VarType.FP32, + stop_gradient=True) + # fuse all grad tensors + coalesce_tensor_attrs = { + "copy_data": True, + "set_constant": False, + "dtype": core.VarDesc.VarType.FP32 + } + block._insert_op(global_offset, + type='coalesce_tensor', + inputs={'Input': input_grads}, + outputs={ + 'Output': output_grads, + 'FusedOutput': fused_output + }, + attrs=coalesce_tensor_attrs) + global_offset += 1 + # grads aggregation of multi-gpus + block._insert_op(global_offset, + type='c_sync_calc_stream', + inputs={'X': fused_output}, + outputs={'Out': fused_output}, + attrs={self.op_role_key: OpRole.Backward}) + global_offset += 1 + ring_id = (ring_id + 1) % self.nrings + block._insert_op(global_offset, + type='c_allreduce_sum', + inputs={'X': fused_output}, + outputs={'Out': fused_output}, + attrs={ + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward + }) + global_offset += 1 + + # sync before adam + block._insert_op(global_offset, + type='c_sync_comm_stream', + inputs={'X': fused_output}, + outputs={'Out': fused_output}, + attrs={ + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward + }) + global_offset += 1 class MultiThread(GradAllReduce):