Skip to content

Commit

Permalink
fix nccl block, fix dual merge sparse bug, thrust used memory pool an…
Browse files Browse the repository at this point in the history
…d remote default stream used, remove sample used default stream (PaddlePaddle#202)
  • Loading branch information
qingshui authored Feb 1, 2023
1 parent cf23031 commit f91ae8d
Show file tree
Hide file tree
Showing 28 changed files with 1,293 additions and 800 deletions.
31 changes: 19 additions & 12 deletions paddle/fluid/framework/data_feed.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,7 @@ void GraphDataGenerator::DoWalkandSage() {
}

void GraphDataGenerator::clear_gpu_mem() {
platform::CUDADeviceGuard guard(gpuid_);
d_len_per_row_.reset();
d_sample_keys_.reset();
d_prefix_sum_.reset();
Expand Down Expand Up @@ -2356,8 +2357,9 @@ int GraphDataGenerator::FillWalkBuf() {
buf_state_.Reset(total_row_);
int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());

paddle::memory::ThrustAllocator<cudaStream_t> allocator(place_, sample_stream_);
thrust::random::default_random_engine engine(shuffle_seed_);
const auto &exec_policy = thrust::cuda::par.on(sample_stream_);
const auto &exec_policy = thrust::cuda::par(allocator).on(sample_stream_);
thrust::counting_iterator<int> cnt_iter(0);
thrust::shuffle_copy(exec_policy,
cnt_iter,
Expand Down Expand Up @@ -2591,8 +2593,9 @@ int GraphDataGenerator::FillWalkBufMultiPath() {
buf_state_.Reset(total_row_);
int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());

paddle::memory::ThrustAllocator<cudaStream_t> allocator(place_, sample_stream_);
thrust::random::default_random_engine engine(shuffle_seed_);
const auto &exec_policy = thrust::cuda::par.on(sample_stream_);
const auto &exec_policy = thrust::cuda::par(allocator).on(sample_stream_);
thrust::counting_iterator<int> cnt_iter(0);
thrust::shuffle_copy(exec_policy,
cnt_iter,
Expand Down Expand Up @@ -2647,22 +2650,22 @@ void GraphDataGenerator::AllocResource(int thread_id,
debug_gpu_memory_info(gpuid_, "AllocResource start");

platform::CUDADeviceGuard guard(gpuid_);
sample_stream_ = gpu_graph_ptr->get_local_stream(gpuid_);
train_stream_ = dynamic_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
if (gpu_graph_training_) {
table_ = new HashTable<uint64_t, uint64_t>(
train_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor);
train_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor, sample_stream_);
} else {
table_ = new HashTable<uint64_t, uint64_t>(
infer_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor);
infer_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor, sample_stream_);
}
}
VLOG(1) << "AllocResource gpuid " << gpuid_
<< " feed_vec.size: " << feed_vec.size()
<< " table cap: " << train_table_cap_;
sample_stream_ = gpu_graph_ptr->get_local_stream(gpuid_);
train_stream_ = dynamic_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
// feed_vec_ = feed_vec;
if (!sage_mode_) {
slot_num_ = (feed_vec.size() - 3) / 2;
Expand Down Expand Up @@ -2782,8 +2785,10 @@ void GraphDataGenerator::AllocResource(int thread_id,
ins_buf_pair_len_ = 0;
if (!sage_mode_) {
d_ins_buf_ =
memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(uint64_t));
d_pair_num_ = memory::AllocShared(place_, sizeof(int));
memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
d_pair_num_ = memory::AllocShared(place_, sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
} else {
d_ins_buf_ = memory::AllocShared(
place_,
Expand All @@ -2796,9 +2801,11 @@ void GraphDataGenerator::AllocResource(int thread_id,
}

d_slot_tensor_ptr_ =
memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *));
memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
d_slot_lod_tensor_ptr_ =
memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *));
memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));

