Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Extensions for Mask R-CNN for improved performance #379

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions maskrcnn_benchmark/csrc/box_encode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once
#include "cuda/vision.h"
#ifndef _box_encode_h_
#define _box_encode_h_

std::vector<at::Tensor> box_encode(at::Tensor boxes, at::Tensor anchors, float wx, float wy, float ww, float wh){
std::vector<at::Tensor> result = box_encode_cuda(boxes, anchors, wx, wy, ww, wh);
return result;
}

#endif
12 changes: 12 additions & 0 deletions maskrcnn_benchmark/csrc/box_iou.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once
#include "cuda/vision.h"
#ifndef _box_iou_h_
#define _box_iou_h_

at::Tensor box_iou(at::Tensor box1, at::Tensor box2){
at::Tensor result = box_iou_cuda(box1, box2);
return result;
}

#endif

85 changes: 85 additions & 0 deletions maskrcnn_benchmark/csrc/cuda/box_encode.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include <torch/torch.h>
#include <vector>
#include <iostream>

__global__ void box_encode_kernel(float *targets_dx, float *targets_dy, float *targets_dw, float *targets_dh,
float4 *boxes, float4 *anchors, float wx, float wy, float ww, float wh,
size_t gt, size_t idxJump) {

int idx = blockIdx.x * blockDim.x + threadIdx.x;
size_t row_offset;
float anchors_x1, anchors_x2, anchors_y1, anchors_y2,
boxes_x1, boxes_x2, boxes_y1, boxes_y2, ex_w, ex_h,
ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y;

for (int i = idx; i < gt; i += idxJump){
row_offset = i;
anchors_x1 = anchors[row_offset].x;
anchors_y1 = anchors[row_offset].y;
anchors_x2 = anchors[row_offset].z;
anchors_y2 = anchors[row_offset].w;

boxes_x1 = boxes[row_offset].x;
boxes_y1 = boxes[row_offset].y;
boxes_x2 = boxes[row_offset].z;
boxes_y2 = boxes[row_offset].w;

ex_w = anchors_x2 - anchors_x1 + 1;
ex_h = anchors_y2 - anchors_y1 + 1;
ex_ctr_x = anchors_x1 + 0.5 * ex_w;
ex_ctr_y = anchors_y1 + 0.5 * ex_h;

gt_w = boxes_x2 - boxes_x1 + 1;
gt_h = boxes_y2 - boxes_y1 + 1;
gt_ctr_x = boxes_x1 + 0.5 * gt_w;
gt_ctr_y = boxes_y1 + 0.5 * gt_h;

targets_dx[i] = wx * (gt_ctr_x - ex_ctr_x) / ex_w;
targets_dy[i] = wy * (gt_ctr_y - ex_ctr_y) / ex_h;
targets_dw[i] = ww * log(gt_w / ex_w);
targets_dh[i] = wh * log(gt_h / ex_h);
}

}


std::vector<at::Tensor> box_encode_cuda(at::Tensor boxes, at::Tensor anchors, float wx, float wy, float ww, float wh){

int minGridSize;
int blockSize;

cudaOccupancyMaxPotentialBlockSize(&minGridSize,
&blockSize,
(void*) box_encode_kernel,
0, // dynamic memory
0); // maximum utilized threads
long size = boxes.size(0);
auto targets_dx = at::ones({size}, torch::CUDA(at::kFloat));
auto targets_dy = at::ones({size}, torch::CUDA(at::kFloat));
auto targets_dw = at::ones({size}, torch::CUDA(at::kFloat));
auto targets_dh = at::ones({size}, torch::CUDA(at::kFloat));

dim3 gridDim(minGridSize);
dim3 blockDim(blockSize);
int idxJump = minGridSize * blockSize;
auto stream = at::cuda::getCurrentCUDAStream();
box_encode_kernel<<<gridDim,blockDim,0,stream.stream()>>>(targets_dx.data<float>(),
targets_dy.data<float>(),
targets_dw.data<float>(),
targets_dh.data<float>(),
(float4*) boxes.data<float>(),
(float4*) anchors.data<float>(),
wx, wy, ww, wh,
size, idxJump);

std::vector<at::Tensor> result;
result.push_back(targets_dx);
result.push_back(targets_dy);
result.push_back(targets_dw);
result.push_back(targets_dh);
return result;
}
75 changes: 75 additions & 0 deletions maskrcnn_benchmark/csrc/cuda/box_iou.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include <torch/torch.h>
#include <iostream>

__global__ void box_iou_cuda_kernel(float *box_iou, float4 *box1, float4 *box2, long M,
long N, int idxJump) {

int idx = blockIdx.x*blockDim.x + threadIdx.x;
size_t b1_idx, b2_idx, b1_row_offset, b2_row_offset;
float xmin1, xmin2, xmax1, xmax2, ymin1, ymin2, ymax1, ymax2;
float x_tl, y_tl, x_br, y_br, w, h, inter, area1, area2, iou;

for (long i = idx; i < M * N; i += idxJump){

b1_idx = i / N;
b2_idx = i % N;
b1_row_offset = b1_idx;
b2_row_offset = b2_idx;

xmin1 = box1[b1_row_offset].x;
ymin1 = box1[b1_row_offset].y;
xmax1 = box1[b1_row_offset].z;
ymax1 = box1[b1_row_offset].w;
xmin2 = box2[b2_row_offset].x;
ymin2 = box2[b2_row_offset].y;
xmax2 = box2[b2_row_offset].z;
ymax2 = box2[b2_row_offset].w;

x_tl = fmaxf(xmin1, xmin2);
y_tl = fmaxf(ymin1, ymin2);

x_br = fminf(xmax1, xmax2);
y_br = fminf(ymax1, ymax2);
w = (x_br - x_tl + 1) < 0 ? 0.0f : (x_br - x_tl + 1);
h = (y_br - y_tl + 1) < 0 ? 0.0f : (y_br - y_tl + 1);

inter = w * h;
area1 = (xmax1 - xmin1 + 1) * (ymax1 - ymin1 + 1);
area2 = (xmax2 - xmin2 + 1) * (ymax2 - ymin2 + 1);
iou = inter / (area1 + area2 - inter);
box_iou[b1_idx * N + b2_idx] = iou;
}

}

at::Tensor box_iou_cuda(at::Tensor box1, at::Tensor box2){

int minGridSize;
int blockSize;

cudaOccupancyMaxPotentialBlockSize(&minGridSize,
&blockSize,
(void*) box_iou_cuda_kernel,
0, // dynamic memory
0); // maximum utilized threads

long M = box1.size(0);
long N = box2.size(0);
auto box_iou = at::ones({M, N}, torch::CUDA(at::kFloat));

dim3 gridDim(minGridSize);
dim3 blockDim(blockSize);
int idxJump = minGridSize * blockSize;
auto stream = at::cuda::getCurrentCUDAStream();
box_iou_cuda_kernel<<<gridDim, blockDim, 0, stream.stream()>>>(box_iou.data<float>(),
(float4*) box1.data<float>(),
(float4*) box2.data<float>(),
M, N,
idxJump);
return box_iou;
}

Loading