From f8e50146f002e59c7e4bb39b89fd0128f8484e2b Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Tue, 27 Aug 2019 13:27:01 +0800 Subject: [PATCH 1/8] add RROIAlign --- src/operator/rroi_align-inl.h | 68 ++++++++ src/operator/rroi_align.cc | 313 ++++++++++++++++++++++++++++++++++ 2 files changed, 381 insertions(+) create mode 100644 src/operator/rroi_align-inl.h create mode 100644 src/operator/rroi_align.cc diff --git a/src/operator/rroi_align-inl.h b/src/operator/rroi_align-inl.h new file mode 100644 index 000000000000..4c517491caba --- /dev/null +++ b/src/operator/rroi_align-inl.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file rroi_align-inl.h + * \brief rroi align operator and symbol + * \author Yixin Bao + * Adapted from Caffe2 +*/ +#ifndef MXNET_OPERATOR_RROI_ALIGN_INL_H_ +#define MXNET_OPERATOR_RROI_ALIGN_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "./mshadow_op.h" +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +// Declare enumeration of input order to make code more intuitive. +// These enums are only visible within this header +namespace rroialign { +enum RROIAlignOpInputs{kData, kBox}; +enum RROIAlignOpOutputs {kOut}; +} // rroialign + +struct RROIAlignParam : public dmlc::Parameter { + mxnet::TShape pooled_size; + float spatial_scale; + int sampling_ratio; + DMLC_DECLARE_PARAMETER(RROIAlignParam) { + DMLC_DECLARE_FIELD(pooled_size) + .set_expect_ndim(2).enforce_nonzero() + .describe("RROI align output shape (h,w) "); + DMLC_DECLARE_FIELD(spatial_scale).set_range(0.0, 1.0) + .describe("Ratio of input feature map height (or w) to raw image height (or w). " + "Equals the reciprocal of total stride in convolutional layers"); + DMLC_DECLARE_FIELD(sampling_ratio).set_default(-1) + .describe("Optional sampling ratio of RROI align, using adaptive size by default."); + } +}; + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_RROI_ALIGN_INL_H_ diff --git a/src/operator/rroi_align.cc b/src/operator/rroi_align.cc new file mode 100644 index 000000000000..bdef465eb437 --- /dev/null +++ b/src/operator/rroi_align.cc @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file rroi_align.cc + * \brief rroi align operator + * \author Yixin Bao + * Adapted from Caffe2 +*/ +#include "./rroi_align-inl.h" +#include +#include "math.h" + +using std::max; +using std::min; +using std::floor; +using std::ceil; + +namespace mxnet { +namespace op { + +template +struct position_for_bilinear_interpolate { + // 4 positions and corresponding weights for + // computing bilinear interpolation + int pos1, pos2, pos3, pos4; + DType w1, w2, w3, w4; +}; + +template +void pre_calc_for_bilinear_interpolate( + const int height, const int width, const int pooled_height, const int pooled_width, + const int iy_upper, const int ix_upper, DType roi_start_h, DType roi_start_w, + DType bin_size_h, DType bin_size_w, int roi_bin_grid_h, int roi_bin_grid_w, + DType roi_center_h, DType roi_center_w, DType theta, + std::vector> *pre_calc) { + int pre_calc_index = 0; + DType cosTheta = cos(theta); + DType sinTheta = sin(theta); + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + // calc bin grid position (xx,yy) + for (int iy = 0; iy < iy_upper; iy++) { + const DType yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const DType xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + DType x = xx * cosTheta + yy * sinTheta + roi_center_w; + DType y = yy * cosTheta - xx * sinTheta + roi_center_h; + + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + position_for_bilinear_interpolate pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc->at(pre_calc_index) = pc; + pre_calc_index += 1; + continue; + } + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + // calc 4 points for interpolation + int y_low = static_cast(y); + int x_low = static_cast(x); + int y_high; + int x_high; + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (DType)y_low; + } else { + y_high = y_low + 1; + } + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (DType)x_low; + } else { + x_high = x_low + 1; + } + DType ly = y - y_low; + DType lx = x - x_low; + DType hy = 1. - ly, hx = 1. - lx; + DType w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // Save weights and indices + position_for_bilinear_interpolate pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc->at(pre_calc_index) = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template +inline void RROIAlignForward(const OpContext &ctx, const RROIAlignParam ¶m, + const std::vector &in_data, const std::vector &req, + const std::vector &out_data) { + // data: [batch_size, c, h, w] + const TBlob &data = in_data[rroialign::kData]; + const TBlob &bbox = in_data[rroialign::kBox]; + const DType *bottom_data = data.dptr(); + const int channels_ = data.size(1); + const int height_ = data.size(2); + const int width_ = data.size(3); + const index_t data_size_c = height_ * width_; + const index_t data_size = channels_ * data_size_c; + + // bbox: [num_rois, 6] (6: [batch_index, x, y, w, h, theta]) + const DType *bottom_rois = bbox.dptr(); + const int num_rois = bbox.size(0); + const float spatial_scale_ = param.spatial_scale; + const int sampling_ratio_ = param.sampling_ratio; + + // out: [num_rois, c, pooled_h, pooled_w] + const TBlob &out = out_data[rroialign::kOut]; + DType *top_data = out.dptr(); + const int pooled_height_ = out.size(2); + const int pooled_width_ = out.size(3); + const index_t out_size_c = pooled_height_ * pooled_width_; + const index_t out_size = channels_ * out_size_c; + + for (int n = 0; n < num_rois; ++n) { + // Increment ROI data pointer + const DType *bottom_rois_n = bottom_rois + n * bbox.size(1); + DType *top_data_n = top_data + n * out_size; + int roi_batch_ind = static_cast(bottom_rois_n[0]); + DType roi_center_w = bottom_rois_n[1] * spatial_scale_; + DType roi_center_h = bottom_rois_n[2] * spatial_scale_; + DType roi_width = bottom_rois_n[3] * spatial_scale_; + DType roi_height = bottom_rois_n[4] * spatial_scale_; + DType roi_theta = bottom_rois_n[5] * M_PI / 180.0; + + // force malformed ROIs to be 1 * 1 + roi_width = max(roi_width, (DType) 1.); + roi_height = max(roi_height, (DType) 1.); + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + DType roi_start_h = -roi_height / 2.0; + DType roi_start_w = -roi_width / 2.0; + + const DType bin_size_h = static_cast(roi_height) / static_cast(pooled_height_); + const DType bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); + // We use roi_bin_grid to sample the grid and mimic integral, + // e.g. roi_bin_grid = 2, means sample 2*2=4 points in each bin + int roi_bin_grid_h = + (sampling_ratio_ > 0) ? sampling_ratio_ : ceil(roi_height / pooled_height_); + int roi_bin_grid_w = (sampling_ratio_ > 0) ? sampling_ratio_ : ceil(roi_width / pooled_width_); + const DType bin_points_count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + // We want to precalculate indices and weights shared by all channels, + // this is the key point of optimization. + std::vector> pre_calc(roi_bin_grid_h * roi_bin_grid_w * + pooled_width_ * pooled_height_); + + pre_calc_for_bilinear_interpolate(height_, width_, pooled_height_, pooled_width_, + roi_bin_grid_h, roi_bin_grid_w, roi_start_h, roi_start_w, + bin_size_h, bin_size_w, roi_bin_grid_h, roi_bin_grid_w, + roi_center_h, roi_center_w, roi_theta, &pre_calc); + +#pragma omp parallel for + for (int c = 0; c < channels_; ++c) { + const DType *offset_bottom_data = bottom_data + roi_batch_ind * data_size + c * data_size_c; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height_; ph++) { + for (int pw = 0; pw < pooled_width_; pw++) { + DType output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + position_for_bilinear_interpolate pc = pre_calc[pre_calc_index]; + output_val += + pc.w1 * offset_bottom_data[pc.pos1] + pc.w2 * offset_bottom_data[pc.pos2] + + pc.w3 * offset_bottom_data[pc.pos3] + pc.w4 * offset_bottom_data[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= bin_points_count; // avg pooling for bin grid + int index = c * pooled_height_ * pooled_width_ + ph * pooled_width_ + pw; + top_data_n[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +template +void RROIAlignForwardCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + const RROIAlignParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_data.size(), 2); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(out_data[rroialign::kOut].shape_[0], in_data[rroialign::kBox].shape_[0]); + + MSHADOW_REAL_TYPE_SWITCH(in_data[0].type_flag_, DType, { + RROIAlignForward(ctx, param, in_data, req, out_data); + }) +} + +DMLC_REGISTER_PARAMETER(RROIAlignParam); + +NNVM_REGISTER_OP(_contrib_RROIAlign) +.describe(R"code(Performs Rotated ROI Align on the input array. + +This operator takes a 4D feature map as an input array and region proposals as `rois`, +then align the feature map over sub-regions of input and produces a fixed-sized output array. + +Different from ROI Align, RROI Align uses rotated rois, which is suitable for text detection. +RRoIAlign computes the value of each sampling point by bilinear interpolation from the nearby +grid points on the rotated feature map. No quantization is performed on any coordinates +involved in the RoI, its bins, or the sampling points. Bilinear interpolation is used to +compute the exact values of the input features at four regularly sampled locations in +each RoI bin. Then the feature map can be aggregated by avgpooling. + +References +---------- + +Ma, Jianqi, et al. "Arbitrary-Oriented Scene Text Detection via Rotation Proposals." +IEEE Transactions on Multimedia, 2018. + +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "rois"}; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", [](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape){ + using namespace mshadow; + const RROIAlignParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input:[data, rois]"; + // data: [batch_size, c, h, w] + mxnet::TShape dshape = in_shape->at(rroialign::kData); + CHECK_EQ(dshape.ndim(), 4U) << "data should be a 4D tensor"; + // bbox: [num_rois, 6] + mxnet::TShape bshape = in_shape->at(rroialign::kBox); + CHECK_EQ(bshape.ndim(), 2U) << "bbox should be a 2D tensor of shape [batch, 6]"; + CHECK_EQ(bshape[1], 6U) << "bbox should be a 2D tensor of shape [batch, 6]"; + // out: [num_rois, c, pooled_h, pooled_w] + out_shape->clear(); + out_shape->push_back(Shape4(bshape[0], dshape[1], param.pooled_size[0], param.pooled_size[1])); + return true; +}) +.set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + CHECK_EQ(in_type->size(), 2U); + int dtype = (*in_type)[0]; + CHECK_EQ(dtype, (*in_type)[1]); + CHECK_NE(dtype, -1) << "Input must have specified type"; + + out_type->clear(); + out_type->push_back(dtype); + return true; +}) +.set_attr("FCompute", RROIAlignForwardCompute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "Input data to the pooling operator, a 4D Feature maps") +.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array") +.add_arguments(RROIAlignParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet From e27b8117cd0b73c29a2b373d4c4e987e34802a35 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Tue, 27 Aug 2019 13:27:32 +0800 Subject: [PATCH 2/8] add test case for RROIAlign --- tests/python/unittest/test_operator.py | 133 +++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ceee51a3e503..82bdd4d414cb 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8409,6 +8409,139 @@ def test_roi_align_autograd(sampling_ratio=0): test_roi_align_value(position_sensitive=True) test_roi_align_autograd() +@with_seed() +def test_op_rroi_align(): + T = np.float32 + + def assert_same_dtype(dtype_a, dtype_b): + ''' + Assert whether the two data type are the same + Parameters + ---------- + dtype_a, dtype_b: type + Input data types to compare + ''' + assert dtype_a == dtype_b,\ + TypeError('Unmatched data types: %s vs %s' % (dtype_a, dtype_b)) + + def bilinear_interpolate(bottom, height, width, y, x): + if y < -1.0 or y > height or x < -1.0 or x > width: + return T(0.0) + x = T(max(0.0, x)) + y = T(max(0.0, y)) + x_low = int(x) + y_low = int(y) + if x_low >= width - 1: + x_low = x_high = width - 1 + x = T(x_low) + else: + x_high = x_low + 1 + if y_low >= height - 1: + y_low = y_high = height - 1 + y = T(y_low) + else: + y_high = y_low + 1 + ly = y - T(y_low) + lx = x - T(x_low) + hy = T(1.0) - ly + hx = T(1.0) - lx + v1 = bottom[y_low, x_low] + v2 = bottom[y_low, x_high] + v3 = bottom[y_high, x_low] + v4 = bottom[y_high, x_high] + w1 = hy * hx + w2 = hy * lx + w3 = ly * hx + w4 = ly * lx + assert_same_dtype(w1.dtype, T) + assert_same_dtype(w2.dtype, T) + assert_same_dtype(w3.dtype, T) + assert_same_dtype(w4.dtype, T) + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + assert_same_dtype(val.dtype, T) + + return val + + def rroialign_forward(data, rois, pooled_size, spatial_scale, sampling_ratio): + N, C, H, W = data.shape + R = rois.shape[0] + PH, PW = pooled_size + assert rois.ndim == 2,\ + ValueError( + 'The ndim of rois should be 2 rather than %d' % rois.ndim) + assert rois.shape[1] == 6,\ + ValueError( + 'The length of the axis 1 of rois should be 6 rather than %d' % rois.shape[1]) + assert_same_dtype(data.dtype, T) + assert_same_dtype(rois.dtype, T) + + out = np.zeros((R, C, PH, PW), dtype=T) + + for r in range(R): + batch_ind = int(rois[r, 0]) + roi_center_w, roi_center_h, roi_w, roi_h = rois[r, 1:5] * T(spatial_scale) + roi_theta = T(rois[r,5] * np.pi / 180.0) + roi_w = T(max(roi_w, 1.0)) + roi_h = T(max(roi_h, 1.0)) + bin_h = roi_h / T(PH) + bin_w = roi_w / T(PW) + bdata = data[batch_ind] + if sampling_ratio > 0: + roi_bin_grid_h = roi_bin_grid_w = sampling_ratio + else: + roi_bin_grid_h = int(np.ceil(roi_h / T(PH))) + roi_bin_grid_w = int(np.ceil(roi_w / T(PW))) + count = T(roi_bin_grid_h * roi_bin_grid_w) + roi_start_h = T(-roi_h / 2.0) + roi_start_w = T(-roi_w / 2.0) + for c in range(C): + for ph in range(PH): + for pw in range(PW): + val = T(0.0) + for iy in range(roi_bin_grid_h): + yy = roi_start_h + T(ph) * bin_h + (T(iy) + T(0.5)) * \ + bin_h / T(roi_bin_grid_h) + for ix in range(roi_bin_grid_w): + xx = roi_start_w + T(pw) * bin_w + (T(ix) + T(0.5)) * \ + bin_w / T(roi_bin_grid_w) + x = xx * np.cos(roi_theta, dtype=T) + yy * np.sin(roi_theta, dtype=T) + roi_center_w + y = yy * np.cos(roi_theta, dtype=T) - xx * np.sin(roi_theta, dtype=T) + roi_center_h + v = bilinear_interpolate( + bdata[c], H, W, y, x) + assert_same_dtype(v.dtype, T) + val += v + + out[r, c, ph, pw] = val / count + assert_same_dtype(out.dtype, T) + + return out + + def test_rroi_align_value(sampling_ratio=-1): + ctx = default_context() + dtype = np.float32 + dlen = 224 + N, C, H, W = 5, 3, 16, 16 + R = 7 + pooled_size = (3, 4) + spatial_scale = H * 1.0 / dlen + data = mx.nd.array( + np.arange(N * C * W * H).reshape((N, C, H, W)), ctx=ctx, dtype=dtype) + center_xy = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype=dtype) + wh = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype=dtype) + theta = mx.nd.random.uniform(0, 180, (R,1), ctx=ctx, dtype=dtype) + batch_ind = mx.nd.array(np.random.randint(0, N, size=(R, 1)), ctx=ctx) + pos = mx.nd.concat(center_xy, wh, theta, dim=1) + rois = mx.nd.concat(batch_ind, pos, dim=1) + + output = mx.nd.contrib.RROIAlign(data, rois, pooled_size=pooled_size, + spatial_scale=spatial_scale, sampling_ratio=sampling_ratio) + real_output = rroialign_forward(data.asnumpy(), rois.asnumpy(), pooled_size, + spatial_scale, sampling_ratio) + + assert_almost_equal(output.asnumpy(), real_output, atol=1e-3) + + test_rroi_align_value() + test_rroi_align_value(sampling_ratio=2) @with_seed() def test_diag(): From 51bf865f3186e2950c069cd4e029db3c214a3bf9 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 28 Aug 2019 08:49:50 +0800 Subject: [PATCH 3/8] move rroi align op to contrib --- src/operator/{ => contrib}/rroi_align-inl.h | 5 ++--- src/operator/{ => contrib}/rroi_align.cc | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) rename src/operator/{ => contrib}/rroi_align-inl.h (96%) rename src/operator/{ => contrib}/rroi_align.cc (98%) diff --git a/src/operator/rroi_align-inl.h b/src/operator/contrib/rroi_align-inl.h similarity index 96% rename from src/operator/rroi_align-inl.h rename to src/operator/contrib/rroi_align-inl.h index 4c517491caba..b6692194009b 100644 --- a/src/operator/rroi_align-inl.h +++ b/src/operator/contrib/rroi_align-inl.h @@ -22,7 +22,6 @@ * \file rroi_align-inl.h * \brief rroi align operator and symbol * \author Yixin Bao - * Adapted from Caffe2 */ #ifndef MXNET_OPERATOR_RROI_ALIGN_INL_H_ #define MXNET_OPERATOR_RROI_ALIGN_INL_H_ @@ -34,8 +33,8 @@ #include #include #include -#include "./mshadow_op.h" -#include "./operator_common.h" +#include "../mshadow_op.h" +#include "../operator_common.h" namespace mxnet { namespace op { diff --git a/src/operator/rroi_align.cc b/src/operator/contrib/rroi_align.cc similarity index 98% rename from src/operator/rroi_align.cc rename to src/operator/contrib/rroi_align.cc index bdef465eb437..3c4b5e75f79f 100644 --- a/src/operator/rroi_align.cc +++ b/src/operator/contrib/rroi_align.cc @@ -22,8 +22,9 @@ * \file rroi_align.cc * \brief rroi align operator * \author Yixin Bao - * Adapted from Caffe2 -*/ + * Forward pass adapted from Caffe2 + * link: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/roi_align_rotated_op.cc + */ #include "./rroi_align-inl.h" #include #include "math.h" From bf8c8fc26ee9eb075d76dd2c973c56e53b896c3f Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 28 Aug 2019 08:50:29 +0800 Subject: [PATCH 4/8] skip gpu test for rroi align --- tests/python/unittest/test_operator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 82bdd4d414cb..4bf5b16ce25c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8518,6 +8518,10 @@ def rroialign_forward(data, rois, pooled_size, spatial_scale, sampling_ratio): def test_rroi_align_value(sampling_ratio=-1): ctx = default_context() + if ctx.device_type == 'gpu': + print('skipped testing rroi align for gpu since it is not supported yet') + return + dtype = np.float32 dlen = 224 N, C, H, W = 5, 3, 16, 16 From 448eebd26930c79c200e8511720abdb6f458683d Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 28 Aug 2019 09:04:47 +0800 Subject: [PATCH 5/8] add omp parallel --- src/operator/contrib/rroi_align.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/operator/contrib/rroi_align.cc b/src/operator/contrib/rroi_align.cc index 3c4b5e75f79f..d22acd2eded0 100644 --- a/src/operator/contrib/rroi_align.cc +++ b/src/operator/contrib/rroi_align.cc @@ -163,6 +163,9 @@ inline void RROIAlignForward(const OpContext &ctx, const RROIAlignParam ¶m, const index_t out_size_c = pooled_height_ * pooled_width_; const index_t out_size = channels_ * out_size_c; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp +#pragma omp parallel for for (int n = 0; n < num_rois; ++n) { // Increment ROI data pointer const DType *bottom_rois_n = bottom_rois + n * bbox.size(1); @@ -201,7 +204,6 @@ inline void RROIAlignForward(const OpContext &ctx, const RROIAlignParam ¶m, bin_size_h, bin_size_w, roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, roi_theta, &pre_calc); -#pragma omp parallel for for (int c = 0; c < channels_; ++c) { const DType *offset_bottom_data = bottom_data + roi_batch_ind * data_size + c * data_size_c; int pre_calc_index = 0; From 5bdd339002d810e282ffcf5672da4268baf67b99 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 28 Aug 2019 09:39:28 +0800 Subject: [PATCH 6/8] fix lint --- src/operator/contrib/rroi_align-inl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/contrib/rroi_align-inl.h b/src/operator/contrib/rroi_align-inl.h index b6692194009b..e0467ac7c873 100644 --- a/src/operator/contrib/rroi_align-inl.h +++ b/src/operator/contrib/rroi_align-inl.h @@ -23,8 +23,8 @@ * \brief rroi align operator and symbol * \author Yixin Bao */ -#ifndef MXNET_OPERATOR_RROI_ALIGN_INL_H_ -#define MXNET_OPERATOR_RROI_ALIGN_INL_H_ +#ifndef MXNET_OPERATOR_CONTRIB_RROI_ALIGN_INL_H_ +#define MXNET_OPERATOR_CONTRIB_RROI_ALIGN_INL_H_ #include #include @@ -64,4 +64,4 @@ struct RROIAlignParam : public dmlc::Parameter { } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_RROI_ALIGN_INL_H_ +#endif // MXNET_OPERATOR_CONTRIB_RROI_ALIGN_INL_H_ From d0d4c0cb6c0fe9cfbb00363d4c65deb5644f8681 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 28 Aug 2019 15:01:50 +0800 Subject: [PATCH 7/8] update omp with num_threads --- src/operator/contrib/rroi_align.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/contrib/rroi_align.cc b/src/operator/contrib/rroi_align.cc index d22acd2eded0..2778100de877 100644 --- a/src/operator/contrib/rroi_align.cc +++ b/src/operator/contrib/rroi_align.cc @@ -165,7 +165,7 @@ inline void RROIAlignForward(const OpContext &ctx, const RROIAlignParam ¶m, // (n, c, ph, pw) is an element in the pooled output // can be parallelized using omp -#pragma omp parallel for +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int n = 0; n < num_rois; ++n) { // Increment ROI data pointer const DType *bottom_rois_n = bottom_rois + n * bbox.size(1); From dcc9d79b9fe14c159f609f1a7fb87d5d5871cd19 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Thu, 29 Aug 2019 13:29:58 +0800 Subject: [PATCH 8/8] fix extra copy --- src/operator/contrib/rroi_align-inl.h | 4 ++-- src/operator/contrib/rroi_align.cc | 24 +++++++++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/operator/contrib/rroi_align-inl.h b/src/operator/contrib/rroi_align-inl.h index e0467ac7c873..084ead78dbc7 100644 --- a/src/operator/contrib/rroi_align-inl.h +++ b/src/operator/contrib/rroi_align-inl.h @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2019 by Contributors * \file rroi_align-inl.h * \brief rroi align operator and symbol * \author Yixin Bao @@ -55,7 +55,7 @@ struct RROIAlignParam : public dmlc::Parameter { .set_expect_ndim(2).enforce_nonzero() .describe("RROI align output shape (h,w) "); DMLC_DECLARE_FIELD(spatial_scale).set_range(0.0, 1.0) - .describe("Ratio of input feature map height (or w) to raw image height (or w). " + .describe("Ratio of input feature map height (or width) to raw image height (or width). " "Equals the reciprocal of total stride in convolutional layers"); DMLC_DECLARE_FIELD(sampling_ratio).set_default(-1) .describe("Optional sampling ratio of RROI align, using adaptive size by default."); diff --git a/src/operator/contrib/rroi_align.cc b/src/operator/contrib/rroi_align.cc index 2778100de877..14690d6270d2 100644 --- a/src/operator/contrib/rroi_align.cc +++ b/src/operator/contrib/rroi_align.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2019 by Contributors * \file rroi_align.cc * \brief rroi align operator * \author Yixin Bao @@ -74,7 +74,7 @@ void pre_calc_for_bilinear_interpolate( // deal with: inverse elements are out of feature map boundary if (y < -1.0 || y > height || x < -1.0 || x > width) { // empty - position_for_bilinear_interpolate pc; + position_for_bilinear_interpolate &pc = (*pre_calc)[pre_calc_index]; pc.pos1 = 0; pc.pos2 = 0; pc.pos3 = 0; @@ -83,7 +83,6 @@ void pre_calc_for_bilinear_interpolate( pc.w2 = 0; pc.w3 = 0; pc.w4 = 0; - pre_calc->at(pre_calc_index) = pc; pre_calc_index += 1; continue; } @@ -117,7 +116,7 @@ void pre_calc_for_bilinear_interpolate( DType w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; // Save weights and indices - position_for_bilinear_interpolate pc; + position_for_bilinear_interpolate &pc = (*pre_calc)[pre_calc_index]; pc.pos1 = y_low * width + x_low; pc.pos2 = y_low * width + x_high; pc.pos3 = y_high * width + x_low; @@ -126,8 +125,6 @@ void pre_calc_for_bilinear_interpolate( pc.w2 = w2; pc.w3 = w3; pc.w4 = w4; - pre_calc->at(pre_calc_index) = pc; - pre_calc_index += 1; } } @@ -245,6 +242,14 @@ void RROIAlignForwardCompute(const nnvm::NodeAttrs& attrs, }) } +template +void RROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + LOG(FATAL) << "RROIAlign: Backward is not supported."; +} + DMLC_REGISTER_PARAMETER(RROIAlignParam); NNVM_REGISTER_OP(_contrib_RROIAlign) @@ -307,10 +312,15 @@ IEEE Transactions on Multimedia, 2018. return true; }) .set_attr("FCompute", RROIAlignForwardCompute) -.set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "Input data to the pooling operator, a 4D Feature maps") .add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array") .add_arguments(RROIAlignParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_RROIAlign) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser) +.set_attr("FCompute", RROIAlignBackwardCompute); + } // namespace op } // namespace mxnet