Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix nccl sync #205

Merged
merged 128 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
09dac41
Merge pull request #39 from PaddlePaddle/develop
qingshui May 19, 2022
6468546
Merge pull request #41 from xuewujiao/gpugraph
qingshui Jun 13, 2022
4cc6ef9
Merge pull request #42 from xuewujiao/gpugraph
qingshui Jun 15, 2022
01ebde5
Merge pull request #43 from xuewujiao/gpugraph
qingshui Jun 22, 2022
fd0a0fd
Merge pull request #44 from xuewujiao/gpugraph
qingshui Jun 22, 2022
c36e14b
Merge pull request #45 from xuewujiao/gpugraph
qingshui Jun 24, 2022
a43ca88
Merge pull request #46 from xuewujiao/gpugraph
qingshui Jun 29, 2022
de1e8bf
Optimize graph loading and sample loading performance
Jul 5, 2022
5946f70
Optimize graph loading and sample loading performance
Jul 5, 2022
a5b4aaa
fix
Jul 5, 2022
f5bfadb
fix function name
Jul 5, 2022
8c50898
Merge pull request #47 from xuewujiao/gpugraph
qingshui Jul 5, 2022
5fcc277
Optimization chart CPU processing
Jul 6, 2022
57c5821
Optimization chart CPU processing
Jul 6, 2022
4a059af
pull push dedup, use pull feature value clip size
Jul 11, 2022
ee41abb
pull push dedup, use pull feature value clip size
Jul 11, 2022
0a7044e
remove same not used kernel
Jul 12, 2022
be891d3
merge master fix confilt
Jul 12, 2022
56adf43
fix merge
Jul 12, 2022
5b90815
add fix restore idx auto fit mode according to repettition ratio
Jul 12, 2022
43c8bc6
filter zero
Jul 12, 2022
fc822cf
fix load node bug
Jul 12, 2022
4bb4f18
fix
Jul 12, 2022
ec7df9a
fix mf dim bug
Jul 12, 2022
db0415f
fix mf dim bug
Jul 12, 2022
c41a460
fix mif size bug
Jul 12, 2022
7813d5f
fix mif size bug
Jul 12, 2022
3f00d46
Merge pull request #50 from xuewujiao/gpugraph
qingshui Jul 13, 2022
c67a0c7
format ps_gpu_wrapper.cc
Jul 13, 2022
2b3bcc9
merge master
Oct 11, 2022
5439c1b
add all2all
Oct 26, 2022
d227a21
add all2all
Oct 26, 2022
8d30c7d
fix trans enable
Oct 27, 2022
2d6e683
fix dualbox
Oct 31, 2022
16ecdb9
fix dualbox
Oct 31, 2022
2924ca7
fix dualbox
Oct 31, 2022
72d3534
fix dualbox
Oct 31, 2022
99550b1
fix dualbox
Oct 31, 2022
1764a2a
fix dualbox
Oct 31, 2022
ee73d21
fix alloc
Nov 1, 2022
782e413
fix alloc
Nov 1, 2022
99ed9dd
fix alloc
Nov 1, 2022
55c8232
fix comm init bug
Nov 1, 2022
5b2f614
add auto detect flags
Nov 1, 2022
d146326
add auto detect flags
Nov 1, 2022
67b12ec
fix update one table
Nov 1, 2022
5f33ecd
add log
Nov 1, 2022
e479b78
add log
Nov 1, 2022
9a0f381
fix trans
Nov 1, 2022
b71d2ff
fix trans
Nov 1, 2022
ee30e93
fix trans
Nov 1, 2022
a1cc217
debug
Nov 1, 2022
e9bece9
fix merge keys
Nov 1, 2022
de2f9ad
fix all2ll push
Nov 2, 2022
ec3d008
fix barrier
Nov 2, 2022
ce75785
add debug
Nov 2, 2022
6f8ec08
fix merge grad
Nov 2, 2022
f4a09ee
fix merge grad
Nov 2, 2022
9cbddf8
fix merge grad
Nov 2, 2022
a14d17c
fix check
Nov 2, 2022
5ea075e
fix check
Nov 2, 2022
dbcf2f7
fix check
Nov 2, 2022
95fd124
fix check
Nov 2, 2022
bb2ad3f
fix check
Nov 2, 2022
ec40bbd
fix check
Nov 2, 2022
1e957bd
add check
Nov 2, 2022
12aaf5e
add check
Nov 2, 2022
20d2a9e
add check
Nov 2, 2022
e807b15
add check
Nov 2, 2022
0b1abd5
add check
Nov 2, 2022
84497b8
add debug
Nov 3, 2022
a104d52
add debug
Nov 3, 2022
64cc988
add debug info
Nov 3, 2022
5152bf5
add debug info
Nov 3, 2022
3cac362
fix check
Nov 3, 2022
d063d1a
fix check
Nov 3, 2022
ccd79e6
add debug
Nov 3, 2022
bad1e91
debug
Nov 3, 2022
c90849e
fix alloc
Nov 3, 2022
58ca1f9
add debug
Nov 4, 2022
4881fe8
fix merge grad
Nov 4, 2022
a56601a
add debug
Nov 4, 2022
485cb2f
add debug
Nov 4, 2022
1ffb67f
fix dual div show
Nov 4, 2022
36bc396
fix dual div show
Nov 4, 2022
fa630b2
fix dual div show
Nov 4, 2022
cc181d8
fix dual div show
Nov 4, 2022
0a8d6fd
fix dual div show
Nov 4, 2022
43e9ead
fix all2all nan inf
Nov 4, 2022
f1938c7
add hashtable collisions stats, train time 1183sec->942.81sec, close …
Nov 14, 2022
2b45e4d
add multi node stream flow optimization
Nov 14, 2022
513bd60
add gflags
Nov 14, 2022
ebab0bc
merge master from gpugraph branch
Nov 14, 2022
586f0cb
fix hashtable localcount add
Nov 14, 2022
4258850
add debug gflags, adjust log levels, optimize heter_comm create and r…
Nov 15, 2022
101f1c5
infer not need dump
Nov 15, 2022
97ffe49
train need check miss key
Nov 17, 2022
eadcb52
merge gpugraph
Nov 17, 2022
0ea5e16
fix allreduce no param scale, add fuse allreduce support
Nov 22, 2022
35f4749
fix allreduce no param scale, add fuse allreduce support
Nov 22, 2022
d182a95
fix conflicts
Nov 22, 2022
3f6eb14
fix conflicts
Nov 22, 2022
4e0343a
add PADDLE_LOSS_SCALE env, add PADDLE_FUSE_ALLREDUCE env
Nov 24, 2022
195cf64
add PADDLE_LOSS_SCALE env, add PADDLE_FUSE_ALLREDUCE env
Nov 24, 2022
4e95133
fix conflicts
Nov 24, 2022
0655495
Merge pull request #63 from xuewujiao/gpugraph
qingshui Nov 25, 2022
77dc33a
Merge pull request #64 from xuewujiao/gpugraph
qingshui Nov 29, 2022
d702bde
fix full hbm infer mode assert error
Nov 29, 2022
f4a2ced
add all2all cpu rpc server and pull ssd
Dec 5, 2022
0a95836
add all2all cpu rpc server and pull ssd
Dec 5, 2022
ede278a
fix CONFLICT
Dec 5, 2022
f35330d
add multi node fix gpu block
Dec 7, 2022
666e700
add multi node equal batch num
Dec 12, 2022
7eb447f
merge master
Dec 13, 2022
9cb5891
fix dualbox core, adjust ps_gpu_wrapper thread
Dec 23, 2022
764ea57
fix dual trainer
Dec 26, 2022
1ee6af3
sync gpugraph
Dec 30, 2022
8d0ab4f
fix merge
Dec 30, 2022
0e566df
remove default stream, add nccl sync stream dual allocator not block
Jan 20, 2023
6bf3100
fix dual merge sparse bug, thrust used paddle memory pool, thrust rem…
Jan 29, 2023
b2f2d93
merge gpugraph
Jan 29, 2023
c4adc5e
fix complier
Jan 29, 2023
967b9a5
fix set constant, remove default stream, add c_allreduce_xsum op
Jan 31, 2023
1354f4a
fix nccl block, fix dual merge sparse bug, thrust used memory pool an…
qingshui Feb 1, 2023
cb4508a
fix hpi
Feb 2, 2023
d5e59b0
fix hpi
Feb 2, 2023
e370086
fix hpi (#204) (#74)
qingshui Feb 2, 2023
89f76da
fix nccl sync
Feb 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3848,6 +3848,7 @@ size_t HeterComm<KeyType, ValType, GradType, GPUAccessor>::
all_shard_part_size * sizeof(int),
cudaMemcpyHostToDevice,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
// barrier
// barrier_.wait();
my_cache.node_barrier_.Resume();
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/hogwild_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ bool HogwildWorker::CheckBatchNum(int flag) {
ncclProd,
comm->comm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret, // output
&stat_ptr[2],
sizeof(float),
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
}
const size_t& memset_bytes = memset_size * sizeof(T);

#ifdef PADDLE_WITH_HIP
hipMemset(p_output, 0, memset_bytes);
#else
cudaMemsetAsync(p_output, 0, memset_bytes, ctx.stream());

#endif
if (index_size == 0) return;

int64_t slice_size = 1;
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/gpu/graph_send_recv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
dst_count->Resize({input_size});
ctx.template Alloc<int32_t>(dst_count);
int* p_dst_count = dst_count->data<int>();

#ifdef PADDLE_WITH_HIP
hipMemset(p_dst_count, 0, input_size * sizeof(int));
#else
cudaMemsetAsync(p_dst_count, 0, input_size * sizeof(int), ctx.stream());

#endif
int64_t grid_count = (index_size + block - 1) / block;
ComputeCountCUDAKernel<T, IndexT><<<grid_count, block, 0, ctx.stream()>>>(
p_dst_count, d_index, index_size);
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/kernels/gpu/unique_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ static void UniqueFlattendCUDATensor(const Context& context,
indices->Resize(phi::make_ddim({num_input}));
auto* indices_data = context.template Alloc<IndexT>(indices);

#ifdef PADDLE_WITH_CUDA
paddle::memory::ThrustAllocator<cudaStream_t> allocator(context.GetPlace(), context.stream());
const auto &exec_policy = thrust::cuda::par(allocator).on(context.stream());
#else
const auto &exec_policy = thrust::hip::par.on(context.stream());
#endif

thrust::sequence(exec_policy, indices_data, indices_data + num_input);
thrust::sort_by_key(
Expand Down Expand Up @@ -232,7 +236,11 @@ static void UniqueFlattendCUDATensor(const Context& context,
in_data_hat + num_input,
inv_loc_data_ptr,
not_equal);
#ifdef PADDLE_WITH_HIP
hipMemset(inv_loc_data_ptr, 0, sizeof(IndexT));
#else
cudaMemsetAsync(inv_loc_data_ptr, 0, sizeof(IndexT), context.stream());
#endif
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(NULL,
temp_storage_bytes,
Expand Down Expand Up @@ -305,8 +313,12 @@ static void ComputeUniqueDims(const Context& context,
equal_T equal,
not_equal_T not_equal,
int64_t row) {
#ifdef PADDLE_WITH_CUDA
paddle::memory::ThrustAllocator<cudaStream_t> allocator(context.GetPlace(), context.stream());
const auto &exec_policy = thrust::cuda::par(allocator).on(context.stream());
#else
const auto &exec_policy = thrust::hip::par.on(context.stream());
#endif
// 1. inverse indices: 'inverse'
inverse->Resize(phi::make_ddim({row}));
auto* inverse_data = context.template Alloc<IndexT>(inverse);
Expand Down Expand Up @@ -401,8 +413,12 @@ static void UniqueDimsCUDATensor(const Context& context,

// 2. Calculate 'indices', 'inverse', 'counts'
// Init index and sort
#ifdef PADDLE_WITH_CUDA
paddle::memory::ThrustAllocator<cudaStream_t> allocator(context.GetPlace(), context.stream());
const auto &exec_policy = thrust::cuda::par(allocator).on(context.stream());
#else
const auto &exec_policy = thrust::hip::par.on(context.stream());
#endif
thrust::sequence(
exec_policy, sorted_indices_data, sorted_indices_data + row);
thrust::sort(exec_policy,
Expand Down