Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace atlas/cblas routines with Eigen in the math functions #85

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
1bcdfd4
compile caffe without MKL (dependency replaced by boost::random, Eigen3)
rodrigob Dec 8, 2013
8a1ede9
Fixed uniform distribution upper bound to be inclusive
kloudkl Jan 11, 2014
9293cc2
Fixed FlattenLayer Backward_cpu/gpu have no return value
kloudkl Jan 11, 2014
d8dd5d0
Fix test stochastic pooling stepsize/threshold to be same as max pooling
kloudkl Jan 11, 2014
00b450b
Fix math funcs, add tests, change Eigen Map to unaligned for lrn_layer
kloudkl Jan 12, 2014
958f038
Fix test_data_layer segfault by adding destructor to join pthread
kloudkl Jan 12, 2014
5385b74
relax precision of MultinomialLogisticLossLayer test
shelhamer Jan 9, 2014
8d894f0
Merge pull request #28 from kloudkl/boost-eigen
shelhamer Jan 22, 2014
d74c16d
nextafter templates off one type
Jan 22, 2014
7ac4a30
mean_bound and sample_mean need referencing with this
Jan 22, 2014
3122c8a
Merge pull request #47 from alito/compileerrorsboosteigenkloudkl
jeffdonahue Jan 22, 2014
3d2696e
make uniform distribution usage compatible with boost 1.46
jeffdonahue Jan 22, 2014
f76b296
use boost variate_generator to pass tests w/ boost 1.46 (Gaussian filler
jeffdonahue Jan 22, 2014
fae6944
change all Rng's to use variate_generator for consistency
jeffdonahue Jan 22, 2014
a5f2cb1
Merge pull request #49 from jeffdonahue/boosteigencompilewithboost146
shelhamer Jan 22, 2014
6639f8f
add bernoulli rng test to demonstrate bug (generates all 0s unless p ==
jeffdonahue Jan 29, 2014
d1c9111
fix bernoulli generator bug
jeffdonahue Jan 29, 2014
ca1c462
Merge pull request #63 from jeffdonahue/bernoullirngbugfix
shelhamer Jan 30, 2014
50f0491
Replace atlas/cblas routines with Eigen in the math functions
kloudkl Feb 8, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,18 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64

INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
LIBRARIES := cudart cublas curand protobuf opencv_core opencv_highgui \
glog mkl_rt mkl_intel_thread leveldb snappy pthread boost_system \
opencv_imgproc
LIBRARIES := cudart cublas curand pthread gomp \
glog protobuf leveldb snappy boost_system \
opencv_core opencv_highgui opencv_imgproc
PYTHON_LIBRARIES := boost_python python2.7
WARNINGS := -Wall

COMMON_FLAGS := -DNDEBUG -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS)
CXXFLAGS += -pthread -fPIC -fopenmp $(COMMON_FLAGS)
NVCCFLAGS := -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
$(foreach library,$(LIBRARIES),-l$(library))
TEST_LDFLAGS += -lopenblas
PYTHON_LDFLAGS := $(LDFLAGS) $(foreach library,$(PYTHON_LIBRARIES),-l$(library))


Expand Down Expand Up @@ -132,7 +133,7 @@ runtest: test
for testbin in $(TEST_BINS); do $$testbin $(TEST_GPUID); done

$(TEST_BINS): %.testbin : %.o $(GTEST_OBJ) $(STATIC_NAME) $(TEST_HDRS)
$(CXX) $< $(GTEST_OBJ) $(STATIC_NAME) -o $@ $(LDFLAGS) $(WARNINGS)
$(CXX) $< $(GTEST_OBJ) $(STATIC_NAME) -o $@ $(LDFLAGS) $(TEST_LDFLAGS) $(WARNINGS)

$(EXAMPLE_BINS): %.bin : %.o $(STATIC_NAME)
$(CXX) $< $(STATIC_NAME) -o $@ $(LDFLAGS) $(WARNINGS)
Expand Down
8 changes: 8 additions & 0 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class Blob {
inline int count() const {return count_; }
inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
CHECK_GE(n, 0);
CHECK_LE(n, num_);
CHECK_GE(channels_, 0);
CHECK_LE(c, channels_);
CHECK_GE(height_, 0);
CHECK_LE(h, height_);
CHECK_GE(width_, 0);
CHECK_LE(w, width_);
return ((n * channels_ + c) * height_ + h) * width_ + w;
}
// Copy from source. If copy_diff is false, we copy the data; if copy_diff
Expand Down
14 changes: 11 additions & 3 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_

#include <boost/random/mersenne_twister.hpp>
#include <boost/shared_ptr.hpp>
#include <cublas_v2.h>
#include <cuda.h>
#include <curand.h>
// cuda driver types
#include <driver_types.h>
#include <glog/logging.h>
#include <mkl_vsl.h>
//#include <mkl_vsl.h>

// various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
Expand Down Expand Up @@ -78,8 +79,13 @@ class Caffe {
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}

// Returns the MKL random stream.
inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }
//inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }

