Skip to content

Commit

Permalink
[Refactor] Repalce the implementation of rotated_feature_align with m…
Browse files Browse the repository at this point in the history
…lu_ops (open-mmlab#2659)
  • Loading branch information
tudejiang79 authored and Danielmic committed Jun 30, 2023
1 parent 6fdf853 commit 162142d
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ We implement common ops used in detection, segmentation, etc.
| PointsInBoxes ||| | | |
| PointsInPolygons | || | | |
| PSAMask |||| ||
| RotatedFeatureAlign ||| | | |
| RotatedFeatureAlign ||| | | |
| RoIPointPool3d | ||| | |
| RoIPool | ||| ||
| RoIAlignRotated |||| | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| PointsInBoxes ||| | | |
| PointsInPolygons | || | | |
| PSAMask |||| ||
| RotatedFeatureAlign ||| | | |
| RotatedFeatureAlign ||| | | |
| RoIPointPool3d | ||| | |
| RoIPool | ||| ||
| RoIAlignRotated |||| | |
Expand Down
115 changes: 115 additions & 0 deletions mmcv/ops/csrc/pytorch/mlu/rotated_feature_align_mlu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"

void RotatedFeatureAlignForwardMLUKernelLauncher(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points,
Tensor output) {
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(features.dim());
auto features_ =
torch_mlu::cnnl::ops::cnnl_contiguous(features, memory_format);
auto best_bboxes_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
best_bboxes, best_bboxes.suggest_memory_format());
auto output_contiguous =
torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format);

MluOpTensorDescriptor features_desc, best_bboxes_desc, output_desc;
features_desc.set_with_layout(features_, MLUOP_LAYOUT_NHWC);
best_bboxes_desc.set(best_bboxes_contiguous);
output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC);

// get ptr of tensors
auto features_impl = torch_mlu::getMluTensorImpl(features_);
auto features_ptr = features_impl->cnnlMalloc();
auto best_bboxes_impl = torch_mlu::getMluTensorImpl(best_bboxes_contiguous);
auto best_bboxes_ptr = best_bboxes_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
auto output_ptr = output_impl->cnnlMalloc();

// get compute handle
auto handle = mluOpGetCurrentHandle();
mluOpRotatedFeatureAlignForward(
handle, features_desc.desc(), features_ptr, best_bboxes_desc.desc(),
best_bboxes_ptr, spatial_scale, points, output_desc.desc(), output_ptr);

output.copy_(output_contiguous);
}

void RotatedFeatureAlignBackwardMLUKernelLauncher(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points,
Tensor bottom_grad) {
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(top_grad.dim());
auto top_grad_ =
torch_mlu::cnnl::ops::cnnl_contiguous(top_grad, memory_format);
auto best_bboxes_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
best_bboxes, best_bboxes.suggest_memory_format());
auto bottom_grad_ =
torch_mlu::cnnl::ops::cnnl_contiguous(bottom_grad, memory_format);

// get ptr of tensors
auto top_grad_impl = torch_mlu::getMluTensorImpl(top_grad_);
auto top_grad_ptr = top_grad_impl->cnnlMalloc();
auto best_bboxes_impl = torch_mlu::getMluTensorImpl(best_bboxes_contiguous);
auto best_bboxes_ptr = best_bboxes_impl->cnnlMalloc();
auto bottom_grad_impl = torch_mlu::getMluTensorImpl(bottom_grad_);
auto bottom_grad_ptr = bottom_grad_impl->cnnlMalloc();

MluOpTensorDescriptor top_grad_desc, best_bboxes_desc, bottom_grad_desc;
top_grad_desc.set_with_layout(top_grad_, MLUOP_LAYOUT_NHWC);
best_bboxes_desc.set(best_bboxes_contiguous);
bottom_grad_desc.set_with_layout(bottom_grad_, MLUOP_LAYOUT_NHWC);

// get compute handle
auto handle = mluOpGetCurrentHandle();
mluOpRotatedFeatureAlignBackward(handle, top_grad_desc.desc(), top_grad_ptr,
best_bboxes_desc.desc(), best_bboxes_ptr,
spatial_scale, points,
bottom_grad_desc.desc(), bottom_grad_ptr);
bottom_grad.copy_(bottom_grad_);
}

void rotated_feature_align_forward_mlu(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output) {
RotatedFeatureAlignForwardMLUKernelLauncher(features, best_bboxes,
spatial_scale, points, output);
}

void rotated_feature_align_backward_mlu(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad) {
RotatedFeatureAlignBackwardMLUKernelLauncher(
top_grad, best_bboxes, spatial_scale, points, bottom_grad);
}

void rotated_feature_align_forward_impl(const Tensor features,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor output);

void rotated_feature_align_backward_impl(const Tensor top_grad,
const Tensor best_bboxes,
const float spatial_scale,
const int points, Tensor bottom_grad);

REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, MLU,
rotated_feature_align_forward_mlu);
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, MLU,
rotated_feature_align_backward_mlu);
6 changes: 5 additions & 1 deletion tests/test_ops/test_rotated_feature_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from mmcv.ops import rotated_feature_align
from mmcv.utils import IS_CUDA_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE


@pytest.mark.skipif(
Expand All @@ -13,6 +13,10 @@
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'cpu',
marks=pytest.mark.skipif(
Expand Down

0 comments on commit 162142d

Please sign in to comment.