From 8d7cdb73e6d948a6ba0fe93e4499fbc736e17ecc Mon Sep 17 00:00:00 2001 From: Jesse Brizzi Date: Tue, 29 Aug 2017 17:23:28 -0400 Subject: [PATCH] DEV-26376: Recode python layer to C++ in detection net --- Makefile.config | 2 +- Makefile.config.ec2.gpu | 16 +- Makefile.config.example | 2 +- Makefile.config.local.gpu | 2 +- README.md | 1 + include/caffe/layers/frcnn_proposal_layer.hpp | 84 +++++ include/caffe/util/frcnn_gpu_nms.hpp | 10 + include/caffe/util/frcnn_helper.hpp | 28 ++ include/caffe/util/frcnn_param.hpp | 40 +++ include/caffe/util/frcnn_utils.hpp | 338 ++++++++++++++++++ src/caffe/layers/frcnn_proposal_layer.cpp | 196 ++++++++++ src/caffe/layers/frcnn_proposal_layer.cu | 128 +++++++ src/caffe/proto/caffe.proto | 9 +- src/caffe/util/frcnn_bbox.cpp | 57 +++ src/caffe/util/frcnn_bbox_transform.cpp | 67 ++++ src/caffe/util/frcnn_config.cpp | 76 ++++ src/caffe/util/frcnn_file.cpp | 64 ++++ src/caffe/util/frcnn_nms_kernel.cu | 140 ++++++++ src/caffe/util/frcnn_param.cpp | 69 ++++ 19 files changed, 1317 insertions(+), 12 deletions(-) create mode 100644 include/caffe/layers/frcnn_proposal_layer.hpp create mode 100644 include/caffe/util/frcnn_gpu_nms.hpp create mode 100644 include/caffe/util/frcnn_helper.hpp create mode 100644 include/caffe/util/frcnn_param.hpp create mode 100644 include/caffe/util/frcnn_utils.hpp create mode 100644 src/caffe/layers/frcnn_proposal_layer.cpp create mode 100644 src/caffe/layers/frcnn_proposal_layer.cu create mode 100644 src/caffe/util/frcnn_bbox.cpp create mode 100644 src/caffe/util/frcnn_bbox_transform.cpp create mode 100644 src/caffe/util/frcnn_config.cpp create mode 100644 src/caffe/util/frcnn_file.cpp create mode 100644 src/caffe/util/frcnn_nms_kernel.cu create mode 100644 src/caffe/util/frcnn_param.cpp diff --git a/Makefile.config b/Makefile.config index 6999e57e0a2..b3b432c21df 100644 --- a/Makefile.config +++ b/Makefile.config @@ -87,7 +87,7 @@ PYTHON_LIB := $(shell python-config --prefix)/lib WITH_PYTHON_LAYER := 1 # Whatever else you find you need goes here. -INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include +INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include /usr/local/cuda/include/thrust/system/cuda/detail/ LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib # If Homebrew is installed at a non standard location (for example your home directory) and you use it for general dependencies diff --git a/Makefile.config.ec2.gpu b/Makefile.config.ec2.gpu index 6999e57e0a2..72e8ef1b3a1 100644 --- a/Makefile.config.ec2.gpu +++ b/Makefile.config.ec2.gpu @@ -8,7 +8,7 @@ USE_CUDNN := 1 # CPU_ONLY := 1 # uncomment to disable IO dependencies and corresponding data layers -# USE_OPENCV := 0 +USE_OPENCV := 0 # USE_LEVELDB := 0 # USE_LMDB := 0 @@ -32,12 +32,14 @@ CUDA_DIR := /usr/local/cuda # CUDA architecture setting: going with all of them. # For CUDA < 6.0, comment the *_50 lines for compatibility. -CUDA_ARCH := -gencode arch=compute_20,code=sm_20 \ - -gencode arch=compute_20,code=sm_21 \ - -gencode arch=compute_30,code=sm_30 \ +# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility. +CUDA_ARCH := -gencode arch=compute_30,code=sm_30 \ -gencode arch=compute_35,code=sm_35 \ -gencode arch=compute_50,code=sm_50 \ - -gencode arch=compute_50,code=compute_50 + -gencode arch=compute_52,code=sm_52 \ + -gencode arch=compute_60,code=sm_60 \ + -gencode arch=compute_61,code=sm_61 \ + -gencode arch=compute_61,code=compute_61 # BLAS choice: # atlas for ATLAS (default) @@ -84,10 +86,10 @@ PYTHON_LIB := $(shell python-config --prefix)/lib # PYTHON_LIB += $(shell brew --prefix numpy)/lib # Uncomment to support layers written in Python (will link against Python libs) -WITH_PYTHON_LAYER := 1 +# WITH_PYTHON_LAYER := 1 # Whatever else you find you need goes here. -INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include +INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include /usr/local/cuda/include/thrust/system/cuda/detail/ LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib # If Homebrew is installed at a non standard location (for example your home directory) and you use it for general dependencies diff --git a/Makefile.config.example b/Makefile.config.example index d552b38a97c..4f14d7a46e4 100644 --- a/Makefile.config.example +++ b/Makefile.config.example @@ -91,7 +91,7 @@ PYTHON_LIB := /usr/lib # WITH_PYTHON_LAYER := 1 # Whatever else you find you need goes here. -INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include +INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include /usr/local/cuda/include/thrust/system/cuda/detail/ LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib # If Homebrew is installed at a non standard location (for example your home directory) and you use it for general dependencies diff --git a/Makefile.config.local.gpu b/Makefile.config.local.gpu index 3527d4804af..9903a947814 100644 --- a/Makefile.config.local.gpu +++ b/Makefile.config.local.gpu @@ -93,7 +93,7 @@ PYTHON_LIB := $(shell python-config --prefix)/lib WITH_PYTHON_LAYER := 1 # Whatever else you find you need goes here. -INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include +INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include /usr/local/cuda/include/thrust/system/cuda/detail/ LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib # If Homebrew is installed at a non standard location (for example your home directory) and you use it for general dependencies diff --git a/README.md b/README.md index 686a0fbeb1e..dddeeb6f9c6 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ This fork of caffe was created off of the [1.0](https://github.com/BVLC/caffe/re - [smooth_L1_loss_ohem_layer](https://github.com/curalate/caffe/pull/2) - [smooth_l1_loss_layer](https://github.com/curalate/caffe/pull/2) - [softmax_loss_ohem_layer](https://github.com/curalate/caffe/pull/2) +- [frcnn_proposal_layer](https://github.com/curalate/caffe/pull/12) [![License](https://img.shields.io/badge/license-BSD-blue.svg)](LICENSE) diff --git a/include/caffe/layers/frcnn_proposal_layer.hpp b/include/caffe/layers/frcnn_proposal_layer.hpp new file mode 100644 index 00000000000..ec04212bcc9 --- /dev/null +++ b/include/caffe/layers/frcnn_proposal_layer.hpp @@ -0,0 +1,84 @@ +// ------------------------------------------------------------------ +// Xuanyi . Refer to Dong Jian +// 2016/03/31 +// ------------------------------------------------------------------ +#ifndef CAFFE_FRCNN_PROPOSAL_LAYER_HPP_ +#define CAFFE_FRCNN_PROPOSAL_LAYER_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/************************************************* +FrcnnProposalLayer +Outputs object detection proposals by applying estimated bounding-box +transformations to a set of regular boxes (called "anchors"). +bottom: 'rpn_cls_prob_reshape' +bottom: 'rpn_bbox_pred' +bottom: 'im_info' +top: 'rpn_rois' +**************************************************/ +template +class FrcnnProposalLayer : public Layer { + public: + explicit FrcnnProposalLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top){}; + + virtual inline const char* type() const { return "FrcnnProposal"; } + + virtual inline int MinBottomBlobs() const { return 3; } + virtual inline int MaxBottomBlobs() const { return 3; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline int MaxTopBlobs() const { return 2; } + +#ifndef CPU_ONLY + virtual ~FrcnnProposalLayer() { + if (this->anchors_) { + CUDA_CHECK(cudaFree(this->anchors_)); + } + if (this->transform_bbox_) { + CUDA_CHECK(cudaFree(this->transform_bbox_)); + } + if (this->mask_) { + CUDA_CHECK(cudaFree(this->mask_)); + } + if (this->selected_flags_) { + CUDA_CHECK(cudaFree(this->selected_flags_)); + } + if (this->gpu_keep_indices_) { + CUDA_CHECK(cudaFree(this->gpu_keep_indices_)); + } + } +#endif + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +#ifndef CPU_ONLY + // CUDA CU + float* anchors_; + float* transform_bbox_; + unsigned long long *mask_; + int *selected_flags_; + int *gpu_keep_indices_; +#endif +}; + +} // namespace caffe + +#endif // CAFFE_FRCNN_PROPOSAL_LAYER_HPP_ diff --git a/include/caffe/util/frcnn_gpu_nms.hpp b/include/caffe/util/frcnn_gpu_nms.hpp new file mode 100644 index 00000000000..6d028da69f1 --- /dev/null +++ b/include/caffe/util/frcnn_gpu_nms.hpp @@ -0,0 +1,10 @@ +#ifndef CAFFE_FRCNN_GPU_NMS_HPP_ +#define CAFFE_FRCNN_GPU_NMS_HPP_ + +namespace caffe { + +void gpu_nms(int* keep_out, int* num_out, const float* boxes_dev, int boxes_num, + int boxes_dim, float nms_overlap_thresh, int device_id=-1); + +} // namespace caffe +#endif // CAFFE_FRCNN_UTILS_HPP_ diff --git a/include/caffe/util/frcnn_helper.hpp b/include/caffe/util/frcnn_helper.hpp new file mode 100644 index 00000000000..2c0c8ee5944 --- /dev/null +++ b/include/caffe/util/frcnn_helper.hpp @@ -0,0 +1,28 @@ +// ------------------------------------------------------------------ +// Xuanyi . Refer to Dong Jian +// 2016/04/01 +// ------------------------------------------------------------------ +#ifndef CAFFE_FRCNN_HELPER_HPP_ +#define CAFFE_FRCNN_HELPER_HPP_ + +#include "caffe/util/frcnn_utils.hpp" + +namespace caffe { + +template +Point4f bbox_transform(const Point4f& ex_rois,const Point4f& gt_rois); + +template +std::vector > bbox_transform(const std::vector >& ex_rois, + const std::vector >& gt_rois); + +template +Point4f bbox_transform_inv(const Point4f& box, const Point4f& delta); + +template +std::vector > bbox_transform_inv(const Point4f& box, + const std::vector >& deltas); + +} // namespace caffe + +#endif diff --git a/include/caffe/util/frcnn_param.hpp b/include/caffe/util/frcnn_param.hpp new file mode 100644 index 00000000000..35c6f3be7f7 --- /dev/null +++ b/include/caffe/util/frcnn_param.hpp @@ -0,0 +1,40 @@ +// ------------------------------------------------------------------ +// Xuanyi . Refer to Dong Jian +// 2016/03/31 +// ------------------------------------------------------------------ +#ifndef CAFFE_FRCNN_PRARM_HPP_ +#define CAFFE_FRCNN_PRARM_HPP_ + +#include +#include + +namespace caffe{ + +class FrcnnParam { +public: + + static float rpn_nms_thresh; + static int rpn_pre_nms_top_n; + static int rpn_post_nms_top_n; + // Proposal height and width both need to be greater than RPN_MIN_SIZE (at + // orig image scale) + static float rpn_min_size; + + static float test_rpn_nms_thresh; + static int test_rpn_pre_nms_top_n; + static int test_rpn_post_nms_top_n; + // Proposal height and width both need to be greater than RPN_MIN_SIZE (at + // orig image scale) + static float test_rpn_min_size; + + static int feat_stride; + static std::vector anchors; + static int n_classes; + // ======================================== + static void load_param(const std::string default_config_path); + static void print_param(); +}; + +} + +#endif // CAFFE_FRCNN_PRARM_HPP_ diff --git a/include/caffe/util/frcnn_utils.hpp b/include/caffe/util/frcnn_utils.hpp new file mode 100644 index 00000000000..d72f085c268 --- /dev/null +++ b/include/caffe/util/frcnn_utils.hpp @@ -0,0 +1,338 @@ +// ------------------------------------------------------------------ +// Xuanyi . Refer to Dong Jian +// 2016/03/29 +// ------------------------------------------------------------------ +#ifndef CAFFE_FRCNN_UTILS_HPP_ +#define CAFFE_FRCNN_UTILS_HPP_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include "boost/filesystem.hpp" + +#include + +#include "caffe/common.hpp" +#include "caffe/util/rng.hpp" +#include "caffe/util/frcnn_param.hpp" + +namespace caffe { + +class DataPrepare { +public: + DataPrepare() { + rois.clear(); + ok = false; + } + inline string GetImagePath(string root = "") { + CHECK(this->ok) << "illegal status(ok=" << ok << ")"; + return root + image_path; + } + inline int GetImageIndex() { + CHECK(this->ok) << "illegal status(ok=" << ok << ")"; + return image_index; + } + inline vector > GetRois(bool include_diff = false) { + CHECK(this->ok) << "illegal status(ok=" << ok << ")"; + CHECK_EQ(this->rois.size(), this->diff.size()); + vector > _rois; + for (size_t index = 0; index < this->rois.size(); index++) { + if (include_diff == false && this->diff[index] == 1) continue; + _rois.push_back( this->rois[index] ); + } + return _rois; + } + inline bool load_WithDiff(std::ifstream &infile) { + string hashtag; + if(!(infile >> hashtag)) return ok=false; + CHECK_EQ(hashtag, "#"); + CHECK(infile >> this->image_index >> this->image_path); + int num_roi; + CHECK(infile >> num_roi); + rois.clear(); diff.clear(); + for (int index = 0; index < num_roi; index++) { + int label, x1, y1, x2, y2; + int diff_; + CHECK(infile >> label >> x1 >> y1 >> x2 >> y2 >> diff_); + //x1 --; y1 --; x2 --; y2 --; + // CHECK LABEL + CHECK(label>0 && label= 1 and < " << FrcnnParam::n_classes; + CHECK_GE(x2, x1) << "illegal coordinate : " << x1 << ", " << x2 << " : " << this->image_path; + CHECK_GE(y2, y1) << "illegal coordinate : " << y1 << ", " << y2 << " : " << this->image_path; + vector roi(DataPrepare::NUM); + roi[DataPrepare::LABEL] = label; + roi[DataPrepare::X1] = x1; + roi[DataPrepare::Y1] = y1; + roi[DataPrepare::X2] = x2; + roi[DataPrepare::Y2] = y2; + rois.push_back(roi); + diff.push_back(diff_); + } + return ok=true; + } + enum RoiDataField { LABEL, X1, Y1, X2, Y2, NUM }; + +private: + vector > rois; + vector diff; + string image_path; + int image_index; + bool ok; +}; + +// image and box +template +class Point4f { +public: + Dtype Point[4]; // x1 y1 x2 y2 + Point4f(Dtype x1 = 0, Dtype y1 = 0, Dtype x2 = 0, Dtype y2 = 0) { + Point[0] = x1; Point[1] = y1; + Point[2] = x2; Point[3] = y2; + } + Point4f(const float data[4]) { + for (int i=0;i<4;i++) Point[i] = data[i]; + } + Point4f(const double data[4]) { + for (int i=0;i<4;i++) Point[i] = data[i]; + } + Point4f(const Point4f &other) { memcpy(Point, other.Point, sizeof(Point)); } + Dtype& operator[](const unsigned int id) { return Point[id]; } + const Dtype& operator[](const unsigned int id) const { return Point[id]; } + + string to_string() const { + char buff[100]; + snprintf(buff, sizeof(buff), "%.1f %.1f %.1f %.1f", Point[0], Point[1], Point[2], Point[3]); + return string(buff); + } + +}; + +template +class BBox : public Point4f { +public: + Dtype confidence; + int id; + + BBox(Dtype x1 = 0, Dtype y1 = 0, Dtype x2 = 0, Dtype y2 = 0, + Dtype confidence = 0, int id = 0) + : Point4f(x1, y1, x2, y2), confidence(confidence), id(id) {} + BBox(Point4f box, Dtype confidence_ = 0, int id = 0) + : Point4f(box), confidence(confidence_), id(id) {} + + BBox &operator=(const BBox &other) { + memcpy(this->Point, other.Point, sizeof(this->Point)); + confidence = other.confidence; + id = other.id; + return *this; + } + + bool operator<(const class BBox &other) const { + if (confidence != other.confidence) + return confidence > other.confidence; + else + return id < other.id; + } + + inline string to_string() const { + char buff[100]; + snprintf(buff, sizeof(buff), "cls:%3d -- (%.3f): %.2f %.2f %.2f %.2f", id, + confidence, this->Point[0], this->Point[1], this->Point[2], this->Point[3]); + return string(buff); + } + + inline string to_short_string() const { + char buff[100]; + snprintf(buff, sizeof(buff), "cls:%1d -- (%.2f)", id, confidence); + return string(buff); + } +}; + +template +class TrackLet : public BBox { +public: + int tracklet; + TrackLet(Dtype x1 = 0, Dtype y1 = 0, Dtype x2 = 0, Dtype y2 = 0, Dtype confidence = 0, int id = 0, int _tracklet = 0): + BBox(x1, y1, x2, y2, confidence, id), tracklet(_tracklet){}; + TrackLet(BBox box, int _tracklet = 0): + BBox(box), tracklet(_tracklet){}; + TrackLet(Point4f box, Dtype confidence = 0, int id = 0, int _tracklet = 0): + BBox(box, confidence, id), tracklet(_tracklet){}; + inline string to_string() const { + char buff[100]; + snprintf(buff, sizeof(buff), "cls:%3d,let:%3d -- (%.3f): %.2f %.2f %.2f %.2f", this->id, this->tracklet, + this->confidence, this->Point[0], this->Point[1], this->Point[2], this->Point[3]); + return string(buff); + } +}; + +template +class VidPrepare { +public: + VidPrepare() { + ok = false; + prefetch_rng_.reset(); + this->current_index = -1; + } + inline void init(const unsigned int seed = 0) { + _image_dataset.clear(); + _objects.clear(); + prefetch_rng_.reset(new Caffe::RNG(seed)); + } + inline bool load_data(std::ifstream &infile) { + if(!(infile >> HASH)) return ok=false; + CHECK_EQ(HASH, "#"); + CHECK(infile >> this->folder); + CHECK(infile >> this->num_image >> this->height >> this->width); + int x1, y1, x2, y2; + int track_let, label; + for (int index = 0; index < this->num_image; index++ ) { + string image; int num_rois; + CHECK(infile >> image >> num_rois); + _image_dataset.push_back(image); + vector > objects; + + for (int roi_ = 0; roi_ < num_rois; roi_++ ) { + CHECK(infile >> track_let >> label >> x1 >> y1 >> x2 >> y2); + TrackLet cobject(x1, y1, x2, y2, 1, label, track_let); + CHECK(label>0 && label= 1 and < " << FrcnnParam::n_classes; + CHECK_GE(x1, 0) << cobject.to_string(); + CHECK_GE(y1, 0) << cobject.to_string(); + CHECK_LT(x1, this->width) << "Width : " << this->width << cobject.to_string(); + CHECK_LT(y1, this->height) << "Height : " << this->height << cobject.to_string(); + objects.push_back(cobject); + } + + _objects.push_back(objects); + } + CHECK_EQ(_image_dataset.size(), _objects.size()); + return ok = true; + } + + inline pair >, string> Next() { + CHECK(ok) << "Status is false"; + this->current_index = PrefetchRand() % _image_dataset.size(); + string image = folder + "/" + _image_dataset[current_index]; + const vector > &objects = _objects[current_index]; + vector > rois; + for (size_t ii = 0; ii < objects.size(); ii++ ) { + vector roi(NUM); + roi[LABEL] = objects[ii].id; + roi[X1] = objects[ii][0]; + roi[Y1] = objects[ii][1]; + roi[X2] = objects[ii][2]; + roi[Y2] = objects[ii][3]; + rois.push_back(roi); + } + CHECK_EQ(rois.size(), objects.size()); + return make_pair(rois, image); + } + + inline string message() { + CHECK(ok) << "Status is false"; + CHECK_GE(this->current_index, 0); + CHECK_LT(this->current_index, int(_image_dataset.size())); + char buff[100]; + snprintf(buff, sizeof(buff), "height : %d, width : %d " , this->height, this->width); + return string(buff); + } + + inline map count_label() { + CHECK(ok) << "Status is false"; + map label_hist; + for (size_t index = 0; index < _objects.size(); index++ ) { + for (size_t oid = 0; oid < _objects[index].size(); oid++ ) { + int label = _objects[index][oid].id; + label_hist.insert(std::make_pair(label, 0)); + label_hist[label]++; + } + } + return label_hist; + } + + inline int H() { + CHECK(ok) << "Status is false"; + return this->height; + } + + inline int W() { + CHECK(ok) << "Status is false"; + return this->width; + } + + enum RoiDataField { LABEL, X1, Y1, X2, Y2, NUM }; +private: + string HASH; + string folder; + int num_image; + int height; + int width; + int current_index; + vector _image_dataset; + vector > > _objects; + bool ok; + + // Random Seed + shared_ptr prefetch_rng_; + inline unsigned int PrefetchRand() { + CHECK(prefetch_rng_); + caffe::rng_t *prefetch_rng = + static_cast(prefetch_rng_->generator()); + return (*prefetch_rng)(); + } +}; + +template +Dtype get_iou(const Point4f &A, const Point4f &B); + +template +vector > get_ious(const vector > &A, const vector > &B); + +template +vector get_ious(const Point4f &A, const vector > &B); + +float get_scale_factor(int width, int height, int short_size, int max_long_size); + +// config +typedef std::map str_map; + +str_map parse_json_config(const string file_path); + +string extract_string(string target_key, str_map& default_map); + +float extract_float(string target_key, str_map& default_map); + +int extract_int(string target_key, str_map& default_map); + +vector extract_vector(string target_key, str_map& default_map); + +// file +vector get_file_list (const string& path, const string& ext); + +template +void print_vector(vector data); + +string anchor_to_string(vector data); + +string float_to_string(const vector data); + +string float_to_string(const float *data); + +} // namespace caffe + +#endif // CAFFE_FRCNN_UTILS_HPP_ diff --git a/src/caffe/layers/frcnn_proposal_layer.cpp b/src/caffe/layers/frcnn_proposal_layer.cpp new file mode 100644 index 00000000000..8d8a6c8b755 --- /dev/null +++ b/src/caffe/layers/frcnn_proposal_layer.cpp @@ -0,0 +1,196 @@ +// ------------------------------------------------------------------ +// Fast R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Ross Girshick +// ------------------------------------------------------------------ + +#include "caffe/layers/frcnn_proposal_layer.hpp" +#include "caffe/util/frcnn_utils.hpp" +#include "caffe/util/frcnn_helper.hpp" +#include "caffe/util/frcnn_param.hpp" + +namespace caffe { + +using std::vector; + +template +void FrcnnProposalLayer::LayerSetUp( + const vector *> &bottom, + const vector *> &top +) { + FrcnnParam::load_param(this->layer_param_.config_param().full_config_path()); + FrcnnParam::print_param(); + +#ifndef CPU_ONLY + CUDA_CHECK(cudaMalloc(&anchors_, sizeof(float) * FrcnnParam::anchors.size())); + CUDA_CHECK(cudaMemcpy(anchors_, &(FrcnnParam::anchors[0]), + sizeof(float) * FrcnnParam::anchors.size(), cudaMemcpyHostToDevice)); + + const int rpn_pre_nms_top_n = + this->phase_ == TRAIN ? FrcnnParam::rpn_pre_nms_top_n : FrcnnParam::test_rpn_pre_nms_top_n; + CUDA_CHECK(cudaMalloc(&transform_bbox_, sizeof(float) * rpn_pre_nms_top_n * 4)); + CUDA_CHECK(cudaMalloc(&selected_flags_, sizeof(int) * rpn_pre_nms_top_n)); + + const int rpn_post_nms_top_n = + this->phase_ == TRAIN ? FrcnnParam::rpn_post_nms_top_n : FrcnnParam::test_rpn_post_nms_top_n; + CUDA_CHECK(cudaMalloc(&gpu_keep_indices_, sizeof(int) * rpn_post_nms_top_n)); + +#endif + top[0]->Reshape(1, 5, 1, 1); + if (top.size() > 1) { + top[1]->Reshape(1, 1, 1, 1); + } +} + +template +void FrcnnProposalLayer::Forward_cpu(const vector *> &bottom, + const vector *> &top) { + DLOG(ERROR) << "========== enter proposal layer"; + const Dtype *bottom_rpn_score = bottom[0]->cpu_data(); // rpn_cls_prob_reshape + const Dtype *bottom_rpn_bbox = bottom[1]->cpu_data(); // rpn_bbox_pred + const Dtype *bottom_im_info = bottom[2]->cpu_data(); // im_info + + const int num = bottom[1]->num(); + const int channes = bottom[1]->channels(); + const int height = bottom[1]->height(); + const int width = bottom[1]->width(); + CHECK(num == 1) << "only single item batches are supported"; + CHECK(channes % 4 == 0) << "rpn bbox pred channels should be divided by 4"; + + const float im_height = bottom_im_info[0]; + const float im_width = bottom_im_info[1]; + + int rpn_pre_nms_top_n; + int rpn_post_nms_top_n; + float rpn_nms_thresh; + int rpn_min_size; + if (this->phase_ == TRAIN) { + rpn_pre_nms_top_n = FrcnnParam::rpn_pre_nms_top_n; + rpn_post_nms_top_n = FrcnnParam::rpn_post_nms_top_n; + rpn_nms_thresh = FrcnnParam::rpn_nms_thresh; + rpn_min_size = FrcnnParam::rpn_min_size; + } else { + rpn_pre_nms_top_n = FrcnnParam::test_rpn_pre_nms_top_n; + rpn_post_nms_top_n = FrcnnParam::test_rpn_post_nms_top_n; + rpn_nms_thresh = FrcnnParam::test_rpn_nms_thresh; + rpn_min_size = FrcnnParam::test_rpn_min_size; + } + const int config_n_anchors = FrcnnParam::anchors.size() / 4; + LOG_IF(ERROR, rpn_pre_nms_top_n <= 0 ) << "rpn_pre_nms_top_n : " << rpn_pre_nms_top_n; + LOG_IF(ERROR, rpn_post_nms_top_n <= 0 ) << "rpn_post_nms_top_n : " << rpn_post_nms_top_n; + if (rpn_pre_nms_top_n <= 0 || rpn_post_nms_top_n <= 0 ) return; + + std::vector > anchors; + typedef pair sort_pair; + std::vector sort_vector; + + const Dtype bounds[4] = { im_width - 1, im_height - 1, im_width - 1, im_height -1 }; + const Dtype min_size = bottom_im_info[2] * rpn_min_size; + + DLOG(ERROR) << "========== generate anchors"; + + for (int j = 0; j < height; j++) { + for (int i = 0; i < width; i++) { + for (int k = 0; k < config_n_anchors; k++) { + Dtype score = bottom_rpn_score[config_n_anchors * height * width + + k * height * width + j * width + i]; + //const int index = i * height * config_n_anchors + j * config_n_anchors + k; + + Point4f anchor( + FrcnnParam::anchors[k * 4 + 0] + i * FrcnnParam::feat_stride, // shift_x[i][j]; + FrcnnParam::anchors[k * 4 + 1] + j * FrcnnParam::feat_stride, // shift_y[i][j]; + FrcnnParam::anchors[k * 4 + 2] + i * FrcnnParam::feat_stride, // shift_x[i][j]; + FrcnnParam::anchors[k * 4 + 3] + j * FrcnnParam::feat_stride); // shift_y[i][j]; + + Point4f box_delta( + bottom_rpn_bbox[(k * 4 + 0) * height * width + j * width + i], + bottom_rpn_bbox[(k * 4 + 1) * height * width + j * width + i], + bottom_rpn_bbox[(k * 4 + 2) * height * width + j * width + i], + bottom_rpn_bbox[(k * 4 + 3) * height * width + j * width + i]); + + Point4f cbox = bbox_transform_inv(anchor, box_delta); + + // 2. clip predicted boxes to image + for (int q = 0; q < 4; q++) { + cbox.Point[q] = std::max(Dtype(0), std::min(cbox[q], bounds[q])); + } + // 3. remove predicted boxes with either height or width < threshold + if((cbox[2] - cbox[0] + 1) >= min_size && (cbox[3] - cbox[1] + 1) >= min_size) { + const int now_index = sort_vector.size(); + sort_vector.push_back(sort_pair(score, now_index)); + anchors.push_back(cbox); + } + } + } + } + + DLOG(ERROR) << "========== after clip and remove size < threshold box " << (int)sort_vector.size(); + + std::sort(sort_vector.begin(), sort_vector.end(), std::greater()); + const int n_anchors = std::min((int)sort_vector.size(), rpn_pre_nms_top_n); + sort_vector.erase(sort_vector.begin() + n_anchors, sort_vector.end()); + //anchors.erase(anchors.begin() + n_anchors, anchors.end()); + std::vector select(n_anchors, true); + + // apply nms + DLOG(ERROR) << "========== apply nms, pre nms number is : " << n_anchors; + std::vector > box_final; + std::vector scores_; + for (int i = 0; i < n_anchors && box_final.size() < rpn_post_nms_top_n; i++) { + if (select[i]) { + const int cur_i = sort_vector[i].second; + for (int j = i + 1; j < n_anchors; j++) + if (select[j]) { + const int cur_j = sort_vector[j].second; + if (get_iou(anchors[cur_i], anchors[cur_j]) > rpn_nms_thresh) { + select[j] = false; + } + } + box_final.push_back(anchors[cur_i]); + scores_.push_back(sort_vector[i].first); + } + } + + DLOG(ERROR) << "rpn number after nms: " << box_final.size(); + + DLOG(ERROR) << "========== copy to top"; + top[0]->Reshape(box_final.size(), 5, 1, 1); + Dtype *top_data = top[0]->mutable_cpu_data(); + CHECK_EQ(box_final.size(), scores_.size()); + for (size_t i = 0; i < box_final.size(); i++) { + Point4f &box = box_final[i]; + top_data[i * 5] = 0; + for (int j = 1; j < 5; j++) { + top_data[i * 5 + j] = box[j - 1]; + } + } + + if (top.size() > 1) { + top[1]->Reshape(box_final.size(), 1, 1, 1); + for (size_t i = 0; i < box_final.size(); i++) { + top[1]->mutable_cpu_data()[i] = scores_[i]; + } + } + + DLOG(ERROR) << "========== exit proposal layer"; +} + +template +void FrcnnProposalLayer::Backward_cpu(const vector *> &top, + const vector &propagate_down, const vector *> &bottom) { + for (int i = 0; i < propagate_down.size(); ++i) { + if (propagate_down[i]) { + NOT_IMPLEMENTED; + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(FrcnnProposalLayer); +#endif + +INSTANTIATE_CLASS(FrcnnProposalLayer); +REGISTER_LAYER_CLASS(FrcnnProposal); + +} // namespace caffe diff --git a/src/caffe/layers/frcnn_proposal_layer.cu b/src/caffe/layers/frcnn_proposal_layer.cu new file mode 100644 index 00000000000..a573fe1fd61 --- /dev/null +++ b/src/caffe/layers/frcnn_proposal_layer.cu @@ -0,0 +1,128 @@ +// ------------------------------------------------------------------ +// Fast R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Ross Girshick +// ------------------------------------------------------------------ +#include +#include + +#include "caffe/layers/frcnn_proposal_layer.hpp" +#include "caffe/util/frcnn_utils.hpp" +#include "caffe/util/frcnn_helper.hpp" +#include "caffe/util/frcnn_param.hpp" +#include "caffe/util/frcnn_gpu_nms.hpp" + +namespace caffe { + +using std::vector; + +__global__ void GetIndex(const int n,int *indices){ + CUDA_KERNEL_LOOP(index , n){ + indices[index] = index; + } +} + +template +__global__ void BBoxTransformInv(const int nthreads, const Dtype* const bottom_rpn_bbox, + const int height, const int width, const int feat_stride, + const int im_height, const int im_width, + const int* sorted_indices, const float* anchors, + float* const transform_bbox) { + CUDA_KERNEL_LOOP(index , nthreads) { + const int score_idx = sorted_indices[index]; + const int i = score_idx % width; // width + const int j = (score_idx % (width * height)) / width; // height + const int k = score_idx / (width * height); // channel + float *box = transform_bbox + index * 4; + box[0] = anchors[k * 4 + 0] + i * feat_stride; + box[1] = anchors[k * 4 + 1] + j * feat_stride; + box[2] = anchors[k * 4 + 2] + i * feat_stride; + box[3] = anchors[k * 4 + 3] + j * feat_stride; + const Dtype det[4] = { bottom_rpn_bbox[(k * 4 + 0) * height * width + j * width + i], + bottom_rpn_bbox[(k * 4 + 1) * height * width + j * width + i], + bottom_rpn_bbox[(k * 4 + 2) * height * width + j * width + i], + bottom_rpn_bbox[(k * 4 + 3) * height * width + j * width + i] }; + float src_w = box[2] - box[0] + 1; + float src_h = box[3] - box[1] + 1; + float src_ctr_x = box[0] + 0.5 * src_w; + float src_ctr_y = box[1] + 0.5 * src_h; + float pred_ctr_x = det[0] * src_w + src_ctr_x; + float pred_ctr_y = det[1] * src_h + src_ctr_y; + float pred_w = exp(det[2]) * src_w; + float pred_h = exp(det[3]) * src_h; + box[0] = pred_ctr_x - 0.5 * pred_w; + box[1] = pred_ctr_y - 0.5 * pred_h; + box[2] = pred_ctr_x + 0.5 * pred_w; + box[3] = pred_ctr_y + 0.5 * pred_h; + box[0] = max(0.0f, min(box[0], im_width - 1.0)); + box[1] = max(0.0f, min(box[1], im_height - 1.0)); + box[2] = max(0.0f, min(box[2], im_width - 1.0)); + box[3] = max(0.0f, min(box[3], im_height - 1.0)); + } +} + +__global__ void SelectBox(const int nthreads, const float *box, float min_size, + int *flags) { + CUDA_KERNEL_LOOP(index , nthreads) { + if ((box[index * 4 + 2] - box[index * 4 + 0] < min_size) || + (box[index * 4 + 3] - box[index * 4 + 1] < min_size)) { + flags[index] = 0; + } else { + flags[index] = 1; + } + } +} + +template +__global__ void SelectBoxByIndices(const int nthreads, const float *in_box, int *selected_indices, + float *out_box, const Dtype *in_score, Dtype *out_score) { + CUDA_KERNEL_LOOP(index , nthreads) { + if ((index == 0 && selected_indices[index] == 1) || + (index > 0 && selected_indices[index] == selected_indices[index - 1] + 1)) { + out_box[(selected_indices[index] - 1) * 4 + 0] = in_box[index * 4 + 0]; + out_box[(selected_indices[index] - 1) * 4 + 1] = in_box[index * 4 + 1]; + out_box[(selected_indices[index] - 1) * 4 + 2] = in_box[index * 4 + 2]; + out_box[(selected_indices[index] - 1) * 4 + 3] = in_box[index * 4 + 3]; + if (in_score!=NULL && out_score!=NULL) { + out_score[selected_indices[index] - 1] = in_score[index]; + } + } + } +} + +template +__global__ void SelectBoxAftNMS(const int nthreads, const float *in_box, int *keep_indices, + Dtype *top_data, const Dtype *in_score, Dtype* top_score) { + CUDA_KERNEL_LOOP(index , nthreads) { + top_data[index * 5] = 0; + int keep_idx = keep_indices[index]; + for (int j = 1; j < 5; ++j) { + top_data[index * 5 + j] = in_box[keep_idx * 4 + j - 1]; + } + if (top_score != NULL && in_score != NULL) { + top_score[index] = in_score[keep_idx]; + } + } +} + +template +void FrcnnProposalLayer::Forward_gpu(const vector *> &bottom, + const vector *> &top) { + Forward_cpu(bottom, top); + return ; +} + +template +void FrcnnProposalLayer::Backward_gpu(const vector *> &top, + const vector &propagate_down, const vector *> &bottom) { + for (int i = 0; i < propagate_down.size(); ++i) { + if (propagate_down[i]) { + NOT_IMPLEMENTED; + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(FrcnnProposalLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 25a91323870..4dd6347ac14 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -308,7 +308,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 153 (last added: lifted_struct_sim) +// LayerParameter next available layer-specific ID: 154 (last added: config_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -366,6 +366,7 @@ message LayerParameter { optional BoxAnnotatorOHEMParameter box_annotator_ohem_param = 150; optional BiasParameter bias_param = 141; optional ConcatParameter concat_param = 104; + optional ConfigParameter config_param = 153; optional ContrastiveLossParameter contrastive_loss_param = 105; optional ConvolutionParameter convolution_param = 106; optional CropParameter crop_param = 144; @@ -862,6 +863,10 @@ message InputParameter { repeated BlobShape shape = 1; } +message ConfigParameter { + required string full_config_path = 1; +} + // Message that stores parameters used by LogLayer message LogParameter { // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. @@ -952,7 +957,7 @@ message PowerParameter { } message PSROIPoolingParameter { - required float spatial_scale = 1; + required float spatial_scale = 1; required int32 output_dim = 2; // output channel number required int32 group_size = 3; // number of groups to encode position-sensitive score maps } diff --git a/src/caffe/util/frcnn_bbox.cpp b/src/caffe/util/frcnn_bbox.cpp new file mode 100644 index 00000000000..ffa6b10792c --- /dev/null +++ b/src/caffe/util/frcnn_bbox.cpp @@ -0,0 +1,57 @@ +#include "caffe/util/frcnn_utils.hpp" + +namespace caffe { + +INSTANTIATE_CLASS(Point4f); +INSTANTIATE_CLASS(BBox); + +template +Dtype get_iou(const Point4f &A, const Point4f &B) { + const Dtype xx1 = std::max(A[0], B[0]); + const Dtype yy1 = std::max(A[1], B[1]); + const Dtype xx2 = std::min(A[2], B[2]); + const Dtype yy2 = std::min(A[3], B[3]); + Dtype inter = std::max(Dtype(0), xx2 - xx1 + 1) * std::max(Dtype(0), yy2 - yy1 + 1); + Dtype areaA = (A[2] - A[0] + 1) * (A[3] - A[1] + 1); + Dtype areaB = (B[2] - B[0] + 1) * (B[3] - B[1] + 1); + return inter / (areaA + areaB - inter); +} +template float get_iou(const Point4f &A, const Point4f &B); +template double get_iou(const Point4f &A, const Point4f &B); + +template +vector > get_ious(const vector > &A, const vector > &B) { + vector >ious; + for (size_t i = 0; i < A.size(); i++) { + ious.push_back(get_ious(A[i], B)); + } + return ious; +} +template vector > get_ious(const vector > &A, const vector > &B); +template vector > get_ious(const vector > &A, const vector > &B); + +template +vector get_ious(const Point4f &A, const vector > &B) { + vector ious; + for (size_t i = 0; i < B.size(); i++) { + ious.push_back(get_iou(A, B[i])); + } + return ious; +} + +template vector get_ious(const Point4f &A, const vector > &B); +template vector get_ious(const Point4f &A, const vector > &B); + +float get_scale_factor(int width, int height, int short_size, int max_long_size) { + float im_size_min = std::min(width, height); + float im_size_max = std::max(width, height); + + float scale_factor = static_cast(short_size) / im_size_min; + // Prevent the biggest axis from being more than max_size + if (scale_factor * im_size_max > max_long_size) { + scale_factor = static_cast(max_long_size) / im_size_max; + } + return scale_factor; +} + +} // namespace caffe diff --git a/src/caffe/util/frcnn_bbox_transform.cpp b/src/caffe/util/frcnn_bbox_transform.cpp new file mode 100644 index 00000000000..01986a2303b --- /dev/null +++ b/src/caffe/util/frcnn_bbox_transform.cpp @@ -0,0 +1,67 @@ +#include "caffe/util/frcnn_utils.hpp" + +namespace caffe { + +using std::vector; + +template +Point4f bbox_transform_inv(const Point4f& box, const Point4f& delta) { + Dtype src_w = box[2] - box[0] + 1; + Dtype src_h = box[3] - box[1] + 1; + Dtype src_ctr_x = box[0] + 0.5 * src_w; // box[0] + 0.5*src_w; + Dtype src_ctr_y = box[1] + 0.5 * src_h; // box[1] + 0.5*src_h; + Dtype pred_ctr_x = delta[0] * src_w + src_ctr_x; + Dtype pred_ctr_y = delta[1] * src_h + src_ctr_y; + Dtype pred_w = exp(delta[2]) * src_w; + Dtype pred_h = exp(delta[3]) * src_h; + return Point4f(pred_ctr_x - 0.5 * pred_w, pred_ctr_y - 0.5 * pred_h, + pred_ctr_x + 0.5 * pred_w, pred_ctr_y + 0.5 * pred_h); + // return Point4f(pred_ctr_x - 0.5*(pred_w-1) , pred_ctr_y - 0.5*(pred_h-1) , + // pred_ctr_x + 0.5*(pred_w-1) , pred_ctr_y + 0.5*(pred_h-1)); +} +template Point4f bbox_transform_inv(const Point4f& box, const Point4f& delta); +template Point4f bbox_transform_inv(const Point4f& box, const Point4f& delta); + +template +vector > bbox_transform_inv(const Point4f& box, const vector >& deltas) { + vector > ans; + for (size_t index = 0; index < deltas.size(); index++) { + ans.push_back(bbox_transform_inv(box, deltas[index])); + } + return ans; +} +template vector > bbox_transform_inv(const Point4f& box, const vector >& deltas); +template vector > bbox_transform_inv(const Point4f& box, const vector >& deltas); + +template +Point4f bbox_transform(const Point4f& ex_roi, const Point4f& gt_roi) { + Dtype ex_width = ex_roi[2] - ex_roi[0] + 1; + Dtype ex_height = ex_roi[3] - ex_roi[1] + 1; + Dtype ex_ctr_x = ex_roi[0] + 0.5 * ex_width; + Dtype ex_ctr_y = ex_roi[1] + 0.5 * ex_height; + Dtype gt_widths = gt_roi[2] - gt_roi[0] + 1; + Dtype gt_heights = gt_roi[3] - gt_roi[1] + 1; + Dtype gt_ctr_x = gt_roi[0] + 0.5 * gt_widths; + Dtype gt_ctr_y = gt_roi[1] + 0.5 * gt_heights; + Dtype targets_dx = (gt_ctr_x - ex_ctr_x) / ex_width; + Dtype targets_dy = (gt_ctr_y - ex_ctr_y) / ex_height; + Dtype targets_dw = log(gt_widths / ex_width); + Dtype targets_dh = log(gt_heights / ex_height); + return Point4f(targets_dx, targets_dy, targets_dw, targets_dh); +} +template Point4f bbox_transform(const Point4f& ex_roi, const Point4f& gt_roi); +template Point4f bbox_transform(const Point4f& ex_roi, const Point4f& gt_roi); + +template +vector > bbox_transform(const vector >& ex_rois, const vector >& gt_rois) { + CHECK_EQ(ex_rois.size(), gt_rois.size()); + vector > transformed_bbox; + for (size_t i = 0; i < gt_rois.size(); i++) { + transformed_bbox.push_back(bbox_transform(ex_rois[i], gt_rois[i])); + } + return transformed_bbox; +} +template vector > bbox_transform(const vector >& ex_rois, const vector >& gt_rois); +template vector > bbox_transform(const vector >& ex_rois, const vector >& gt_rois); + +} // namespace caffe diff --git a/src/caffe/util/frcnn_config.cpp b/src/caffe/util/frcnn_config.cpp new file mode 100644 index 00000000000..7438bf42789 --- /dev/null +++ b/src/caffe/util/frcnn_config.cpp @@ -0,0 +1,76 @@ +#include +#include +#include + +#include "caffe/util/frcnn_utils.hpp" + +namespace caffe { + +std::vector &split(const std::string &s, char delim, + std::vector &elems) { + std::stringstream ss(s); + std::string item; + while (std::getline(ss, item, delim)) { + elems.push_back(item); + } + return elems; +} + +std::vector split(const std::string &s, char delim) { + std::vector elems; + split(s, delim, elems); + return elems; +} + +str_map parse_json_config(const std::string file_path) { + std::ifstream ifs(file_path.c_str()); + std::map json_map; + boost::property_tree::ptree pt; + boost::property_tree::read_json(ifs, pt); + + boost::property_tree::basic_ptree::const_iterator + iter = pt.begin(); + + for (; iter != pt.end(); ++iter) { + json_map[iter->first.data()] = iter->second.data(); + } + return json_map; +} + +std::string extract_string(std::string target_key, + str_map& default_map) { + std::string target_str; + if (default_map.count(target_key) > 0) { + target_str = default_map[target_key]; + } else { + LOG(FATAL) << "Can not find target_key : " << target_key; + } + return target_str; +} + +float extract_float(std::string target_key, + str_map& default_map) { + std::string target_str = extract_string(target_key, default_map); + return atof(target_str.c_str()); +} + +int extract_int(std::string target_key, + str_map& default_map) { + std::string target_str = extract_string(target_key, default_map); + return atoi(target_str.c_str()); +} + +std::vector extract_vector(std::string target_key, + str_map& default_map) { + std::string target_str = extract_string(target_key, default_map); + std::vector results; + std::vector elems = split(target_str, ','); + + for (std::vector::const_iterator it = elems.begin(); + it != elems.end(); ++it) { + results.push_back(atof((*it).c_str())); + } + return results; +} + +} // namespace caffe diff --git a/src/caffe/util/frcnn_file.cpp b/src/caffe/util/frcnn_file.cpp new file mode 100644 index 00000000000..486ffaef2f2 --- /dev/null +++ b/src/caffe/util/frcnn_file.cpp @@ -0,0 +1,64 @@ +#include "caffe/util/frcnn_utils.hpp" + +namespace caffe { + +// ==================== file system +// return the filenames of all files that have the specified extension +// in the specified directory and all subdirectories +namespace fs = ::boost::filesystem; +std::vector get_file_list(const std::string& path, const string& ext) { + fs::path fs_path(path); + vector file_list; + + if(!fs::exists(fs_path) || !fs::is_directory(fs_path)) + return file_list; + + fs::recursive_directory_iterator it(fs_path); + fs::recursive_directory_iterator endit; + + while (it != endit) { + if (fs::is_regular_file(*it) && it->path().extension() == ext) + file_list.push_back(it->path().filename().string()); + ++it; + } + + return file_list; +} + +template +void print_vector(std::vector data) { + for (int i = 0; i < data.size(); i++) { + LOG(ERROR) << data[i]; + } +} +template void print_vector(std::vector data); +template void print_vector(std::vector data); + +std::string anchor_to_string(std::vector data) { + CHECK_EQ( data.size() % 4 , 0 ) << "Anchors Size is wrong : " << data.size(); + char buff[200]; + std::string ans; + for (size_t index = 0; index * 4 < data.size(); index++) { + snprintf(buff, sizeof(buff), "%.2f %.2f %.2f %.2f", data[index*4+0], data[index*4+1], data[index*4+2], data[index*4+3]); + ans += std::string(buff) + "\n"; + } + return ans; +} + +std::string float_to_string(const std::vector data) { + char buff[200]; + std::string ans; + for (size_t index = 0; index < data.size(); index++) { + snprintf(buff, sizeof(buff), "%.2f", data[index]); + if( index == 0 ) ans = std::string(buff); + else ans += ", " + std::string(buff); + } + return ans; +} + +std::string float_to_string(const float *data) { + const int n = sizeof(data) / sizeof(data[0]); + return float_to_string( std::vector(data, data+n) ); +} + +} // namespace caffe diff --git a/src/caffe/util/frcnn_nms_kernel.cu b/src/caffe/util/frcnn_nms_kernel.cu new file mode 100644 index 00000000000..ea036eeb782 --- /dev/null +++ b/src/caffe/util/frcnn_nms_kernel.cu @@ -0,0 +1,140 @@ +// ------------------------------------------------------------------ +// Faster R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Shaoqing Ren +// ------------------------------------------------------------------ + +#include "caffe/util/frcnn_gpu_nms.hpp" +#include "caffe/common.hpp" +#include +#include + +namespace caffe { + +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +__device__ inline float devIoU(float const * const a, float const * const b) { + float left = max(a[0], b[0]), right = min(a[2], b[2]); + float top = max(a[1], b[1]), bottom = min(a[3], b[3]); + float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return interS / (Sa + Sb - interS); +} + +__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 4]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 4 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0]; + block_boxes[threadIdx.x * 4 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1]; + block_boxes[threadIdx.x * 4 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2]; + block_boxes[threadIdx.x * 4 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3]; + //block_boxes[threadIdx.x * 5 + 4] = + // dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 4; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 4) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +void _set_device(int device_id) { + if (device_id<=0) return; + int current_device; + CUDA_CHECK(cudaGetDevice(¤t_device)); + if (current_device == device_id) { + return; + } + // The call to cudaSetDevice must come before any calls to Get, which + // may perform initialization using the GPU. + CUDA_CHECK(cudaSetDevice(device_id)); +} + +void gpu_nms(int* keep_out, int* num_out, const float* boxes_dev, int boxes_num, + int boxes_dim, float nms_overlap_thresh, int device_id) { + _set_device(device_id); + + // float* boxes_dev = NULL; + unsigned long long* mask_dev = NULL; + + const int col_blocks = DIVUP(boxes_num, threadsPerBlock); + + // CUDA_CHECK(cudaMalloc(&boxes_dev, + // boxes_num * boxes_dim * sizeof(float))); + // CUDA_CHECK(cudaMemcpy(boxes_dev, + // boxes_host, + // boxes_num * boxes_dim * sizeof(float), + // cudaMemcpyHostToDevice)); + + CUDA_CHECK(cudaMalloc(&mask_dev, + boxes_num * col_blocks * sizeof(unsigned long long))); + + dim3 blocks(DIVUP(boxes_num, threadsPerBlock), + DIVUP(boxes_num, threadsPerBlock)); + dim3 threads(threadsPerBlock); + nms_kernel<<>>(boxes_num, + nms_overlap_thresh, + boxes_dev, + mask_dev); + + std::vector mask_host(boxes_num * col_blocks); + CUDA_CHECK(cudaMemcpy(&mask_host[0], + mask_dev, + sizeof(unsigned long long) * boxes_num * col_blocks, + cudaMemcpyDeviceToHost)); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + int num_to_keep = 0; + for (int i = 0; i < boxes_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long *p = &mask_host[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + *num_out = num_to_keep; + + // CUDA_CHECK(cudaFree(boxes_dev)); + CUDA_CHECK(cudaFree(mask_dev)); +} + +} // namespace caffe diff --git a/src/caffe/util/frcnn_param.cpp b/src/caffe/util/frcnn_param.cpp new file mode 100644 index 00000000000..1479c25bc2d --- /dev/null +++ b/src/caffe/util/frcnn_param.cpp @@ -0,0 +1,69 @@ +#include "caffe/util/frcnn_utils.hpp" +#include "caffe/util/frcnn_param.hpp" +#include "caffe/common.hpp" + +namespace caffe { + +using namespace caffe; + +float FrcnnParam::rpn_nms_thresh; +int FrcnnParam::rpn_pre_nms_top_n; +int FrcnnParam::rpn_post_nms_top_n; +// Proposal height and width both need to be greater than RPN_MIN_SIZE (at +// orig image scale) +float FrcnnParam::rpn_min_size; + +float FrcnnParam::test_rpn_nms_thresh; +int FrcnnParam::test_rpn_pre_nms_top_n; +int FrcnnParam::test_rpn_post_nms_top_n; +// Proposal height and width both need to be greater than RPN_MIN_SIZE (at +// orig image scale) +float FrcnnParam::test_rpn_min_size; + +int FrcnnParam::feat_stride; +std::vector FrcnnParam::anchors; +int FrcnnParam::n_classes; + +void FrcnnParam::load_param(const std::string default_config_path) { + + str_map default_map = parse_json_config(default_config_path); + + FrcnnParam::rpn_nms_thresh = extract_float("rpn_nms_thresh", default_map); + FrcnnParam::rpn_pre_nms_top_n = extract_int("rpn_pre_nms_top_n", default_map); + FrcnnParam::rpn_post_nms_top_n = extract_int("rpn_post_nms_top_n", default_map); + FrcnnParam::rpn_min_size = extract_float("rpn_min_size", default_map); + + // ======================================== Test + + FrcnnParam::test_rpn_nms_thresh = extract_float("test_rpn_nms_thresh", default_map); + FrcnnParam::test_rpn_pre_nms_top_n = extract_int("test_rpn_pre_nms_top_n", default_map); + FrcnnParam::test_rpn_post_nms_top_n = extract_int("test_rpn_post_nms_top_n", default_map); + FrcnnParam::test_rpn_min_size = extract_float("test_rpn_min_size", default_map); + + // ======================================== + FrcnnParam::feat_stride = extract_int("feat_stride", default_map); + FrcnnParam::anchors = extract_vector("anchors", default_map); + FrcnnParam::n_classes = extract_int("n_classes", default_map); +} + +void FrcnnParam::print_param(){ + + LOG(INFO) << "== Frcnn Parameters =="; + + LOG(INFO) << "rpn_nms_thresh : " << FrcnnParam::rpn_nms_thresh; + LOG(INFO) << "rpn_pre_nms_top_n : " << FrcnnParam::rpn_pre_nms_top_n; + LOG(INFO) << "rpn_post_nms_top_n: " << FrcnnParam::rpn_post_nms_top_n; + LOG(INFO) << "rpn_min_size : " << FrcnnParam::rpn_min_size; + + LOG(INFO) << "test_rpn_nms_thresh : " << FrcnnParam::test_rpn_nms_thresh; + LOG(INFO) << "test_rpn_pre_nms_top_n : " << FrcnnParam::test_rpn_pre_nms_top_n; + LOG(INFO) << "test_rpn_post_nms_top_n: " << FrcnnParam::test_rpn_post_nms_top_n; + LOG(INFO) << "test_rpn_min_size : " << FrcnnParam::test_rpn_min_size; + + LOG(INFO) << "== Global Parameters =="; + LOG(INFO) << "feat_stride : " << FrcnnParam::feat_stride; + LOG(INFO) << "anchors_size : " << FrcnnParam::anchors.size(); + LOG(INFO) << "n_classes : " << FrcnnParam::n_classes; +} + +} // namespace detection