Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nms_cuda signature update #945

Merged
merged 1 commit into from
May 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torchvision/csrc/cpu/nms_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ template <typename scalar_t>
at::Tensor nms_cpu_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {
const float iou_threshold) {
AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor");
AT_ASSERTM(
Expand Down Expand Up @@ -61,7 +61,7 @@ at::Tensor nms_cpu_kernel(
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);
if (ovr >= threshold)
if (ovr >= iou_threshold)
suppressed[j] = 1;
}
}
Expand All @@ -71,11 +71,11 @@ at::Tensor nms_cpu_kernel(
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {
const float iou_threshold) {
auto result = at::empty({0}, dets.options());

AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
result = nms_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
});
return result;
}
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ at::Tensor ROIAlign_backward_cpu(
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float threshold);
const float iou_threshold);
57 changes: 28 additions & 29 deletions torchvision/csrc/cuda/nms_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ __device__ inline float devIoU(T const* const a, T const* const b) {
template <typename T>
__global__ void nms_kernel(
const int n_boxes,
const float nms_overlap_thresh,
const float iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask) {
const int row_start = blockIdx.y;
Expand All @@ -37,32 +37,30 @@ __global__ void nms_kernel(
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);

__shared__ T block_boxes[threadsPerBlock * 5];
__shared__ T block_boxes[threadsPerBlock * 4];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
block_boxes[threadIdx.x * 4 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];
block_boxes[threadIdx.x * 4 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];
block_boxes[threadIdx.x * 4 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];
block_boxes[threadIdx.x * 4 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];
}
__syncthreads();

if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 5;
const T* cur_box = dev_boxes + cur_box_idx * 4;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU<T>(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
if (devIoU<T>(cur_box, block_boxes + i * 4) > iou_threshold) {
t |= 1ULL << i;
}
}
Expand All @@ -71,33 +69,34 @@ __global__ void nms_kernel(
}
}

// boxes is a N x 5 tensor
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
at::Tensor nms_cuda(const at::Tensor& dets,
const at::Tensor& scores,
float iou_threshold) {
using scalar_t = float;
AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(boxes.device());
AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());

auto scores = boxes.select(1, 4);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto boxes_sorted = boxes.index_select(0, order_t);
auto dets_sorted = dets.index_select(0, order_t);

int boxes_num = boxes.size(0);
int dets_num = dets.size(0);

const int col_blocks = at::cuda::ATenCeilDiv(boxes_num, threadsPerBlock);
const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);

at::Tensor mask =
at::empty({boxes_num * col_blocks}, boxes.options().dtype(at::kLong));
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));

dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
boxes_sorted.type(), "nms_kernel_cuda", [&] {
dets_sorted.type(), "nms_kernel_cuda", [&] {
nms_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
boxes_num,
nms_overlap_thresh,
boxes_sorted.data<scalar_t>(),
dets_num,
iou_threshold,
dets_sorted.data<scalar_t>(),
(unsigned long long*)mask.data<int64_t>());
});

Expand All @@ -108,11 +107,11 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

at::Tensor keep =
at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data<int64_t>();

int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;

Expand Down
5 changes: 4 additions & 1 deletion torchvision/csrc/cuda/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,7 @@ at::Tensor ROIPool_backward_cuda(
const int height,
const int width);

at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
at::Tensor nms_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
7 changes: 3 additions & 4 deletions torchvision/csrc/nms.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,19 @@
at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {
const float iou_threshold) {
if (dets.device().is_cuda()) {
#ifdef WITH_CUDA
if (dets.numel() == 0) {
at::cuda::CUDAGuard device_guard(dets.device());
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto b = at::cat({dets, scores.unsqueeze(1)}, 1);
return nms_cuda(b, threshold);
return nms_cuda(dets, scores, iou_threshold);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}

at::Tensor result = nms_cpu(dets, scores, threshold);
at::Tensor result = nms_cpu(dets, scores, iou_threshold);
return result;
}