Skip to content

Commit

Permalink
Improve support ONNX Resize-10 created by PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Jul 16, 2020
1 parent 682e4d3 commit 74f84b9
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 16 deletions.
9 changes: 0 additions & 9 deletions model-optimizer/extensions/middle/UpsampleToResample.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,6 @@ def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
height_scale = upsample['height_scale']
width_scale = upsample['width_scale']

if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
log.debug('Width and height scales are not equal: {} vs {} for node {}'.format(
width_scale, height_scale, upsample_name))
return
if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5):
log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format(
depth_scale, height_scale, upsample_name))
return

if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
upsample.in_port(1).disconnect()

Expand Down
13 changes: 7 additions & 6 deletions model-optimizer/extensions/middle/UpsampleToResample_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@
class UpsampleToResampleTest(unittest.TestCase):
@generate(*[([2, 10, 20, 30], [1, 1, 5, 5],),
([2, 20, 30, 40], [1, 1, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 3],)
([2, 10, 20, 30], [1, 1, 6, 5],),
([2, 20, 30, 40], [1, 1, 3, 4],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3],),
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4],),
])
def test_conversion(self, input_shape, scales):
graph = build_graph(graph_node_attrs, graph_edges,
Expand All @@ -122,11 +127,7 @@ def test_conversion(self, input_shape, scales):
self.assertTrue(flag, resp)

@generate(*[([2, 10, 20, 30], [1, 2, 5, 5],),
([2, 10, 20, 30], [1, 1, 6, 5],),
([2, 20, 30, 40], [1, 1, 3, 4],),
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3],),
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4],),
([2, 3, 20, 30, 40], [1, 2, 3, 3, 3],),
])
def test_pattern_does_not_satisfy(self, input_shape, scales):
graph = build_graph(graph_node_attrs, graph_edges,
Expand Down
2 changes: 1 addition & 1 deletion model-optimizer/extensions/ops/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ def upsample_infer(node: Node):
else:
assert node.in_node(1).value is not None
# generic output shape calculation to support 5D input shape case
node.out_node().shape = np.array(input_shape * node.in_node(1).value).astype(np.int64)
node.out_node().shape = np.array((input_shape + 1e-5) * node.in_node(1).value).astype(np.int64)

0 comments on commit 74f84b9

Please sign in to comment.