typedef boost::mt19937 random_generator_t;
inline static random_generator_t &vsl_stream() { return Get().random_generator_; }

// Returns the mode: running on CPU or GPU.
inline static Brew mode() { return Get().mode_; }
// Returns the phase: TRAIN or TEST.
Expand All @@ -103,7 +109,9 @@ class Caffe {
protected:
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
VSLStreamStatePtr vsl_stream_;
//VSLStreamStatePtr vsl_stream_;
random_generator_t random_generator_;

Brew mode_;
Phase phase_;
static shared_ptr<Caffe> singleton_;
Expand Down
2 changes: 1 addition & 1 deletion include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#ifndef CAFFE_FILLER_HPP
#define CAFFE_FILLER_HPP

#include <mkl.h>
//#include <mkl.h>
#include <string>

#include "caffe/common.hpp"
Expand Down
38 changes: 36 additions & 2 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,39 @@
#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
#define CAFFE_UTIL_MATH_FUNCTIONS_H_

#include <mkl.h>
#include <cublas_v2.h>
//#include <mkl.h>
#include <eigen3/Eigen/Dense>

namespace caffe {

// Operations on aligned memory are faster than on unaligned memory.
// But unfortunately, the pointers passed in are not always aligned.
// Therefore, the memory-aligned Eigen::Map objects that wrap them
// cannot be assigned to. This happens in lrn_layer and makes
// test_lrn_layer crash with segmentation fault.
// TODO: Use aligned Eigen::Map when the pointer to be wrapped is aligned.

// Though the default map option is unaligned, making it explicit is no harm.
//const int data_alignment = Eigen::Aligned; // how is data allocated ?
const int data_alignment = Eigen::Unaligned;
typedef Eigen::Map<const Eigen::VectorXf, data_alignment> const_map_vector_float_t;
typedef Eigen::Map<Eigen::VectorXf, data_alignment> map_vector_float_t;
typedef Eigen::Map<const Eigen::VectorXd, data_alignment> const_map_vector_double_t;
typedef Eigen::Map<Eigen::VectorXd, data_alignment> map_vector_double_t;

// The default in Eigen is column-major. This is also the case if one
// of the convenience typedefs (Matrix3f, ArrayXXd, etc.) is used.
// http://eigen.tuxfamily.org/dox-devel/group__TopicStorageOrders.html
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatXf;
typedef Eigen::Map<MatXf, data_alignment> map_matrix_float_t;
typedef Eigen::Map<const MatXf, data_alignment> const_map_matrix_float_t;
typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatXd;
typedef Eigen::Map<MatXd, data_alignment> map_matrix_double_t;
typedef Eigen::Map<const MatXd, data_alignment> const_map_matrix_double_t;

// From cblas.h
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};

// Decaf gemm provides a simpler interface to the gemm functions, with the
// limitation that the data has to be contiguous in memory.
template <typename Dtype>
Expand Down Expand Up @@ -84,13 +112,19 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);

template <typename Dtype>
Dtype caffe_nextafter(const Dtype b);

template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);

template <typename Dtype>
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma);

template <typename Dtype>
void caffe_vRngBernoulli(const int n, Dtype* r, const double p);

template <typename Dtype>
void caffe_exp(const int n, const Dtype* a, Dtype* y);

Expand Down
1 change: 1 addition & 0 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ class DataLayer : public Layer<Dtype> {
public:
explicit DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~DataLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

Expand Down
24 changes: 15 additions & 9 deletions src/caffe/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ long cluster_seedgen(void) {

Caffe::Caffe()
: mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL),
curand_generator_(NULL), vsl_stream_(NULL) {
curand_generator_(NULL),
//vsl_stream_(NULL)
random_generator_()
{
// Try to create a cublas handler, and report an error if failed (but we will
// keep the program running as one might just want to run CPU code).
if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -34,21 +37,22 @@ Caffe::Caffe()
!= CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
}

// Try to create a vsl stream. This should almost always work, but we will
// check it anyway.
if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) {
LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
<< "won't be available.";
}
//if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) {
// LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
// << "won't be available.";
//}
}

Caffe::~Caffe() {
if (cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_));
if (curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(curand_generator_));
}
if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
};
//if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
}

void Caffe::set_random_seed(const unsigned int seed) {
// Curand seed
Expand All @@ -64,8 +68,10 @@ void Caffe::set_random_seed(const unsigned int seed) {
LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
}
// VSL seed
VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
//VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
//VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
Get().random_generator_ = random_generator_t(seed);

}

