Skip to content

Commit

Permalink
add tracing unittest for batch inference
Browse files Browse the repository at this point in the history
Summary: test that batch inference works in tracing

Reviewed By: zhanghang1989

Differential Revision: D32609078

fbshipit-source-id: 0f219f6fcb500c793da0085b6ee9c3dd370d6831
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Nov 23, 2021
1 parent bbdee4c commit f013777
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
3 changes: 3 additions & 0 deletions detectron2/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def convert_scripted_instances(instances):
"""
Convert a scripted Instances object to a regular :class:`Instances` object
"""
assert hasattr(
instances, "image_size"
), f"Expect an Instances object, but got {type(instances)}!"
ret = Instances(instances.image_size)
for name in instances._field_names:
val = getattr(instances, "_" + name, None)
Expand Down
40 changes: 31 additions & 9 deletions tests/test_export_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import random
import tempfile
import unittest
import torch
Expand Down Expand Up @@ -62,7 +63,10 @@ def _test_rcnn_model(self, config_path):
}
script_model = scripting_with_instances(model, fields)

inputs = [{"image": get_sample_coco_image()}] * 2
# Test that batch inference with different shapes are supported
image = get_sample_coco_image()
small_image = nn.functional.interpolate(image, scale_factor=0.5)
inputs = [{"image": image}, {"image": small_image}]
with torch.no_grad():
instance = model.inference(inputs, do_postprocess=False)[0]
scripted_instance = script_model.inference(inputs, do_postprocess=False)[0]
Expand Down Expand Up @@ -130,20 +134,38 @@ def inference_func(model, image):

self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", inference_func)

def _test_model(self, config_path, inference_func):
def _test_model(self, config_path, inference_func, batch=1):
model = model_zoo.get(config_path, trained=True)
image = get_sample_coco_image()
inputs = tuple(image.clone() for _ in range(batch))

wrapper = TracingAdapter(model, image, inference_func)
wrapper = TracingAdapter(model, inputs, inference_func)
wrapper.eval()
with torch.no_grad():
small_image = nn.functional.interpolate(image, scale_factor=0.5)
# trace with a different image, and the trace must still work
traced_model = torch.jit.trace(wrapper, (small_image,))
# trace with smaller images, and the trace must still work
trace_inputs = tuple(
nn.functional.interpolate(image, scale_factor=random.uniform(0.5, 0.7))
for _ in range(batch)
)
traced_model = torch.jit.trace(wrapper, trace_inputs)

outputs = inference_func(model, *inputs)
traced_outputs = wrapper.outputs_schema(traced_model(*inputs))
if batch > 1:
for output, traced_output in zip(outputs, traced_outputs):
assert_instances_allclose(output, traced_output, size_as_tensor=True)
else:
assert_instances_allclose(outputs, traced_outputs, size_as_tensor=True)

@SLOW_PUBLIC_CPU_TEST
def testMaskRCNNFPN_batched(self):
def inference_func(model, image1, image2):
inputs = [{"image": image1}, {"image": image2}]
return model.inference(inputs, do_postprocess=False)

output = inference_func(model, image)
traced_output = wrapper.outputs_schema(traced_model(image))
assert_instances_allclose(output, traced_output, size_as_tensor=True)
self._test_model(
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func, batch=2
)

def testKeypointHead(self):
class M(nn.Module):
Expand Down

0 comments on commit f013777

Please sign in to comment.