Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
miguel committed Mar 26, 2019
2 parents 289cb26 + 4a7dcc4 commit 619f520
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 9 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ image = ...
predictions = coco_demo.run_on_opencv_image(image)
```

### Use it on an arbitrary GPU device
For some cases, while multi-GPU devices are installed in a machine, a possible situation is that
we only have accesse to a specified GPU device (e.g. CUDA:1 or CUDA:2) for inference, testing or training.
Here, the repository currently supports two methods to control devices.

#### 1. using CUDA_VISIBLE_DEVICES environment variable (Recommend)
Here is an example for Mask R-CNN R-50 FPN quick on the second device (CUDA:1):
```bash
export CUDA_VISIBLE_DEVICES=1
python tools/train_net.py --config-file=configs/quick_schedules/e2e_mask_rcnn_R_50_FPN_quick.yaml
```
Now, the session will be totally loaded on the second GPU device (CUDA:1).

#### 2. using MODEL.DEVICE flag
In addition, the program could run on a sepcific GPU device by setting `MODEL.DEVICE` flag.
```bash
python tools/train_net.py --config-file=configs/quick_schedules/e2e_mask_rcnn_R_50_FPN_quick.yaml MODEL.DEVICE cuda:1
```
Where, we add a `MODEL.DEVICE cuda:1` flag to configure the target device.
*Pay attention, there is still a small part of memory stored in `cuda:0` for some reasons.*

## Perform training on COCO dataset

For the following examples to work, you need to first install `maskrcnn_benchmark`.
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
return output;
}

AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] {
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
ROIAlignForward_cpu_kernel<scalar_t>(
output_size,
input.data<scalar_t>(),
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ at::Tensor nms_cpu(const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {
at::Tensor result;
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
});
return result;
Expand Down
8 changes: 6 additions & 2 deletions maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
Expand Down Expand Up @@ -263,6 +264,8 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");

at::cuda::CUDAGuard device_guard(input.device());

auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
Expand All @@ -280,7 +283,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
return output;
}

AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] {
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input.contiguous().data<scalar_t>(),
Expand Down Expand Up @@ -311,6 +314,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
const int sampling_ratio) {
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(grad.device());

auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
Expand All @@ -326,7 +330,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
return grad_input;
}

AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIAlign_backward", [&] {
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] {
RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
Expand Down
8 changes: 6 additions & 2 deletions maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
Expand Down Expand Up @@ -115,6 +116,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");

at::cuda::CUDAGuard device_guard(input.device());

auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
Expand All @@ -134,7 +137,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
return std::make_tuple(output, argmax);
}

AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIPool_forward", [&] {
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] {
RoIPoolFForward<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input.contiguous().data<scalar_t>(),
Expand Down Expand Up @@ -167,6 +170,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
// TODO add more checks
at::cuda::CUDAGuard device_guard(grad.device());

auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
Expand All @@ -182,7 +186,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
return grad_input;
}

AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIPool_backward", [&] {
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] {
RoIPoolFBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
Expand Down
11 changes: 8 additions & 3 deletions maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// [email protected]
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
Expand Down Expand Up @@ -111,6 +112,8 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");

at::cuda::CUDAGuard device_guard(logits.device());

const int num_samples = logits.size(0);

auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
Expand All @@ -125,7 +128,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
return losses;
}

AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_forward", [&] {
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] {
SigmoidFocalLossForward<scalar_t><<<grid, block, 0, stream>>>(
losses_size,
logits.contiguous().data<scalar_t>(),
Expand Down Expand Up @@ -156,7 +159,9 @@ at::Tensor SigmoidFocalLoss_backward_cuda(

const int num_samples = logits.size(0);
AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes");


at::cuda::CUDAGuard device_guard(logits.device());

auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
auto d_logits_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand All @@ -169,7 +174,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
return d_logits;
}

AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_backward", [&] {
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] {
SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, stream>>>(
d_logits_size,
logits.contiguous().data<scalar_t>(),
Expand Down
3 changes: 3 additions & 0 deletions maskrcnn_benchmark/csrc/cuda/nms.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>

#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
Expand Down Expand Up @@ -70,6 +71,8 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
using scalar_t = float;
AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(boxes.device());

auto scores = boxes.select(1, 4);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto boxes_sorted = boxes.index_select(0, order_t);
Expand Down

0 comments on commit 619f520

Please sign in to comment.