From 06e30c5ae60275327d228679033d6c31366c09e4 Mon Sep 17 00:00:00 2001 From: Tung Le Duc Date: Wed, 8 Mar 2017 03:10:04 -0500 Subject: [PATCH 1/3] fix the problem of convergence for resnet50 --- src/caffe/parallel.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp index 51b00c42361..bd8b5837117 100644 --- a/src/caffe/parallel.cpp +++ b/src/caffe/parallel.cpp @@ -655,6 +655,7 @@ void OverlapSync::on_gradients_layers_ready(int l) { const vector curr_params_vecs = solver_->net() ->learnable_params_id_vecs(l); if (curr_params_vecs.size() > 0){ + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); // wait for reseting global variables if (curr_params_vecs[0] == (blobs_num_ - curr_params_vecs.size())){ int last = criticals_free_->size() - 1; From 5ab307b07923f88fd8aa27a8ec5023f6bd2da6af Mon Sep 17 00:00:00 2001 From: Tung Le Duc Date: Fri, 10 Mar 2017 04:31:47 -0500 Subject: [PATCH 2/3] a better fix, run faster --- src/caffe/layers/scale_layer.cu | 4 ++++ src/caffe/parallel.cpp | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/caffe/layers/scale_layer.cu b/src/caffe/layers/scale_layer.cu index fc9a8064db5..94916b89c98 100644 --- a/src/caffe/layers/scale_layer.cu +++ b/src/caffe/layers/scale_layer.cu @@ -6,6 +6,8 @@ namespace caffe { +__global__ void sync_scale() { } + template __global__ void ScaleForward(const int n, const Dtype* in, const Dtype* scale, const int scale_dim, const int inner_dim, @@ -53,6 +55,7 @@ void ScaleLayer::Forward_gpu( <<>>( count, bottom_data, scale_data, scale_dim_, inner_dim_, top_data); } + sync_scale<<<1, 1>>>(); } template @@ -128,6 +131,7 @@ void ScaleLayer::Backward_gpu(const vector*>& top, <<>>( count, top_diff, scale_data, scale_dim_, inner_dim_, bottom_diff); } + sync_scale<<<1, 1>>>(); } INSTANTIATE_LAYER_GPU_FUNCS(ScaleLayer); diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp index bd8b5837117..51b00c42361 100644 --- a/src/caffe/parallel.cpp +++ b/src/caffe/parallel.cpp @@ -655,7 +655,6 @@ void OverlapSync::on_gradients_layers_ready(int l) { const vector curr_params_vecs = solver_->net() ->learnable_params_id_vecs(l); if (curr_params_vecs.size() > 0){ - CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); // wait for reseting global variables if (curr_params_vecs[0] == (blobs_num_ - curr_params_vecs.size())){ int last = criticals_free_->size() - 1; From 8cf0da392f9bbcaca20603f458276c4590b465a4 Mon Sep 17 00:00:00 2001 From: Tung Le Duc Date: Mon, 13 Mar 2017 23:51:49 -0400 Subject: [PATCH 3/3] do a synchronization over the default stream at the end of the Scale layer's backward --- src/caffe/layers/scale_layer.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/caffe/layers/scale_layer.cu b/src/caffe/layers/scale_layer.cu index 94916b89c98..fce15a8076c 100644 --- a/src/caffe/layers/scale_layer.cu +++ b/src/caffe/layers/scale_layer.cu @@ -6,8 +6,6 @@ namespace caffe { -__global__ void sync_scale() { } - template __global__ void ScaleForward(const int n, const Dtype* in, const Dtype* scale, const int scale_dim, const int inner_dim, @@ -55,7 +53,6 @@ void ScaleLayer::Forward_gpu( <<>>( count, bottom_data, scale_data, scale_dim_, inner_dim_, top_data); } - sync_scale<<<1, 1>>>(); } template @@ -131,7 +128,7 @@ void ScaleLayer::Backward_gpu(const vector*>& top, <<>>( count, top_diff, scale_data, scale_dim_, inner_dim_, bottom_diff); } - sync_scale<<<1, 1>>>(); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); } INSTANTIATE_LAYER_GPU_FUNCS(ScaleLayer);