void Caffe::SetDevice(const int device_id) {
Expand Down
10 changes: 10 additions & 0 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ namespace caffe {

template <typename Dtype>
void* DataLayerPrefetch(void* layer_pointer) {
CHECK(layer_pointer);
DataLayer<Dtype>* layer = reinterpret_cast<DataLayer<Dtype>*>(layer_pointer);
CHECK(layer);
Datum datum;
CHECK(layer->prefetch_data_);
Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
const Dtype scale = layer->layer_param_.scale();
Expand All @@ -38,6 +41,8 @@ void* DataLayerPrefetch(void* layer_pointer) {
const Dtype* mean = layer->data_mean_.cpu_data();
for (int itemid = 0; itemid < batchsize; ++itemid) {
// get a blob
CHECK(layer->iter_);
CHECK(layer->iter_->Valid());
datum.ParseFromString(layer->iter_->value().ToString());
const string& data = datum.data();
if (cropsize) {
Expand Down Expand Up @@ -109,6 +114,11 @@ void* DataLayerPrefetch(void* layer_pointer) {
return (void*)NULL;
}

template <typename Dtype>
DataLayer<Dtype>::~DataLayer<Dtype>() {
// Finally, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
}

template <typename Dtype>
void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
Expand Down
7 changes: 5 additions & 2 deletions src/caffe/layers/dropout_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <limits>

#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/layer.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/vision_layers.hpp"
Expand Down Expand Up @@ -34,8 +35,10 @@ void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const int count = bottom[0]->count();
if (Caffe::phase() == Caffe::TRAIN) {
// Create random numbers
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
count, mask, 1. - threshold_);
//viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
// count, mask, 1. - threshold_);
caffe_vRngBernoulli<int>(count, mask, 1. - threshold_);

for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * mask[i] * scale_;
}
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layers/flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Dtype FlattenLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
caffe_copy(count_, top_diff, bottom_diff);
return Dtype(0);
}


Expand All @@ -52,6 +53,7 @@ Dtype FlattenLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
caffe_gpu_copy(count_, top_diff, bottom_diff);
return Dtype(0);
}

INSTANTIATE_CLASS(FlattenLayer);
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/inner_product_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2013 Yangqing Jia


#include <mkl.h>
//#include <mkl.h>
#include <cublas_v2.h>

#include <vector>
Expand Down
17 changes: 11 additions & 6 deletions src/caffe/test/test_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"

#include "caffe/util/math_functions.hpp"
#include "caffe/test/test_caffe_main.hpp"

namespace caffe {
Expand All @@ -20,7 +20,8 @@ TEST_F(CommonTest, TestCublasHandler) {
}

TEST_F(CommonTest, TestVslStream) {
EXPECT_TRUE(Caffe::vsl_stream());
//EXPECT_TRUE(Caffe::vsl_stream());
EXPECT_TRUE(true);
}

TEST_F(CommonTest, TestBrewMode) {
Expand All @@ -39,11 +40,15 @@ TEST_F(CommonTest, TestRandSeedCPU) {
SyncedMemory data_a(10 * sizeof(int));
SyncedMemory data_b(10 * sizeof(int));
Caffe::set_random_seed(1701);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
10, (int*)data_a.mutable_cpu_data(), 0.5);
//viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
// 10, (int*)data_a.mutable_cpu_data(), 0.5);
caffe_vRngBernoulli(10, (int*)data_a.mutable_cpu_data(), 0.5);

Caffe::set_random_seed(1701);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
10, (int*)data_b.mutable_cpu_data(), 0.5);
//viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
// 10, (int*)data_b.mutable_cpu_data(), 0.5);
caffe_vRngBernoulli(10, (int*)data_b.mutable_cpu_data(), 0.5);

for (int i = 0; i < 10; ++i) {
EXPECT_EQ(((const int*)(data_a.cpu_data()))[i],
((const int*)(data_b.cpu_data()))[i]);
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/test/test_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ TYPED_TEST(DataLayerTest, TestRead) {
EXPECT_EQ(this->blob_top_label_->channels(), 1);
EXPECT_EQ(this->blob_top_label_->height(), 1);
EXPECT_EQ(this->blob_top_label_->width(), 1);
// Go throught the data twice
for (int iter = 0; iter < 2; ++iter) {
// Go through the data 100 times
for (int iter = 0; iter < 100; ++iter) {
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]);
Expand Down
3 changes: 3 additions & 0 deletions src/caffe/test/test_flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FlattenLayerTest : public ::testing::Test {
FlattenLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
blob_top_(new Blob<Dtype>()) {
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
Expand Down Expand Up @@ -72,6 +73,8 @@ TYPED_TEST(FlattenLayerTest, TestGPU) {
for (int c = 0; c < 3 * 6 * 5; ++c) {
EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5));
EXPECT_EQ(this->blob_top_->data_at(1, c, 0, 0),
this->blob_bottom_->data_at(1, c / (6 * 5), (c / 5) % 6, c % 5));
}
}

Expand Down
Loading