diff --git a/test/test_runtime_ort.py b/test/test_runtime_ort.py index d9aed5d0..deab6779 100644 --- a/test/test_runtime_ort.py +++ b/test/test_runtime_ort.py @@ -114,4 +114,22 @@ def test_onnx_export_multi_batches(self, arch): [img_two, img_dummy], [img_dummy, img_two], ] - self.run_model(model, inputs_list=inputs_list) + self.run_model(model, inputs_list) + + @pytest.mark.parametrize("arch", ["yolov5n"]) + def test_onnx_export_misbatch(self, arch): + img_one, img_two = self.get_test_images() + img_dummy = torch.ones(3, 640, 480) * 0.3 + + size = (640, 640) if arch[-1] == "6" else (320, 320) + model = models.__dict__[arch](pretrained=True, size=size, score_thresh=0.45) + model = model.eval() + model([img_one, img_two]) + + # Test exported model on images of misbatch + with pytest.raises(IndexError, match="list index out of range"): + self.run_model(model, [[img_one, img_two], [img_two, img_one, img_dummy]]) + + # Test exported model on images of misbatch + with pytest.raises(ValueError, match="Model requires 3 inputs. Input Feed contains 2"): + self.run_model(model, [[img_two, img_one, img_dummy], [img_one, img_two]])