diff --git a/test/test_onnx.py b/test/test_onnx.py index f4db56f0a14..7cbbaadcebc 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -3,6 +3,8 @@ from torchvision import ops from torchvision.models.detection.transform import GeneralizedRCNNTransform +from collections import OrderedDict + # onnxruntime requires python 3.5 or above try: import onnxruntime @@ -69,19 +71,18 @@ def forward(self, boxes, scores): self.run_model(Module(), [(boxes, scores)]) - def test_roi_pool(self): + def test_roi_align(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) model = ops.RoIAlign((5, 5), 1, 2) self.run_model(model, [(x, single_roi)]) - def test_roi_align(self): + def test_roi_pool(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) pool_h = 5 pool_w = 5 model = ops.RoIPool((pool_h, pool_w), 2) - model.eval() self.run_model(model, [(x, rois)]) @unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime") @@ -103,6 +104,31 @@ def forward(self_module, images): input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)] self.run_model(TransformModule(), [input, input_test]) + def test_multi_scale_roi_align(self): + + class TransformModule(torch.nn.Module): + def __init__(self): + super(TransformModule, self).__init__() + self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2) + self.image_sizes = [(512, 512)] + + def forward(self, input, boxes): + return self.model(input, boxes, self.image_sizes) + + i = OrderedDict() + i['feat1'] = torch.rand(1, 5, 64, 64) + i['feat2'] = torch.rand(1, 5, 16, 16) + boxes = torch.rand(6, 4) * 256 + boxes[:, 2:] += boxes[:, :2] + + i1 = OrderedDict() + i1['feat1'] = torch.rand(1, 5, 64, 64) + i1['feat2'] = torch.rand(1, 5, 16, 16) + boxes1 = torch.rand(6, 4) * 256 + boxes1[:, 2:] += boxes1[:, :2] + + self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index d87da57dfd7..8fabbc9571d 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -3,10 +3,31 @@ import torch.nn.functional as F from torch import nn +import torchvision from torchvision.ops import roi_align from torchvision.ops.boxes import box_area +# copying result_idx_in_level to a specific index in result[] +# is not supported by ONNX tracing yet. +# _onnx_merge_levels() is an implementation supported by ONNX +# that merges the levels to the right indices +def _onnx_merge_levels(levels, unmerged_results): + first_result = unmerged_results[0] + dtype, device = first_result.dtype, first_result.device + res = torch.zeros((levels.size(0), first_result.size(1), + first_result.size(2), first_result.size(3)), + dtype=dtype, device=device) + for l in range(len(unmerged_results)): + index = (levels == l).nonzero().view(-1, 1, 1, 1) + index = index.expand(index.size(0), + unmerged_results[l].size(1), + unmerged_results[l].size(2), + unmerged_results[l].size(3)) + res = res.scatter(0, index, unmerged_results[l]) + return res + + class LevelMapper(object): """Determine which FPN level each RoI in a set of RoIs should map to based on the heuristic in the FPN paper. @@ -35,9 +56,9 @@ def __call__(self, boxlists): s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists])) # Eqn.(1) in FPN paper - target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps)) + target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype)) target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) - return target_lvls.to(torch.int64) - self.k_min + return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64) class MultiScaleRoIAlign(nn.Module): @@ -84,7 +105,7 @@ def convert_to_roi_format(self, boxes): device, dtype = concat_boxes.device, concat_boxes.dtype ids = torch.cat( [ - torch.full((len(b), 1), i, dtype=dtype, device=device) + torch.full_like(b[:, :1], i, dtype=dtype, device=device) for i, b in enumerate(boxes) ], dim=0, @@ -153,14 +174,21 @@ def forward(self, x, boxes, image_shapes): device=device, ) + results = [] for level, (per_level_feature, scale) in enumerate(zip(x, self.scales)): idx_in_level = torch.nonzero(levels == level).squeeze(1) rois_per_level = rois[idx_in_level] - result[idx_in_level] = roi_align( + result_idx_in_level = roi_align( per_level_feature, rois_per_level, output_size=self.output_size, - spatial_scale=scale, sampling_ratio=self.sampling_ratio - ) + spatial_scale=scale, sampling_ratio=self.sampling_ratio) + + if torchvision._is_tracing(): + results.append(result_idx_in_level.to(dtype)) + else: + result[idx_in_level] = result_idx_in_level + if torchvision._is_tracing(): + result = _onnx_merge_levels(levels, results) return result