From f223f5e88881208d9635bb0d720ccaf1325eea8b Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 26 Mar 2019 11:14:05 +0100 Subject: [PATCH 1/4] Merge branch 'master' of /home/braincreator/projects/maskrcnn-benchmark with conflicts. --- docker/Dockerfile | 14 +- .../csrc/cuda/SigmoidFocalLoss_cuda.cu | 188 ++++++++++++++++++ 2 files changed, 197 insertions(+), 5 deletions(-) create mode 100644 maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu diff --git a/docker/Dockerfile b/docker/Dockerfile index 39b508258..b2806a5b5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -31,8 +31,9 @@ ENV CONDA_AUTO_UPDATE_CONDA=false RUN conda install -y ipython RUN pip install ninja yacs cython matplotlib opencv-python -# Install PyTorch 1.0 Nightly and OpenCV -RUN conda install -y pytorch-nightly -c pytorch \ +# Install PyTorch 1.0 Nightly +ARG CUDA +RUN echo conda install pytorch-nightly cudatoolkit=${CUDA} -c pytorch \ && conda clean -ya # Install TorchVision master @@ -46,8 +47,11 @@ RUN git clone https://github.com/cocodataset/cocoapi.git \ && python setup.py build_ext install # install PyTorch Detection -RUN git clone https://github.com/facebookresearch/maskrcnn-benchmark.git \ - && cd maskrcnn-benchmark \ - && python setup.py build develop +ADD . ./maskrcnn-benchmark +RUN cd maskrcnn-benchmark && python setup.py build develop + +#RUN git clone https://github.com/facebookresearch/maskrcnn-benchmark.git \ +# && cd maskrcnn-benchmark \ +# && python setup.py build develop WORKDIR /maskrcnn-benchmark diff --git a/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu new file mode 100644 index 000000000..7d40767bb --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu @@ -0,0 +1,188 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This file is modified from https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu +// Cheng-Yang Fu +// cyfu@cs.unc.edu +#include +#include + +#include +#include +#include + +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__global__ void SigmoidFocalLossForward(const int nthreads, + const T* logits, + const int* targets, + const int num_classes, + const float gamma, + const float alpha, + const int num, + T* losses) { + CUDA_1D_KERNEL_LOOP(i, nthreads) { + + int n = i / num_classes; + int d = i % num_classes; // current class[0~79]; + int t = targets[n]; // target class [1~80]; + + // Decide it is positive or negative case. + T c1 = (t == (d+1)); + T c2 = (t>=0 & t != (d+1)); + + T zn = (1.0 - alpha); + T zp = (alpha); + + // p = 1. / 1. + expf(-x); p = sigmoid(x) + T p = 1. / (1. + expf(-logits[i])); + + // (1-p)**gamma * log(p) where + T term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN)); + + // p**gamma * log(1-p) + T term2 = powf(p, gamma) * + (-1. * logits[i] * (logits[i] >= 0) - + logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))); + + losses[i] = 0.0; + losses[i] += -c1 * term1 * zp; + losses[i] += -c2 * term2 * zn; + + } // CUDA_1D_KERNEL_LOOP +} // SigmoidFocalLossForward + + +template +__global__ void SigmoidFocalLossBackward(const int nthreads, + const T* logits, + const int* targets, + const T* d_losses, + const int num_classes, + const float gamma, + const float alpha, + const int num, + T* d_logits) { + CUDA_1D_KERNEL_LOOP(i, nthreads) { + + int n = i / num_classes; + int d = i % num_classes; // current class[0~79]; + int t = targets[n]; // target class [1~80], 0 is background; + + // Decide it is positive or negative case. + T c1 = (t == (d+1)); + T c2 = (t>=0 & t != (d+1)); + + T zn = (1.0 - alpha); + T zp = (alpha); + // p = 1. / 1. + expf(-x); p = sigmoid(x) + T p = 1. / (1. + expf(-logits[i])); + + // (1-p)**g * (1 - p - g*p*log(p) + T term1 = powf((1. - p), gamma) * + (1. - p - (p * gamma * logf(max(p, FLT_MIN)))); + + // (p**g) * (g*(1-p)*log(1-p) - p) + T term2 = powf(p, gamma) * + ((-1. * logits[i] * (logits[i] >= 0) - + logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) * + (1. - p) * gamma - p); + d_logits[i] = 0.0; + d_logits[i] += -c1 * term1 * zp; + d_logits[i] += -c2 * term2 * zn; + d_logits[i] = d_logits[i] * d_losses[i]; + + } // CUDA_1D_KERNEL_LOOP +} // SigmoidFocalLossBackward + + +at::Tensor SigmoidFocalLoss_forward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const int num_classes, + const float gamma, + const float alpha) { + AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor"); + AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor"); + AT_ASSERTM(logits.dim() == 2, "logits should be NxClass"); + + const int num_samples = logits.size(0); + + auto losses = at::empty({num_samples, logits.size(1)}, logits.options()); + auto losses_size = num_samples * logits.size(1); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L)); + dim3 block(512); + + if (losses.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return losses; + } + + AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] { + SigmoidFocalLossForward<<>>( + losses_size, + logits.contiguous().data(), + targets.contiguous().data(), + num_classes, + gamma, + alpha, + num_samples, + losses.data()); + }); + THCudaCheck(cudaGetLastError()); + return losses; +} + + +at::Tensor SigmoidFocalLoss_backward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const at::Tensor& d_losses, + const int num_classes, + const float gamma, + const float alpha) { + AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor"); + AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor"); + AT_ASSERTM(d_losses.type().is_cuda(), "d_losses must be a CUDA tensor"); + + AT_ASSERTM(logits.dim() == 2, "logits should be NxClass"); + + const int num_samples = logits.size(0); + AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes"); + + auto d_logits = at::zeros({num_samples, num_classes}, logits.options()); + auto d_logits_size = num_samples * logits.size(1); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L)); + dim3 block(512); + + if (d_logits.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return d_logits; + } + + AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] { + SigmoidFocalLossBackward<<>>( + d_logits_size, + logits.contiguous().data(), + targets.contiguous().data(), + d_losses.contiguous().data(), + num_classes, + gamma, + alpha, + num_samples, + d_logits.data()); + }); + + THCudaCheck(cudaGetLastError()); + return d_logits; +} + From b8d6ab2e1455ef635227ea69935263b8a7bd19fa Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 26 Mar 2019 11:17:48 +0100 Subject: [PATCH 2/4] rolls back the breaking AT dispatch changes (#555) --- maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp | 2 +- maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp | 2 +- maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu | 4 ++-- maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu | 4 ++-- maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp b/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp index cd9fde2ae..d35aedf27 100644 --- a/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp +++ b/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp @@ -239,7 +239,7 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, return output; } - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] { + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { ROIAlignForward_cpu_kernel( output_size, input.data(), diff --git a/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp b/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp index 639ca472e..1153dea04 100644 --- a/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp +++ b/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp @@ -68,7 +68,7 @@ at::Tensor nms_cpu(const at::Tensor& dets, const at::Tensor& scores, const float threshold) { at::Tensor result; - AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] { + AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] { result = nms_cpu_kernel(dets, scores, threshold); }); return result; diff --git a/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu index 170771aa8..1142fb375 100644 --- a/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu +++ b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu @@ -280,7 +280,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, return output; } - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] { + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { RoIAlignForward<<>>( output_size, input.contiguous().data(), @@ -326,7 +326,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, return grad_input; } - AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIAlign_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { RoIAlignBackwardFeature<<>>( grad.numel(), grad.contiguous().data(), diff --git a/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu index cef3beaa4..8f072ffc2 100644 --- a/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu +++ b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu @@ -134,7 +134,7 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, return std::make_tuple(output, argmax); } - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIPool_forward", [&] { + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] { RoIPoolFForward<<>>( output_size, input.contiguous().data(), @@ -182,7 +182,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, return grad_input; } - AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIPool_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] { RoIPoolFBackward<<>>( grad.numel(), grad.contiguous().data(), diff --git a/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu index cd9b4c96b..7d40767bb 100644 --- a/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu +++ b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu @@ -125,7 +125,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda( return losses; } - AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_forward", [&] { + AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] { SigmoidFocalLossForward<<>>( losses_size, logits.contiguous().data(), @@ -169,7 +169,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda( return d_logits; } - AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] { SigmoidFocalLossBackward<<>>( d_logits_size, logits.contiguous().data(), From 1fc6a8ddedcc02cd331763f0b8c3d9ffcc6dba99 Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 26 Mar 2019 11:21:55 +0100 Subject: [PATCH 3/4] revert accidental docker changes --- docker/Dockerfile | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index b2806a5b5..58b924cf4 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -47,11 +47,8 @@ RUN git clone https://github.com/cocodataset/cocoapi.git \ && python setup.py build_ext install # install PyTorch Detection -ADD . ./maskrcnn-benchmark -RUN cd maskrcnn-benchmark && python setup.py build develop - -#RUN git clone https://github.com/facebookresearch/maskrcnn-benchmark.git \ -# && cd maskrcnn-benchmark \ -# && python setup.py build develop +RUN git clone https://github.com/facebookresearch/maskrcnn-benchmark.git \ + && cd maskrcnn-benchmark \ + && python setup.py build develop WORKDIR /maskrcnn-benchmark From 6eeca85abf0aa7cf5b0c6991d9f3d74a77da4af6 Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 26 Mar 2019 11:24:44 +0100 Subject: [PATCH 4/4] revert accidental docker changes (2) --- docker/Dockerfile | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 58b924cf4..39b508258 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -31,9 +31,8 @@ ENV CONDA_AUTO_UPDATE_CONDA=false RUN conda install -y ipython RUN pip install ninja yacs cython matplotlib opencv-python -# Install PyTorch 1.0 Nightly -ARG CUDA -RUN echo conda install pytorch-nightly cudatoolkit=${CUDA} -c pytorch \ +# Install PyTorch 1.0 Nightly and OpenCV +RUN conda install -y pytorch-nightly -c pytorch \ && conda clean -ya # Install TorchVision master