diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index b7d29e528a96..4a42093bdbc0 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -213,7 +213,7 @@ def _mx_slice_axis(inputs, attrs): ax_end = attrs.get_str("end") if axis < 0: axis += len(shape) - assert axis >= 0 and axis < len(shape) + assert 0 <= axis < len(shape) if ax_end == "None": ax_end = int(shape[axis]) else: @@ -222,8 +222,8 @@ def _mx_slice_axis(inputs, attrs): ax_beg += int(shape[axis]) if ax_end < 0: ax_end += int(shape[axis]) - assert ax_beg >= 0 and ax_beg < int(shape[axis]) - assert ax_end > ax_beg and ax_end <= int(shape[axis]) + assert 0 <= ax_beg < int(shape[axis]) + assert ax_beg < ax_end <= int(shape[axis]) begin = [] end = [] for i, dim in enumerate(shape): @@ -516,11 +516,53 @@ def _mx_shape_array(inputs, attrs): return _op.shape_of(inputs[0], dtype='int64') +def _mx_full(inputs, attrs): + assert len(inputs) == 0 + val = attrs.get_float("value") + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + return _op.full(_expr.const(val, dtype), shape, dtype) + + +def _mx_squeeze(inputs, attrs): + assert len(inputs) == 1 + axis = attrs.get_int_tuple("axis", None) + return _op.squeeze(inputs[0], axis) + + +def _mx_broadcast_axis(inputs, attrs): + assert len(inputs) == 1 + axis = attrs.get_int_tuple("axis", []) + size = attrs.get_int_tuple("size", []) + assert len(axis) == len(size) + if len(axis) == 0: + return inputs[0] + src_shape = ir_pass.infer_type(inputs[0])._checked_type_.shape + tgt_shape = [] + for i, dim in enumerate(src_shape): + if i not in axis: + tgt_shape.append(dim) + else: + assert int(dim) == 1 + idx = axis.index(i) + tgt_shape.append(size[idx]) + return _op.broadcast_to(inputs[0], tgt_shape) + + +def _mx_embedding(inputs, _): + assert len(inputs) == 2 + indices, weight = inputs + return _op.take(weight, indices.astype('int32'), axis=0) + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ "log", "exp", + "sqrt", + "floor", + "ceil", "sigmoid", "tanh", "negative", @@ -556,7 +598,6 @@ def _mx_shape_array(inputs, attrs): "Flatten" : _rename(_op.nn.batch_flatten), # scalar power "square" : _mx_make_power(2), - "sqrt" : _mx_make_power(1/2), "rsqrt" : _mx_make_power(-1/2), "cbrt" : _mx_make_power(1/3), "rcbrt" : _mx_make_power(-1/3), @@ -638,11 +679,15 @@ def _mx_shape_array(inputs, attrs): "batch_dot" : _mx_batch_dot, "LeakyReLU" : _mx_leaky_relu, "_arange" : _mx_arange, + "_full" : _mx_full, "repeat" : _mx_repeat, "tile" : _mx_tile, "reverse" : _mx_reverse, + "squeeze" : _mx_squeeze, + "broadcast_axis": _mx_broadcast_axis, "BlockGrad" : _mx_BlockGrad, "shape_array" : _mx_shape_array, + "Embedding" : _mx_embedding, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, # vision diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index e83f1e569545..aad666ca75b4 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -379,7 +379,6 @@ def test_forward_l2_normalize(): mx_sym = mx.sym.L2Normalization(data, mode="channel") verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) - def test_forward_shape_array(): def verify(shape): x_np = np.random.uniform(size=shape).astype("float32") @@ -395,6 +394,75 @@ def verify(shape): verify((3, 4, 5)) verify((3, 4, 5, 6)) +def test_forward_squeeze(): + def verify(shape, axis): + x_np = np.random.uniform(size=shape).astype("float32") + if axis is None: + ref_res = mx.nd.squeeze(mx.nd.array(x_np)) + mx_sym = mx.sym.squeeze(mx.sym.var("x")) + else: + ref_res = mx.nd.squeeze(mx.nd.array(x_np), axis=axis) + mx_sym = mx.sym.squeeze(mx.sym.var("x"), axis=axis) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((1, 3, 1), None) + verify((1, 3, 1), 0) + verify((1, 3, 1), 2) + verify((1, 3, 1), (0, 2)) + +def test_forward_broadcast_axis(): + def verify(shape, axis, size): + x_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size) + mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((1, 2, 1), 2, 3) + verify((1, 2, 1), (0, 2), (2, 3)) + +def test_forward_full(): + def verify(val, shape, dtype): + ctx = mx.cpu() + ref_res = mx.nd.full(shape, val, dtype=dtype) + mx_sym = mx.sym.full(shape, val, dtype=dtype) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {}) + for target, ctx in ctx_list(): + # Skip testing graph runtime because this op will be optimized out + # by constant folding. + for kind in ["debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)() + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify(2, (3, 4), "float32") + verify(2, (3, 4), "int32") + verify(3.5, (1, 3, 4), "float32") + +def test_forward_embedding(): + def verify(data_shape, weight_shape): + in_dim, out_dim = weight_shape + x_np = np.random.randint(0, weight_shape[0], size=data_shape).astype("float32") + w_np = np.random.uniform(size=weight_shape).astype("float32") + ref_res = mx.nd.Embedding(mx.nd.array(x_np), mx.nd.array(w_np), + input_dim=in_dim, output_dim=out_dim) + mx_sym = mx.sym.Embedding(mx.sym.var("x"), mx.sym.var("w"), + input_dim=in_dim, output_dim=out_dim) + new_sym, _ = relay.frontend.from_mxnet( + mx_sym, {"x": data_shape, "w": weight_shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x=x_np, w=w_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((2, 2), (4, 5)) + verify((2, 3, 4), (4, 5)) if __name__ == '__main__': test_forward_mlp() @@ -426,3 +494,7 @@ def verify(shape): test_forward_slice_axis() test_forward_l2_normalize() test_forward_shape_array() + test_forward_squeeze() + test_forward_broadcast_axis() + test_forward_full() + test_forward_embedding()