diff --git a/Makefile b/Makefile index 9c97ce796ed..587846d79c1 100644 --- a/Makefile +++ b/Makefile @@ -69,17 +69,18 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64 INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR) LIBRARY_DIRS += $(CUDA_LIB_DIR) $(MKL_LIB_DIR) -LIBRARIES := cudart cublas curand protobuf opencv_core opencv_highgui \ - glog mkl_rt mkl_intel_thread leveldb snappy pthread boost_system \ - opencv_imgproc +LIBRARIES := cudart cublas curand pthread gomp \ + glog protobuf leveldb snappy boost_system \ + opencv_core opencv_highgui opencv_imgproc PYTHON_LIBRARIES := boost_python python2.7 WARNINGS := -Wall COMMON_FLAGS := -DNDEBUG -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) -CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS) +CXXFLAGS += -pthread -fPIC -fopenmp $(COMMON_FLAGS) NVCCFLAGS := -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \ $(foreach library,$(LIBRARIES),-l$(library)) +TEST_LDFLAGS += -lopenblas PYTHON_LDFLAGS := $(LDFLAGS) $(foreach library,$(PYTHON_LIBRARIES),-l$(library)) @@ -132,7 +133,7 @@ runtest: test for testbin in $(TEST_BINS); do $$testbin $(TEST_GPUID); done $(TEST_BINS): %.testbin : %.o $(GTEST_OBJ) $(STATIC_NAME) $(TEST_HDRS) - $(CXX) $< $(GTEST_OBJ) $(STATIC_NAME) -o $@ $(LDFLAGS) $(WARNINGS) + $(CXX) $< $(GTEST_OBJ) $(STATIC_NAME) -o $@ $(LDFLAGS) $(TEST_LDFLAGS) $(WARNINGS) $(EXAMPLE_BINS): %.bin : %.o $(STATIC_NAME) $(CXX) $< $(STATIC_NAME) -o $@ $(LDFLAGS) $(WARNINGS) diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index f31d3b0f693..75cc3c67288 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -27,6 +27,14 @@ class Blob { inline int count() const {return count_; } inline int offset(const int n, const int c = 0, const int h = 0, const int w = 0) const { + CHECK_GE(n, 0); + CHECK_LE(n, num_); + CHECK_GE(channels_, 0); + CHECK_LE(c, channels_); + CHECK_GE(height_, 0); + CHECK_LE(h, height_); + CHECK_GE(width_, 0); + CHECK_LE(w, width_); return ((n * channels_ + c) * height_ + h) * width_ + w; } // Copy from source. If copy_diff is false, we copy the data; if copy_diff diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index e7c5abe7435..4607e029568 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -3,6 +3,7 @@ #ifndef CAFFE_COMMON_HPP_ #define CAFFE_COMMON_HPP_ +#include #include #include #include @@ -10,7 +11,7 @@ // cuda driver types #include #include -#include +//#include // various checks for different function calls. #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess) @@ -78,8 +79,13 @@ class Caffe { inline static curandGenerator_t curand_generator() { return Get().curand_generator_; } + // Returns the MKL random stream. - inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; } + //inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; } + + typedef boost::mt19937 random_generator_t; + inline static random_generator_t &vsl_stream() { return Get().random_generator_; } + // Returns the mode: running on CPU or GPU. inline static Brew mode() { return Get().mode_; } // Returns the phase: TRAIN or TEST. @@ -103,7 +109,9 @@ class Caffe { protected: cublasHandle_t cublas_handle_; curandGenerator_t curand_generator_; - VSLStreamStatePtr vsl_stream_; + //VSLStreamStatePtr vsl_stream_; + random_generator_t random_generator_; + Brew mode_; Phase phase_; static shared_ptr singleton_; diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp index effe62ff2c5..d606f97b880 100644 --- a/include/caffe/filler.hpp +++ b/include/caffe/filler.hpp @@ -7,7 +7,7 @@ #ifndef CAFFE_FILLER_HPP #define CAFFE_FILLER_HPP -#include +//#include #include #include "caffe/common.hpp" diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index e9e2db8f274..f438e67839c 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -3,11 +3,39 @@ #ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_ #define CAFFE_UTIL_MATH_FUNCTIONS_H_ -#include -#include +//#include +#include namespace caffe { +// Operations on aligned memory are faster than on unaligned memory. +// But unfortunately, the pointers passed in are not always aligned. +// Therefore, the memory-aligned Eigen::Map objects that wrap them +// cannot be assigned to. This happens in lrn_layer and makes +// test_lrn_layer crash with segmentation fault. +// TODO: Use aligned Eigen::Map when the pointer to be wrapped is aligned. + +// Though the default map option is unaligned, making it explicit is no harm. +//const int data_alignment = Eigen::Aligned; // how is data allocated ? +const int data_alignment = Eigen::Unaligned; +typedef Eigen::Map const_map_vector_float_t; +typedef Eigen::Map map_vector_float_t; +typedef Eigen::Map const_map_vector_double_t; +typedef Eigen::Map map_vector_double_t; + +// The default in Eigen is column-major. This is also the case if one +// of the convenience typedefs (Matrix3f, ArrayXXd, etc.) is used. +// http://eigen.tuxfamily.org/dox-devel/group__TopicStorageOrders.html +typedef Eigen::Matrix MatXf; +typedef Eigen::Map map_matrix_float_t; +typedef Eigen::Map const_map_matrix_float_t; +typedef Eigen::Matrix MatXd; +typedef Eigen::Map map_matrix_double_t; +typedef Eigen::Map const_map_matrix_double_t; + +// From cblas.h +enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; + // Decaf gemm provides a simpler interface to the gemm functions, with the // limitation that the data has to be contiguous in memory. template @@ -84,6 +112,9 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y); template void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y); +template +Dtype caffe_nextafter(const Dtype b); + template void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b); @@ -91,6 +122,9 @@ template void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a, const Dtype sigma); +template +void caffe_vRngBernoulli(const int n, Dtype* r, const double p); + template void caffe_exp(const int n, const Dtype* a, Dtype* y); diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 2c23b456535..fd84866ce54 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -294,6 +294,7 @@ class DataLayer : public Layer { public: explicit DataLayer(const LayerParameter& param) : Layer(param) {} + virtual ~DataLayer(); virtual void SetUp(const vector*>& bottom, vector*>* top); diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index c5dcd10138d..b671a99bfe2 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -21,7 +21,10 @@ long cluster_seedgen(void) { Caffe::Caffe() : mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL), - curand_generator_(NULL), vsl_stream_(NULL) { + curand_generator_(NULL), + //vsl_stream_(NULL) + random_generator_() +{ // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { @@ -34,12 +37,13 @@ Caffe::Caffe() != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "Cannot create Curand generator. Curand won't be available."; } + // Try to create a vsl stream. This should almost always work, but we will // check it anyway. - if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) { - LOG(ERROR) << "Cannot create vsl stream. VSL random number generator " - << "won't be available."; - } + //if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) { + // LOG(ERROR) << "Cannot create vsl stream. VSL random number generator " + // << "won't be available."; + //} } Caffe::~Caffe() { @@ -47,8 +51,8 @@ Caffe::~Caffe() { if (curand_generator_) { CURAND_CHECK(curandDestroyGenerator(curand_generator_)); } - if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_)); -}; + //if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_)); +} void Caffe::set_random_seed(const unsigned int seed) { // Curand seed @@ -64,8 +68,10 @@ void Caffe::set_random_seed(const unsigned int seed) { LOG(ERROR) << "Curand not available. Skipping setting the curand seed."; } // VSL seed - VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_))); - VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed)); + //VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_))); + //VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed)); + Get().random_generator_ = random_generator_t(seed); + } void Caffe::SetDevice(const int device_id) { diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 078d49708b6..d1262d03f24 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -17,8 +17,11 @@ namespace caffe { template void* DataLayerPrefetch(void* layer_pointer) { + CHECK(layer_pointer); DataLayer* layer = reinterpret_cast*>(layer_pointer); + CHECK(layer); Datum datum; + CHECK(layer->prefetch_data_); Dtype* top_data = layer->prefetch_data_->mutable_cpu_data(); Dtype* top_label = layer->prefetch_label_->mutable_cpu_data(); const Dtype scale = layer->layer_param_.scale(); @@ -38,6 +41,8 @@ void* DataLayerPrefetch(void* layer_pointer) { const Dtype* mean = layer->data_mean_.cpu_data(); for (int itemid = 0; itemid < batchsize; ++itemid) { // get a blob + CHECK(layer->iter_); + CHECK(layer->iter_->Valid()); datum.ParseFromString(layer->iter_->value().ToString()); const string& data = datum.data(); if (cropsize) { @@ -109,6 +114,11 @@ void* DataLayerPrefetch(void* layer_pointer) { return (void*)NULL; } +template +DataLayer::~DataLayer() { + // Finally, join the thread + CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed."; +} template void DataLayer::SetUp(const vector*>& bottom, diff --git a/src/caffe/layers/dropout_layer.cu b/src/caffe/layers/dropout_layer.cu index df94f2deb24..fcc5fb30ac1 100644 --- a/src/caffe/layers/dropout_layer.cu +++ b/src/caffe/layers/dropout_layer.cu @@ -4,6 +4,7 @@ #include #include "caffe/common.hpp" +#include "caffe/util/math_functions.hpp" #include "caffe/layer.hpp" #include "caffe/syncedmem.hpp" #include "caffe/vision_layers.hpp" @@ -34,8 +35,10 @@ void DropoutLayer::Forward_cpu(const vector*>& bottom, const int count = bottom[0]->count(); if (Caffe::phase() == Caffe::TRAIN) { // Create random numbers - viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - count, mask, 1. - threshold_); + //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), + // count, mask, 1. - threshold_); + caffe_vRngBernoulli(count, mask, 1. - threshold_); + for (int i = 0; i < count; ++i) { top_data[i] = bottom_data[i] * mask[i] * scale_; } diff --git a/src/caffe/layers/flatten_layer.cpp b/src/caffe/layers/flatten_layer.cpp index f2467444809..f4ca6d0607f 100644 --- a/src/caffe/layers/flatten_layer.cpp +++ b/src/caffe/layers/flatten_layer.cpp @@ -43,6 +43,7 @@ Dtype FlattenLayer::Backward_cpu(const vector*>& top, const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); caffe_copy(count_, top_diff, bottom_diff); + return Dtype(0); } @@ -52,6 +53,7 @@ Dtype FlattenLayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); caffe_gpu_copy(count_, top_diff, bottom_diff); + return Dtype(0); } INSTANTIATE_CLASS(FlattenLayer); diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index 18f1df0dc1f..c99bfbcd661 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -1,7 +1,7 @@ // Copyright 2013 Yangqing Jia -#include +//#include #include #include diff --git a/src/caffe/test/test_common.cpp b/src/caffe/test/test_common.cpp index 3afd6d09af5..ef6125ec70c 100644 --- a/src/caffe/test/test_common.cpp +++ b/src/caffe/test/test_common.cpp @@ -6,7 +6,7 @@ #include "gtest/gtest.h" #include "caffe/common.hpp" #include "caffe/syncedmem.hpp" - +#include "caffe/util/math_functions.hpp" #include "caffe/test/test_caffe_main.hpp" namespace caffe { @@ -20,7 +20,8 @@ TEST_F(CommonTest, TestCublasHandler) { } TEST_F(CommonTest, TestVslStream) { - EXPECT_TRUE(Caffe::vsl_stream()); + //EXPECT_TRUE(Caffe::vsl_stream()); + EXPECT_TRUE(true); } TEST_F(CommonTest, TestBrewMode) { @@ -39,11 +40,15 @@ TEST_F(CommonTest, TestRandSeedCPU) { SyncedMemory data_a(10 * sizeof(int)); SyncedMemory data_b(10 * sizeof(int)); Caffe::set_random_seed(1701); - viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - 10, (int*)data_a.mutable_cpu_data(), 0.5); + //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), + // 10, (int*)data_a.mutable_cpu_data(), 0.5); + caffe_vRngBernoulli(10, (int*)data_a.mutable_cpu_data(), 0.5); + Caffe::set_random_seed(1701); - viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - 10, (int*)data_b.mutable_cpu_data(), 0.5); + //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), + // 10, (int*)data_b.mutable_cpu_data(), 0.5); + caffe_vRngBernoulli(10, (int*)data_b.mutable_cpu_data(), 0.5); + for (int i = 0; i < 10; ++i) { EXPECT_EQ(((const int*)(data_a.cpu_data()))[i], ((const int*)(data_b.cpu_data()))[i]); diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index fe3e915b5aa..66e9956838b 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -81,8 +81,8 @@ TYPED_TEST(DataLayerTest, TestRead) { EXPECT_EQ(this->blob_top_label_->channels(), 1); EXPECT_EQ(this->blob_top_label_->height(), 1); EXPECT_EQ(this->blob_top_label_->width(), 1); - // Go throught the data twice - for (int iter = 0; iter < 2; ++iter) { + // Go through the data 100 times + for (int iter = 0; iter < 100; ++iter) { layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_); for (int i = 0; i < 5; ++i) { EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]); diff --git a/src/caffe/test/test_flatten_layer.cpp b/src/caffe/test/test_flatten_layer.cpp index 805fd72eb5b..bb345d93302 100644 --- a/src/caffe/test/test_flatten_layer.cpp +++ b/src/caffe/test/test_flatten_layer.cpp @@ -22,6 +22,7 @@ class FlattenLayerTest : public ::testing::Test { FlattenLayerTest() : blob_bottom_(new Blob(2, 3, 6, 5)), blob_top_(new Blob()) { + Caffe::set_random_seed(1701); // fill the values FillerParameter filler_param; GaussianFiller filler(filler_param); @@ -72,6 +73,8 @@ TYPED_TEST(FlattenLayerTest, TestGPU) { for (int c = 0; c < 3 * 6 * 5; ++c) { EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0), this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5)); + EXPECT_EQ(this->blob_top_->data_at(1, c, 0, 0), + this->blob_bottom_->data_at(1, c / (6 * 5), (c / 5) % 6, c % 5)); } } diff --git a/src/caffe/test/test_gradient_check_util.hpp b/src/caffe/test/test_gradient_check_util.hpp index d7360085d40..85edd05b693 100644 --- a/src/caffe/test/test_gradient_check_util.hpp +++ b/src/caffe/test/test_gradient_check_util.hpp @@ -82,11 +82,11 @@ void GradientChecker::CheckGradientSingle(Layer& layer, blobs_to_check.push_back(bottom[check_bottom]); } // go through the bottom and parameter blobs - // LOG(ERROR) << "Checking " << blobs_to_check.size() << " blobs."; +// LOG(ERROR) << "Checking " << blobs_to_check.size() << " blobs."; for (int blobid = 0; blobid < blobs_to_check.size(); ++blobid) { Blob* current_blob = blobs_to_check[blobid]; - // LOG(ERROR) << "Blob " << blobid << ": checking " << current_blob->count() - // << " parameters."; +// LOG(ERROR) << "Blob " << blobid << ": checking " << current_blob->count() +// << " parameters."; // go through the values for (int feat_id = 0; feat_id < current_blob->count(); ++feat_id) { // First, obtain the original data @@ -96,25 +96,28 @@ void GradientChecker::CheckGradientSingle(Layer& layer, // Get any additional loss from the layer computed_objective += layer.Backward(top, true, &bottom); Dtype computed_gradient = current_blob->cpu_diff()[feat_id]; + // compute score by adding stepsize current_blob->mutable_cpu_data()[feat_id] += stepsize_; Caffe::set_random_seed(seed_); layer.Forward(bottom, &top); Dtype positive_objective = GetObjAndGradient(top, top_id, top_data_id); positive_objective += layer.Backward(top, true, &bottom); + // compute score by subtracting stepsize current_blob->mutable_cpu_data()[feat_id] -= stepsize_ * 2; Caffe::set_random_seed(seed_); layer.Forward(bottom, &top); Dtype negative_objective = GetObjAndGradient(top, top_id, top_data_id); negative_objective += layer.Backward(top, true, &bottom); + // Recover stepsize current_blob->mutable_cpu_data()[feat_id] += stepsize_; Dtype estimated_gradient = (positive_objective - negative_objective) / stepsize_ / 2.; Dtype feature = current_blob->cpu_data()[feat_id]; - // LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " " - // << current_blob->cpu_diff()[feat_id]; +// LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " " +// << current_blob->cpu_diff()[feat_id]; if (kink_ - kink_range_ > feature || feature > kink_ + kink_range_) { // We check relative accuracy, but for too small values, we threshold // the scale factor by 1. @@ -126,10 +129,12 @@ void GradientChecker::CheckGradientSingle(Layer& layer, EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale) << "debug: (top_id, top_data_id, blob_id, feat_id)=" << top_id << "," << top_data_id << "," << blobid << "," << feat_id; +// LOG(ERROR) << "computed gradient: " << computed_gradient +// << " estimated_gradient: " << estimated_gradient +// << " positive_objective: " << positive_objective +// << " negative_objective: " << negative_objective; } - // LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id]; - // LOG(ERROR) << "computed gradient: " << computed_gradient - // << " estimated_gradient: " << estimated_gradient; + // LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id] } } } diff --git a/src/caffe/test/test_math_functions.cpp b/src/caffe/test/test_math_functions.cpp new file mode 100644 index 00000000000..141817b6358 --- /dev/null +++ b/src/caffe/test/test_math_functions.cpp @@ -0,0 +1,285 @@ +// Copyright 2013 Yangqing Jia + +#include +#include // for rand() +#include + +#include "gtest/gtest.h" +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/util/math_functions.hpp" +#include "test_math_functions_golden.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class MathFunctionsTest : public ::testing::Test { + protected: + MathFunctionsTest() + : loops_(10) + ,M_(12) + ,N_(12) + ,K_(15) + ,a_(new Blob(2, 3, 6, 5)) + ,b_(new Blob(2, 3, 6, 5)) + ,y_(new Blob(2, 3, 6, 5)) + ,golden_y_(new Blob(2, 3, 6, 5)) + ,a_cpu_data_(a_->cpu_data()) + ,b_cpu_data_(b_->cpu_data()) + ,y_cpu_data_(y_->mutable_cpu_data()) + ,golden_y_cpu_data_(golden_y_->mutable_cpu_data()) + ,near_delta_(1e-5) + {} + + virtual void SetUp() { + num_ = a_->count(); + filler_param_.set_min(1e-5); + filler_param_.set_max(10); + } + + virtual ~MathFunctionsTest() { + delete a_; + delete b_; + delete y_; + } + + int loops_; + int num_; + int M_; + int N_; + int K_; + Blob* a_; + Blob* b_; + Blob* y_; + Blob* golden_y_; + const Dtype* const a_cpu_data_; + const Dtype* const b_cpu_data_; + Dtype* y_cpu_data_; + Dtype* golden_y_cpu_data_; + const Dtype near_delta_; + FillerParameter filler_param_; + math_functions_cpu_golden golden_; +}; + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(MathFunctionsTest, Dtypes); + +TYPED_TEST(MathFunctionsTest, TestAdd) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + filler.Fill(this->b_); + caffe_add(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] + this->b_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestSub) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + filler.Fill(this->b_); + caffe_sub(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] - this->b_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestMul) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + filler.Fill(this->b_); + caffe_mul(this->num_, this->a_->cpu_data(), this->b_->cpu_data(), this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] * this->b_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestDiv) { + GaussianFiller filler(this->filler_param_); + UniformFiller uniform_filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + filler.Fill(this->b_); + FillerParameter filler_param; + filler_param.set_min(1e-5); // to avoid dividing by zero + uniform_filler.Fill(this->b_); + caffe_div(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] / + this->b_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestPowx) { + UniformFiller uniform_filler(this->filler_param_); + TypeParam p; + TypeParam ps[] = {-1.5, -0.5, 0, 0.5, 1.5}; + for (int l = 0; l < this->loops_; ++l) { + for (int k = 0; k < 5; ++k) { + p = ps[k]; + uniform_filler.Fill(this->a_); + caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) , + this->near_delta_) + << "debug: (i, y_cpu_data_, a_cpu_data_, p)=" + << i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i] + << "," << p; + } + } + } +} + +TYPED_TEST(MathFunctionsTest, TestSqr) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + caffe_sqr(this->num_, this->a_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] * this->a_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestExp) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + caffe_exp(this->num_, this->a_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], std::exp(this->a_cpu_data_[i]), this->near_delta_); + } + } +} + + +TYPED_TEST(MathFunctionsTest, TestCpuGemm) { + GaussianFiller filler(this->filler_param_); + UniformFiller uniform_filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + uniform_filler.Fill(this->b_); + CBLAS_TRANSPOSE TransA; + CBLAS_TRANSPOSE TransB; + for (int ta = 0; ta < 2; ++ta) { + TransA = ta ? CblasTrans : CblasNoTrans; + for (int tb = 0; tb < 2; ++tb) { + TransB = tb ? CblasTrans : CblasNoTrans; + int alpha_idx = rand() % this->num_; + int beta_idx = rand() % this->num_; + caffe_cpu_gemm(TransA, + TransB, this->M_, this->N_, this->K_, this->a_cpu_data_[alpha_idx], this->a_cpu_data_, + this->b_cpu_data_, this->b_cpu_data_[beta_idx], this->y_cpu_data_); + this->golden_.gemm(TransA, + TransB, this->M_, this->N_, this->K_, this->a_cpu_data_[alpha_idx], this->a_cpu_data_, + this->b_cpu_data_, this->b_cpu_data_[beta_idx], this->golden_y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->golden_y_cpu_data_[i], this->near_delta_); + } + } + } + } +} + +TYPED_TEST(MathFunctionsTest, TestCpuGemv) { + GaussianFiller filler(this->filler_param_); + UniformFiller uniform_filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + CBLAS_TRANSPOSE TransA; + for (int ta = 0; ta < 2; ++ta) { + TransA = ta ? CblasTrans : CblasNoTrans; + int alpha_idx = rand() % this->num_; + int beta_idx = rand() % this->num_; + filler.Fill(this->a_); + uniform_filler.Fill(this->b_); + filler.Fill(this->y_); + for (int i = 0; i < this->num_; ++i) { + this->golden_y_cpu_data_[i] = this->y_cpu_data_[i]; + } + caffe_cpu_gemv(TransA, + this->M_, this->N_, this->a_cpu_data_[alpha_idx], this->a_cpu_data_, + this->b_cpu_data_, this->b_cpu_data_[beta_idx], this->y_cpu_data_); + this->golden_.gemv(TransA, + this->M_, this->N_, this->a_cpu_data_[alpha_idx], this->a_cpu_data_, + this->b_cpu_data_, this->b_cpu_data_[beta_idx], this->golden_y_cpu_data_); + for (int i = 0; i < this->M_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->golden_y_cpu_data_[i], this->near_delta_); + } + } + } +} + +TYPED_TEST(MathFunctionsTest, TestCpuAxpy) { + GaussianFiller filler(this->filler_param_); + UniformFiller uniform_filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + uniform_filler.Fill(this->a_); + filler.Fill(this->b_); + int alpha_idx = rand() % this->num_; + caffe_axpy( + this->num_, this->a_cpu_data_[alpha_idx], this->b_cpu_data_, this->y_cpu_data_); + this->golden_.axpy(this->num_, this->a_cpu_data_[alpha_idx], this->b_cpu_data_, this->golden_y_cpu_data_); + for (int i = 0; i < this->M_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->golden_y_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestCpuCopy) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + CBLAS_TRANSPOSE TransA; + for (int ta = 0; ta < 2; ++ta) { + TransA = ta ? CblasTrans : CblasNoTrans; + caffe_copy( + this->num_, this->a_cpu_data_, this->y_cpu_data_); + this->golden_.copy(this->num_, this->a_cpu_data_, this->golden_y_cpu_data_); + for (int i = 0; i < this->M_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->golden_y_cpu_data_[i], this->near_delta_); + } + } + } +} + +TYPED_TEST(MathFunctionsTest, TestCpuScal) { + GaussianFiller filler(this->filler_param_); + UniformFiller uniform_filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->y_); + uniform_filler.Fill(this->a_); + int alpha_idx = rand() % this->num_; + for (int i = 0; i < this->M_; ++i) { + this->golden_y_cpu_data_[i] = this->y_cpu_data_[i]; + } + caffe_scal(this->num_, this->a_cpu_data_[alpha_idx], this->y_cpu_data_); + this->golden_.scal(this->num_, this->a_cpu_data_[alpha_idx], this->golden_y_cpu_data_); + for (int i = 0; i < this->M_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->golden_y_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestCpuDot) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + caffe_cpu_dot(this->num_, this->a_cpu_data_, this->y_cpu_data_); + this->golden_.dot(this->num_, this->a_cpu_data_, this->golden_y_cpu_data_); + for (int i = 0; i < this->M_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->golden_y_cpu_data_[i], this->near_delta_); + } + } +} + +} // namespace caffe diff --git a/src/caffe/test/test_multinomial_logistic_loss_layer.cpp b/src/caffe/test/test_multinomial_logistic_loss_layer.cpp index 5595c84fea3..6bd94ae24b8 100644 --- a/src/caffe/test/test_multinomial_logistic_loss_layer.cpp +++ b/src/caffe/test/test_multinomial_logistic_loss_layer.cpp @@ -24,6 +24,7 @@ class MultinomialLogisticLossLayerTest : public ::testing::Test { MultinomialLogisticLossLayerTest() : blob_bottom_data_(new Blob(10, 5, 1, 1)), blob_bottom_label_(new Blob(10, 1, 1, 1)) { + Caffe::set_random_seed(1701); // fill the values FillerParameter filler_param; PositiveUnitballFiller filler(filler_param); @@ -53,7 +54,7 @@ TYPED_TEST(MultinomialLogisticLossLayerTest, TestGradientCPU) { Caffe::set_mode(Caffe::CPU); MultinomialLogisticLossLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_); - GradientChecker checker(1e-2, 1e-2, 1701, 0, 0.05); + GradientChecker checker(1e-2, 2*1e-2, 1701, 0, 0.05); checker.CheckGradientSingle(layer, this->blob_bottom_vec_, this->blob_top_vec_, 0, -1, -1); } diff --git a/src/caffe/test/test_random_number_generator.cpp b/src/caffe/test/test_random_number_generator.cpp new file mode 100644 index 00000000000..c43a5d9404c --- /dev/null +++ b/src/caffe/test/test_random_number_generator.cpp @@ -0,0 +1,95 @@ +#include +#include +#include + +#include "gtest/gtest.h" +#include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class RandomNumberGeneratorTest : public ::testing::Test { + public: + virtual ~RandomNumberGeneratorTest() {} + + Dtype sample_mean(const Dtype* const seqs, const size_t sample_size) + { + double sum = 0; + for (int i = 0; i < sample_size; ++i) { + sum += seqs[i]; + } + return sum / sample_size; + } + + Dtype sample_mean(const int* const seqs, const size_t sample_size) + { + Dtype sum = 0; + for (int i = 0; i < sample_size; ++i) { + sum += Dtype(seqs[i]); + } + return sum / sample_size; + } + + Dtype mean_bound(const Dtype std, const size_t sample_size) + { + return std/sqrt((double)sample_size); + } +}; + + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(RandomNumberGeneratorTest, Dtypes); + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussian) { + size_t sample_size = 10000; + SyncedMemory data_a(sample_size * sizeof(TypeParam)); + Caffe::set_random_seed(1701); + TypeParam mu = 0; + TypeParam sigma = 1; + caffe_vRngGaussian(sample_size, + (TypeParam*)data_a.mutable_cpu_data(), mu, sigma); + TypeParam true_mean = mu; + TypeParam true_std = sigma; + TypeParam bound = this->mean_bound(true_std, sample_size); + TypeParam empirical_mean = + this->sample_mean((TypeParam*)data_a.cpu_data(), sample_size); + EXPECT_NEAR(empirical_mean, true_mean, bound); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniform) { + size_t sample_size = 10000; + SyncedMemory data_a(sample_size * sizeof(TypeParam)); + Caffe::set_random_seed(1701); + TypeParam lower = 0; + TypeParam upper = 1; + caffe_vRngUniform(sample_size, + (TypeParam*)data_a.mutable_cpu_data(), lower, upper); + TypeParam true_mean = (lower + upper) / 2; + TypeParam true_std = (upper - lower) / sqrt(12); + TypeParam bound = this->mean_bound(true_std, sample_size); + TypeParam empirical_mean = + this->sample_mean((TypeParam*)data_a.cpu_data(), sample_size); + EXPECT_NEAR(empirical_mean, true_mean, bound); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngBernoulli) { + size_t sample_size = 10000; + SyncedMemory data_a(sample_size * sizeof(int)); + Caffe::set_random_seed(1701); + double p = 0.3; + caffe_vRngBernoulli(sample_size, (int*)data_a.mutable_cpu_data(), p); + TypeParam true_mean = p; + TypeParam true_std = sqrt(p * (1 - p)); + TypeParam bound = this->mean_bound(true_std, sample_size); + TypeParam empirical_mean = + this->sample_mean((const int *)data_a.cpu_data(), sample_size); + EXPECT_NEAR(empirical_mean, true_mean, bound); +} + + +} // namespace caffe diff --git a/src/caffe/test/test_stochastic_pooing.cpp b/src/caffe/test/test_stochastic_pooing.cpp index e2b60eeec34..b8b07cb5999 100644 --- a/src/caffe/test/test_stochastic_pooing.cpp +++ b/src/caffe/test/test_stochastic_pooing.cpp @@ -140,8 +140,6 @@ TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPUTestPhase) { } } - - TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) { Caffe::set_mode(Caffe::GPU); Caffe::set_phase(Caffe::TRAIN); @@ -151,12 +149,10 @@ TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) { layer_param.set_pool(LayerParameter_PoolMethod_STOCHASTIC); PoolingLayer layer(layer_param); - GradientChecker checker(1e-2, 1e-3); + GradientChecker checker(1e-4, 1e-2); // it is too expensive to call curand multiple times, so we don't do an // exhaustive gradient check. checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_); } - - } diff --git a/src/caffe/test/test_util_blas.cpp b/src/caffe/test/test_util_blas.cpp deleted file mode 100644 index 3fed148c0b4..00000000000 --- a/src/caffe/test/test_util_blas.cpp +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2013 Yangqing Jia - -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "caffe/blob.hpp" -#include "caffe/util/math_functions.hpp" - -#include "caffe/test/test_caffe_main.hpp" - -namespace caffe { - -extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; - -typedef ::testing::Types Dtypes; - -template -class GemmTest : public ::testing::Test {}; - -TYPED_TEST_CASE(GemmTest, Dtypes); - -TYPED_TEST(GemmTest, TestGemm) { - Blob A(1,1,2,3); - Blob B(1,1,3,4); - Blob C(1,1,2,4); - TypeParam data[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - TypeParam A_reshape_data[6] = {1, 4, 2, 5, 3, 6}; - TypeParam B_reshape_data[12] = {1,5,9,2,6,10,3,7,11,4,8,12}; - TypeParam result[8] = {38,44,50,56,83,98,113,128}; - memcpy(A.mutable_cpu_data(), data, 6 * sizeof(TypeParam)); - memcpy(B.mutable_cpu_data(), data, 12 * sizeof(TypeParam)); - - if (sizeof(TypeParam) == 4 || CAFFE_TEST_CUDA_PROP.major >= 2) { - //[1,2,3; 4 5 6] * [1,2,3,4; 5,6,7,8; 9,10,11,12]; - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, 2, 4, 3, 1., - A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data()); - for (int i = 0; i < 8; ++i) { - EXPECT_EQ(C.cpu_data()[i], result[i]); - } - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, 2, 4, 3, 1., - A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data()); - for (int i = 0; i < 8; ++i) { - EXPECT_EQ(C.cpu_data()[i], result[i]); - } - - // Test when we have a transposed A - A.Reshape(1,1,3,2); - memcpy(A.mutable_cpu_data(), A_reshape_data, 6 * sizeof(TypeParam)); - caffe_cpu_gemm(CblasTrans, CblasNoTrans, 2, 4, 3, 1., - A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data()); - for (int i = 0; i < 8; ++i) { - EXPECT_EQ(C.cpu_data()[i], result[i]); - } - caffe_gpu_gemm(CblasTrans, CblasNoTrans, 2, 4, 3, 1., - A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data()); - for (int i = 0; i < 8; ++i) { - EXPECT_EQ(C.cpu_data()[i], result[i]); - } - - // Test when we have a transposed A and a transposed B too - B.Reshape(1,1,4,3); - memcpy(B.mutable_cpu_data(), B_reshape_data, 12 * sizeof(TypeParam)); - caffe_cpu_gemm(CblasTrans, CblasTrans, 2, 4, 3, 1., - A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data()); - for (int i = 0; i < 8; ++i) { - EXPECT_EQ(C.cpu_data()[i], result[i]); - } - caffe_gpu_gemm(CblasTrans, CblasTrans, 2, 4, 3, 1., - A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data()); - for (int i = 0; i < 8; ++i) { - EXPECT_EQ(C.cpu_data()[i], result[i]); - } - - // Test when we have a transposed B - A.Reshape(1,1,2,3); - memcpy(A.mutable_cpu_data(), data, 6 * sizeof(TypeParam)); - caffe_cpu_gemm(CblasNoTrans, CblasTrans, 2, 4, 3, 1., - A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data()); - for (int i = 0; i < 8; ++i) { - EXPECT_EQ(C.cpu_data()[i], result[i]); - } - caffe_gpu_gemm(CblasNoTrans, CblasTrans, 2, 4, 3, 1., - A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data()); - for (int i = 0; i < 8; ++i) { - EXPECT_EQ(C.cpu_data()[i], result[i]); - } - } else { - LOG(ERROR) << "Skipping test due to old architecture."; - } -} - - -TYPED_TEST(GemmTest, TestGemv) { - Blob A(1,1,2,3); - Blob x(1,1,1,3); - Blob y(1,1,1,2); - TypeParam data[6] = {1, 2, 3, 4, 5, 6}; - TypeParam result_2[2] = {14, 32}; - TypeParam result_3[3] = {9, 12, 15}; - memcpy(A.mutable_cpu_data(), data, 6 * sizeof(TypeParam)); - memcpy(x.mutable_cpu_data(), data, 3 * sizeof(TypeParam)); - - if (sizeof(TypeParam) == 4 || CAFFE_TEST_CUDA_PROP.major >= 2) { - caffe_cpu_gemv(CblasNoTrans, 2, 3, 1., A.cpu_data(), - x.cpu_data(), 0., y.mutable_cpu_data()); - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(y.cpu_data()[i], result_2[i]); - } - caffe_gpu_gemv(CblasNoTrans, 2, 3, 1., A.gpu_data(), - x.gpu_data(), 0., y.mutable_gpu_data()); - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(y.cpu_data()[i], result_2[i]); - } - - // Test transpose case - memcpy(y.mutable_cpu_data(), data, 2 * sizeof(TypeParam)); - caffe_cpu_gemv(CblasTrans, 2, 3, 1., A.cpu_data(), - y.cpu_data(), 0., x.mutable_cpu_data()); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(x.cpu_data()[i], result_3[i]); - } - caffe_gpu_gemv(CblasTrans, 2, 3, 1., A.gpu_data(), - y.gpu_data(), 0., x.mutable_gpu_data()); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(x.cpu_data()[i], result_3[i]); - } - } else { - LOG(ERROR) << "Skipping test due to old architecture."; - } -} - -} diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 60656b87093..f9442f6bed4 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -1,39 +1,103 @@ // Copyright 2013 Yangqing Jia -#include +#include +//#include +#include +#include + #include #include "caffe/common.hpp" #include "caffe/util/math_functions.hpp" namespace caffe { +// http://eigen.tuxfamily.org/dox/TopicWritingEfficientProductExpression.html template<> void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, - const float alpha, const float* A, const float* B, const float beta, - float* C) { - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; - cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, - ldb, beta, C, N); + const CBLAS_TRANSPOSE TransB, const int M, + const int N, const int K, const float alpha, + const float* A, const float* B, const float beta, + float* C) { + CHECK_GE(M, 0); + CHECK_GE(N, 0); + CHECK_GE(K, 0); + CHECK(A); + CHECK(B); + CHECK(C); + map_matrix_float_t c_map(C, M, N); + c_map *= beta; + if (TransA == CblasNoTrans) { + if (TransB == CblasNoTrans) { + c_map.noalias() += alpha + * (const_map_matrix_float_t(A, M, K) + * const_map_matrix_float_t(B, K, N)); + } else { + c_map.noalias() += alpha + * (const_map_matrix_float_t(A, M, K) + * const_map_matrix_float_t(B, N, K).transpose()); + } + } else { + if (TransB == CblasNoTrans) { + c_map.noalias() += alpha + * (const_map_matrix_float_t(A, K, M).transpose() + * const_map_matrix_float_t(B, K, N)); + } else { + c_map.noalias() += alpha + * (const_map_matrix_float_t(A, K, M).transpose() + * const_map_matrix_float_t(B, N, K).transpose()); + } + } } template<> void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, - const double alpha, const double* A, const double* B, const double beta, - double* C) { - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; - cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, - ldb, beta, C, N); + const CBLAS_TRANSPOSE TransB, const int M, + const int N, const int K, const double alpha, + const double* A, const double* B, const double beta, + double* C) { + CHECK_GE(M, 0); + CHECK_GE(N, 0); + CHECK_GE(K, 0); + CHECK(A); + CHECK(B); + CHECK(C); + map_matrix_double_t c_map(C, M, N); + c_map *= beta; + if (TransA == CblasNoTrans) { + if (TransB == CblasNoTrans) { + c_map.noalias() += alpha + * (const_map_matrix_double_t(A, M, K) + * const_map_matrix_double_t(B, K, N)); + } else { + c_map.noalias() += alpha + * (const_map_matrix_double_t(A, M, K) + * const_map_matrix_double_t(B, N, K).transpose()); + } + } else { + if (TransB == CblasNoTrans) { + c_map.noalias() += alpha + * (const_map_matrix_double_t(A, K, M).transpose() + * const_map_matrix_double_t(B, K, N)); + } else { + c_map.noalias() += alpha + * (const_map_matrix_double_t(A, K, M).transpose() + * const_map_matrix_double_t(B, N, K).transpose()); + } + } } -template <> +template<> void caffe_gpu_gemm(const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, - const float alpha, const float* A, const float* B, const float beta, - float* C) { + const CBLAS_TRANSPOSE TransB, const int M, + const int N, const int K, const float alpha, + const float* A, const float* B, const float beta, + float* C) { + CHECK_GE(M, 0); + CHECK_GE(N, 0); + CHECK_GE(K, 0); + CHECK(A); + CHECK(B); + CHECK(C); // Note that cublas follows fortran order. int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; @@ -41,255 +105,501 @@ void caffe_gpu_gemm(const CBLAS_TRANSPOSE TransA, (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_CHECK(cublasSgemm(Caffe::cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); + CUBLAS_CHECK( + cublasSgemm(Caffe::cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } -template <> +template<> void caffe_gpu_gemm(const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, - const double alpha, const double* A, const double* B, const double beta, - double* C) { - // Note that cublas follows fortran order. + const CBLAS_TRANSPOSE TransB, const int M, + const int N, const int K, const double alpha, + const double* A, const double* B, const double beta, + double* C) { + CHECK_GE(M, 0); + CHECK_GE(N, 0); + CHECK_GE(K, 0); + CHECK(A); + CHECK(B); + CHECK(C); + // Note that cublas follows fortran order.CblasRowMajor int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_CHECK(cublasDgemm(Caffe::cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); + CUBLAS_CHECK( + cublasDgemm(Caffe::cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } -template <> +template<> void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, - const int N, const float alpha, const float* A, const float* x, - const float beta, float* y) { - cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1); + const int N, const float alpha, const float* A, + const float* x, const float beta, float* y) { + CHECK_GE(M, 0); + CHECK_GE(N, 0); + CHECK(A); + CHECK(x); + CHECK(y); + if (TransA == CblasNoTrans) { + map_vector_float_t y_map(y, M); + y_map *= beta; + y_map.noalias() += alpha + * (const_map_matrix_float_t(A, M, N) * const_map_vector_float_t(x, N)); + } else { + map_vector_float_t y_map(y, N); + y_map *= beta; + y_map.noalias() += alpha + * (const_map_matrix_float_t(A, M, N).transpose() + * const_map_vector_float_t(x, M)); + } } -template <> +template<> void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, - const int N, const double alpha, const double* A, const double* x, - const double beta, double* y) { - cblas_dgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1); + const int N, const double alpha, const double* A, + const double* x, const double beta, double* y) { + CHECK_GE(M, 0); + CHECK_GE(N, 0); + CHECK(A); + CHECK(x); + CHECK(y); + if (TransA == CblasNoTrans) { + map_vector_double_t y_map(y, M); + y_map *= beta; + y_map.noalias() += + alpha + * (const_map_matrix_double_t(A, M, N) + * const_map_vector_double_t(x, N)); + } else { + map_vector_double_t y_map(y, N); + y_map *= beta; + y_map.noalias() += alpha + * (const_map_matrix_double_t(A, M, N).transpose() + * const_map_vector_double_t(x, M)); + } } -template <> +template<> void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, - const int N, const float alpha, const float* A, const float* x, - const float beta, float* y) { + const int N, const float alpha, const float* A, + const float* x, const float beta, float* y) { + CHECK_GE(M, 0); + CHECK_GE(N, 0); + CHECK(A); + CHECK(x); + CHECK(y); cublasOperation_t cuTransA = (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBLAS_CHECK(cublasSgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha, - A, N, x, 1, &beta, y, 1)); + CUBLAS_CHECK( + cublasSgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1)); } -template <> +template<> void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, - const int N, const double alpha, const double* A, const double* x, - const double beta, double* y) { + const int N, const double alpha, const double* A, + const double* x, const double beta, double* y) { + CHECK_GE(M, 0); + CHECK_GE(N, 0); + CHECK(A); + CHECK(x); + CHECK(y); cublasOperation_t cuTransA = (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBLAS_CHECK(cublasDgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha, - A, N, x, 1, &beta, y, 1)); + CUBLAS_CHECK( + cublasDgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1)); } -template <> +template<> void caffe_axpy(const int N, const float alpha, const float* X, - float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); } + float* Y) { + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_float_t(Y, N).noalias() += alpha * const_map_vector_float_t(X, N); +} -template <> +template<> void caffe_axpy(const int N, const double alpha, const double* X, - double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); } - + double* Y) { + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_double_t(Y, N).noalias() += alpha + * const_map_vector_double_t(X, N); +} -template <> +template<> void caffe_gpu_axpy(const int N, const float alpha, const float* X, - float* Y) { + float* Y) { CUBLAS_CHECK(cublasSaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1)); } -template <> +template<> void caffe_gpu_axpy(const int N, const double alpha, const double* X, - double* Y) { + double* Y) { CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1)); } -template <> -void caffe_axpby(const int N, const float alpha, const float* X, - const float beta, float* Y) { - cblas_saxpby(N, alpha, X, 1, beta, Y, 1); -} - -template <> -void caffe_axpby(const int N, const double alpha, const double* X, - const double beta, double* Y) { - cblas_daxpby(N, alpha, X, 1, beta, Y, 1); -} - -template <> +template<> void caffe_copy(const int N, const float* X, float* Y) { - cblas_scopy(N, X, 1, Y, 1); + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_float_t(Y, N).noalias() = const_map_vector_float_t(X, N); } -template <> +template<> void caffe_copy(const int N, const double* X, double* Y) { - cblas_dcopy(N, X, 1, Y, 1); + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_double_t(Y, N).noalias() = const_map_vector_double_t(X, N); } -template <> +template<> void caffe_gpu_copy(const int N, const float* X, float* Y) { CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), N, X, 1, Y, 1)); } -template <> +template<> void caffe_gpu_copy(const int N, const double* X, double* Y) { CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), N, X, 1, Y, 1)); } -template <> +template<> void caffe_scal(const int N, const float alpha, float *X) { - cblas_sscal(N, alpha, X, 1); + CHECK_GE(N, 0); + CHECK(X); + map_vector_float_t(X, N) *= alpha; } -template <> +template<> void caffe_scal(const int N, const double alpha, double *X) { - cblas_dscal(N, alpha, X, 1); + CHECK_GE(N, 0); + CHECK(X); + map_vector_double_t(X, N) *= alpha; } -template <> +template<> void caffe_gpu_scal(const int N, const float alpha, float *X) { CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1)); } -template <> +template<> void caffe_gpu_scal(const int N, const double alpha, double *X) { CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1)); } -template <> +template<> void caffe_gpu_axpby(const int N, const float alpha, const float* X, - const float beta, float* Y) { + const float beta, float* Y) { caffe_gpu_scal(N, beta, Y); caffe_gpu_axpy(N, alpha, X, Y); } -template <> +template<> void caffe_gpu_axpby(const int N, const double alpha, const double* X, - const double beta, double* Y) { + const double beta, double* Y) { caffe_gpu_scal(N, beta, Y); caffe_gpu_axpy(N, alpha, X, Y); } -template <> -void caffe_sqr(const int n, const float* a, float* y) { - vsSqr(n, a, y); +template<> +void caffe_axpby(const int N, const float alpha, const float* X, + const float beta, float* Y) { + // y := a*x + b*y + //cblas_saxpby(N, alpha, X, 1, beta, Y, 1); + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_float_t y_map(Y, N); + // Eigen produces optimized code using lazy evaluation + // http://eigen.tuxfamily.org/dox/TopicLazyEvaluation.html + y_map = const_map_vector_float_t(X, N) * alpha + y_map * beta; } -template <> -void caffe_sqr(const int n, const double* a, double* y) { - vdSqr(n, a, y); +template<> +void caffe_axpby(const int N, const double alpha, const double* X, + const double beta, double* Y) { + // y := a*x + b*y + //cblas_daxpby(N, alpha, X, 1, beta, Y, 1); + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_double_t y_map(Y, N); + y_map = const_map_vector_double_t(X, N) * alpha + y_map * beta; } -template <> -void caffe_add(const int n, const float* a, const float* b, - float* y) { vsAdd(n, a, b, y); } +template<> +void caffe_add(const int n, const float* a, const float* b, float* y) { + //vsAdd(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + + const_map_vector_float_t(b, n); +} -template <> +template<> void caffe_add(const int n, const double* a, const double* b, - double* y) { vdAdd(n, a, b, y); } + double* y) { + //vdAdd(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + + const_map_vector_double_t(b, n); +} -template <> -void caffe_sub(const int n, const float* a, const float* b, - float* y) { vsSub(n, a, b, y); } +template<> +void caffe_sub(const int n, const float* a, const float* b, float* y) { + //vsSub(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + - const_map_vector_float_t(b, n); +} -template <> +template<> void caffe_sub(const int n, const double* a, const double* b, - double* y) { vdSub(n, a, b, y); } + double* y) { + //vdSub(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + - const_map_vector_double_t(b, n); +} -template <> -void caffe_mul(const int n, const float* a, const float* b, - float* y) { vsMul(n, a, b, y); } +template<> +void caffe_mul(const int n, const float* a, const float* b, float* y) { + //vsMul(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() + * const_map_vector_float_t(b, n).array(); +} -template <> +template<> void caffe_mul(const int n, const double* a, const double* b, - double* y) { vdMul(n, a, b, y); } + double* y) { + //vdMul(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() + * const_map_vector_double_t(b, n).array(); +} -template <> -void caffe_div(const int n, const float* a, const float* b, - float* y) { vsDiv(n, a, b, y); } +template<> +void caffe_div(const int n, const float* a, const float* b, float* y) { + //vsDiv(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() + / const_map_vector_float_t(b, n).array(); +} -template <> +template<> void caffe_div(const int n, const double* a, const double* b, - double* y) { vdDiv(n, a, b, y); } + double* y) { + //vdDiv(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() + / const_map_vector_double_t(b, n).array(); +} -template <> -void caffe_powx(const int n, const float* a, const float b, - float* y) { vsPowx(n, a, b, y); } +template<> +void caffe_powx(const int n, const float* a, const float b, float* y) { + //vsPowx(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().pow(b); +} -template <> +template<> void caffe_powx(const int n, const double* a, const double b, - double* y) { vdPowx(n, a, b, y); } + double* y) { + //vdPowx(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().pow(b); +} -template <> -void caffe_vRngUniform(const int n, float* r, - const float a, const float b) { - VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), - n, r, a, b)); +template<> +void caffe_sqr(const int n, const float* a, float* y) { + // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-F003F826-81BF-42EC-AE51-2EF624893133.htm + // v?Sqr Performs element by element squaring of the vector. + //vsSqr(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + caffe_powx(n, a, 2, y); + // TODO: which is faster? +// map_vector_float_t(y, n) = const_map_vector_float_t(a, n) * +// const_map_vector_float_t(a, n); } -template <> -void caffe_vRngUniform(const int n, double* r, - const double a, const double b) { - VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), - n, r, a, b)); +template<> +void caffe_sqr(const int n, const double* a, double* y) { + //vdSqr(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + caffe_powx(n, a, 2, y); } -template <> -void caffe_vRngGaussian(const int n, float* r, const float a, - const float sigma) { - VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, - Caffe::vsl_stream(), n, r, a, sigma)); +template<> +void caffe_exp(const int n, const float* a, float* y) { + //vsExp(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().exp(); } +template<> +void caffe_exp(const int n, const double* a, double* y) { + //vdExp(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().exp(); +} + +template +Dtype caffe_nextafter(const Dtype b) { + return boost::math::nextafter(b, std::numeric_limits::max()); +} + +template +float caffe_nextafter(const float b); + +template +double caffe_nextafter(const double b); + +template +void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_LE(a, b); + //VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), + // n, r, a, b)); + + // FIXME check if boundaries are handled in the same way ? + // Fixed by caffe_nextafter + boost::uniform_real random_distribution(a, caffe_nextafter(b)); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); + boost::variate_generator > variate_generator( + generator, random_distribution); + + for (int i = 0; i < n; ++i) { + r[i] = variate_generator(); + } +} + +template +void caffe_vRngUniform(const int n, float* r, const float a, + const float b); +template +void caffe_vRngUniform(const int n, double* r, const double a, + const double b); + +template +void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a, + const Dtype sigma) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GT(sigma, 0); + //VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, +// Caffe::vsl_stream(), n, r, a, sigma)); + + // FIXME check if parameters are handled in the same way ? + // http://www.boost.org/doc/libs/1_55_0/doc/html/boost/random/normal_distribution.html + // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-63196F25-5013-4038-8BCD-2613C4EF3DE4.htm + // The above two documents show that the probability density functions are different. + // But the unit tests still pass. Maybe their codes are the same or + // the tests are irrelevant to the random numbers. + boost::normal_distribution random_distribution(a, sigma); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); + boost::variate_generator > variate_generator( + generator, random_distribution); + + for (int i = 0; i < n; ++i) { + r[i] = variate_generator(); + } +} + +template +void caffe_vRngGaussian(const int n, float* r, const float a, + const float sigma); -template <> +template void caffe_vRngGaussian(const int n, double* r, const double a, - const double sigma) { - VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, - Caffe::vsl_stream(), n, r, a, sigma)); -} + const double sigma); -template <> -void caffe_exp(const int n, const float* a, float* y) { - vsExp(n, a, y); -} +template +void caffe_vRngBernoulli(const int n, Dtype* r, const double p) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GE(p, 0); + CHECK_LE(p, 1); + boost::bernoulli_distribution random_distribution(p); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); + boost::variate_generator > variate_generator( + generator, random_distribution); -template <> -void caffe_exp(const int n, const double* a, double* y) { - vdExp(n, a, y); + for (int i = 0; i < n; ++i) { + r[i] = variate_generator(); + } } -template <> +template +void caffe_vRngBernoulli(const int n, int* r, const double p); + +template<> float caffe_cpu_dot(const int n, const float* x, const float* y) { - return cblas_sdot(n, x, 1, y, 1); + CHECK_GE(n, 0); + CHECK(x); + CHECK(y); + return const_map_vector_float_t(x, n).dot(const_map_vector_float_t(y, n)); } -template <> +template<> double caffe_cpu_dot(const int n, const double* x, const double* y) { - return cblas_ddot(n, x, 1, y, 1); + CHECK_GE(n, 0); + CHECK(x); + CHECK(y); + return const_map_vector_double_t(x, n).dot(const_map_vector_double_t(y, n)); } -template <> +template<> void caffe_gpu_dot(const int n, const float* x, const float* y, - float* out) { + float* out) { CUBLAS_CHECK(cublasSdot(Caffe::cublas_handle(), n, x, 1, y, 1, out)); } -template <> +template<> void caffe_gpu_dot(const int n, const double* x, const double* y, - double * out) { + double * out) { CUBLAS_CHECK(cublasDdot(Caffe::cublas_handle(), n, x, 1, y, 1, out)); }