diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp index 0051e2fa067..a4fef3d9e49 100644 --- a/include/caffe/util/im2col.hpp +++ b/include/caffe/util/im2col.hpp @@ -4,28 +4,28 @@ namespace caffe { template -void im2col_cpu(const Dtype* data_im, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_col); +void im2col_cpu(const Dtype* data_im, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col); template -void col2im_cpu(const Dtype* data_col, const int channels, - const int height, const int width, const int patch_h, const int patch_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_im); +void col2im_cpu(const Dtype* data_col, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im); template -void im2col_gpu(const Dtype* data_im, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_col); +void im2col_gpu(const Dtype* data_im, const int num_spatial_axes, + const int col_size, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col); template -void col2im_gpu(const Dtype* data_col, const int channels, - const int height, const int width, const int patch_h, const int patch_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_im); +void col2im_gpu(const Dtype* data_col, const int num_spatial_axes, + const int im_size, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im); } // namespace caffe diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 6cb507a5780..e0302296e66 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -64,44 +64,66 @@ class BaseConvolutionLayer : public Layer { // Compute height_out_ and width_out_ from other parameters. virtual void compute_output_shape() = 0; - int kernel_h_, kernel_w_; - int stride_h_, stride_w_; + /// @brief The spatial dimensions of a filter kernel. + Blob kernel_shape_; + /// @brief The spatial dimensions of the stride. + Blob stride_; + /// @brief The spatial dimensions of the padding. + Blob pad_; + /// @brief The spatial dimensions of the convolution input. + Blob conv_input_shape_; + /// @brief The spatial dimensions of the input. + Blob input_shape_; + /// @brief The spatial dimensions of the col_buffer. + vector col_buffer_shape_; + /// @brief The spatial dimensions of the output. + vector output_shape_; + + int num_spatial_axes_; + int bottom_dim_; + int top_dim_; + + int channel_axis_; int num_; int channels_; - int pad_h_, pad_w_; - int height_, width_; int group_; int num_output_; - int height_out_, width_out_; bool bias_term_; bool is_1x1_; private: // wrap im2col/col2im so we don't have to remember the (long) argument lists inline void conv_im2col_cpu(const Dtype* data, Dtype* col_buff) { - im2col_cpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); + im2col_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(), + col_buffer_shape_.data(), kernel_shape_.cpu_data(), + pad_.cpu_data(), stride_.cpu_data(), col_buff); } inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) { - col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); + col2im_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(), + col_buffer_shape_.data(), kernel_shape_.cpu_data(), + pad_.cpu_data(), stride_.cpu_data(), data); } #ifndef CPU_ONLY inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) { - im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); + im2col_gpu(data, num_spatial_axes_, num_kernels_im2col_, + conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), + kernel_shape_.gpu_data(), pad_.gpu_data(), + stride_.gpu_data(), col_buff); } inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) { - col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); + col2im_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_, + conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), + kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), + data); } #endif + int num_kernels_im2col_; + int num_kernels_col2im_; int conv_out_channels_; int conv_in_channels_; int conv_out_spatial_dim_; - int conv_in_height_; - int conv_in_width_; + int out_spatial_dim_; int kernel_dim_; int weight_offset_; int col_offset_; @@ -285,11 +307,26 @@ class Im2colLayer : public Layer { virtual void Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); - int kernel_h_, kernel_w_; - int stride_h_, stride_w_; + /// @brief The spatial dimensions of a filter kernel. + Blob kernel_shape_; + /// @brief The spatial dimensions of the stride. + Blob stride_; + /// @brief The spatial dimensions of the padding. + Blob pad_; + /// @brief The (full) shape of the input. + Blob* input_shape_; + /// @brief The (full) shape of the conv input. + Blob* conv_input_shape_; + /// @brief The spatial dimensions of the output col. + vector col_shape_; + + int num_spatial_axes_; + int bottom_dim_; + int top_dim_; + + int channel_axis_; + int num_; int channels_; - int height_, width_; - int pad_h_, pad_w_; }; // Forward declare PoolingLayer and SplitLayer for use in LRNLayer. diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index ccb3adc7e89..d6b5875e825 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -11,50 +11,102 @@ namespace caffe { template void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { - CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, " - << "corresponding to (num, channels, height, width)"; // Configure the kernel size, padding, stride, and inputs. ConvolutionParameter conv_param = this->layer_param_.convolution_param(); - CHECK(!conv_param.has_kernel_size() != - !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; - CHECK(conv_param.has_kernel_size() || - (conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "For non-square filters both kernel_h and kernel_w are required."; - CHECK((!conv_param.has_pad() && conv_param.has_pad_h() - && conv_param.has_pad_w()) - || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) - << "pad is pad OR pad_h and pad_w are required."; - CHECK((!conv_param.has_stride() && conv_param.has_stride_h() - && conv_param.has_stride_w()) - || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) - << "Stride is stride OR stride_h and stride_w are required."; - if (conv_param.has_kernel_size()) { - kernel_h_ = kernel_w_ = conv_param.kernel_size(); + channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis()); + const int first_spatial_axis = channel_axis_ + 1; + const int num_axes = bottom[0]->num_axes(); + num_spatial_axes_ = num_axes - first_spatial_axis; + CHECK_GE(num_spatial_axes_, 1); + // Setup input dimensions (input_shape_). + vector bottom_dim_blob_shape(1, num_spatial_axes_ + 1); + input_shape_.Reshape(bottom_dim_blob_shape); + int* input_shape_data = input_shape_.mutable_cpu_data(); + for (int i = 0; i < num_spatial_axes_ + 1; ++i) { + input_shape_data[i] = bottom[0]->shape(channel_axis_ + i); + } + vector spatial_dim_blob_shape(1, num_spatial_axes_); + // Setup filter kernel dimensions (kernel_shape_). + kernel_shape_.Reshape(spatial_dim_blob_shape); + int* kernel_shape_data = kernel_shape_.mutable_cpu_data(); + if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "kernel_h & kernel_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.kernel_size_size()) + << "Either kernel_size or kernel_h/w should be specified; not both."; + kernel_shape_data[0] = conv_param.kernel_h(); + kernel_shape_data[1] = conv_param.kernel_w(); } else { - kernel_h_ = conv_param.kernel_h(); - kernel_w_ = conv_param.kernel_w(); + const int num_kernel_dims = conv_param.kernel_size_size(); + CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_) + << "kernel_size must be specified once, or once per spatial dimension " + << "(kernel_size specified " << num_kernel_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + for (int i = 0; i < num_spatial_axes_; ++i) { + kernel_shape_data[i] = + conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i); + } + } + for (int i = 0; i < num_spatial_axes_; ++i) { + CHECK_GT(kernel_shape_data[i], 0) << "Filter dimensions must be nonzero."; } - CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; - CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; - if (!conv_param.has_pad_h()) { - pad_h_ = pad_w_ = conv_param.pad(); + // Setup stride dimensions (stride_). + stride_.Reshape(spatial_dim_blob_shape); + int* stride_data = stride_.mutable_cpu_data(); + if (conv_param.has_stride_h() || conv_param.has_stride_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "stride_h & stride_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.stride_size()) + << "Either stride or stride_h/w should be specified; not both."; + stride_data[0] = conv_param.stride_h(); + stride_data[1] = conv_param.stride_w(); } else { - pad_h_ = conv_param.pad_h(); - pad_w_ = conv_param.pad_w(); + const int num_stride_dims = conv_param.stride_size(); + CHECK(num_stride_dims == 0 || num_stride_dims == 1 || + num_stride_dims == num_spatial_axes_) + << "stride must be specified once, or once per spatial dimension " + << "(stride specified " << num_stride_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultStride = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + stride_data[i] = (num_stride_dims == 0) ? kDefaultStride : + conv_param.stride((num_stride_dims == 1) ? 0 : i); + CHECK_GT(stride_data[i], 0) << "Stride dimensions must be nonzero."; + } } - if (!conv_param.has_stride_h()) { - stride_h_ = stride_w_ = conv_param.stride(); + // Setup pad dimensions (pad_). + pad_.Reshape(spatial_dim_blob_shape); + int* pad_data = pad_.mutable_cpu_data(); + if (conv_param.has_pad_h() || conv_param.has_pad_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "pad_h & pad_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.pad_size()) + << "Either pad or pad_h/w should be specified; not both."; + pad_data[0] = conv_param.pad_h(); + pad_data[1] = conv_param.pad_w(); } else { - stride_h_ = conv_param.stride_h(); - stride_w_ = conv_param.stride_w(); + const int num_pad_dims = conv_param.pad_size(); + CHECK(num_pad_dims == 0 || num_pad_dims == 1 || + num_pad_dims == num_spatial_axes_) + << "pad must be specified once, or once per spatial dimension " + << "(pad specified " << num_pad_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultPad = 0; + for (int i = 0; i < num_spatial_axes_; ++i) { + pad_data[i] = (num_pad_dims == 0) ? kDefaultPad : + conv_param.pad((num_pad_dims == 1) ? 0 : i); + } } // Special case: im2col is the identity for 1x1 convolution with stride 1 // and no padding, so flag for skipping the buffer and transformation. - is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1 - && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0; + is_1x1_ = true; + for (int i = 0; i < num_spatial_axes_; ++i) { + is_1x1_ &= + kernel_shape_data[i] == 1 && stride_data[i] == 1 && pad_data[i] == 0; + if (!is_1x1_) { break; } + } // Configure output channels and groups. - channels_ = bottom[0]->channels(); + channels_ = bottom[0]->shape(channel_axis_); num_output_ = this->layer_param_.convolution_param().num_output(); CHECK_GT(num_output_, 0); group_ = this->layer_param_.convolution_param().group(); @@ -82,8 +134,13 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, } // Initialize and fill the weights: // output channels x input channels per-group x kernel height x kernel width - this->blobs_[0].reset(new Blob( - conv_out_channels_, conv_in_channels_ / group_, kernel_h_, kernel_w_)); + vector weight_shape(2); + weight_shape[0] = conv_out_channels_; + weight_shape[1] = conv_in_channels_ / group_; + for (int i = 0; i < num_spatial_axes_; ++i) { + weight_shape.push_back(kernel_shape_data[i]); + } + this->blobs_[0].reset(new Blob(weight_shape)); shared_ptr > weight_filler(GetFiller( this->layer_param_.convolution_param().weight_filler())); weight_filler->Fill(this->blobs_[0].get()); @@ -103,52 +160,77 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, template void BaseConvolutionLayer::Reshape(const vector*>& bottom, const vector*>& top) { - CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, " - << "corresponding to (num, channels, height, width)"; - num_ = bottom[0]->num(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with" - " convolution kernel."; + ConvolutionParameter conv_param = this->layer_param_.convolution_param(); + channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis()); + const int first_spatial_axis = channel_axis_ + 1; + const int num_axes = bottom[0]->num_axes(); + num_spatial_axes_ = num_axes - first_spatial_axis; + CHECK_GE(num_spatial_axes_, 1); + num_ = bottom[0]->count(0, channel_axis_); + CHECK_EQ(bottom[0]->shape(channel_axis_), channels_) + << "Input size incompatible with convolution kernel."; // TODO: generalize to handle inputs of different shapes. for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { - CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num."; - CHECK_EQ(channels_, bottom[bottom_id]->channels()) - << "Inputs must have same channels."; - CHECK_EQ(height_, bottom[bottom_id]->height()) - << "Inputs must have same height."; - CHECK_EQ(width_, bottom[bottom_id]->width()) - << "Inputs must have same width."; + CHECK(bottom[0]->shape() == bottom[bottom_id]->shape()) + << "All inputs must have the same shape."; } // Shape the tops. compute_output_shape(); + vector top_shape = bottom[0]->shape(); + top_shape[channel_axis_] = num_output_; + top_shape.resize(first_spatial_axis); // Discard input spatial axes. + for (int i = 0; i < num_spatial_axes_; ++i) { + top_shape.push_back(output_shape_[i]); + } for (int top_id = 0; top_id < top.size(); ++top_id) { - top[top_id]->Reshape(num_, num_output_, height_out_, width_out_); + top[top_id]->Reshape(top_shape); } if (reverse_dimensions()) { - conv_in_height_ = height_out_; - conv_in_width_ = width_out_; - conv_out_spatial_dim_ = height_ * width_; + conv_out_spatial_dim_ = bottom[0]->count(first_spatial_axis); } else { - conv_in_height_ = height_; - conv_in_width_ = width_; - conv_out_spatial_dim_ = height_out_ * width_out_; + conv_out_spatial_dim_ = top[0]->count(first_spatial_axis); + } + const int* kernel_shape_data = kernel_shape_.cpu_data(); + kernel_dim_ = conv_in_channels_; + for (int i = 0; i < num_spatial_axes_; ++i) { + kernel_dim_ *= kernel_shape_data[i]; } - kernel_dim_ = conv_in_channels_ * kernel_h_ * kernel_w_; weight_offset_ = conv_out_channels_ * kernel_dim_ / group_ / group_; col_offset_ = kernel_dim_ * conv_out_spatial_dim_ / group_; output_offset_ = conv_out_channels_ * conv_out_spatial_dim_ / group_; + // Setup input dimensions (conv_input_shape_). + vector bottom_dim_blob_shape(1, num_spatial_axes_ + 1); + conv_input_shape_.Reshape(bottom_dim_blob_shape); + int* conv_input_shape_data = conv_input_shape_.mutable_cpu_data(); + for (int i = 0; i < num_spatial_axes_ + 1; ++i) { + if (reverse_dimensions()) { + conv_input_shape_data[i] = top[0]->shape(channel_axis_ + i); + } else { + conv_input_shape_data[i] = bottom[0]->shape(channel_axis_ + i); + } + } // The im2col result buffer will only hold one image at a time to avoid // overly large memory usage. In the special case of 1x1 convolution // it goes lazily unused to save memory. - if (reverse_dimensions()) { - col_buffer_.Reshape(1, kernel_dim_, height_, width_); - } else { - col_buffer_.Reshape(1, kernel_dim_, height_out_, width_out_); + col_buffer_shape_.clear(); + col_buffer_shape_.push_back(kernel_dim_); + const int* input_shape_data = input_shape_.cpu_data() + 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + if (reverse_dimensions()) { + col_buffer_shape_.push_back(input_shape_data[i]); + } else { + col_buffer_shape_.push_back(output_shape_[i]); + } } + col_buffer_.Reshape(col_buffer_shape_); + bottom_dim_ = bottom[0]->count(channel_axis_); + top_dim_ = top[0]->count(channel_axis_); + num_kernels_im2col_ = conv_in_channels_ * conv_out_spatial_dim_; + num_kernels_col2im_ = reverse_dimensions() ? top_dim_ : bottom_dim_; // Set up the all ones "bias multiplier" for adding biases by BLAS + out_spatial_dim_ = top[0]->count(first_spatial_axis); if (bias_term_) { - vector bias_multiplier_shape(1, height_out_ * width_out_); + vector bias_multiplier_shape(1, out_spatial_dim_); bias_multiplier_.Reshape(bias_multiplier_shape); caffe_set(bias_multiplier_.count(), Dtype(1), bias_multiplier_.mutable_cpu_data()); @@ -177,7 +259,7 @@ template void BaseConvolutionLayer::forward_cpu_bias(Dtype* output, const Dtype* bias) { caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, - height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), + out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), (Dtype)1., output); } @@ -218,7 +300,7 @@ void BaseConvolutionLayer::weight_cpu_gemm(const Dtype* input, template void BaseConvolutionLayer::backward_cpu_bias(Dtype* bias, const Dtype* input) { - caffe_cpu_gemv(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + caffe_cpu_gemv(CblasNoTrans, num_output_, out_spatial_dim_, 1., input, bias_multiplier_.cpu_data(), 1., bias); } @@ -246,7 +328,7 @@ template void BaseConvolutionLayer::forward_gpu_bias(Dtype* output, const Dtype* bias) { caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, - height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(), + out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(), (Dtype)1., output); } @@ -287,7 +369,7 @@ void BaseConvolutionLayer::weight_gpu_gemm(const Dtype* input, template void BaseConvolutionLayer::backward_gpu_bias(Dtype* bias, const Dtype* input) { - caffe_gpu_gemv(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + caffe_gpu_gemv(CblasNoTrans, num_output_, out_spatial_dim_, 1., input, bias_multiplier_.gpu_data(), 1., bias); } diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index c0c9f6f3371..3e03d495d80 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -10,10 +10,18 @@ namespace caffe { template void ConvolutionLayer::compute_output_shape() { - this->height_out_ = (this->height_ + 2 * this->pad_h_ - this->kernel_h_) - / this->stride_h_ + 1; - this->width_out_ = (this->width_ + 2 * this->pad_w_ - this->kernel_w_) - / this->stride_w_ + 1; + // input_shape_ + 1 to skip channel axis + const int* input_shape_data = this->input_shape_.cpu_data() + 1; + const int* kernel_shape_data = this->kernel_shape_.cpu_data(); + const int* stride_data = this->stride_.cpu_data(); + const int* pad_data = this->pad_.cpu_data(); + this->output_shape_.clear(); + for (int i = 0; i < this->num_spatial_axes_; ++i) { + const int input_dim = input_shape_data[i]; + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i]) + / stride_data[i] + 1; + this->output_shape_.push_back(output_dim); + } } template @@ -24,11 +32,11 @@ void ConvolutionLayer::Forward_cpu(const vector*>& bottom, const Dtype* bottom_data = bottom[i]->cpu_data(); Dtype* top_data = top[i]->mutable_cpu_data(); for (int n = 0; n < this->num_; ++n) { - this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight, - top_data + top[i]->offset(n)); + this->forward_cpu_gemm(bottom_data + n * this->bottom_dim_, weight, + top_data + n * this->top_dim_); if (this->bias_term_) { const Dtype* bias = this->blobs_[1]->cpu_data(); - this->forward_cpu_bias(top_data + top[i]->offset(n), bias); + this->forward_cpu_bias(top_data + n * this->top_dim_, bias); } } } @@ -54,20 +62,20 @@ void ConvolutionLayer::Backward_cpu(const vector*>& top, if (this->bias_term_ && this->param_propagate_down_[1]) { Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); for (int n = 0; n < this->num_; ++n) { - this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n)); + this->backward_cpu_bias(bias_diff, top_diff + n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { for (int n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - this->weight_cpu_gemm(bottom_data + bottom[i]->offset(n), - top_diff + top[i]->offset(n), weight_diff); + this->weight_cpu_gemm(bottom_data + n * this->bottom_dim_, + top_diff + n * this->top_dim_, weight_diff); } // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { - this->backward_cpu_gemm(top_diff + top[i]->offset(n), weight, - bottom_diff + bottom[i]->offset(n)); + this->backward_cpu_gemm(top_diff + n * this->top_dim_, weight, + bottom_diff + n * this->bottom_dim_); } } } diff --git a/src/caffe/layers/conv_layer.cu b/src/caffe/layers/conv_layer.cu index 3902fdf3930..7474ae04b95 100644 --- a/src/caffe/layers/conv_layer.cu +++ b/src/caffe/layers/conv_layer.cu @@ -16,11 +16,11 @@ void ConvolutionLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[i]->gpu_data(); Dtype* top_data = top[i]->mutable_gpu_data(); for (int n = 0; n < this->num_; ++n) { - this->forward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight, - top_data + top[i]->offset(n)); + this->forward_gpu_gemm(bottom_data + n * this->bottom_dim_, weight, + top_data + n * this->top_dim_); if (this->bias_term_) { const Dtype* bias = this->blobs_[1]->gpu_data(); - this->forward_gpu_bias(top_data + top[i]->offset(n), bias); + this->forward_gpu_bias(top_data + n * this->top_dim_, bias); } } } @@ -44,7 +44,7 @@ void ConvolutionLayer::Backward_gpu(const vector*>& top, if (this->bias_term_ && this->param_propagate_down_[1]) { Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); for (int n = 0; n < this->num_; ++n) { - this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n)); + this->backward_gpu_bias(bias_diff, top_diff + n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { @@ -53,13 +53,13 @@ void ConvolutionLayer::Backward_gpu(const vector*>& top, for (int n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - this->weight_gpu_gemm(bottom_data + bottom[i]->offset(n), - top_diff + top[i]->offset(n), weight_diff); + this->weight_gpu_gemm(bottom_data + n * this->bottom_dim_, + top_diff + n * this->top_dim_, weight_diff); } // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { - this->backward_gpu_gemm(top_diff + top[i]->offset(n), weight, - bottom_diff + bottom[i]->offset(n)); + this->backward_gpu_gemm(top_diff + n * this->top_dim_, weight, + bottom_diff + n * this->bottom_dim_); } } } diff --git a/src/caffe/layers/deconv_layer.cpp b/src/caffe/layers/deconv_layer.cpp index e6d65ab526b..c0fd3b79dc1 100644 --- a/src/caffe/layers/deconv_layer.cpp +++ b/src/caffe/layers/deconv_layer.cpp @@ -10,10 +10,18 @@ namespace caffe { template void DeconvolutionLayer::compute_output_shape() { - this->height_out_ = this->stride_h_ * (this->height_ - 1) + this->kernel_h_ - - 2 * this->pad_h_; - this->width_out_ = this->stride_w_ * (this->width_ - 1) + this->kernel_w_ - - 2 * this->pad_w_; + // input_shape_ + 1 to skip channel axis + const int* input_shape_data = this->input_shape_.cpu_data() + 1; + const int* kernel_shape_data = this->kernel_shape_.cpu_data(); + const int* stride_data = this->stride_.cpu_data(); + const int* pad_data = this->pad_.cpu_data(); + this->output_shape_.clear(); + for (int i = 0; i < this->num_spatial_axes_; ++i) { + const int input_dim = input_shape_data[i]; + const int output_dim = stride_data[i] * (input_dim - 1) + + kernel_shape_data[i] - 2 * pad_data[i]; + this->output_shape_.push_back(output_dim); + } } template @@ -24,11 +32,11 @@ void DeconvolutionLayer::Forward_cpu(const vector*>& bottom, const Dtype* bottom_data = bottom[i]->cpu_data(); Dtype* top_data = top[i]->mutable_cpu_data(); for (int n = 0; n < this->num_; ++n) { - this->backward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight, - top_data + top[i]->offset(n)); + this->backward_cpu_gemm(bottom_data + n * this->bottom_dim_, weight, + top_data + n * this->top_dim_); if (this->bias_term_) { const Dtype* bias = this->blobs_[1]->cpu_data(); - this->forward_cpu_bias(top_data + top[i]->offset(n), bias); + this->forward_cpu_bias(top_data + n * this->top_dim_, bias); } } } @@ -54,21 +62,21 @@ void DeconvolutionLayer::Backward_cpu(const vector*>& top, if (this->bias_term_ && this->param_propagate_down_[1]) { Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); for (int n = 0; n < this->num_; ++n) { - this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n)); + this->backward_cpu_bias(bias_diff, top_diff + n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { for (int n = 0; n < this->num_; ++n) { // Gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - this->weight_cpu_gemm(top_diff + top[i]->offset(n), - bottom_data + bottom[i]->offset(n), weight_diff); + this->weight_cpu_gemm(top_diff + n * this->top_dim_, + bottom_data + n * this->bottom_dim_, weight_diff); } // Gradient w.r.t. bottom data, if necessary, reusing the column buffer // we might have just computed above. if (propagate_down[i]) { - this->forward_cpu_gemm(top_diff + top[i]->offset(n), weight, - bottom_diff + bottom[i]->offset(n), + this->forward_cpu_gemm(top_diff + n * this->top_dim_, weight, + bottom_diff + n * this->bottom_dim_, this->param_propagate_down_[0]); } } diff --git a/src/caffe/layers/deconv_layer.cu b/src/caffe/layers/deconv_layer.cu index 9198dd64c72..06e65660131 100644 --- a/src/caffe/layers/deconv_layer.cu +++ b/src/caffe/layers/deconv_layer.cu @@ -16,11 +16,11 @@ void DeconvolutionLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[i]->gpu_data(); Dtype* top_data = top[i]->mutable_gpu_data(); for (int n = 0; n < this->num_; ++n) { - this->backward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight, - top_data + top[i]->offset(n)); + this->backward_gpu_gemm(bottom_data + n * this->bottom_dim_, weight, + top_data + n * this->top_dim_); if (this->bias_term_) { const Dtype* bias = this->blobs_[1]->gpu_data(); - this->forward_gpu_bias(top_data + top[i]->offset(n), bias); + this->forward_gpu_bias(top_data + n * this->top_dim_, bias); } } } @@ -46,20 +46,20 @@ void DeconvolutionLayer::Backward_gpu(const vector*>& top, if (this->bias_term_ && this->param_propagate_down_[1]) { Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); for (int n = 0; n < this->num_; ++n) { - this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n)); + this->backward_gpu_bias(bias_diff, top_diff + n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { for (int n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - this->weight_gpu_gemm(top_diff + top[i]->offset(n), - bottom_data + bottom[i]->offset(n), weight_diff); + this->weight_gpu_gemm(top_diff + n * this->top_dim_, + bottom_data + n * this->bottom_dim_, weight_diff); } // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { - this->forward_gpu_gemm(top_diff + top[i]->offset(n), weight, - bottom_diff + bottom[i]->offset(n)); + this->forward_gpu_gemm(top_diff + n * this->top_dim_, weight, + bottom_diff + n * this->bottom_dim_); } } } diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index 1c802714e33..8f5dc2d0250 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -11,54 +11,105 @@ template void Im2colLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { ConvolutionParameter conv_param = this->layer_param_.convolution_param(); - CHECK(!conv_param.has_kernel_size() != - !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; - CHECK(conv_param.has_kernel_size() || - (conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "For non-square filters both kernel_h and kernel_w are required."; - CHECK((!conv_param.has_pad() && conv_param.has_pad_h() - && conv_param.has_pad_w()) - || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) - << "pad is pad OR pad_h and pad_w are required."; - CHECK((!conv_param.has_stride() && conv_param.has_stride_h() - && conv_param.has_stride_w()) - || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) - << "Stride is stride OR stride_h and stride_w are required."; - if (conv_param.has_kernel_size()) { - kernel_h_ = kernel_w_ = conv_param.kernel_size(); + const int input_num_dims = bottom[0]->shape().size(); + channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis()); + const int first_spatial_dim = channel_axis_ + 1; + num_spatial_axes_ = input_num_dims - first_spatial_dim; + CHECK_GE(num_spatial_axes_, 1); + vector dim_blob_shape(1, num_spatial_axes_); + // Setup filter kernel dimensions (kernel_shape_). + kernel_shape_.Reshape(dim_blob_shape); + int* kernel_shape_data = kernel_shape_.mutable_cpu_data(); + if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "kernel_h & kernel_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.kernel_size_size()) + << "Either kernel_size or kernel_h/w should be specified; not both."; + kernel_shape_data[0] = conv_param.kernel_h(); + kernel_shape_data[1] = conv_param.kernel_w(); } else { - kernel_h_ = conv_param.kernel_h(); - kernel_w_ = conv_param.kernel_w(); + const int num_kernel_dims = conv_param.kernel_size_size(); + CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_) + << "kernel_size must be specified once, or once per spatial dimension " + << "(kernel_size specified " << num_kernel_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + for (int i = 0; i < num_spatial_axes_; ++i) { + kernel_shape_data[i] = + conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i); + } } - CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; - CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; - if (!conv_param.has_pad_h()) { - pad_h_ = pad_w_ = conv_param.pad(); + for (int i = 0; i < num_spatial_axes_; ++i) { + CHECK_GT(kernel_shape_data[i], 0) << "Filter dimensions must be nonzero."; + } + // Setup stride dimensions (stride_). + stride_.Reshape(dim_blob_shape); + int* stride_data = stride_.mutable_cpu_data(); + if (conv_param.has_stride_h() || conv_param.has_stride_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "stride_h & stride_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.stride_size()) + << "Either stride or stride_h/w should be specified; not both."; + stride_data[0] = conv_param.stride_h(); + stride_data[1] = conv_param.stride_w(); } else { - pad_h_ = conv_param.pad_h(); - pad_w_ = conv_param.pad_w(); + const int num_stride_dims = conv_param.stride_size(); + CHECK(num_stride_dims == 0 || num_stride_dims == 1 || + num_stride_dims == num_spatial_axes_) + << "stride must be specified once, or once per spatial dimension " + << "(stride specified " << num_stride_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultStride = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + stride_data[i] = (num_stride_dims == 0) ? kDefaultStride : + conv_param.stride((num_stride_dims == 1) ? 0 : i); + CHECK_GT(stride_data[i], 0) << "Stride dimensions must be nonzero."; + } } - if (!conv_param.has_stride_h()) { - stride_h_ = stride_w_ = conv_param.stride(); + // Setup pad dimensions (pad_). + pad_.Reshape(dim_blob_shape); + int* pad_data = pad_.mutable_cpu_data(); + if (conv_param.has_pad_h() || conv_param.has_pad_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "pad_h & pad_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.pad_size()) + << "Either pad or pad_h/w should be specified; not both."; + pad_data[0] = conv_param.pad_h(); + pad_data[1] = conv_param.pad_w(); } else { - stride_h_ = conv_param.stride_h(); - stride_w_ = conv_param.stride_w(); + const int num_pad_dims = conv_param.pad_size(); + CHECK(num_pad_dims == 0 || num_pad_dims == 1 || + num_pad_dims == num_spatial_axes_) + << "pad must be specified once, or once per spatial dimension " + << "(pad specified " << num_pad_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultPad = 0; + for (int i = 0; i < num_spatial_axes_; ++i) { + pad_data[i] = (num_pad_dims == 0) ? kDefaultPad : + conv_param.pad((num_pad_dims == 1) ? 0 : i); + } } } template void Im2colLayer::Reshape(const vector*>& bottom, const vector*>& top) { - CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, " - << "corresponding to (num, channels, height, width)"; - channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - top[0]->Reshape( - bottom[0]->num(), channels_ * kernel_h_ * kernel_w_, - (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1, - (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1); + vector top_shape = bottom[0]->shape(); + const int* kernel_shape_data = kernel_shape_.cpu_data(); + const int* stride_data = stride_.cpu_data(); + const int* pad_data = pad_.cpu_data(); + for (int i = 0; i < num_spatial_axes_; ++i) { + top_shape[channel_axis_] *= kernel_shape_data[i]; + const int input_dim = bottom[0]->shape(channel_axis_ + i + 1); + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i]) + / stride_data[i] + 1; + top_shape[channel_axis_ + i + 1] = output_dim; + } + top[0]->Reshape(top_shape); + num_ = bottom[0]->count(0, channel_axis_); + bottom_dim_ = bottom[0]->count(channel_axis_); + top_dim_ = top[0]->count(channel_axis_); + + channels_ = bottom[0]->shape(channel_axis_); } template @@ -66,10 +117,17 @@ void Im2colLayer::Forward_cpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* top_data = top[0]->mutable_cpu_data(); - for (int n = 0; n < bottom[0]->num(); ++n) { - im2col_cpu(bottom_data + bottom[0]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, top_data + top[0]->offset(n)); + for (int n = 0; n < num_; ++n) { + DCHECK_EQ(bottom[0]->shape().size() - channel_axis_, num_spatial_axes_ + 1); + DCHECK_EQ(top[0]->shape().size() - channel_axis_, num_spatial_axes_ + 1); + DCHECK_EQ(kernel_shape_.count(), num_spatial_axes_); + DCHECK_EQ(pad_.count(), num_spatial_axes_); + DCHECK_EQ(stride_.count(), num_spatial_axes_); + im2col_cpu(bottom_data + n * bottom_dim_, num_spatial_axes_, + bottom[0]->shape().data() + channel_axis_, + top[0]->shape().data() + channel_axis_, + kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(), + top_data + n * top_dim_); } } @@ -78,10 +136,12 @@ void Im2colLayer::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); - for (int n = 0; n < top[0]->num(); ++n) { - col2im_cpu(top_diff + top[0]->offset(n), channels_, height_, width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n)); + for (int n = 0; n < num_; ++n) { + col2im_cpu(top_diff + n * top_dim_, num_spatial_axes_, + bottom[0]->shape().data() + channel_axis_, + top[0]->shape().data() + channel_axis_, + kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(), + bottom_diff + n * bottom_dim_); } } diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu index 9c338b14cb7..fc119eeb315 100644 --- a/src/caffe/layers/im2col_layer.cu +++ b/src/caffe/layers/im2col_layer.cu @@ -12,10 +12,13 @@ void Im2colLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); - for (int n = 0; n < bottom[0]->num(); ++n) { - im2col_gpu(bottom_data + bottom[0]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, top_data + top[0]->offset(n)); + const int num_kernels = channels_ * top[0]->count(channel_axis_ + 1); + for (int n = 0; n < num_; ++n) { + im2col_gpu(bottom_data + n * bottom_dim_, num_spatial_axes_, num_kernels, + bottom[0]->gpu_shape() + channel_axis_, + top[0]->gpu_shape() + channel_axis_, + kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), + top_data + n * top_dim_); } } @@ -24,10 +27,12 @@ void Im2colLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - for (int n = 0; n < top[0]->num(); ++n) { - col2im_gpu(top_diff + top[0]->offset(n), channels_, height_, width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n)); + for (int n = 0; n < num_; ++n) { + col2im_gpu(top_diff + n * top_dim_, num_spatial_axes_, bottom_dim_, + bottom[0]->gpu_shape() + channel_axis_, + top[0]->gpu_shape() + channel_axis_, + kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), + bottom_diff + n * bottom_dim_); } } diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index c1fe3b58c58..009cae9a796 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -21,25 +21,25 @@ void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, Blob* out) { // Kernel size, stride, and pad int kernel_h, kernel_w; - if (conv_param->has_kernel_size()) { - kernel_h = kernel_w = conv_param->kernel_size(); - } else { + if (conv_param->has_kernel_h() || conv_param->has_kernel_w()) { kernel_h = conv_param->kernel_h(); kernel_w = conv_param->kernel_w(); + } else { + kernel_h = kernel_w = conv_param->kernel_size(0); } int pad_h, pad_w; - if (!conv_param->has_pad_h()) { - pad_h = pad_w = conv_param->pad(); - } else { + if (conv_param->has_pad_h() || conv_param->has_pad_w()) { pad_h = conv_param->pad_h(); pad_w = conv_param->pad_w(); + } else { + pad_h = pad_w = conv_param->pad_size() ? conv_param->pad(0) : 0; } int stride_h, stride_w; - if (!conv_param->has_stride_h()) { - stride_h = stride_w = conv_param->stride(); - } else { + if (conv_param->has_stride_h() || conv_param->has_stride_w()) { stride_h = conv_param->stride_h(); stride_w = conv_param->stride_w(); + } else { + stride_h = stride_w = conv_param->stride_size() ? conv_param->stride(0) : 1; } // Groups int groups = conv_param->group(); @@ -150,8 +150,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSetup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); @@ -188,8 +188,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -222,8 +222,8 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(1); - convolution_param->set_stride(1); + convolution_param->add_kernel_size(1); + convolution_param->add_stride(1); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -249,8 +249,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); @@ -288,8 +288,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(1); convolution_param->set_bias_term(false); shared_ptr > layer( @@ -375,8 +375,8 @@ TYPED_TEST(ConvolutionLayerTest, TestGradient) { layer_param.mutable_convolution_param(); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -393,8 +393,8 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Gradient) { layer_param.mutable_convolution_param(); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - convolution_param->set_kernel_size(1); - convolution_param->set_stride(1); + convolution_param->add_kernel_size(1); + convolution_param->add_stride(1); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -409,8 +409,8 @@ TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); @@ -473,8 +473,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); @@ -511,8 +511,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -545,8 +545,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); @@ -584,8 +584,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(1); convolution_param->set_bias_term(false); shared_ptr > layer( @@ -671,8 +671,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientCuDNN) { layer_param.mutable_convolution_param(); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -687,8 +687,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientGroupCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); diff --git a/src/caffe/test/test_deconvolution_layer.cpp b/src/caffe/test/test_deconvolution_layer.cpp index fc63d5efbe3..5b1d5d2f375 100644 --- a/src/caffe/test/test_deconvolution_layer.cpp +++ b/src/caffe/test/test_deconvolution_layer.cpp @@ -58,8 +58,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestSetup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); @@ -96,8 +96,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestSimpleDeconvolution) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("constant"); convolution_param->mutable_weight_filler()->set_value(1); @@ -144,8 +144,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestGradient) { layer_param.mutable_convolution_param(); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - convolution_param->set_kernel_size(2); - convolution_param->set_stride(1); + convolution_param->add_kernel_size(2); + convolution_param->add_stride(1); convolution_param->set_num_output(1); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu index ee684c00255..5f5125e6f25 100644 --- a/src/caffe/test/test_im2col_kernel.cu +++ b/src/caffe/test/test_im2col_kernel.cu @@ -14,12 +14,10 @@ namespace caffe { // Forward declare kernel functions -template +template __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int height_col, const int width_col, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, Dtype* data_col); extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; @@ -30,11 +28,18 @@ class Im2colKernelTest : public ::testing::Test { Im2colKernelTest() // big so launches > 1024 threads : blob_bottom_(new Blob(5, 500, 10, 10)), + blob_kernel_shape_(new Blob()), + blob_stride_(new Blob()), + blob_pad_(new Blob()), blob_top_(new Blob()), blob_top_cpu_(new Blob()) { FillerParameter filler_param; GaussianFiller filler(filler_param); filler.Fill(this->blob_bottom_); + vector dim_blob_shape(1, 2); + blob_kernel_shape_->Reshape(dim_blob_shape); + blob_stride_->Reshape(dim_blob_shape); + blob_pad_->Reshape(dim_blob_shape); height_ = blob_bottom_->height(); width_ = blob_bottom_->width(); @@ -44,14 +49,26 @@ class Im2colKernelTest : public ::testing::Test { kernel_size_ = 3; height_col_ = (height_ + 2 * pad_ - kernel_size_) / stride_ + 1; width_col_ = (width_ + 2 * pad_ - kernel_size_) / stride_ + 1; + + for (int i = 0; i < 2; ++i) { + blob_kernel_shape_->mutable_cpu_data()[i] = kernel_size_; + blob_stride_->mutable_cpu_data()[i] = stride_; + blob_pad_->mutable_cpu_data()[i] = pad_; + } } virtual ~Im2colKernelTest() { - delete blob_bottom_; - delete blob_top_; - delete blob_top_cpu_; + delete blob_bottom_; + delete blob_top_; + delete blob_top_cpu_; + delete blob_kernel_shape_; + delete blob_stride_; + delete blob_pad_; } + Blob* const blob_kernel_shape_; + Blob* const blob_stride_; + Blob* const blob_pad_; Blob* const blob_bottom_; Blob* const blob_top_; Blob* const blob_top_cpu_; @@ -72,49 +89,48 @@ TYPED_TEST(Im2colKernelTest, TestGPU) { // Reshape the blobs to correct size for im2col output this->blob_top_->Reshape(this->blob_bottom_->num(), - this->channels_ * this->kernel_size_ * this->kernel_size_, - this->height_col_, - this->width_col_); + this->channels_ * this->kernel_size_ * this->kernel_size_, + this->height_col_, + this->width_col_); - this->blob_top_cpu_->Reshape(this->blob_bottom_->num(), - this->channels_ * this->kernel_size_ * this->kernel_size_, - this->height_col_, - this->width_col_); + this->blob_top_cpu_->ReshapeLike(*this->blob_top_); - const TypeParam* bottom_data = this->blob_bottom_->gpu_data(); - TypeParam* top_data = this->blob_top_->mutable_gpu_data(); - TypeParam* cpu_data = this->blob_top_cpu_->mutable_cpu_data(); + const TypeParam* bottom_data_cpu = this->blob_bottom_->cpu_data(); + TypeParam* top_data_cpu = this->blob_top_cpu_->mutable_cpu_data(); // CPU Version for (int n = 0; n < this->blob_bottom_->num(); ++n) { - im2col_cpu(this->blob_bottom_->cpu_data() + this->blob_bottom_->offset(n), - this->channels_, this->height_, this->width_, - this->kernel_size_, this->kernel_size_, this->pad_, this->pad_, - this->stride_, this->stride_, - cpu_data + this->blob_top_cpu_->offset(n)); + im2col_cpu(bottom_data_cpu + this->blob_bottom_->offset(n), 2, + this->blob_bottom_->shape().data() + 1, + this->blob_top_cpu_->shape().data() + 1, + this->blob_kernel_shape_->cpu_data(), + this->blob_pad_->cpu_data(), this->blob_stride_->cpu_data(), + top_data_cpu + this->blob_top_cpu_->offset(n)); } // GPU version int num_kernels = this->channels_ * this->height_col_ * this->width_col_; int default_grid_dim = CAFFE_GET_BLOCKS(num_kernels); + const TypeParam* bottom_data_gpu = this->blob_bottom_->gpu_data(); // Launch with different grid sizes for (int grid_div = 2; grid_div <= 8; grid_div++) { for (int n = 0; n < this->blob_bottom_->num(); ++n) { - int grid_dim = default_grid_dim/grid_div; + const int grid_dim = default_grid_dim / grid_div; + TypeParam* top_data_gpu = this->blob_top_->mutable_gpu_data(); // NOLINT_NEXT_LINE(whitespace/operators) - im2col_gpu_kernel<<>>( - num_kernels, bottom_data + this->blob_bottom_->offset(n), - this->height_, this->width_, this->kernel_size_, this->kernel_size_, - this->pad_, this->pad_, this->stride_, this->stride_, - this->height_col_, this->width_col_, - top_data + this->blob_top_->offset(n)); + im2col_gpu_kernel<<>>( + num_kernels, bottom_data_gpu + this->blob_bottom_->offset(n), + this->blob_bottom_->gpu_shape() + 1, this->blob_top_->gpu_shape() + 1, + this->blob_kernel_shape_->gpu_data(), this->blob_pad_->gpu_data(), + this->blob_stride_->gpu_data(), + top_data_gpu + this->blob_top_->offset(n)); CUDA_POST_KERNEL_CHECK; } // Compare results against CPU version for (int i = 0; i < this->blob_top_->count(); ++i) { - TypeParam cpuval = cpu_data[i]; + TypeParam cpuval = top_data_cpu[i]; TypeParam gpuval = this->blob_top_->cpu_data()[i]; EXPECT_EQ(cpuval, gpuval); if (cpuval != gpuval) { diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp index f50abe103f8..5c5683a5fe8 100644 --- a/src/caffe/test/test_im2col_layer.cpp +++ b/src/caffe/test/test_im2col_layer.cpp @@ -21,6 +21,7 @@ class Im2colLayerTest : public MultiDeviceTest { : blob_bottom_(new Blob(2, 3, 6, 5)), blob_top_(new Blob()) { // fill the values + Caffe::set_random_seed(1701); FillerParameter filler_param; GaussianFiller filler(filler_param); filler.Fill(this->blob_bottom_); @@ -41,8 +42,8 @@ TYPED_TEST(Im2colLayerTest, TestSetup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_->num(), 2); @@ -56,8 +57,8 @@ TYPED_TEST(Im2colLayerTest, TestForward) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); @@ -73,8 +74,8 @@ TYPED_TEST(Im2colLayerTest, TestGradient) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, @@ -89,7 +90,7 @@ TYPED_TEST(Im2colLayerTest, TestRect) { layer_param.mutable_convolution_param(); convolution_param->set_kernel_h(5); convolution_param->set_kernel_w(3); - convolution_param->set_stride(2); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); @@ -108,7 +109,7 @@ TYPED_TEST(Im2colLayerTest, TestRectGradient) { layer_param.mutable_convolution_param(); convolution_param->set_kernel_h(5); convolution_param->set_kernel_w(3); - convolution_param->set_stride(2); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index c48f31f35d4..f12debe7ef9 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "caffe/util/im2col.hpp" #include "caffe/util/math_functions.hpp" @@ -8,76 +9,118 @@ namespace caffe { template -void im2col_cpu(const Dtype* data_im, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - Dtype* data_col) { - int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; - int channels_col = channels * kernel_h * kernel_w; +inline void im2col_core_cpu(const Dtype* data_input, const bool im2col, + const int num_spatial_axes, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_output) { + if (!im2col) { + int im_size = im_shape[0]; + for (int i = 0; i < num_spatial_axes; ++i) { + im_size *= im_shape[1 + i]; + } + caffe_set(im_size, Dtype(0), data_output); + } + int kernel_size = 1; + for (int i = 0; i < num_spatial_axes; ++i) { + kernel_size *= kernel_shape[i]; + } + const int channels_col = col_shape[0]; + vector d_offset(num_spatial_axes, 0); + vector d_iter(num_spatial_axes, 0); for (int c = 0; c < channels_col; ++c) { - int w_offset = c % kernel_w; - int h_offset = (c / kernel_w) % kernel_h; - int c_im = c / kernel_h / kernel_w; - for (int h = 0; h < height_col; ++h) { - for (int w = 0; w < width_col; ++w) { - int h_pad = h * stride_h - pad_h + h_offset; - int w_pad = w * stride_w - pad_w + w_offset; - if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) - data_col[(c * height_col + h) * width_col + w] = - data_im[(c_im * height + h_pad) * width + w_pad]; - else - data_col[(c * height_col + h) * width_col + w] = 0; + // Loop over spatial axes in reverse order to compute a per-axis offset. + int offset = c; + for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) { + if (d_i < num_spatial_axes - 1) { + offset /= kernel_shape[d_i + 1]; } + d_offset[d_i] = offset % kernel_shape[d_i]; } - } + for (bool incremented = true; incremented; ) { + // Loop over spatial axes in forward order to compute the indices in the + // image and column, and whether the index lies in the padding. + int index_col = c; + int index_im = c / kernel_size; + bool is_padding = false; + for (int d_i = 0; d_i < num_spatial_axes; ++d_i) { + const int d = d_iter[d_i]; + const int d_pad = d * stride[d_i] - pad[d_i] + d_offset[d_i]; + is_padding |= d_pad < 0 || d_pad >= im_shape[d_i + 1]; + index_col *= col_shape[d_i + 1]; + index_col += d; + index_im *= im_shape[d_i + 1]; + index_im += d_pad; + } + if (im2col) { + if (is_padding) { + data_output[index_col] = 0; + } else { + data_output[index_col] = data_input[index_im]; + } + } else if (!is_padding) { // col2im + data_output[index_im] += data_input[index_col]; + } + // Loop over spatial axes in reverse order to choose an index, + // like counting. + incremented = false; + for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) { + const int d_max = col_shape[d_i + 1]; + DCHECK_LT(d_iter[d_i], d_max); + if (d_iter[d_i] == d_max - 1) { + d_iter[d_i] = 0; + } else { // d_iter[d_i] < d_max - 1 + ++d_iter[d_i]; + incremented = true; + break; + } + } + } // while(incremented) { + } // for (int c = 0; c < channels_col; ++c) { +} + +template +void im2col_cpu(const Dtype* data_im, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col) { + const bool kIm2Col = true; + im2col_core_cpu(data_im, kIm2Col, num_spatial_axes, im_shape, col_shape, + kernel_shape, pad, stride, data_col); } // Explicit instantiation -template void im2col_cpu(const float* data_im, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_col); -template void im2col_cpu(const double* data_im, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_col); +template void im2col_cpu(const float* data_im, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_col); +template void im2col_cpu(const double* data_im, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_col); template -void col2im_cpu(const Dtype* data_col, const int channels, - const int height, const int width, const int patch_h, const int patch_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, +void col2im_cpu(const Dtype* data_col, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, Dtype* data_im) { - caffe_set(height * width * channels, Dtype(0), data_im); - int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1; - int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1; - int channels_col = channels * patch_h * patch_w; - for (int c = 0; c < channels_col; ++c) { - int w_offset = c % patch_w; - int h_offset = (c / patch_w) % patch_h; - int c_im = c / patch_h / patch_w; - for (int h = 0; h < height_col; ++h) { - for (int w = 0; w < width_col; ++w) { - int h_pad = h * stride_h - pad_h + h_offset; - int w_pad = w * stride_w - pad_w + w_offset; - if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) - data_im[(c_im * height + h_pad) * width + w_pad] += - data_col[(c * height_col + h) * width_col + w]; - } - } - } + const bool kIm2Col = false; + im2col_core_cpu(data_col, kIm2Col, num_spatial_axes, im_shape, col_shape, + kernel_shape, pad, stride, data_im); } // Explicit instantiation -template void col2im_cpu(const float* data_col, const int channels, - const int height, const int width, const int patch_h, const int patch_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_im); -template void col2im_cpu(const double* data_col, const int channels, - const int height, const int width, const int patch_h, const int patch_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_im); +template void col2im_cpu(const float* data_col, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_im); +template void col2im_cpu(const double* data_col, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_im); + } // namespace caffe diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index c90f93eb67b..2c3633175a6 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -8,137 +8,187 @@ namespace caffe { -template +template __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int height_col, const int width_col, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, Dtype* data_col) { + int d_temp[num_axes]; // NOLINT(runtime/arrays) + int d_iter[num_axes]; // NOLINT(runtime/arrays) + int i; CUDA_KERNEL_LOOP(index, n) { - int w_out = index % width_col; - int h_index = index / width_col; - int h_out = h_index % height_col; - int channel_in = h_index / height_col; - int channel_out = channel_in * kernel_h * kernel_w; - int h_in = h_out * stride_h - pad_h; - int w_in = w_out * stride_w - pad_w; - Dtype* data_col_ptr = data_col; - data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out; - const Dtype* data_im_ptr = data_im; - data_im_ptr += (channel_in * height + h_in) * width + w_in; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - int h = h_in + i; - int w = w_in + j; - *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ? - data_im_ptr[i * width + j] : 0; - data_col_ptr += height_col * width_col; - } + // Initialize channel_in, computed in the loop below, with intermediate + // computations used to compute the spatial indices. + int channel_in = index; + int channel_out = 1; + for (i = num_axes - 1; i >= 0; --i) { + d_temp[i] = channel_in % col_shape[i + 1]; + channel_in /= col_shape[i + 1]; + channel_out *= kernel_shape[i]; } - } + channel_out *= channel_in; + int data_col_inc = 1; + for (i = 0; i < num_axes; ++i) { + channel_out *= col_shape[i + 1]; + channel_out += d_temp[i]; + d_temp[i] = d_temp[i] * stride[i] - pad[i]; + channel_in *= im_shape[i + 1]; + channel_in += d_temp[i]; + data_col_inc *= col_shape[i + 1]; + d_iter[i] = 0; + } + Dtype* data_col_ptr = data_col + channel_out; + const Dtype* data_im_ptr = data_im + channel_in; + bool incremented; + do { + bool in_range = true; + for (i = 0; i < num_axes; ++i) { + const int d_iter_im = d_iter[i] + d_temp[i]; + in_range &= d_iter_im >= 0 && d_iter_im < im_shape[i + 1]; + if (!in_range) { break; } + } + if (in_range) { + int data_im_offset = d_iter[0]; + for (i = 1; i < num_axes; ++i) { + data_im_offset *= im_shape[i + 1]; + data_im_offset += d_iter[i]; + } + *data_col_ptr = data_im_ptr[data_im_offset]; + } else { + *data_col_ptr = 0; + } + data_col_ptr += data_col_inc; + incremented = false; + for (i = num_axes - 1; i >= 0; --i) { + const int d_max = kernel_shape[i]; + if (d_iter[i] == d_max - 1) { + d_iter[i] = 0; + } else { // d_iter[i] < d_max - 1 + ++d_iter[i]; + incremented = true; + break; + } + } // for (int i = num_axes - 1; i >= 0; --i) + } while (incremented); // do + } // CUDA_KERNEL_LOOP(index, n) } template -void im2col_gpu(const Dtype* data_im, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, +void im2col_gpu(const Dtype* data_im, const int num_spatial_axes, + const int num_kernels, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, Dtype* data_col) { - // We are going to launch channels * height_col * width_col kernels, each - // kernel responsible for copying a single-channel grid. - int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; - int num_kernels = channels * height_col * width_col; - // NOLINT_NEXT_LINE(whitespace/operators) - im2col_gpu_kernel<<>>( - num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h, - pad_w, stride_h, stride_w, height_col, - width_col, data_col); + im2col_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); CUDA_POST_KERNEL_CHECK; } - // Explicit instantiation -template void im2col_gpu(const float* data_im, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, +template void im2col_gpu(const float* data_im, + const int num_spatial_axes, const int col_size, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, float* data_col); -template void im2col_gpu(const double* data_im, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, +template void im2col_gpu(const double* data_im, + const int num_spatial_axes, const int col_size, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, double* data_col); -template +template __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, - const int height, const int width, const int channels, - const int patch_h, const int patch_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int height_col, const int width_col, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, Dtype* data_im) { + int d_im[num_axes]; // NOLINT(runtime/arrays) + int d_col_iter[num_axes]; // NOLINT(runtime/arrays) + int d_col_start[num_axes]; // NOLINT(runtime/arrays) + int d_col_end[num_axes]; // NOLINT(runtime/arrays) CUDA_KERNEL_LOOP(index, n) { - Dtype val = 0; - int w = index % width + pad_w; - int h = (index / width) % height + pad_h; - int c = index / (width * height); - // compute the start and end of the output - int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1; - int w_col_end = min(w / stride_w + 1, width_col); - int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1; - int h_col_end = min(h / stride_h + 1, height_col); - /* - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - // the col location: [c * width * height + h_out, w_out] - int c_col = c * patch_h * patch_w + (h - h_col * stride_h) * ksize - + (w - w_col * stride_w); - val += data_col[(c_col * height_col + h_col) * width_col + w_col]; - } + // Initialize channel_in, computed in the loop below, with intermediate + // computations used to compute the spatial indices. + int channel_im = index; + // Calculate d_im (image dimensions). + for (int i = num_axes - 1; i >= 0; --i) { + d_im[i] = channel_im % im_shape[i + 1] + pad[i]; + channel_im /= im_shape[i + 1]; } - */ - // equivalent implementation - int offset = - (c * patch_h * patch_w + h * patch_w + w) * height_col * width_col; - int coeff_h_col = (1 - stride_h * patch_w * height_col) * width_col; - int coeff_w_col = (1 - stride_w * height_col * width_col); - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; + // Calculate col start/end indices. + bool done = false; + for (int i = 0; i < num_axes; ++i) { + d_col_start[i] = d_col_iter[i] = + (d_im[i] < kernel_shape[i]) ? + 0 : (d_im[i] - kernel_shape[i]) / stride[i] + 1; + d_col_end[i] = min(d_im[i] / stride[i] + 1, col_shape[i + 1]); + if (d_col_start[i] >= d_col_end[i]) { + // Skip computation if the dimension is 0 at any spatial axis -- + // final val will be 0. + data_im[index] = 0; + done = true; + break; // for (int i = 0; i < num_axes; ++i) } } + if (done) { + continue; // CUDA_KERNEL_LOOP(index, n) + } + // Loop over the col to compute the output val. + Dtype val = 0; + bool incremented = true; + do { + // Compute the final offset. + int final_offset = 0; + int kernel_shape_prod = 1; + for (int i = num_axes - 1; i >= 0; --i) { + final_offset += + (d_im[i] - d_col_iter[i] * stride[i]) * kernel_shape_prod; + kernel_shape_prod *= kernel_shape[i]; + } + final_offset += kernel_shape_prod * channel_im; + for (int i = 0; i < num_axes; ++i) { + final_offset *= col_shape[i + 1]; + final_offset += d_col_iter[i]; + } + val += data_col[final_offset]; + incremented = false; + for (int i = num_axes - 1; i >= 0; --i) { + const int d_max = d_col_end[i]; + if (d_col_iter[i] == d_max - 1) { + d_col_iter[i] = d_col_start[i]; + } else { // d_col_iter[i] < d_max - 1 + ++d_col_iter[i]; + incremented = true; + break; // for (int i = num_axes - 1; i >= 0; --i) + } + } // for (int i = num_axes - 1; i >= 0; --i) + } while (incremented); data_im[index] = val; - } + } // CUDA_KERNEL_LOOP(index, n) } template -void col2im_gpu(const Dtype* data_col, const int channels, - const int height, const int width, const int patch_h, const int patch_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_im) { - int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1; - int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1; - int num_kernels = channels * height * width; - // To avoid involving atomic operations, we will launch one kernel per - // bottom dimension, and then in the kernel add up the top dimensions. - // NOLINT_NEXT_LINE(whitespace/operators) - col2im_gpu_kernel<<>>( - num_kernels, data_col, height, width, channels, patch_h, patch_w, - pad_h, pad_w, stride_h, stride_w, - height_col, width_col, data_im); +void col2im_gpu(const Dtype* data_col, const int num_spatial_axes, + const int im_size, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im) { + col2im_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); CUDA_POST_KERNEL_CHECK; } // Explicit instantiation -template void col2im_gpu(const float* data_col, const int channels, - const int height, const int width, const int patch_h, const int patch_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_im); -template void col2im_gpu(const double* data_col, const int channels, - const int height, const int width, const int patch_h, const int patch_w, - const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_im); +template void col2im_gpu(const float* data_col, + const int num_spatial_axes, const int im_size, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_im); +template void col2im_gpu(const double* data_col, + const int num_spatial_axes, const int im_size, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_im); } // namespace caffe diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 38a06026adf..57fa192626b 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -193,7 +193,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection, } if (v0_layer_param.has_pad()) { if (type == "conv") { - layer_param->mutable_convolution_param()->set_pad(v0_layer_param.pad()); + layer_param->mutable_convolution_param()->add_pad(v0_layer_param.pad()); } else if (type == "pool") { layer_param->mutable_pooling_param()->set_pad(v0_layer_param.pad()); } else { @@ -203,7 +203,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection, } if (v0_layer_param.has_kernelsize()) { if (type == "conv") { - layer_param->mutable_convolution_param()->set_kernel_size( + layer_param->mutable_convolution_param()->add_kernel_size( v0_layer_param.kernelsize()); } else if (type == "pool") { layer_param->mutable_pooling_param()->set_kernel_size( @@ -224,7 +224,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection, } if (v0_layer_param.has_stride()) { if (type == "conv") { - layer_param->mutable_convolution_param()->set_stride( + layer_param->mutable_convolution_param()->add_stride( v0_layer_param.stride()); } else if (type == "pool") { layer_param->mutable_pooling_param()->set_stride(