Skip to content

Commit

Permalink
Fix graph hang (#42768)
Browse files Browse the repository at this point in the history
* fix device_free

* fix hang
  • Loading branch information
Thunderbrook authored May 18, 2022
1 parent fa8c755 commit 133d63f
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ std::vector<std::vector<int64_t>> GraphTable::get_all_id(int type_id, int idx,
}
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) res[id % slice_num].push_back(id);
for (auto &id : ids) res[(uint64_t)(id) % slice_num].push_back(id);
}
return res;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
class GpuPsGraphTable : public HeterComm<int64_t, int64_t, int> {
class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
: HeterComm<int64_t, int64_t, int>(1, resource) {
: HeterComm<uint64_t, int64_t, int>(1, resource) {
load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t());
gpu_num = resource_->total_device();
Expand Down
23 changes: 15 additions & 8 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) {
keys.push_back(g.node_list[j].node_id);
offset.push_back(j);
}
build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8);
build_ps(i, (uint64_t*)keys.data(), offset.data(), keys.size(), 1024, 8);
gpu_graph_list[i].node_size = g.node_size;
} else {
build_ps(i, NULL, NULL, 0, 1024, 8);
Expand Down Expand Up @@ -572,7 +572,8 @@ void GpuPsGraphTable::build_graph_from_cpu(
keys.push_back(cpu_graph_list[i].node_list[j].node_id);
offset.push_back(j);
}
build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8);
build_ps(i, (uint64_t*)(keys.data()), offset.data(), keys.size(), 1024,
8);
gpu_graph_list[i].node_size = cpu_graph_list[i].node_size;
} else {
build_ps(i, NULL, NULL, 0, 1024, 8);
Expand Down Expand Up @@ -665,7 +666,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());

split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr,
d_right_ptr, gpu_id);

heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len,
stream);
Expand Down Expand Up @@ -708,7 +710,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
sizeof(int) * (shard_len + shard_len % 2));
// auto& node = path_[gpu_id][i].nodes_[0];
}
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
walk_to_dest(gpu_id, total_gpu, h_left, h_right,
(uint64_t*)(d_shard_keys_ptr), NULL);

for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
Expand All @@ -720,7 +723,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
node.in_stream);
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage),
tables_[i]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<int64_t*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id));
Expand Down Expand Up @@ -805,7 +808,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr,
d_right_ptr, gpu_id);
heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len,
stream);
Expand All @@ -824,7 +830,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
shard_len * (1 + sample_size) * sizeof(int64_t) +
sizeof(int) * (shard_len + shard_len % 2));
}
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
walk_to_dest(gpu_id, total_gpu, h_left, h_right,
(uint64_t*)(d_shard_keys_ptr), NULL);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
Expand All @@ -837,7 +844,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
// If not found, val is -1.
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage),
tables_[i]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<int64_t*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id));
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,

template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template class HashTable<long, int>;
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
template class HashTable<long, long>;
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
Expand All @@ -333,6 +335,8 @@ template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
int* d_vals, size_t len,
cudaStream_t stream);

template void HashTable<unsigned long, int>::get<cudaStream_t>(
const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, unsigned long>::get<cudaStream_t>(
const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
Expand All @@ -359,6 +363,9 @@ template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
size_t len,
cudaStream_t stream);

template void HashTable<unsigned long, int>::insert<cudaStream_t>(
const unsigned long* d_keys, const int* d_vals, size_t len,
cudaStream_t stream);
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
const long* d_keys, const unsigned long* d_vals, size_t len,
cudaStream_t stream);
Expand Down

0 comments on commit 133d63f

Please sign in to comment.