From 89c3791d064034856cbc0c0fd26ffeac9a1298be Mon Sep 17 00:00:00 2001 From: bhack Date: Tue, 2 Jun 2015 01:31:57 +0200 Subject: [PATCH] First attemp to remove boost dependecies. Still alive only for python wrapper. --- CMakeLists.txt | 2 +- Makefile | 16 ++++++---- cmake/Dependencies.cmake | 6 ++-- examples/cifar10/convert_cifar_data.cpp | 10 +++--- .../cpp_classification/classification.cpp | 7 +++- include/caffe/blob.hpp | 1 - include/caffe/common.hpp | 15 ++++----- include/caffe/data_layers.hpp | 1 - include/caffe/internal_thread.hpp | 11 ++----- include/caffe/layer.hpp | 1 - include/caffe/util/benchmark.hpp | 9 ++++-- include/caffe/util/io.hpp | 17 ++++++++++ include/caffe/util/rng.hpp | 7 ++-- include/caffe/vision_layers.hpp | 1 + python/caffe/_caffe.cpp | 22 +++++++------ scripts/travis/travis_install.sh | 16 +++++++--- src/caffe/internal_thread.cpp | 4 +-- src/caffe/layers/hdf5_data_layer.cpp | 1 + src/caffe/layers/hdf5_data_layer.cu | 1 + src/caffe/test/test_data_layer.cpp | 12 +++---- src/caffe/test/test_db.cpp | 26 +++++++-------- .../test/test_softmax_with_loss_layer.cpp | 6 ++-- src/caffe/util/benchmark.cpp | 24 +++++++------- src/caffe/util/math_functions.cpp | 32 +++++++++---------- tools/caffe.cpp | 5 +-- tools/compute_image_mean.cpp | 9 +++--- tools/convert_imageset.cpp | 8 ++--- tools/extract_features.cpp | 8 ++--- 28 files changed, 153 insertions(+), 125 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f22aa5763a3..51b90dcdac2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,7 +29,7 @@ include(cmake/Dependencies.cmake) # ---[ Flags if(UNIX OR APPLE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall -std=c++11") endif() if(USE_libstdcpp) diff --git a/Makefile b/Makefile index e4e66dfd138..ef7f8767487 100644 --- a/Makefile +++ b/Makefile @@ -170,7 +170,7 @@ ifneq ($(CPU_ONLY), 1) LIBRARIES := cudart cublas curand endif LIBRARIES += glog gflags protobuf leveldb snappy \ - lmdb boost_system hdf5_hl hdf5 m \ + lmdb hdf5_hl hdf5 m \ opencv_core opencv_highgui opencv_imgproc PYTHON_LIBRARIES := boost_python python2.7 WARNINGS := -Wall -Wno-sign-compare @@ -233,7 +233,7 @@ ifeq ($(LINUX), 1) endif # boost::thread is reasonably called boost_thread (compare OS X) # We will also explicitly add stdc++ to the link target. - LIBRARIES += boost_thread stdc++ + LIBRARIES += stdc++ endif # OS X: @@ -253,7 +253,7 @@ ifeq ($(OSX), 1) # gtest needs to use its own tuple to not conflict with clang COMMON_FLAGS += -DGTEST_USE_OWN_TR1_TUPLE=1 # boost::thread is called boost_thread-mt to mark multithreading on OS X - LIBRARIES += boost_thread-mt + #LIBRARIES += boost_thread-mt # we need to explicitly ask for the rpath to be obeyed DYNAMIC_FLAGS := -install_name @rpath/libcaffe.so ORIGIN := @loader_path @@ -344,16 +344,18 @@ LIBRARY_DIRS += $(BLAS_LIB) LIBRARY_DIRS += $(LIB_BUILD_DIR) +INCLUDE_DIRS += /usr/include/hdf5/serial + # Automatic dependency generation (nvcc is handled separately) CXXFLAGS += -MMD -MP # Complete build flags. COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) -CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS) $(WARNINGS) -NVCCFLAGS += -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) +CXXFLAGS += -pthread -fPIC -std=c++11 $(COMMON_FLAGS) $(WARNINGS) +NVCCFLAGS += -ccbin=$(CXX) -Xcompiler -fPIC -std=c++11 $(COMMON_FLAGS) # mex may invoke an older gcc that is too liberal with -Wuninitalized MATLAB_CXXFLAGS := $(CXXFLAGS) -Wno-uninitialized -LINKFLAGS += -pthread -fPIC $(COMMON_FLAGS) $(WARNINGS) +LINKFLAGS += -pthread -fPIC -std=c++11 $(COMMON_FLAGS) $(WARNINGS) USE_PKG_CONFIG ?= 0 ifeq ($(USE_PKG_CONFIG), 1) @@ -441,7 +443,7 @@ py: $(PY$(PROJECT)_SO) $(PROTO_GEN_PY) $(PY$(PROJECT)_SO): $(PY$(PROJECT)_SRC) $(PY$(PROJECT)_HXX) | $(DYNAMIC_NAME) @ echo CXX/LD -o $@ $< - $(Q)$(CXX) -shared -o $@ $(PY$(PROJECT)_SRC) \ + $(Q)$(CXX) -std=c++11 -shared -o $@ $(PY$(PROJECT)_SRC) \ -o $@ $(LINKFLAGS) -l$(PROJECT) $(PYTHON_LDFLAGS) \ -Wl,-rpath,$(ORIGIN)/../../build/lib diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 7cae5c9da25..4d7caa3389c 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -2,9 +2,9 @@ set(Caffe_LINKER_LIBS "") # ---[ Boost -find_package(Boost 1.46 REQUIRED COMPONENTS system thread) -include_directories(SYSTEM ${Boost_INCLUDE_DIR}) -list(APPEND Caffe_LINKER_LIBS ${Boost_LIBRARIES}) +#find_package(Boost 1.46 REQUIRED COMPONENTS system thread) +#include_directories(SYSTEM ${Boost_INCLUDE_DIR}) +#list(APPEND Caffe_LINKER_LIBS ${Boost_LIBRARIES}) # ---[ Threads find_package(Threads REQUIRED) diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index f4c42e4d2e7..645a7dc9c1e 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -7,9 +7,9 @@ // http://www.cs.toronto.edu/~kriz/cifar.html #include // NOLINT(readability/streams) +#include #include -#include "boost/scoped_ptr.hpp" #include "glog/logging.h" #include "google/protobuf/text_format.h" #include "stdint.h" @@ -18,7 +18,7 @@ #include "caffe/util/db.hpp" using caffe::Datum; -using boost::scoped_ptr; +using std::unique_ptr; using std::string; namespace db = caffe::db; @@ -37,9 +37,9 @@ void read_image(std::ifstream* file, int* label, char* buffer) { void convert_dataset(const string& input_folder, const string& output_folder, const string& db_type) { - scoped_ptr train_db(db::GetDB(db_type)); + unique_ptr train_db(db::GetDB(db_type)); train_db->Open(output_folder + "/cifar10_train_" + db_type, db::NEW); - scoped_ptr txn(train_db->NewTransaction()); + unique_ptr txn(train_db->NewTransaction()); // Data buffer int label; char str_buffer[kCIFARImageNBytes]; @@ -71,7 +71,7 @@ void convert_dataset(const string& input_folder, const string& output_folder, train_db->Close(); LOG(INFO) << "Writing Testing data"; - scoped_ptr test_db(db::GetDB(db_type)); + unique_ptr test_db(db::GetDB(db_type)); test_db->Open(output_folder + "/cifar10_test_" + db_type, db::NEW); txn.reset(test_db->NewTransaction()); // Open files diff --git a/examples/cpp_classification/classification.cpp b/examples/cpp_classification/classification.cpp index 1c6371e382b..959d4673b82 100644 --- a/examples/cpp_classification/classification.cpp +++ b/examples/cpp_classification/classification.cpp @@ -1,13 +1,18 @@ -#include #include #include #include + +#include + +#include #include #include #include #include #include + + using namespace caffe; // NOLINT(build/namespaces) using std::string; diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index 472cc1841f7..4edf956cb95 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -1,7 +1,6 @@ #ifndef CAFFE_BLOB_HPP_ #define CAFFE_BLOB_HPP_ -#include #include #include diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 5f86bc2625b..397113345da 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -1,7 +1,6 @@ #ifndef CAFFE_COMMON_HPP_ #define CAFFE_COMMON_HPP_ -#include #include #include @@ -10,6 +9,7 @@ #include // NOLINT(readability/streams) #include // NOLINT(readability/streams) #include +#include #include #include #include @@ -70,9 +70,8 @@ namespace cv { class Mat; } namespace caffe { -// We will use the boost shared_ptr instead of the new C++11 one mainly -// because cuda does not work (at least now) well with C++11 features. -using boost::shared_ptr; + +using std::shared_ptr; // Common functions and classes from std that caffe often uses. using std::fstream; @@ -106,7 +105,7 @@ class Caffe { } enum Brew { CPU, GPU }; - // This random number generator facade hides boost and CUDA rng + // This random number generator facade hides std and CUDA rng // implementation from one another (for cross-platform compatibility). class RNG { public: @@ -117,10 +116,10 @@ class Caffe { void* generator(); private: class Generator; - shared_ptr generator_; + std::shared_ptr generator_; }; - // Getters for boost rng, curand, and cublas handles + // Getters for std rng, curand, and cublas handles inline static RNG& rng_stream() { if (!Get().random_generator_) { Get().random_generator_.reset(new RNG()); @@ -142,7 +141,7 @@ class Caffe { // freed in a non-pinned way, which may cause problems - I haven't verified // it personally but better to note it here in the header file. inline static void set_mode(Brew mode) { Get().mode_ = mode; } - // Sets the random seed of both boost and curand + // Sets the random seed of both std and curand static void set_random_seed(const unsigned int seed); // Sets the device. Since we have cublas and curand stuff, set device also // requires us to reset those values. diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 3958cb7ecb0..119ca904a71 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -5,7 +5,6 @@ #include #include -#include "boost/scoped_ptr.hpp" #include "hdf5.h" #include "caffe/blob.hpp" diff --git a/include/caffe/internal_thread.hpp b/include/caffe/internal_thread.hpp index 815ca54605e..46cc583ea2b 100644 --- a/include/caffe/internal_thread.hpp +++ b/include/caffe/internal_thread.hpp @@ -1,18 +1,13 @@ #ifndef CAFFE_INTERNAL_THREAD_HPP_ #define CAFFE_INTERNAL_THREAD_HPP_ +#include #include "caffe/common.hpp" -/** - Forward declare boost::thread instead of including boost/thread.hpp - to avoid a boost/NVCC issues (#1009, #1010) on OSX. - */ -namespace boost { class thread; } - namespace caffe { /** - * Virtual class encapsulate boost::thread for use in base class + * Virtual class encapsulate std::thread for use in base class * The child class will acquire the ability to run a single thread, * by reimplementing the virutal function InternalThreadEntry. */ @@ -34,7 +29,7 @@ class InternalThread { with the code you want your thread to run. */ virtual void InternalThreadEntry() {} - shared_ptr thread_; + shared_ptr thread_; }; } // namespace caffe diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index 8f924a75755..e0170722199 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -1,7 +1,6 @@ #ifndef CAFFE_LAYER_H_ #define CAFFE_LAYER_H_ -#include #include #include diff --git a/include/caffe/util/benchmark.hpp b/include/caffe/util/benchmark.hpp index d63582776ee..79012339ab4 100644 --- a/include/caffe/util/benchmark.hpp +++ b/include/caffe/util/benchmark.hpp @@ -1,7 +1,7 @@ #ifndef CAFFE_UTIL_BENCHMARK_H_ #define CAFFE_UTIL_BENCHMARK_H_ -#include +#include #include "caffe/util/device_alternate.hpp" @@ -31,8 +31,11 @@ class Timer { cudaEvent_t start_gpu_; cudaEvent_t stop_gpu_; #endif - boost::posix_time::ptime start_cpu_; - boost::posix_time::ptime stop_cpu_; + typedef std::chrono::high_resolution_clock clock; + typedef std::chrono::microseconds microseconds; + typedef std::chrono::milliseconds milliseconds; + clock::time_point start_cpu_; + clock::time_point stop_cpu_; float elapsed_milliseconds_; float elapsed_microseconds_; }; diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index 3a62c3c9fa9..259f97aa3b8 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -3,6 +3,8 @@ #include #include +#include + #include "google/protobuf/message.h" #include "hdf5.h" @@ -140,6 +142,21 @@ cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color); void CVMatToDatum(const cv::Mat& cv_img, Datum* datum); +inline void string_split(vector* tokens, const string& str, + const string& delimiters) { + // Skip delimiters at beginning. + string::size_type lastPos = str.find_first_not_of(delimiters, 0); + // Find first "non-delimiter". + string::size_type pos = str.find_first_of(delimiters, lastPos); + while (string::npos != pos || string::npos != lastPos) { + // Found a token, add it to the vector. + tokens->push_back(str.substr(lastPos, pos - lastPos)); + // Skip delimiters. Note the "not_of" + lastPos = str.find_first_not_of(delimiters, pos); + // Find next "non-delimiter" + pos = str.find_first_of(delimiters, lastPos); + } +} template void hdf5_load_nd_dataset_helper( hid_t file_id, const char* dataset_name_, int min_dim, int max_dim, diff --git a/include/caffe/util/rng.hpp b/include/caffe/util/rng.hpp index 8f1cf0d17c2..6a8fa3b7bbf 100644 --- a/include/caffe/util/rng.hpp +++ b/include/caffe/util/rng.hpp @@ -4,14 +4,13 @@ #include #include -#include "boost/random/mersenne_twister.hpp" -#include "boost/random/uniform_int.hpp" +#include #include "caffe/common.hpp" namespace caffe { -typedef boost::mt19937 rng_t; +typedef std::mt19937 rng_t; inline rng_t* caffe_rng() { return static_cast(Caffe::rng_stream().generator()); @@ -23,7 +22,7 @@ inline void shuffle(RandomAccessIterator begin, RandomAccessIterator end, RandomGenerator* gen) { typedef typename std::iterator_traits::difference_type difference_type; - typedef typename boost::uniform_int dist_type; + typedef typename std::uniform_int_distribution dist_type; difference_type length = std::distance(begin, end); if (length <= 0) return; diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index a6bd86a93f5..c4e37bb4e28 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -1,6 +1,7 @@ #ifndef CAFFE_VISION_LAYERS_HPP_ #define CAFFE_VISION_LAYERS_HPP_ +#include #include #include #include diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index dff7f627016..4196dd7c9d2 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -3,13 +3,13 @@ // Produce deprecation warnings (needs to come before arrayobject.h inclusion). #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include #include #include #include #include // these need to be included after boost on OS X +#include #include // NOLINT(build/include_order) #include // NOLINT(build/include_order) #include // NOLINT @@ -25,6 +25,7 @@ #endif namespace bp = boost::python; +using std::shared_ptr; namespace caffe { @@ -103,7 +104,8 @@ void Net_SetInputArrays(Net* net, bp::object data_obj, bp::object labels_obj) { // check that this network has an input MemoryDataLayer shared_ptr > md_layer = - boost::dynamic_pointer_cast >(net->layers()[0]); + std::dynamic_pointer_cast > + (net->layers()[0]); if (!md_layer) { throw std::runtime_error("set_input_arrays may only be called if the" " first layer is a MemoryDataLayer"); @@ -159,7 +161,7 @@ struct NdarrayCallPolicies : public bp::default_call_policies { PyObject* postcall(PyObject* pyargs, PyObject* result) { bp::object pyblob = bp::extract(pyargs)()[0]; shared_ptr > blob = - bp::extract > >(pyblob); + bp::extract > >(pyblob); // Free the temporary pointer-holding array, and construct a new one with // the shape information from the blob. void* data = PyArray_DATA(reinterpret_cast(result)); @@ -200,8 +202,8 @@ BOOST_PYTHON_MODULE(_caffe) { bp::def("set_mode_gpu", &set_mode_gpu); bp::def("set_device", &Caffe::SetDevice); - bp::class_, shared_ptr >, boost::noncopyable >("Net", - bp::no_init) + bp::class_, + shared_ptr >, boost::noncopyable >("Net", bp::no_init) .def("__init__", bp::make_constructor(&Net_Init)) .def("__init__", bp::make_constructor(&Net_Init_Load)) .def("_forward", &Net::ForwardFromTo) @@ -242,8 +244,8 @@ BOOST_PYTHON_MODULE(_caffe) { .add_property("diff", bp::make_function(&Blob::mutable_cpu_diff, NdarrayCallPolicies())); - bp::class_, shared_ptr >, - boost::noncopyable>("Layer", bp::init()) + bp::class_, shared_ptr >, boost::noncopyable> + ("Layer", bp::init()) .add_property("blobs", bp::make_function(&Layer::blobs, bp::return_internal_reference<>())) .def("setup", &Layer::LayerSetUp) @@ -253,8 +255,8 @@ BOOST_PYTHON_MODULE(_caffe) { bp::class_("LayerParameter", bp::no_init); - bp::class_, shared_ptr >, boost::noncopyable>( - "Solver", bp::no_init) + bp::class_, shared_ptr >, boost::noncopyable> + ("Solver", bp::no_init) .add_property("net", &Solver::net) .add_property("test_nets", bp::make_function(&Solver::test_nets, bp::return_internal_reference<>())) @@ -271,7 +273,7 @@ BOOST_PYTHON_MODULE(_caffe) { shared_ptr >, boost::noncopyable>( "NesterovSolver", bp::init()); bp::class_, bp::bases >, - shared_ptr >, boost::noncopyable>( + shared_ptr >, boost::noncopyable>( "AdaGradSolver", bp::init()); bp::def("get_solver", &GetSolverFromFile, diff --git a/scripts/travis/travis_install.sh b/scripts/travis/travis_install.sh index b6e6f6ce821..1cceec82485 100755 --- a/scripts/travis/travis_install.sh +++ b/scripts/travis/travis_install.sh @@ -9,17 +9,23 @@ MAKE="make --jobs=$NUM_THREADS" # This ppa is for gflags and glog add-apt-repository -y ppa:tuleu/precise-backports +# This ppa is for boost 1.54 +add-apt-repository -y ppa:boost-latest/ppa +# This ppa is for g++ 4.8 +add-apt-repository -y ppa:ubuntu-toolchain-r/test + apt-get -y update apt-get install \ - wget git curl \ + g++-4.8 wget git curl \ python-dev python-numpy \ libleveldb-dev libsnappy-dev libopencv-dev \ - libboost-dev libboost-system-dev libboost-python-dev libboost-thread-dev \ + libboost-python1.54-dev \ libprotobuf-dev protobuf-compiler \ libatlas-dev libatlas-base-dev \ libhdf5-serial-dev libgflags-dev libgoogle-glog-dev \ bc +update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-4.8 90 # Add a special apt-repository to install CMake 2.8.9 for CMake Caffe build, # if needed. By default, Aptitude in Ubuntu 12.04 installs CMake 2.8.7, but # Caffe requires a minimum CMake version of 2.8.8. @@ -31,7 +37,7 @@ fi # Install CUDA, if needed if $WITH_CUDA; then - CUDA_URL=http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1204/x86_64/cuda-repo-ubuntu1204_6.5-14_amd64.deb + CUDA_URL=http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1204/x86_64/cuda-repo-ubuntu1204_7.0-28_amd64.deb CUDA_FILE=/tmp/cuda_install.deb curl $CUDA_URL -o $CUDA_FILE dpkg -i $CUDA_FILE @@ -39,11 +45,11 @@ if $WITH_CUDA; then apt-get -y update # Install the minimal CUDA subpackages required to test Caffe build. # For a full CUDA installation, add 'cuda' to the list of packages. - apt-get -y install cuda-core-6-5 cuda-cublas-6-5 cuda-cublas-dev-6-5 cuda-cudart-6-5 cuda-cudart-dev-6-5 cuda-curand-6-5 cuda-curand-dev-6-5 + apt-get -y install cuda-core-7-0 cuda-cublas-7-0 cuda-cublas-dev-7-0 cuda-cudart-7-0 cuda-cudart-dev-7-0 cuda-curand-7-0 cuda-curand-dev-7-0 # Create CUDA symlink at /usr/local/cuda # (This would normally be created by the CUDA installer, but we create it # manually since we did a partial installation.) - ln -s /usr/local/cuda-6.5 /usr/local/cuda + ln -s /usr/local/cuda-7.0 /usr/local/cuda fi # Install LMDB diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index c2d19d433b4..34f51d14e43 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -1,4 +1,4 @@ -#include +#include #include "caffe/internal_thread.hpp" namespace caffe { @@ -18,7 +18,7 @@ bool InternalThread::StartInternalThread() { } try { thread_.reset( - new boost::thread(&InternalThread::InternalThreadEntry, this)); + new std::thread(&InternalThread::InternalThreadEntry, this)); } catch (...) { return false; } diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp index 8a782f7e524..40c6e28efd9 100644 --- a/src/caffe/layers/hdf5_data_layer.cpp +++ b/src/caffe/layers/hdf5_data_layer.cpp @@ -6,6 +6,7 @@ :: don't forget to update hdf5_daa_layer.cu accordingly - add ability to shuffle filenames if flag is set */ +#include #include // NOLINT(readability/streams) #include #include diff --git a/src/caffe/layers/hdf5_data_layer.cu b/src/caffe/layers/hdf5_data_layer.cu index 5e3e4ced141..d496592b02c 100644 --- a/src/caffe/layers/hdf5_data_layer.cu +++ b/src/caffe/layers/hdf5_data_layer.cu @@ -4,6 +4,7 @@ TODO: */ #include +#include #include #include diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index afe2a40d227..1eed114139f 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -1,7 +1,7 @@ +#include #include #include -#include "boost/scoped_ptr.hpp" #include "gtest/gtest.h" #include "caffe/blob.hpp" @@ -16,7 +16,7 @@ namespace caffe { -using boost::scoped_ptr; +using std::unique_ptr; template class DataLayerTest : public MultiDeviceTest { @@ -42,9 +42,9 @@ class DataLayerTest : public MultiDeviceTest { void Fill(const bool unique_pixels, DataParameter_DB backend) { backend_ = backend; LOG(INFO) << "Using temporary dataset " << *filename_; - scoped_ptr db(db::GetDB(backend)); + unique_ptr db(db::GetDB(backend)); db->Open(*filename_, db::NEW); - scoped_ptr txn(db->NewTransaction()); + unique_ptr txn(db->NewTransaction()); for (int i = 0; i < 5; ++i) { Datum datum; datum.set_label(i); @@ -108,9 +108,9 @@ class DataLayerTest : public MultiDeviceTest { const int num_inputs = 5; // Save data of varying shapes. LOG(INFO) << "Using temporary dataset " << *filename_; - scoped_ptr db(db::GetDB(backend)); + unique_ptr db(db::GetDB(backend)); db->Open(*filename_, db::NEW); - scoped_ptr txn(db->NewTransaction()); + unique_ptr txn(db->NewTransaction()); for (int i = 0; i < num_inputs; ++i) { Datum datum; datum.set_label(i); diff --git a/src/caffe/test/test_db.cpp b/src/caffe/test/test_db.cpp index 5b2ac230a0b..26dd35c0e1d 100644 --- a/src/caffe/test/test_db.cpp +++ b/src/caffe/test/test_db.cpp @@ -1,6 +1,6 @@ +#include #include -#include "boost/scoped_ptr.hpp" #include "gtest/gtest.h" #include "caffe/common.hpp" @@ -12,7 +12,7 @@ namespace caffe { -using boost::scoped_ptr; +using std::unique_ptr; template class DBTest : public ::testing::Test { @@ -26,9 +26,9 @@ class DBTest : public ::testing::Test { source_ += "/db"; string keys[] = {"cat.jpg", "fish-bike.jpg"}; LOG(INFO) << "Using temporary db " << source_; - scoped_ptr db(db::GetDB(TypeParam::backend)); + unique_ptr db(db::GetDB(TypeParam::backend)); db->Open(this->source_, db::NEW); - scoped_ptr txn(db->NewTransaction()); + unique_ptr txn(db->NewTransaction()); for (int i = 0; i < 2; ++i) { Datum datum; ReadImageToDatum(root_images_ + keys[i], i, &datum); @@ -62,13 +62,13 @@ typedef ::testing::Types TestTypes; TYPED_TEST_CASE(DBTest, TestTypes); TYPED_TEST(DBTest, TestGetDB) { - scoped_ptr db(db::GetDB(TypeParam::backend)); + unique_ptr db(db::GetDB(TypeParam::backend)); } TYPED_TEST(DBTest, TestNext) { - scoped_ptr db(db::GetDB(TypeParam::backend)); + unique_ptr db(db::GetDB(TypeParam::backend)); db->Open(this->source_, db::READ); - scoped_ptr cursor(db->NewCursor()); + unique_ptr cursor(db->NewCursor()); EXPECT_TRUE(cursor->valid()); cursor->Next(); EXPECT_TRUE(cursor->valid()); @@ -77,9 +77,9 @@ TYPED_TEST(DBTest, TestNext) { } TYPED_TEST(DBTest, TestSeekToFirst) { - scoped_ptr db(db::GetDB(TypeParam::backend)); + unique_ptr db(db::GetDB(TypeParam::backend)); db->Open(this->source_, db::READ); - scoped_ptr cursor(db->NewCursor()); + unique_ptr cursor(db->NewCursor()); cursor->Next(); cursor->SeekToFirst(); EXPECT_TRUE(cursor->valid()); @@ -93,9 +93,9 @@ TYPED_TEST(DBTest, TestSeekToFirst) { } TYPED_TEST(DBTest, TestKeyValue) { - scoped_ptr db(db::GetDB(TypeParam::backend)); + unique_ptr db(db::GetDB(TypeParam::backend)); db->Open(this->source_, db::READ); - scoped_ptr cursor(db->NewCursor()); + unique_ptr cursor(db->NewCursor()); EXPECT_TRUE(cursor->valid()); string key = cursor->key(); Datum datum; @@ -117,9 +117,9 @@ TYPED_TEST(DBTest, TestKeyValue) { } TYPED_TEST(DBTest, TestWrite) { - scoped_ptr db(db::GetDB(TypeParam::backend)); + unique_ptr db(db::GetDB(TypeParam::backend)); db->Open(this->source_, db::WRITE); - scoped_ptr txn(db->NewTransaction()); + unique_ptr txn(db->NewTransaction()); Datum datum; ReadFileToDatum(this->root_images_ + "cat.jpg", 0, &datum); string out; diff --git a/src/caffe/test/test_softmax_with_loss_layer.cpp b/src/caffe/test/test_softmax_with_loss_layer.cpp index 1498d5c5ce1..805527943e2 100644 --- a/src/caffe/test/test_softmax_with_loss_layer.cpp +++ b/src/caffe/test/test_softmax_with_loss_layer.cpp @@ -1,9 +1,9 @@ #include #include #include +#include #include -#include "boost/scoped_ptr.hpp" #include "gtest/gtest.h" #include "caffe/blob.hpp" @@ -14,7 +14,7 @@ #include "caffe/test/test_caffe_main.hpp" #include "caffe/test/test_gradient_check_util.hpp" -using boost::scoped_ptr; +using std::unique_ptr; namespace caffe { @@ -68,7 +68,7 @@ TYPED_TEST(SoftmaxWithLossLayerTest, TestForwardIgnoreLabel) { LayerParameter layer_param; layer_param.mutable_loss_param()->set_normalize(false); // First, compute the loss with all labels - scoped_ptr > layer( + unique_ptr > layer( new SoftmaxWithLossLayer(layer_param)); layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); diff --git a/src/caffe/util/benchmark.cpp b/src/caffe/util/benchmark.cpp index 1d269c351c1..0029166a70e 100644 --- a/src/caffe/util/benchmark.cpp +++ b/src/caffe/util/benchmark.cpp @@ -1,5 +1,3 @@ -#include - #include "caffe/common.hpp" #include "caffe/util/benchmark.hpp" @@ -32,7 +30,7 @@ void Timer::Start() { NO_GPU; #endif } else { - start_cpu_ = boost::posix_time::microsec_clock::local_time(); + start_cpu_ = std::chrono::high_resolution_clock::now(); } running_ = true; has_run_at_least_once_ = true; @@ -49,7 +47,7 @@ void Timer::Stop() { NO_GPU; #endif } else { - stop_cpu_ = boost::posix_time::microsec_clock::local_time(); + stop_cpu_ = std::chrono::high_resolution_clock::now(); } running_ = false; } @@ -74,7 +72,8 @@ float Timer::MicroSeconds() { NO_GPU; #endif } else { - elapsed_microseconds_ = (stop_cpu_ - start_cpu_).total_microseconds(); + elapsed_microseconds_ = std::chrono::duration_cast + (stop_cpu_ - start_cpu_).count(); } return elapsed_microseconds_; } @@ -95,7 +94,8 @@ float Timer::MilliSeconds() { NO_GPU; #endif } else { - elapsed_milliseconds_ = (stop_cpu_ - start_cpu_).total_milliseconds(); + elapsed_milliseconds_ = std::chrono::duration_cast + (stop_cpu_ - start_cpu_).count(); } return elapsed_milliseconds_; } @@ -126,7 +126,7 @@ CPUTimer::CPUTimer() { void CPUTimer::Start() { if (!running()) { - this->start_cpu_ = boost::posix_time::microsec_clock::local_time(); + this->start_cpu_ = std::chrono::high_resolution_clock::now(); this->running_ = true; this->has_run_at_least_once_ = true; } @@ -134,7 +134,7 @@ void CPUTimer::Start() { void CPUTimer::Stop() { if (running()) { - this->stop_cpu_ = boost::posix_time::microsec_clock::local_time(); + this->stop_cpu_ = std::chrono::high_resolution_clock::now(); this->running_ = false; } } @@ -147,8 +147,8 @@ float CPUTimer::MilliSeconds() { if (running()) { Stop(); } - this->elapsed_milliseconds_ = (this->stop_cpu_ - - this->start_cpu_).total_milliseconds(); + this->elapsed_milliseconds_ = std::chrono::duration_cast + (stop_cpu_ - start_cpu_).count(); return this->elapsed_milliseconds_; } @@ -160,8 +160,8 @@ float CPUTimer::MicroSeconds() { if (running()) { Stop(); } - this->elapsed_microseconds_ = (this->stop_cpu_ - - this->start_cpu_).total_microseconds(); + this->elapsed_microseconds_ = std::chrono::duration_cast + (stop_cpu_ - start_cpu_).count(); return this->elapsed_microseconds_; } diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 0aab6b17b85..63020f64caf 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -1,5 +1,4 @@ -#include -#include +#include #include @@ -232,8 +231,7 @@ unsigned int caffe_rng_rand() { template Dtype caffe_nextafter(const Dtype b) { - return boost::math::nextafter( - b, std::numeric_limits::max()); + return std::nextafter(b, std::numeric_limits::max()); } template @@ -247,9 +245,10 @@ void caffe_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r) { CHECK_GE(n, 0); CHECK(r); CHECK_LE(a, b); - boost::uniform_real random_distribution(a, caffe_nextafter(b)); - boost::variate_generator > - variate_generator(caffe_rng(), random_distribution); + std::uniform_real_distribution<> + random_distribution(a, caffe_nextafter(b)); + std::function + variate_generator = bind(random_distribution, std::ref(*caffe_rng())); for (int i = 0; i < n; ++i) { r[i] = variate_generator(); } @@ -269,9 +268,9 @@ void caffe_rng_gaussian(const int n, const Dtype a, CHECK_GE(n, 0); CHECK(r); CHECK_GT(sigma, 0); - boost::normal_distribution random_distribution(a, sigma); - boost::variate_generator > - variate_generator(caffe_rng(), random_distribution); + std::normal_distribution<> random_distribution(a, sigma); + std::function + variate_generator= bind(random_distribution, std::ref(*caffe_rng())); for (int i = 0; i < n; ++i) { r[i] = variate_generator(); } @@ -291,9 +290,10 @@ void caffe_rng_bernoulli(const int n, const Dtype p, int* r) { CHECK(r); CHECK_GE(p, 0); CHECK_LE(p, 1); - boost::bernoulli_distribution random_distribution(p); - boost::variate_generator > - variate_generator(caffe_rng(), random_distribution); + std::bernoulli_distribution random_distribution(p); + std::function + variate_generator = bind(random_distribution, + std::ref(*caffe_rng())); for (int i = 0; i < n; ++i) { r[i] = variate_generator(); } @@ -311,9 +311,9 @@ void caffe_rng_bernoulli(const int n, const Dtype p, unsigned int* r) { CHECK(r); CHECK_GE(p, 0); CHECK_LE(p, 1); - boost::bernoulli_distribution random_distribution(p); - boost::variate_generator > - variate_generator(caffe_rng(), random_distribution); + std::bernoulli_distribution random_distribution(p); + std::function + variate_generator = bind(random_distribution, std::ref(*caffe_rng())); for (int i = 0; i < n; ++i) { r[i] = static_cast(variate_generator()); } diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 0b7523fccf9..3065a176ff8 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -1,11 +1,12 @@ #include #include +#include #include #include #include -#include "boost/algorithm/string.hpp" + #include "caffe/caffe.hpp" using caffe::Blob; @@ -81,7 +82,7 @@ RegisterBrewFunction(device_query); // test nets. void CopyLayers(caffe::Solver* solver, const std::string& model_list) { std::vector model_names; - boost::split(model_names, model_list, boost::is_any_of(",") ); + caffe::string_split(&model_names, model_list, ","); for (int i = 0; i < model_names.size(); ++i) { LOG(INFO) << "Finetuning from " << model_names[i]; solver->net()->CopyTrainedLayersFrom(model_names[i]); diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp index b1fc7cae38f..fd4a05f196f 100644 --- a/tools/compute_image_mean.cpp +++ b/tools/compute_image_mean.cpp @@ -1,10 +1,11 @@ #include #include +#include #include #include #include -#include "boost/scoped_ptr.hpp" + #include "gflags/gflags.h" #include "glog/logging.h" @@ -16,7 +17,7 @@ using namespace caffe; // NOLINT(build/namespaces) using std::max; using std::pair; -using boost::scoped_ptr; +using std::unique_ptr; DEFINE_string(backend, "lmdb", "The backend {leveldb, lmdb} containing the images"); @@ -40,9 +41,9 @@ int main(int argc, char** argv) { return 1; } - scoped_ptr db(db::GetDB(FLAGS_backend)); + unique_ptr db(db::GetDB(FLAGS_backend)); db->Open(argv[1], db::READ); - scoped_ptr cursor(db->NewCursor()); + unique_ptr cursor(db->NewCursor()); BlobProto sum_blob; int count = 0; diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 816a91f971b..3a5e1a723ee 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -10,11 +10,11 @@ #include #include // NOLINT(readability/streams) +#include #include #include #include -#include "boost/scoped_ptr.hpp" #include "gflags/gflags.h" #include "glog/logging.h" @@ -25,7 +25,7 @@ using namespace caffe; // NOLINT(build/namespaces) using std::pair; -using boost::scoped_ptr; +using std::unique_ptr; DEFINE_bool(gray, false, "When this option is on, treat images as grayscale ones"); @@ -88,9 +88,9 @@ int main(int argc, char** argv) { int resize_width = std::max(0, FLAGS_resize_width); // Create new DB - scoped_ptr db(db::GetDB(FLAGS_backend)); + unique_ptr db(db::GetDB(FLAGS_backend)); db->Open(argv[3], db::NEW); - scoped_ptr txn(db->NewTransaction()); + unique_ptr txn(db->NewTransaction()); // Storing to db std::string root_folder(argv[1]); diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 365dd495bbf..a573be9f628 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -2,7 +2,6 @@ #include #include -#include "boost/algorithm/string.hpp" #include "google/protobuf/text_format.h" #include "caffe/blob.hpp" @@ -17,7 +16,7 @@ using caffe::Blob; using caffe::Caffe; using caffe::Datum; using caffe::Net; -using boost::shared_ptr; +using std::shared_ptr; using std::string; namespace db = caffe::db; @@ -102,12 +101,11 @@ int feature_extraction_pipeline(int argc, char** argv) { std::string extract_feature_blob_names(argv[++arg_pos]); std::vector blob_names; - boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); + caffe::string_split(&blob_names, extract_feature_blob_names, ","); std::string save_feature_dataset_names(argv[++arg_pos]); std::vector dataset_names; - boost::split(dataset_names, save_feature_dataset_names, - boost::is_any_of(",")); + caffe::string_split(&dataset_names, save_feature_dataset_names, ","); CHECK_EQ(blob_names.size(), dataset_names.size()) << " the number of blob names and dataset names must be equal"; size_t num_features = blob_names.size();