diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py index d67fb3d140a4..d526e463412a 100644 --- a/contrib/tvmop/basic/ufunc.py +++ b/contrib/tvmop/basic/ufunc.py @@ -52,6 +52,15 @@ def vadd_gpu(dtype, ndim): return s, [A, B, C] +def assign_by_req(a, req): + b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype) + if (req == "kAddTo"): + c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx]) + else: + c = tvm.compute(a.shape, lambda *idx: a[idx]) + return b, c + + def reduce_axes(X, axes, reducer): def get_index(idx, ridx): j = 0 @@ -71,31 +80,34 @@ def get_index(idx, ridx): return ret -def compute_backward_vadd(dtype, ndim, reduce1st): +def compute_backward_vadd(dtype, ndim, reduce1st, req): axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim] X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype) reducer = tvm.comm_reducer(lambda x, y: x + y, lambda t: tvm.const(0, dtype=t), name="sum") ret = reduce_axes(X, axes, reducer) - s = tvm.create_schedule(ret.op) - return s, X, ret, [ret] + in_grad_a, in_grad = assign_by_req(ret, req) + s = tvm.create_schedule(in_grad.op) + return s, X, in_grad_a, in_grad, [ret, in_grad] -@defop(name="backward_vadd", target="cpu", dtype=AllTypes, - ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"]) -def backward_vadd(dtype, ndim, reduce1st): - s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st) +@defop(name="backward_vadd", target="cpu", dtype=AllTypes, + ndim=list(range(1, 6)), reduce1st=[0, 1], + req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"]) +def backward_vadd(dtype, ndim, reduce1st, req): + s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req) for t in c_list: axes = [axis for axis in t.op.axis] fused = s[t].fuse(*axes) s[t].parallel(fused) - return s, [X, ret] + return s, [X, in_grad_a, in_grad] @defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"], - ndim=list(range(1, 6)), reduce1st=[0, 1], attrs=["reduce1st"]) -def backward_vadd_gpu(dtype, ndim, reduce1st): - s, X, ret, c_list = compute_backward_vadd(dtype, ndim, reduce1st) + ndim=list(range(1, 6)), reduce1st=[0, 1], + req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"]) +def backward_vadd_gpu(dtype, ndim, reduce1st, req): + s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req) num_thread = 64 for t in c_list: block_x = tvm.thread_axis("blockIdx.x") @@ -105,4 +117,4 @@ def backward_vadd_gpu(dtype, ndim, reduce1st): bx, tx = s[t].split(fused, factor=num_thread) s[t].bind(bx, block_x) s[t].bind(tx, thread_x) - return s, [X, ret] + return s, [X, in_grad_a, in_grad] diff --git a/src/operator/contrib/tvmop/ufunc.cc b/src/operator/contrib/tvmop/ufunc.cc index e6999e27b6a0..b4f3ab4bd317 100644 --- a/src/operator/contrib/tvmop/ufunc.cc +++ b/src/operator/contrib/tvmop/ufunc.cc @@ -62,6 +62,7 @@ void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 2U); int ndim = inputs[0].shape_.ndim(); for (int k = 0; k < 2; ++k) { + // dispatch by backward std::vector ov, iv; const TBlob& ograd = inputs[0], igrad = outputs[k]; bool flag = ograd.size(0) != igrad.size(0); @@ -79,7 +80,16 @@ void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs, TBlob ograd_tvm(ograd.reshape(oshape).dltensor()); TBlob igrad_tvm(igrad.reshape(ishape).dltensor()); std::string funcname = std::string(func) + "reduce1st_" + std::to_string(flag); - tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm}); + // dispatch by req + funcname += "req_"; + MXNET_ASSIGN_REQ_SWITCH(req[k], req_type, { + if (req_type == kWriteTo) { + funcname += "kWriteTo"; + } else { + funcname += "kAddTo"; + } + }) + tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, {ograd_tvm, igrad_tvm, igrad_tvm}); } } diff --git a/tests/python/unittest/test_tvm_op.py b/tests/python/unittest/test_tvm_op.py index 7253ad9a40ce..2126631077d4 100644 --- a/tests/python/unittest/test_tvm_op.py +++ b/tests/python/unittest/test_tvm_op.py @@ -36,12 +36,26 @@ def test_tvm_broadcast_add(): c = mx.nd.contrib.tvm_vadd(a, b) c_np = a.asnumpy() + b.asnumpy() assert same(c.asnumpy(), c_np) + # test backward c.backward() expected_grad_a = _np.ones_like(a.asnumpy()) * c_np.size / a.asnumpy().size expected_grad_b = _np.ones_like(b.asnumpy()) * c_np.size / b.asnumpy().size assert same(a.grad.asnumpy(), expected_grad_a) assert same(b.grad.asnumpy(), expected_grad_b) - + # test kAddTo request + a = mx.nd.normal(shape=a_shape) + b = mx.nd.normal(shape=b_shape) + a.attach_grad() + b.attach_grad() + with mx.autograd.record(): + c = mx.nd.contrib.tvm_vadd(a, b) + d = mx.nd.contrib.tvm_vadd(a, b) + mx.autograd.backward([c, d]) + expected_grad_a = 2 * _np.ones_like(a.asnumpy()) * c.size / a.size + expected_grad_b = 2 * _np.ones_like(b.asnumpy()) * c.size / b.size + assert same(a.grad.asnumpy(), expected_grad_a) + assert same(b.grad.asnumpy(), expected_grad_b) + if __name__ == '__main__': import nose