Skip to content

Commit

Permalink
Support Exporting MultiScaleRoiAlign to ONNX (#1324)
Browse files Browse the repository at this point in the history
* Support Exporting MultiScaleRoiAlign to ONNX

* remove cast

* fix dtype

* move cast
  • Loading branch information
lara-hdr authored and fmassa committed Oct 4, 2019
1 parent f0d3daa commit 76702a0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 9 deletions.
32 changes: 29 additions & 3 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from torchvision import ops
from torchvision.models.detection.transform import GeneralizedRCNNTransform

from collections import OrderedDict

# onnxruntime requires python 3.5 or above
try:
import onnxruntime
Expand Down Expand Up @@ -69,19 +71,18 @@ def forward(self, boxes, scores):

self.run_model(Module(), [(boxes, scores)])

def test_roi_pool(self):
def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2)
self.run_model(model, [(x, single_roi)])

def test_roi_align(self):
def test_roi_pool(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
pool_h = 5
pool_w = 5
model = ops.RoIPool((pool_h, pool_w), 2)
model.eval()
self.run_model(model, [(x, rois)])

@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
Expand All @@ -103,6 +104,31 @@ def forward(self_module, images):
input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
self.run_model(TransformModule(), [input, input_test])

def test_multi_scale_roi_align(self):

class TransformModule(torch.nn.Module):
def __init__(self):
super(TransformModule, self).__init__()
self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2)
self.image_sizes = [(512, 512)]

def forward(self, input, boxes):
return self.model(input, boxes, self.image_sizes)

i = OrderedDict()
i['feat1'] = torch.rand(1, 5, 64, 64)
i['feat2'] = torch.rand(1, 5, 16, 16)
boxes = torch.rand(6, 4) * 256
boxes[:, 2:] += boxes[:, :2]

i1 = OrderedDict()
i1['feat1'] = torch.rand(1, 5, 64, 64)
i1['feat2'] = torch.rand(1, 5, 16, 16)
boxes1 = torch.rand(6, 4) * 256
boxes1[:, 2:] += boxes1[:, :2]

self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])


if __name__ == '__main__':
unittest.main()
40 changes: 34 additions & 6 deletions torchvision/ops/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,31 @@
import torch.nn.functional as F
from torch import nn

import torchvision
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area


# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
def _onnx_merge_levels(levels, unmerged_results):
first_result = unmerged_results[0]
dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1),
first_result.size(2), first_result.size(3)),
dtype=dtype, device=device)
for l in range(len(unmerged_results)):
index = (levels == l).nonzero().view(-1, 1, 1, 1)
index = index.expand(index.size(0),
unmerged_results[l].size(1),
unmerged_results[l].size(2),
unmerged_results[l].size(3))
res = res.scatter(0, index, unmerged_results[l])
return res


class LevelMapper(object):
"""Determine which FPN level each RoI in a set of RoIs should map to based
on the heuristic in the FPN paper.
Expand Down Expand Up @@ -35,9 +56,9 @@ def __call__(self, boxlists):
s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

# Eqn.(1) in FPN paper
target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))
target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
return target_lvls.to(torch.int64) - self.k_min
return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)


class MultiScaleRoIAlign(nn.Module):
Expand Down Expand Up @@ -84,7 +105,7 @@ def convert_to_roi_format(self, boxes):
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
[
torch.full((len(b), 1), i, dtype=dtype, device=device)
torch.full_like(b[:, :1], i, dtype=dtype, device=device)
for i, b in enumerate(boxes)
],
dim=0,
Expand Down Expand Up @@ -153,14 +174,21 @@ def forward(self, x, boxes, image_shapes):
device=device,
)

results = []
for level, (per_level_feature, scale) in enumerate(zip(x, self.scales)):
idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level]

result[idx_in_level] = roi_align(
result_idx_in_level = roi_align(
per_level_feature, rois_per_level,
output_size=self.output_size,
spatial_scale=scale, sampling_ratio=self.sampling_ratio
)
spatial_scale=scale, sampling_ratio=self.sampling_ratio)

if torchvision._is_tracing():
results.append(result_idx_in_level.to(dtype))
else:
result[idx_in_level] = result_idx_in_level

if torchvision._is_tracing():
result = _onnx_merge_levels(levels, results)
return result

0 comments on commit 76702a0

Please sign in to comment.