Skip to content

Commit

Permalink
[MXNET]DepthToSpace & SpaceToDepth Operator (apache#5408)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and Trevor Morris committed Jun 8, 2020
1 parent 766a0d5 commit 9d96ca3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,20 @@ def _mx_one_hot(inputs, attrs):
return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)


def _mx_depth_to_space(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["block_size"] = attrs.get_int("block_size")
return _op.nn.depth_to_space(*inputs, **new_attrs)


def _mx_space_to_depth(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["block_size"] = attrs.get_int("block_size")
return _op.nn.space_to_depth(*inputs, **new_attrs)


def _mx_contrib_fifo_buffer(inputs, attrs):
new_attrs = {}
new_attrs['axis'] = attrs.get_int('axis')
Expand Down Expand Up @@ -1854,6 +1868,8 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
"make_loss" : _mx_make_loss,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
"one_hot" : _mx_one_hot,
"depth_to_space" : _mx_depth_to_space,
"space_to_depth" : _mx_space_to_depth,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,38 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
# _verify_swap_axis((4, 5), (5, 4), 0, 0)


def test_forward_depth_to_space():
def verify(shape, blocksize=2):
x = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.depth_to_space(mx.nd.array(x), blocksize)
mx_sym = mx.sym.depth_to_space(mx.sym.var("x"), blocksize)
shape_dict = {"x": x.shape, }
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)

verify((1, 18, 3, 3), 3)


def test_forward_space_to_depth():
def verify(shape, blocksize=2):
x = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.space_to_depth(mx.nd.array(x), blocksize)
mx_sym = mx.sym.space_to_depth(mx.sym.var("x"), blocksize)
shape_dict = {"x": x.shape, }
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)

verify((1, 1, 9, 9), 3)


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -1047,6 +1079,8 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
test_forward_instance_norm()
test_forward_layer_norm()
test_forward_one_hot()
test_forward_depth_to_space()
test_forward_space_to_depth()
test_forward_convolution()
test_forward_deconvolution()
test_forward_cond()
Expand Down

0 comments on commit 9d96ca3

Please sign in to comment.