Skip to content

Commit

Permalink
Merge pull request BVLC#80 from andrei-pokrovsky/caffe-0.14
Browse files Browse the repository at this point in the history
Add cudnn v4 batch normalization integration
  • Loading branch information
thatguymike committed Nov 24, 2015
2 parents 1206d56 + ff38824 commit e9f8357
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 1 deletion.
29 changes: 29 additions & 0 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,35 @@ class BatchNormLayer : public Layer<Dtype> {
Blob<Dtype> spatial_sum_multiplier_;
};

#ifdef USE_CUDNN
template <typename Dtype>
class CuDNNBatchNormLayer : public BatchNormLayer<Dtype> {
public:
explicit CuDNNBatchNormLayer(const LayerParameter& param)
: BatchNormLayer<Dtype>(param), epsilon_(1e-4), handles_setup_(false) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual ~CuDNNBatchNormLayer();

protected:
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

// cuDNN descriptors / handles
cudnnTensorDescriptor_t bottom_desc_, top_desc_;
cudnnTensorDescriptor_t scale_bias_mean_var_desc_;
cudnnBatchNormMode_t mode_;

double epsilon_;
Blob<Dtype> save_mean_, save_inv_var_;
bool handles_setup_;
};
#endif

/**
* @brief Takes at least two Blob%s and concatenates them along either the num
* or channel dimension, outputting the result.
Expand Down
23 changes: 23 additions & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,29 @@ shared_ptr<Layer<Dtype> > GetConvolutionLayer(

REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer);

// Get BN layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetBatchNormLayer(const LayerParameter& param) {
BatchNormParameter_Engine engine = param.batch_norm_param().engine();
if (engine == BatchNormParameter_Engine_DEFAULT) {
engine = BatchNormParameter_Engine_CAFFE;
#ifdef USE_CUDNN
engine = BatchNormParameter_Engine_CUDNN;
#endif
}
if (engine == BatchNormParameter_Engine_CAFFE) {
return shared_ptr<Layer<Dtype> >(new BatchNormLayer<Dtype>(param));
#ifdef USE_CUDNN
} else if (engine == BatchNormParameter_Engine_CUDNN) {
return shared_ptr<Layer<Dtype> >(new CuDNNBatchNormLayer<Dtype>(param));
#endif
} else {
LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
}
}

REGISTER_LAYER_CREATOR(BatchNorm, GetBatchNormLayer);

// Get pooling layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
Expand Down
1 change: 0 additions & 1 deletion src/caffe/layers/batch_norm_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,4 @@ STUB_GPU(BatchNormLayer);
#endif

INSTANTIATE_CLASS(BatchNormLayer);
REGISTER_LAYER_CLASS(BatchNorm);
} // namespace caffe
97 changes: 97 additions & 0 deletions src/caffe/layers/cudnn_batch_norm_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#ifdef USE_CUDNN

#include <vector>

#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {

template <typename Dtype>
void CuDNNBatchNormLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
BatchNormLayer<Dtype>::LayerSetUp(bottom, top);

cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createTensor4dDesc<Dtype>(&scale_bias_mean_var_desc_);

// currently only SPATIAL mode is supported (most commonly used mode)
// If there's enough demand we can implement CUDNN_BATCHNORM_PER_ACTIVATION
// though it's not currently implemented for the CPU layer
mode_ = CUDNN_BATCHNORM_SPATIAL;

if (this->blobs_.size() > 5) {
LOG(INFO) << "Skipping parameter initialization";
} else {
this->blobs_.resize(5);
this->blobs_[0].reset(new Blob<Dtype>(1, bottom[0]->channels(), 1, 1));
this->blobs_[1].reset(new Blob<Dtype>(1, bottom[0]->channels(), 1, 1));
this->blobs_[2].reset(new Blob<Dtype>(1, 1, 1, 1));
this->blobs_[3].reset(new Blob<Dtype>(1, bottom[0]->channels(), 1, 1));
this->blobs_[4].reset(new Blob<Dtype>(1, bottom[0]->channels(), 1, 1));

shared_ptr<Filler<Dtype> > scale_filler(
GetFiller<Dtype>(this->layer_param_.batch_norm_param().scale_filler()));
scale_filler->Fill(this->blobs_[0].get());

shared_ptr<Filler<Dtype> > bias_filler(
GetFiller<Dtype>(this->layer_param_.batch_norm_param().bias_filler()));
bias_filler->Fill(this->blobs_[1].get());

for (int i = 2; i < 5; i++) {
caffe_set(this->blobs_[i]->count(), Dtype(0),
this->blobs_[i]->mutable_cpu_data());
}
}

handles_setup_ = true;
}

template <typename Dtype>
void CuDNNBatchNormLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
BatchNormLayer<Dtype>::Reshape(bottom, top);

// set up main tensors
cudnn::setTensor4dDesc<Dtype>(
&bottom_desc_, bottom[0]->num(),
bottom[0]->channels(), bottom[0]->height(), bottom[0]->width());
cudnn::setTensor4dDesc<Dtype>(
&top_desc_, bottom[0]->num(),
bottom[0]->channels(), bottom[0]->height(), bottom[0]->width());

// aux tensors for caching mean & invVar from fwd to bwd pass
int C = bottom[0]->channels();
int H = bottom[0]->height();
int W = bottom[0]->width();
if (mode_ == CUDNN_BATCHNORM_SPATIAL) {
save_mean_.Reshape(1, C, 1, 1);
save_inv_var_.Reshape(1, C, 1, 1);
} else if (mode_ == CUDNN_BATCHNORM_PER_ACTIVATION) {
save_mean_.Reshape(1, C, H, W);
save_inv_var_.Reshape(1, C, H, W);
} else {
LOG(FATAL) << "Unknown cudnnBatchNormMode_t";
}
CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(scale_bias_mean_var_desc_,
bottom_desc_, mode_));
}

template <typename Dtype>
CuDNNBatchNormLayer<Dtype>::~CuDNNBatchNormLayer() {
if (!handles_setup_) return;

cudnnDestroyTensorDescriptor(bottom_desc_);
cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_);
}

INSTANTIATE_CLASS(CuDNNBatchNormLayer);
} // namespace caffe
#endif
106 changes: 106 additions & 0 deletions src/caffe/layers/cudnn_batch_norm_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <cfloat>
#include <vector>

#include "thrust/device_vector.h"

#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {

template <typename Dtype>
void CuDNNBatchNormLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* scale_data = this->blobs_[0]->gpu_data();
const Dtype* bias_data = this->blobs_[1]->gpu_data();

Dtype* top_data = top[0]->mutable_gpu_data();
Dtype* save_mean = save_mean_.mutable_gpu_data();
Dtype* save_inv_var = save_inv_var_.mutable_gpu_data();

if (this->phase_ == TRAIN) {
// Call Batch normalization forward
CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
Caffe::cudnn_handle(),
mode_,
cudnn::dataType<Dtype>::one,
cudnn::dataType<Dtype>::zero,
bottom_desc_,
bottom_data,
bottom_desc_,
top_data,
scale_bias_mean_var_desc_,
scale_data,
bias_data,
1-this->moving_average_fraction_,
this->blobs_[3]->mutable_gpu_data(), // mean
this->blobs_[4]->mutable_gpu_data(), // variance
epsilon_,
save_mean,
save_inv_var));
} else if (this->phase_ == TEST) {
CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
Caffe::cudnn_handle(),
mode_,
cudnn::dataType<Dtype>::one,
cudnn::dataType<Dtype>::zero,
bottom_desc_,
bottom_data,
bottom_desc_,
top_data,
scale_bias_mean_var_desc_,
scale_data,
bias_data,
this->blobs_[3]->gpu_data(), // mean
this->blobs_[4]->gpu_data(), // variance
epsilon_));
} else {
LOG(FATAL) << "Unknown phase";
}
}

template <typename Dtype>
void CuDNNBatchNormLayer<Dtype>::Backward_gpu(
const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_data = top[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* save_mean = save_mean_.gpu_data();
const Dtype* save_inv_var = save_inv_var_.gpu_data();

Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const Dtype* scale_data = this->blobs_[0]->gpu_data();
Dtype* scale_diff = this->blobs_[0]->mutable_gpu_diff();
Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();

// call Batch Normalization Backward
CUDNN_CHECK(cudnnBatchNormalizationBackward(
Caffe::cudnn_handle(),
mode_,
cudnn::dataType<Dtype>::one,
cudnn::dataType<Dtype>::zero,
bottom_desc_,
bottom_data,
bottom_desc_,
top_diff,
bottom_desc_,
bottom_diff,
scale_bias_mean_var_desc_,
scale_data,
scale_diff,
bias_diff,
this->epsilon_,
save_mean,
save_inv_var));
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNBatchNormLayer);

} // namespace caffe
#endif
8 changes: 8 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,14 @@ message BatchNormParameter {
// Small value to add to the variance estimate so that we don't divide by
// zero.
optional float eps = 3 [default = 1e-5];
optional FillerParameter scale_filler = 5;
optional FillerParameter bias_filler = 6;
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 15 [default = DEFAULT];
}

message ContrastiveLossParameter {
Expand Down
92 changes: 92 additions & 0 deletions src/caffe/test/test_batch_norm_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,96 @@ namespace caffe {
this->blob_top_vec_);
}

#ifdef USE_CUDNN
template <typename Dtype>
class CuDNNBatchNormLayerTest : public GPUDeviceTest<Dtype> {
protected:
CuDNNBatchNormLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
blob_top_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
filler_param.set_mean(-10);
filler_param.set_std(5);
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
}
virtual ~CuDNNBatchNormLayerTest() { delete blob_bottom_; delete blob_top_; }
void checkMeanVar(const Blob<Dtype> *blob_bottom, int num,
int channels, int height, int width);
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};

template <typename TypeParam>
void CuDNNBatchNormLayerTest<TypeParam>::checkMeanVar(
const Blob<TypeParam> *top,
int num, int channels, int height, int width) {
typedef TypeParam Dtype;

for (int j = 0; j < channels; ++j) {
Dtype mean = 0, var = 0;
for (int i = 0; i < num; ++i) {
for (int k = 0; k < height; ++k) {
for (int l = 0; l < width; ++l) {
Dtype data = top->data_at(i, j, k, l);
mean += data;
var += data * data;
}
}
}
mean /= num * height * width;
var /= num * height * width;

const Dtype kErrorBound = 0.001;
EXPECT_NEAR(0, mean, kErrorBound);
EXPECT_NEAR(1, var, kErrorBound);
}
}

TYPED_TEST_CASE(CuDNNBatchNormLayerTest, TestDtypes);

TYPED_TEST(CuDNNBatchNormLayerTest, TestForward) {
Caffe::set_random_seed(1701);
typedef TypeParam Dtype;
LayerParameter layer_param;
BatchNormParameter* bn_param = layer_param.mutable_batch_norm_param();
FillerParameter *scale_param = bn_param->mutable_scale_filler();
scale_param->set_value(1);

CuDNNBatchNormLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Reshape(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);

// Test mean
Dtype mean, var;
int num = this->blob_bottom_->num();
int channels = this->blob_bottom_->channels();
int height = this->blob_bottom_->height();
int width = this->blob_bottom_->width();

this->checkMeanVar(this->blob_top_, num, channels, height, width);
}

TYPED_TEST(CuDNNBatchNormLayerTest, TestGradient) {
typedef TypeParam Dtype;
LayerParameter layer_param;
BatchNormParameter* bn_param = layer_param.mutable_batch_norm_param();
FillerParameter *scale_param = bn_param->mutable_scale_filler();
scale_param->set_value(1);
FillerParameter *bias_param = bn_param->mutable_bias_filler();
bias_param->set_value(0);

CuDNNBatchNormLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 4e-4);
checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
this->blob_top_vec_);
}
#endif

} // namespace caffe

0 comments on commit e9f8357

Please sign in to comment.