diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 6949a6f61e5e..0d552da6661d 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -721,7 +721,7 @@ def _mx_topk(inputs, attrs): return _op.topk(inputs[0], **new_attrs) -def _mx_SequenceMask(inputs, attrs): +def _mx_sequence_mask(inputs, attrs): assert len(inputs) == 1 or len(inputs) == 2 new_attrs = {} use_sequence_length = attrs.get_bool('use_sequence_length', False) @@ -733,6 +733,15 @@ def _mx_SequenceMask(inputs, attrs): return inputs[0] +def _mx_contrib_div_sqrt_dim(inputs, _): + assert len(inputs) == 1 + ndim = len(_infer_type(inputs[0]).checked_type.shape) + dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32")) + sqrt_dim = _op.sqrt(dim.astype('float32')) + out = inputs[0] / sqrt_dim + return out + + def _mx_rnn_param_concat(inputs, _): # We don't need to concatenate RNN params because we will unravel the RNN op return [inputs] @@ -1020,11 +1029,12 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "Embedding" : _mx_embedding, "argsort" : _mx_argsort, "topk" : _mx_topk, - "SequenceMask" : _mx_SequenceMask, + "SequenceMask" : _mx_sequence_mask, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, "LinearRegressionOutput" : _mx_linear_regression_output, "smooth_l1" : _mx_smooth_l1, + "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim, # vision "_contrib_BilinearResize2D" : _mx_resize, "_contrib_MultiBoxPrior" : _mx_multibox_prior, @@ -1189,8 +1199,10 @@ def from_mxnet(symbol, params = {} for k, v in symbol.collect_params().items(): params[k] = _nd.array(v.data().asnumpy()) - data = mx.sym.Variable("data") - sym = symbol(data) + inputs = [] + for name in shape: + inputs.append(mx.sym.Variable(name)) + sym = symbol(*inputs) if isinstance(sym, (list, tuple)): sym = mx.sym.Group(sym) shape, dtype = _update_shape_dtype(shape, dtype, params) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 09ae02bc41c3..451679cf9e19 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -714,6 +714,19 @@ def verify(shape, use_sequence_length, value, axis, dtype, itype): verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64') verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32') +def test_forward_contrib_div_sqrt_dim(): + def verify(shape): + x_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.contrib.div_sqrt_dim(mx.nd.array(x_np)) + mx_sym = mx.sym.contrib.div_sqrt_dim(mx.sym.var("x")) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((3, 4)) + verify((3, 4, 5)) if __name__ == '__main__': test_forward_mlp() @@ -759,3 +772,4 @@ def verify(shape, use_sequence_length, value, axis, dtype, itype): test_forward_argsort() test_forward_topk() test_forward_sequence_mask() + test_forward_contrib_div_sqrt_dim() diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py index c973f1d8ea34..b5dd802ad1e9 100644 --- a/topi/python/topi/cuda/batch_matmul.py +++ b/topi/python/topi/cuda/batch_matmul.py @@ -38,6 +38,7 @@ def schedule_batch_matmul(outs): s: Schedule The computation schedule for the op. """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) def _schedule(op): @@ -49,6 +50,9 @@ def _schedule(op): BB = s.cache_read(B, "shared", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") + if op not in s.outputs: + s[C].compute_inline() + C = s.outputs[0].output(0) b, y, x = s[C].op.axis y_bn = get_max_power2_factor(M, 64)