Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Make mrcnn_mask_target arg mask_size a 2d tuple #16567

Merged
merged 1 commit into from
Oct 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/operator/contrib/mrcnn_mask_target-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,17 @@ namespace mrcnn_index {
struct MRCNNMaskTargetParam : public dmlc::Parameter<MRCNNMaskTargetParam> {
int num_rois;
int num_classes;
int mask_size;
int sample_ratio;
mxnet::TShape mask_size;

DMLC_DECLARE_PARAMETER(MRCNNMaskTargetParam) {
DMLC_DECLARE_FIELD(num_rois)
.describe("Number of sampled RoIs.");
DMLC_DECLARE_FIELD(num_classes)
.describe("Number of classes.");
DMLC_DECLARE_FIELD(mask_size)
.describe("Size of the pooled masks.");
.set_expect_ndim(2).enforce_nonzero()
.describe("Size of the pooled masks height and width: (h, w).");
DMLC_DECLARE_FIELD(sample_ratio).set_default(2)
.describe("Sampling ratio of ROI align. Set to -1 to use adaptative size.");
}
Expand Down Expand Up @@ -91,7 +92,8 @@ inline bool MRCNNMaskTargetShape(const NodeAttrs& attrs,
CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";

// out: 2 * (B, N, C, MS, MS)
auto oshape = Shape5(batch_size, num_rois, param.num_classes, param.mask_size, param.mask_size);
auto oshape = Shape5(batch_size, num_rois, param.num_classes,
param.mask_size[0], param.mask_size[1]);
out_shape->clear();
out_shape->push_back(oshape);
out_shape->push_back(oshape);
Expand Down
10 changes: 6 additions & 4 deletions src/operator/contrib/mrcnn_mask_target.cu
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,16 @@ __global__ void MRCNNMaskTargetKernel(const DType *rois,
int num_gtmasks,
int gt_height,
int gt_width,
int mask_size,
int mask_size_h,
int mask_size_w,
int sample_ratio) {
// computing sampled_masks
RoIAlignForward(gt_masks, rois, matches, total_out_el,
num_classes, gt_height, gt_width, mask_size, mask_size,
num_classes, gt_height, gt_width, mask_size_h, mask_size_w,
sample_ratio, num_rois, num_gtmasks, sampled_masks);
// computing mask_cls
int num_masks = batch_size * num_rois * num_classes;
int mask_vol = mask_size * mask_size;
int mask_vol = mask_size_h * mask_size_w;
for (int mask_idx = blockIdx.x; mask_idx < num_masks; mask_idx += gridDim.x) {
int cls_idx = mask_idx % num_classes;
int roi_idx = (mask_idx / num_classes) % num_rois;
Expand Down Expand Up @@ -252,7 +253,8 @@ void MRCNNMaskTargetRun<gpu>(const MRCNNMaskTargetParam& param, const std::vecto
(rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_,
out_masks.dptr_, out_mask_cls.dptr_,
num_el, batch_size, param.num_classes, param.num_rois,
num_gtmasks, gt_height, gt_width, param.mask_size, param.sample_ratio);
num_gtmasks, gt_height, gt_width,
param.mask_size[0], param.mask_size[1], param.sample_ratio);
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNMaskTargetKernel);
});
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_op_mrcnn_mask_target():

num_rois = 2
num_classes = 4
mask_size = 3
mask_size = (3, 3)
ctx = mx.gpu(0)
# (B, N, 4)
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
Expand Down