if (sage_mode_) {
reindex_table_size_ = batch_size_ * 2;
Expand Down
65 changes: 48 additions & 17 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/timer.h"

#include "paddle/fluid/framework/threadpool.h"
#ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
Expand Down Expand Up @@ -447,18 +447,31 @@ void MultiSlotDataset::PrepareTrain() {
#endif
return;
}

inline std::vector<std::shared_ptr<paddle::framework::ThreadPool>>&
GetReadThreadPool(int thread_num) {
static std::vector<std::shared_ptr<paddle::framework::ThreadPool>>
thread_pools;
if (!thread_pools.empty()) {
return thread_pools;
}
thread_pools.resize(thread_num);
for (int i = 0; i < thread_num; ++i) {
thread_pools[i].reset(new paddle::framework::ThreadPool(1));
}
return thread_pools;
}
// load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel
template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
platform::Timer timeline;
timeline.Start();
std::vector<std::thread> load_threads;
if (gpu_graph_mode_) {
VLOG(1) << "in gpu_graph_mode";
#ifdef PADDLE_WITH_HETERPS
std::vector<std::future<void>> wait_futures;
auto pool = GetReadThreadPool(thread_num_);
for (size_t i = 0; i < readers_.size(); i++) {
readers_[i]->SetGpuGraphMode(gpu_graph_mode_);
}
Expand All @@ -473,24 +486,41 @@ void DatasetImpl<T>::LoadIntoMemory() {
}

for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(
&paddle::framework::DataFeed::DoWalkandSage, readers_[i].get()));
wait_futures.emplace_back(
pool[i]->Run([this, i]() { readers_[i]->DoWalkandSage(); }));
}
for (std::thread& t : load_threads) {
t.join();
for (auto& th : wait_futures) {
th.get();
}
wait_futures.clear();

uint64_t node_num = 0;
std::vector<uint64_t> offsets;
offsets.resize(thread_num_);

for (int i = 0; i < thread_num_; i++) {
auto& host_vec = readers_[i]->GetHostVec();
offsets[i] = node_num;
node_num += host_vec.size();
}
gpu_graph_total_keys_.reserve(node_num);
gpu_graph_total_keys_.resize(node_num);
for (int i = 0; i < thread_num_; i++) {
auto& host_vec = readers_[i]->GetHostVec();
for (size_t j = 0; j < host_vec.size(); j++) {
gpu_graph_total_keys_.push_back(host_vec[j]);
}
uint64_t off = offsets[i];
wait_futures.emplace_back(
pool[i]->Run([this, i, off]() {
auto& host_vec = readers_[i]->GetHostVec();
for (size_t j = 0; j < host_vec.size(); j++) {
gpu_graph_total_keys_[off + j] = host_vec[j];
}
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
readers_[i]->clear_gpu_mem();
}
}));
}
for (auto& th : wait_futures) {
th.get();
}
wait_futures.clear();

if (GetEpochFinish() == true) {
VLOG(0) << "epoch finish, set stat and clear sample stat!";
Expand All @@ -499,16 +529,17 @@ void DatasetImpl<T>::LoadIntoMemory() {
readers_[i]->ClearSampleState();
}
}
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
for (size_t i = 0; i < readers_.size(); i++) {
readers_[i]->clear_gpu_mem();
}
}
// if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
// for (size_t i = 0; i < readers_.size(); i++) {
// readers_[i]->clear_gpu_mem();
// }
// }

VLOG(2) << "end add edge into gpu_graph_total_keys_ size["
<< gpu_graph_total_keys_.size() << "]";
#endif
} else {
std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ class concurrent_unordered_map : public managed {
public:
concurrent_unordered_map(const concurrent_unordered_map&) = delete;
concurrent_unordered_map& operator=(const concurrent_unordered_map&) = delete;
explicit concurrent_unordered_map(size_type n,
explicit concurrent_unordered_map(cudaStream_t stream,
size_type n,
const mapped_type unused_element,
const Hasher& hf = hasher(),
const Equality& eql = key_equal(),
Expand Down Expand Up @@ -388,13 +389,13 @@ class concurrent_unordered_map : public managed {
int dev_id = 0;
CUDA_RT_CALL(cudaGetDevice(&dev_id));
CUDA_RT_CALL(cudaMemPrefetchAsync(
m_hashtbl_values, m_hashtbl_size * sizeof(value_type), dev_id, 0));
m_hashtbl_values, m_hashtbl_size * sizeof(value_type), dev_id, stream));
}
}
// Initialize kernel, set all entry to unused <K,V>
init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size>>>(
init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0, stream>>>(
m_hashtbl_values, m_hashtbl_size, unused_key, m_unused_element);
CUDA_RT_CALL(cudaStreamSynchronize(0));
CUDA_RT_CALL(cudaStreamSynchronize(stream));
CUDA_RT_CALL(cudaGetLastError());
m_enable_collision_stat = FLAGS_gpugraph_enable_hbm_table_collision_stat;
}
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,20 @@ struct NeighborSampleResultV2 {
struct NodeQueryResult {
uint64_t *val;
int actual_sample_size;
cudaStream_t stream = 0;
void set_stream(cudaStream_t stream_t) { stream = stream_t; }
uint64_t get_val() { return (uint64_t)val; }
int get_len() { return actual_sample_size; }
std::shared_ptr<memory::Allocation> val_mem;
void initialize(int query_size, int dev_id) {
platform::CUDADeviceGuard guard(dev_id);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
val_mem = memory::AllocShared(place, query_size * sizeof(uint64_t));
if (stream != 0) {
val_mem = memory::AllocShared(place, query_size * sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
} else {
val_mem = memory::AllocShared(place, query_size * sizeof(uint64_t));
}
val = (uint64_t *)val_mem->ptr();
actual_sample_size = 0;
}
Expand Down
Loading

0 comments on commit f91ae8d

Please sign in to comment.