Skip to content

Commit

Permalink
Move test_torchscript to test_models
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 10, 2021
1 parent 44c2e64 commit 0226dd9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 24 deletions.
21 changes: 20 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import io
import contextlib
import warnings

import pytest
import torch
from torch import Tensor

from yolort import models
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.transformer import darknet_tan_backbone
from yolort.models.anchor_utils import AnchorGenerator
Expand Down Expand Up @@ -257,3 +258,21 @@ def test_criterion(self):
assert isinstance(losses['cls_logits'], Tensor)
assert isinstance(losses['bbox_regression'], Tensor)
assert isinstance(losses['objectness'], Tensor)


@pytest.mark.parametrize('arch', ['yolov5s', 'yolov5m', 'yolov5l', 'yolotr'])
def test_torchscript(arch):
model = models.__dict__[arch](pretrained=True, size=(320, 320), score_thresh=0.45)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.eval()

x = [torch.rand(3, 288, 320), torch.rand(3, 300, 256)]

out = model(x)
out_script = scripted_model(x)

torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)
23 changes: 0 additions & 23 deletions test/test_torchscript.py

This file was deleted.

0 comments on commit 0226dd9

Please sign in to comment.