This repository has been archived by the owner on Dec 15, 2021. It is now read-only.
forked from facebookresearch/maskrcnn-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for running on arbitrary CUDA device. (facebookresearch#537)
* support for any one cuda device * Revert "support for any one cuda device" This reverts commit 0197e4e. * support runnning for anyone cuda device * using safe CUDAGuard rather than intrinsic CUDASetDevice * supplement a header dependency (test passed) * Support for arbitrary GPU device. * Support for arbitrary GPU device. * add docs for two method to control devices
- Loading branch information
1 parent
c2e72b3
commit 7755314
Showing
5 changed files
with
38 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
// [email protected] | ||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <ATen/cuda/CUDAGuard.h> | ||
|
||
#include <THC/THC.h> | ||
#include <THC/THCAtomics.cuh> | ||
|
@@ -111,6 +112,8 @@ at::Tensor SigmoidFocalLoss_forward_cuda( | |
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor"); | ||
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass"); | ||
|
||
at::cuda::CUDAGuard device_guard(logits.device()); | ||
|
||
const int num_samples = logits.size(0); | ||
|
||
auto losses = at::empty({num_samples, logits.size(1)}, logits.options()); | ||
|
@@ -156,7 +159,9 @@ at::Tensor SigmoidFocalLoss_backward_cuda( | |
|
||
const int num_samples = logits.size(0); | ||
AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes"); | ||
|
||
|
||
at::cuda::CUDAGuard device_guard(logits.device()); | ||
|
||
auto d_logits = at::zeros({num_samples, num_classes}, logits.options()); | ||
auto d_logits_size = num_samples * logits.size(1); | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters