diff --git a/detectron2/utils/testing.py b/detectron2/utils/testing.py index 5351ee6c8a..161fa6b808 100644 --- a/detectron2/utils/testing.py +++ b/detectron2/utils/testing.py @@ -68,6 +68,9 @@ def convert_scripted_instances(instances): """ Convert a scripted Instances object to a regular :class:`Instances` object """ + assert hasattr( + instances, "image_size" + ), f"Expect an Instances object, but got {type(instances)}!" ret = Instances(instances.image_size) for name in instances._field_names: val = getattr(instances, "_" + name, None) diff --git a/tests/test_export_torchscript.py b/tests/test_export_torchscript.py index 4a2dcd7d0c..e9a0ff5851 100644 --- a/tests/test_export_torchscript.py +++ b/tests/test_export_torchscript.py @@ -2,6 +2,7 @@ import json import os +import random import tempfile import unittest import torch @@ -62,7 +63,10 @@ def _test_rcnn_model(self, config_path): } script_model = scripting_with_instances(model, fields) - inputs = [{"image": get_sample_coco_image()}] * 2 + # Test that batch inference with different shapes are supported + image = get_sample_coco_image() + small_image = nn.functional.interpolate(image, scale_factor=0.5) + inputs = [{"image": image}, {"image": small_image}] with torch.no_grad(): instance = model.inference(inputs, do_postprocess=False)[0] scripted_instance = script_model.inference(inputs, do_postprocess=False)[0] @@ -130,20 +134,38 @@ def inference_func(model, image): self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", inference_func) - def _test_model(self, config_path, inference_func): + def _test_model(self, config_path, inference_func, batch=1): model = model_zoo.get(config_path, trained=True) image = get_sample_coco_image() + inputs = tuple(image.clone() for _ in range(batch)) - wrapper = TracingAdapter(model, image, inference_func) + wrapper = TracingAdapter(model, inputs, inference_func) wrapper.eval() with torch.no_grad(): - small_image = nn.functional.interpolate(image, scale_factor=0.5) - # trace with a different image, and the trace must still work - traced_model = torch.jit.trace(wrapper, (small_image,)) + # trace with smaller images, and the trace must still work + trace_inputs = tuple( + nn.functional.interpolate(image, scale_factor=random.uniform(0.5, 0.7)) + for _ in range(batch) + ) + traced_model = torch.jit.trace(wrapper, trace_inputs) + + outputs = inference_func(model, *inputs) + traced_outputs = wrapper.outputs_schema(traced_model(*inputs)) + if batch > 1: + for output, traced_output in zip(outputs, traced_outputs): + assert_instances_allclose(output, traced_output, size_as_tensor=True) + else: + assert_instances_allclose(outputs, traced_outputs, size_as_tensor=True) + + @SLOW_PUBLIC_CPU_TEST + def testMaskRCNNFPN_batched(self): + def inference_func(model, image1, image2): + inputs = [{"image": image1}, {"image": image2}] + return model.inference(inputs, do_postprocess=False) - output = inference_func(model, image) - traced_output = wrapper.outputs_schema(traced_model(image)) - assert_instances_allclose(output, traced_output, size_as_tensor=True) + self._test_model( + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func, batch=2 + ) def testKeypointHead(self): class M(nn.Module):