Skip to content

Commit

Permalink
Add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Feb 16, 2022
1 parent b6928b1 commit 2e0ffe3
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion test/test_runtime_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,22 @@ def test_onnx_export_multi_batches(self, arch):
[img_two, img_dummy],
[img_dummy, img_two],
]
self.run_model(model, inputs_list=inputs_list)
self.run_model(model, inputs_list)

@pytest.mark.parametrize("arch", ["yolov5n"])
def test_onnx_export_misbatch(self, arch):
img_one, img_two = self.get_test_images()
img_dummy = torch.ones(3, 640, 480) * 0.3

size = (640, 640) if arch[-1] == "6" else (320, 320)
model = models.__dict__[arch](pretrained=True, size=size, score_thresh=0.45)
model = model.eval()
model([img_one, img_two])

# Test exported model on images of misbatch
with pytest.raises(IndexError, match="list index out of range"):
self.run_model(model, [[img_one, img_two], [img_two, img_one, img_dummy]])

# Test exported model on images of misbatch
with pytest.raises(ValueError, match="Model requires 3 inputs. Input Feed contains 2"):
self.run_model(model, [[img_two, img_one, img_dummy], [img_one, img_two]])

0 comments on commit 2e0ffe3

Please sign in to comment.