Skip to content

Commit

Permalink
[Frontend][MXNet] Fix mxnet converter for hybridblock and add div_sqr…
Browse files Browse the repository at this point in the history
…t_dim (apache#3701)

* Fix mxnet converter for hybrid block

* tweak

* fix rebase

* fix

* add test
  • Loading branch information
icemelon authored and wweic committed Sep 6, 2019
1 parent 6b50508 commit d3f965b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
20 changes: 16 additions & 4 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def _mx_topk(inputs, attrs):
return _op.topk(inputs[0], **new_attrs)


def _mx_SequenceMask(inputs, attrs):
def _mx_sequence_mask(inputs, attrs):
assert len(inputs) == 1 or len(inputs) == 2
new_attrs = {}
use_sequence_length = attrs.get_bool('use_sequence_length', False)
Expand All @@ -727,6 +727,15 @@ def _mx_SequenceMask(inputs, attrs):
return inputs[0]


def _mx_contrib_div_sqrt_dim(inputs, _):
assert len(inputs) == 1
ndim = len(_infer_type(inputs[0]).checked_type.shape)
dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
sqrt_dim = _op.sqrt(dim.astype('float32'))
out = inputs[0] / sqrt_dim
return out


def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs]
Expand Down Expand Up @@ -1014,11 +1023,12 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
"SequenceMask" : _mx_SequenceMask,
"SequenceMask" : _mx_sequence_mask,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
Expand Down Expand Up @@ -1183,8 +1193,10 @@ def from_mxnet(symbol,
params = {}
for k, v in symbol.collect_params().items():
params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data")
sym = symbol(data)
inputs = []
for name in shape:
inputs.append(mx.sym.Variable(name))
sym = symbol(*inputs)
if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params)
Expand Down
14 changes: 14 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,19 @@ def verify(shape, use_sequence_length, value, axis, dtype, itype):
verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64')
verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32')

def test_forward_contrib_div_sqrt_dim():
def verify(shape):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.contrib.div_sqrt_dim(mx.nd.array(x_np))
mx_sym = mx.sym.contrib.div_sqrt_dim(mx.sym.var("x"))
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4))
verify((3, 4, 5))

if __name__ == '__main__':
test_forward_mlp()
Expand Down Expand Up @@ -759,3 +772,4 @@ def verify(shape, use_sequence_length, value, axis, dtype, itype):
test_forward_argsort()
test_forward_topk()
test_forward_sequence_mask()
test_forward_contrib_div_sqrt_dim()
4 changes: 4 additions & 0 deletions topi/python/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def schedule_batch_matmul(outs):
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])

def _schedule(op):
Expand All @@ -49,6 +50,9 @@ def _schedule(op):
BB = s.cache_read(B, "shared", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")
if op not in s.outputs:
s[C].compute_inline()
C = s.outputs[0].output(0)

b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 64)
Expand Down

0 comments on commit d3f965b

Please sign in to comment.