From ff3882473beacbf335b858210719c857374e3200 Mon Sep 17 00:00:00 2001 From: Andrei Pokrovsky Date: Mon, 23 Nov 2015 19:08:02 -0800 Subject: [PATCH] Add cudnn v4 batch normalization integration --- include/caffe/common_layers.hpp | 29 ++++++ src/caffe/layer_factory.cpp | 23 +++++ src/caffe/layers/batch_norm_layer.cpp | 1 - src/caffe/layers/cudnn_batch_norm_layer.cpp | 97 ++++++++++++++++++ src/caffe/layers/cudnn_batch_norm_layer.cu | 106 ++++++++++++++++++++ src/caffe/proto/caffe.proto | 8 ++ src/caffe/test/test_batch_norm_layer.cpp | 92 +++++++++++++++++ 7 files changed, 355 insertions(+), 1 deletion(-) create mode 100644 src/caffe/layers/cudnn_batch_norm_layer.cpp create mode 100644 src/caffe/layers/cudnn_batch_norm_layer.cu diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 3277f1b6c2c..385615c71a5 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -145,6 +145,35 @@ class BatchNormLayer : public Layer { Blob spatial_sum_multiplier_; }; +#ifdef USE_CUDNN +template +class CuDNNBatchNormLayer : public BatchNormLayer { + public: + explicit CuDNNBatchNormLayer(const LayerParameter& param) + : BatchNormLayer(param), epsilon_(1e-4), handles_setup_(false) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNBatchNormLayer(); + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + // cuDNN descriptors / handles + cudnnTensorDescriptor_t bottom_desc_, top_desc_; + cudnnTensorDescriptor_t scale_bias_mean_var_desc_; + cudnnBatchNormMode_t mode_; + + double epsilon_; + Blob save_mean_, save_inv_var_; + bool handles_setup_; +}; +#endif + /** * @brief Takes at least two Blob%s and concatenates them along either the num * or channel dimension, outputting the result. diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 989c348c1bb..0d1625602db 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -40,6 +40,29 @@ shared_ptr > GetConvolutionLayer( REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer); +// Get BN layer according to engine. +template +shared_ptr > GetBatchNormLayer(const LayerParameter& param) { + BatchNormParameter_Engine engine = param.batch_norm_param().engine(); + if (engine == BatchNormParameter_Engine_DEFAULT) { + engine = BatchNormParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = BatchNormParameter_Engine_CUDNN; +#endif + } + if (engine == BatchNormParameter_Engine_CAFFE) { + return shared_ptr >(new BatchNormLayer(param)); +#ifdef USE_CUDNN + } else if (engine == BatchNormParameter_Engine_CUDNN) { + return shared_ptr >(new CuDNNBatchNormLayer(param)); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(BatchNorm, GetBatchNormLayer); + // Get pooling layer according to engine. template shared_ptr > GetPoolingLayer(const LayerParameter& param) { diff --git a/src/caffe/layers/batch_norm_layer.cpp b/src/caffe/layers/batch_norm_layer.cpp index b5c91b5e1b3..64e3f4d77da 100644 --- a/src/caffe/layers/batch_norm_layer.cpp +++ b/src/caffe/layers/batch_norm_layer.cpp @@ -235,5 +235,4 @@ STUB_GPU(BatchNormLayer); #endif INSTANTIATE_CLASS(BatchNormLayer); -REGISTER_LAYER_CLASS(BatchNorm); } // namespace caffe diff --git a/src/caffe/layers/cudnn_batch_norm_layer.cpp b/src/caffe/layers/cudnn_batch_norm_layer.cpp new file mode 100644 index 00000000000..eec324e25d1 --- /dev/null +++ b/src/caffe/layers/cudnn_batch_norm_layer.cpp @@ -0,0 +1,97 @@ +#ifdef USE_CUDNN + +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNBatchNormLayer::LayerSetUp( + const vector*>& bottom, + const vector*>& top) { + BatchNormLayer::LayerSetUp(bottom, top); + + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); + cudnn::createTensor4dDesc(&scale_bias_mean_var_desc_); + + // currently only SPATIAL mode is supported (most commonly used mode) + // If there's enough demand we can implement CUDNN_BATCHNORM_PER_ACTIVATION + // though it's not currently implemented for the CPU layer + mode_ = CUDNN_BATCHNORM_SPATIAL; + + if (this->blobs_.size() > 5) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + this->blobs_.resize(5); + this->blobs_[0].reset(new Blob(1, bottom[0]->channels(), 1, 1)); + this->blobs_[1].reset(new Blob(1, bottom[0]->channels(), 1, 1)); + this->blobs_[2].reset(new Blob(1, 1, 1, 1)); + this->blobs_[3].reset(new Blob(1, bottom[0]->channels(), 1, 1)); + this->blobs_[4].reset(new Blob(1, bottom[0]->channels(), 1, 1)); + + shared_ptr > scale_filler( + GetFiller(this->layer_param_.batch_norm_param().scale_filler())); + scale_filler->Fill(this->blobs_[0].get()); + + shared_ptr > bias_filler( + GetFiller(this->layer_param_.batch_norm_param().bias_filler())); + bias_filler->Fill(this->blobs_[1].get()); + + for (int i = 2; i < 5; i++) { + caffe_set(this->blobs_[i]->count(), Dtype(0), + this->blobs_[i]->mutable_cpu_data()); + } + } + + handles_setup_ = true; +} + +template +void CuDNNBatchNormLayer::Reshape( + const vector*>& bottom, + const vector*>& top) { + BatchNormLayer::Reshape(bottom, top); + + // set up main tensors + cudnn::setTensor4dDesc( + &bottom_desc_, bottom[0]->num(), + bottom[0]->channels(), bottom[0]->height(), bottom[0]->width()); + cudnn::setTensor4dDesc( + &top_desc_, bottom[0]->num(), + bottom[0]->channels(), bottom[0]->height(), bottom[0]->width()); + + // aux tensors for caching mean & invVar from fwd to bwd pass + int C = bottom[0]->channels(); + int H = bottom[0]->height(); + int W = bottom[0]->width(); + if (mode_ == CUDNN_BATCHNORM_SPATIAL) { + save_mean_.Reshape(1, C, 1, 1); + save_inv_var_.Reshape(1, C, 1, 1); + } else if (mode_ == CUDNN_BATCHNORM_PER_ACTIVATION) { + save_mean_.Reshape(1, C, H, W); + save_inv_var_.Reshape(1, C, H, W); + } else { + LOG(FATAL) << "Unknown cudnnBatchNormMode_t"; + } + CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(scale_bias_mean_var_desc_, + bottom_desc_, mode_)); +} + +template +CuDNNBatchNormLayer::~CuDNNBatchNormLayer() { + if (!handles_setup_) return; + + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); + cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_); +} + +INSTANTIATE_CLASS(CuDNNBatchNormLayer); +} // namespace caffe +#endif diff --git a/src/caffe/layers/cudnn_batch_norm_layer.cu b/src/caffe/layers/cudnn_batch_norm_layer.cu new file mode 100644 index 00000000000..bc1ee25154d --- /dev/null +++ b/src/caffe/layers/cudnn_batch_norm_layer.cu @@ -0,0 +1,106 @@ +#ifdef USE_CUDNN +#include +#include +#include + +#include "thrust/device_vector.h" + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNBatchNormLayer::Forward_gpu( + const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* scale_data = this->blobs_[0]->gpu_data(); + const Dtype* bias_data = this->blobs_[1]->gpu_data(); + + Dtype* top_data = top[0]->mutable_gpu_data(); + Dtype* save_mean = save_mean_.mutable_gpu_data(); + Dtype* save_inv_var = save_inv_var_.mutable_gpu_data(); + + if (this->phase_ == TRAIN) { + // Call Batch normalization forward + CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( + Caffe::cudnn_handle(), + mode_, + cudnn::dataType::one, + cudnn::dataType::zero, + bottom_desc_, + bottom_data, + bottom_desc_, + top_data, + scale_bias_mean_var_desc_, + scale_data, + bias_data, + 1-this->moving_average_fraction_, + this->blobs_[3]->mutable_gpu_data(), // mean + this->blobs_[4]->mutable_gpu_data(), // variance + epsilon_, + save_mean, + save_inv_var)); + } else if (this->phase_ == TEST) { + CUDNN_CHECK(cudnnBatchNormalizationForwardInference( + Caffe::cudnn_handle(), + mode_, + cudnn::dataType::one, + cudnn::dataType::zero, + bottom_desc_, + bottom_data, + bottom_desc_, + top_data, + scale_bias_mean_var_desc_, + scale_data, + bias_data, + this->blobs_[3]->gpu_data(), // mean + this->blobs_[4]->gpu_data(), // variance + epsilon_)); + } else { + LOG(FATAL) << "Unknown phase"; + } +} + +template +void CuDNNBatchNormLayer::Backward_gpu( + const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* save_mean = save_mean_.gpu_data(); + const Dtype* save_inv_var = save_inv_var_.gpu_data(); + + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const Dtype* scale_data = this->blobs_[0]->gpu_data(); + Dtype* scale_diff = this->blobs_[0]->mutable_gpu_diff(); + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + + // call Batch Normalization Backward + CUDNN_CHECK(cudnnBatchNormalizationBackward( + Caffe::cudnn_handle(), + mode_, + cudnn::dataType::one, + cudnn::dataType::zero, + bottom_desc_, + bottom_data, + bottom_desc_, + top_diff, + bottom_desc_, + bottom_diff, + scale_bias_mean_var_desc_, + scale_data, + scale_diff, + bias_diff, + this->epsilon_, + save_mean, + save_inv_var)); +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNBatchNormLayer); + +} // namespace caffe +#endif diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index b81f39f90fa..98a10cbdae4 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -472,6 +472,14 @@ message BatchNormParameter { // Small value to add to the variance estimate so that we don't divide by // zero. optional float eps = 3 [default = 1e-5]; + optional FillerParameter scale_filler = 5; + optional FillerParameter bias_filler = 6; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; } message ContrastiveLossParameter { diff --git a/src/caffe/test/test_batch_norm_layer.cpp b/src/caffe/test/test_batch_norm_layer.cpp index 22b9667f31b..59bf9b6ed2a 100644 --- a/src/caffe/test/test_batch_norm_layer.cpp +++ b/src/caffe/test/test_batch_norm_layer.cpp @@ -130,4 +130,96 @@ namespace caffe { this->blob_top_vec_); } +#ifdef USE_CUDNN +template +class CuDNNBatchNormLayerTest : public GPUDeviceTest { + protected: + CuDNNBatchNormLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + filler_param.set_mean(-10); + filler_param.set_std(5); + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~CuDNNBatchNormLayerTest() { delete blob_bottom_; delete blob_top_; } + void checkMeanVar(const Blob *blob_bottom, int num, + int channels, int height, int width); + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +template +void CuDNNBatchNormLayerTest::checkMeanVar( + const Blob *top, + int num, int channels, int height, int width) { + typedef TypeParam Dtype; + + for (int j = 0; j < channels; ++j) { + Dtype mean = 0, var = 0; + for (int i = 0; i < num; ++i) { + for (int k = 0; k < height; ++k) { + for (int l = 0; l < width; ++l) { + Dtype data = top->data_at(i, j, k, l); + mean += data; + var += data * data; + } + } + } + mean /= num * height * width; + var /= num * height * width; + + const Dtype kErrorBound = 0.001; + EXPECT_NEAR(0, mean, kErrorBound); + EXPECT_NEAR(1, var, kErrorBound); + } +} + +TYPED_TEST_CASE(CuDNNBatchNormLayerTest, TestDtypes); + +TYPED_TEST(CuDNNBatchNormLayerTest, TestForward) { + Caffe::set_random_seed(1701); + typedef TypeParam Dtype; + LayerParameter layer_param; + BatchNormParameter* bn_param = layer_param.mutable_batch_norm_param(); + FillerParameter *scale_param = bn_param->mutable_scale_filler(); + scale_param->set_value(1); + + CuDNNBatchNormLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Reshape(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Test mean + Dtype mean, var; + int num = this->blob_bottom_->num(); + int channels = this->blob_bottom_->channels(); + int height = this->blob_bottom_->height(); + int width = this->blob_bottom_->width(); + + this->checkMeanVar(this->blob_top_, num, channels, height, width); +} + +TYPED_TEST(CuDNNBatchNormLayerTest, TestGradient) { + typedef TypeParam Dtype; + LayerParameter layer_param; + BatchNormParameter* bn_param = layer_param.mutable_batch_norm_param(); + FillerParameter *scale_param = bn_param->mutable_scale_filler(); + scale_param->set_value(1); + FillerParameter *bias_param = bn_param->mutable_bias_filler(); + bias_param->set_value(0); + + CuDNNBatchNormLayer layer(layer_param); + GradientChecker checker(1e-2, 4e-4); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} +#endif + } // namespace caffe