Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MMSIG] Support the deployment of SparseInst on TensorRT #2541

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
| [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N |
| [SparseInst](https://github.com/open-mmlab/mmdetection/blob/main/projects/SparseInst) | Instance Segmentation | Y | Y | N | N | N |
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ cv2.imwrite('output_detection.png', img)
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
| [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N |
| [SparseInst](https://github.com/open-mmlab/mmdetection/blob/main/projects/SparseInst) | Instance Segmentation | Y | Y | N | N | N |
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def postprocessing_results(self,
masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
if model_type in ['RTMDet', 'CondInst']:
if model_type in ['RTMDet', 'CondInst', 'SparseInst']:
export_postprocess_mask = True
else:
export_postprocess_mask = False
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
from . import rtmdet_ins_head # noqa: F401,F403
from . import solo_head # noqa: F401,F403
from . import solov2_head # noqa: F401,F403
from . import sparseinst_head # noqa: F401,F403
from . import yolo_head # noqa: F401,F403
from . import yolox_head # noqa: F401,F403
58 changes: 58 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
import torch.nn.functional as F
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@torch.jit.script
def rescoring_mask(scores, mask_pred, masks):
mask_pred_ = mask_pred.float()
Boomerl marked this conversation as resolved.
Show resolved Hide resolved
return scores * ((masks * mask_pred_).sum([2, 3]) /
(mask_pred_.sum([2, 3]) + 1e-6))


@FUNCTION_REWRITER.register_rewriter(
'projects.SparseInst.sparseinst.SparseInst.predict')
def sparseinst__predict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when build trt engine, there's warning

Profile kMIN values are not self-consistent. /encoder/ppm/Resize_1: IResizeLayer requires that if input dimension is zero, output dimension must be zero too (axis = 3 input dimension = (+ (CEIL_DIV (+ width -800) 416) 1) output dimension = (+ (CEIL_DIV (+ width -32) 32) 1))
Condition '==' violated: 0 != 10. Instruction: CHECK_EQUAL 0 10.

This might lead to failure in some cases.
Running test.py failed after some iterations

Epoch(test) [  50/5000]    eta: 0:03:57  time: 0.0470  data_time: 0.0014  memory: 495  
12/15 02:59:38 - mmengine - INFO - Epoch(test) [ 100/5000]    eta: 0:03:35  time: 0.0413  data_time: 0.0013  memory: 506  
12/15 02:59:40 - mmengine - INFO - Epoch(test) [ 150/5000]    eta: 0:03:24  time: 0.0349  data_time: 0.0014  memory: 501  
12/15 02:59:42 - mmengine - INFO - Epoch(test) [ 200/5000]    eta: 0:03:16  time: 0.0358  data_time: 0.0014  memory: 501  
12/15 02:59:44 - mmengine - INFO - Epoch(test) [ 250/5000]    eta: 0:03:10  time: 0.0344  data_time: 0.0013  memory: 505  
12/15 02:59:46 - mmengine - INFO - Epoch(test) [ 300/5000]    eta: 0:03:07  time: 0.0371  data_time: 0.0013  memory: 501  
12/15 02:59:48 - mmengine - INFO - Epoch(test) [ 350/5000]    eta: 0:03:06  time: 0.0420  data_time: 0.0015  memory: 508  
12/15 02:59:50 - mmengine - INFO - Epoch(test) [ 400/5000]    eta: 0:03:03  time: 0.0340  data_time: 0.0013  memory: 501  
12/15 02:59:52 - mmengine - INFO - Epoch(test) [ 450/5000]    eta: 0:03:01  time: 0.0420  data_time: 0.0014  memory: 508  
12/15 02:59:54 - mmengine - INFO - Epoch(test) [ 500/5000]    eta: 0:02:59  time: 0.0390  data_time: 0.0014  memory: 502  
[12/15/2023-02:59:56] [TRT] [E] 7: [shapeMachine.cpp::executeContinuation::864] Error Code 7: Internal Error (/encoder/ppm/Resize_1: IResizeLayer requires that if input dimension is zero, output dimension must be zero too (axis = 3 input dimension = (+ (CEIL_DIV (+ width -800) 416) 1) output dimension = (+ (CEIL_DIV (+ width -32) 32) 1))
 Condition '==' violated: 0 != 12. Instruction: CHECK_EQUAL 0 12.)
Traceback (most recent call last):
  File "tools/test.py", line 159, in <module>
    main()
  File "tools/test.py", line 153, in main
    runner.test()
  File "/usr/local/lib/python3.8/dist-packages/mmengine/runner/runner.py", line 1791, in test
    metrics = self.test_loop.run()  # type: ignore
  File "/usr/local/lib/python3.8/dist-packages/mmengine/runner/loops.py", line 435, in run
    self.run_iter(idx, data_batch)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/mmengine/runner/loops.py", line 454, in run_iter
    outputs = self.runner.model.test_step(data_batch)
  File "/usr/local/lib/python3.8/dist-packages/mmengine/model/base_model/base_model.py", line 145, in test_step
    return self._run_forward(data, mode='predict')  # type: ignore
  File "/usr/local/lib/python3.8/dist-packages/mmengine/model/base_model/base_model.py", line 340, in _run_forward
    results = self(**data, mode=mode)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/deploy/object_detection_model.py", line 296, in forward
    outputs = self.predict(inputs)
  File "/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/deploy/object_detection_model.py", line 313, in predict
    outputs = self.wrapper({self.input_name: imgs})
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/backend/tensorrt/wrapper.py", line 167, in forward
    shape = tuple(self.context.get_binding_shape(idx))
ValueError: __len__() should return >= 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Boomerl From my experience, tensorrt have not supported nn.AdaptiveAvgPool2d which is used in here. This means sparseinst model can only accept static shape for tensorrt backends, which should use configs like configs/mmdet/instance-seg/instance-seg_tensorrt_static-800x1344.py.
You could include this note in docs like https://github.com/open-mmlab/mmdeploy/blob/main/docs/en/04-supported-codebases/mmdet.md#reminder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank u for the guidance, I will check it immediately.

self,
batch_inputs: Tensor,
batch_data_samples: List[dict],
rescale: bool = False,
):
"""Rewrite `predict` of `SparseInst` for default backend."""
max_shape = batch_inputs.shape[-2:]
x = self.extract_feat(batch_inputs)
output = self.decoder(x)

pred_scores = output['pred_logits'].sigmoid()
pred_masks = output['pred_masks'].sigmoid()
pred_objectness = output['pred_scores'].sigmoid()
pred_scores = torch.sqrt(pred_scores * pred_objectness)

# max/argmax
scores, labels = pred_scores.max(dim=-1)
# cls threshold
keep = scores > self.cls_threshold
scores = scores.where(keep, scores.new_zeros(1))
labels = labels.where(keep, labels.new_zeros(1))
keep = keep.unsqueeze(-1).unsqueeze(-1).expand_as(pred_masks)
pred_masks = pred_masks.where(keep, pred_masks.new_zeros(1))

img_meta = batch_data_samples[0].metainfo
# rescoring mask using maskness
scores = rescoring_mask(scores, pred_masks > self.mask_threshold,
pred_masks)
h, w = img_meta['img_shape'][:2]
pred_masks = F.interpolate(
pred_masks, size=max_shape, mode='bilinear',
align_corners=False)[:, :, :h, :w]
Boomerl marked this conversation as resolved.
Show resolved Hide resolved

bboxes = torch.zeros(scores.shape[0], scores.shape[1], 4)
dets = torch.cat([bboxes, scores.unsqueeze(-1)], dim=-1)
masks = (pred_masks > self.mask_threshold).float()

return dets, labels, masks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we post the evaluation results of trt model and check if it's aligned with pytorch model?

90 changes: 90 additions & 0 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,3 +2582,93 @@ def forward(self, x, param_preds, points, strides):
deploy_cfg=deploy_cfg)

assert rewrite_outputs is not None


def get_sparseinst():
"""SparseInst Config."""
test_cfg = Config(dict(score_thr=0.4, mask_thr_binary=0.45))
data_preprocessor = dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=True,
pad_size_divisor=32)
backbone = Config(
dict(
type='ResNet',
depth=50,
out_indices=(1, 2, 3),
frozen_stages=0,
norm_cfg=dict(type='BN', requires_grad=False),
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet50')))

from projects.SparseInst.sparseinst import SparseInst
model = SparseInst(
data_preprocessor=data_preprocessor,
backbone=backbone,
encoder=dict(
type='InstanceContextEncoder', in_channels=[512, 1024, 2048]),
decoder=dict(
type='BaseIAMDecoder', in_channels=256 + 2, num_classes=80),
criterion=dict(
type='SparseInstCriterion',
num_classes=80,
assigner=dict(type='SparseInstMatcher', alpha=0.8, beta=0.2)),
test_cfg=test_cfg,
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01)))

model.requires_grad_(False)
return model


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_sparseinst_predict(backend_type):
"""Test predict rewrite of sparseinst."""
check_backend(backend_type)
sparseinst = get_sparseinst()
sparseinst.cpu().eval()

output_names = ['dets', 'labels', 'masks']
deploy_cfg = Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
export_postprocess_mask=False))))

img = torch.randn(1, 3, 320, 320)
from mmdet.structures import DetDataSample
data_sample = DetDataSample(metainfo=dict(img_shape=(320, 320, 3)))

# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(
sparseinst, 'predict', batch_data_samples=[data_sample])
rewrite_inputs = {'batch_inputs': img}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

if is_backend_output:
assert rewrite_outputs[0].shape[-1] == 5
assert rewrite_outputs[1] is not None
assert rewrite_outputs[2] is not None
else:
assert rewrite_outputs is not None
Loading