Skip to content

Commit

Permalink
[Enhancement] Fix ncnn unittest (#626)
Browse files Browse the repository at this point in the history
* optmize-csp-darknet

* replace floordiv to torch.div

* update csp_darknet default implement

* fix test
  • Loading branch information
q.yao authored Jun 28, 2022
1 parent 05cafab commit 4d9e209
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
19 changes: 19 additions & 0 deletions mmdeploy/codebase/mmdet/models/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@
from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward')
def focus__forward__default(ctx, self, x):
"""Rewrite forward function of Focus class.
Replace slice with transpose.
"""
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
B, C, H, W = x.shape
x = x.reshape(B, C, -1, 2, W)
x = x.reshape(B, C, x.shape[2], 2, -1, 2)
half_H = x.shape[2]
half_W = x.shape[4]
x = x.permute(0, 5, 3, 1, 2, 4)
x = x.reshape(B, C * 4, half_H, half_W)

return self.conv(x)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward',
backend='ncnn')
Expand Down
11 changes: 5 additions & 6 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def get_gfl_head_model():
return model


def test_focus_forward_ncnn():
backend_type = Backend.NCNN
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
def test_focus_forward(backend_type):
check_backend(backend_type)
focus_model = get_focus_backbone_model()
focus_model.cpu().eval()
Expand All @@ -222,11 +222,10 @@ def test_focus_forward_ncnn():
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs[0]):
model_output = model_output.squeeze().cpu().numpy()
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
torch.testing.assert_allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)


Expand Down

0 comments on commit 4d9e209

Please sign in to comment.